Source code for utils.wrappers.sklearn
import typing
import warnings
from ixai.utils.wrappers.base import Wrapper
[docs]class SklearnWrapper(Wrapper):
"""Wrapper for sklearn prediction functions.
This wrapper turns any prediction function ouput into an iterable (list or np.ndarray) output. And allows
for dict inputs.
Examples:
Basic usage:
>>> from sklearn.ensemble import RandomForestClassifier
>>> model = RandomForestClassifier()
>>> model_function = SklearnWrapper(model.predict)
For classifiers returning probas:
>>> model_function = SklearnWrapper(model.predict_proba)
If the dict-inputs may be in a different orderings
>>> feature_orderings: list = ['feature_1', 'feature_2', 'feature_3']
>>> model_function = SklearnWrapper(model.predict, feature_names=feature_orderings)
Note:
If the sklearn model is trained with access to the feature names (e.g. trained on a pandas DataFrame) it will
usually raise a warning, if unnamed feature values are provided (e.g. in the form of a np.ndarray). Since
instantiating a pandas DataFrame for each input is computationally more expensive, the specific warning is
manually suppressed in this Wrapper.
"""
def __init__(self, prediction_function: typing.Callable, feature_names: typing.Optional[list] = None):
super().__init__(prediction_function, feature_names)
[docs] def __call__(self, x: typing.Union[typing.List[dict], dict]) -> typing.Union[dict, typing.List[dict]]:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="X does not have valid feature names, but")
if isinstance(x, dict):
x_input = self.convert_1d_input_to_arr(x)
return self.convert_arr_output_to_dict(self._prediction_function(x_input))
x_input = self.convert_2d_input_to_arr(x)
y_predictions = self._prediction_function(x_input)
y_prediction = [self.convert_arr_output_to_dict(y_predictions[i]) for i in range(len(y_predictions))]
return y_prediction