Source code for ixai.utils.validators.model
import typing
import warnings
from ixai.utils.wrappers.base import Wrapper
from ixai.utils.wrappers import SklearnWrapper, TorchWrapper, RiverWrapper
[docs]def validate_model_function(model_function: typing.Any) -> typing.Callable:
if isinstance(model_function, Wrapper):
return model_function # we assume the wrapper is applied correctly
try:
function_name = str(type(model_function.__self__))
if 'sklearn' in function_name:
return SklearnWrapper(prediction_function=model_function)
elif 'river' in function_name:
return RiverWrapper(prediction_function=model_function)
except AttributeError: # model_function has no __self__
try:
import torch
if isinstance(model_function, torch.nn.Module): # only shallow check
warnings.warn(
"Torch Model Function provided. Default device 'cpu' is used. If your module is"
" not running on this device, apply the `ixai.utils.wrappers.TorchWrapper` "
"manually and specify the correct device.", UserWarning)
return TorchWrapper(link_function=model_function, device='cpu')
warnings.warn("Torch detected, but model function is not a torch module. If your model "
"function is based on torch, apply the `ixai.utils.wrappers.TorchWrapper`"
" manually. If not neglect this message.", UserWarning)
except ImportError: # torch is not installed -> it will not be a torch model function
pass
warnings.warn("The provided model function cannot be automatically wrapped and will be used "
"directly. If you follow the internal logic of our model functions and designed "
"your model function accordingly, this warning can be neglected. If not please "
"investigate how the model functions need to be designed.", UserWarning)
return model_function