utils.wrappers.sklearn#

Classes

SklearnWrapper(prediction_function[, ...])

Wrapper for sklearn prediction functions.

class utils.wrappers.sklearn.SklearnWrapper(prediction_function, feature_names=None)[source]#

Bases: Wrapper

Wrapper for sklearn prediction functions.

This wrapper turns any prediction function ouput into an iterable (list or np.ndarray) output. And allows for dict inputs.

Examples

Basic usage:

>>> from sklearn.ensemble import RandomForestClassifier
>>> model = RandomForestClassifier()
>>> model_function = SklearnWrapper(model.predict)

For classifiers returning probas:

>>> model_function = SklearnWrapper(model.predict_proba)

If the dict-inputs may be in a different orderings

>>> feature_orderings: list = ['feature_1', 'feature_2', 'feature_3']
>>> model_function = SklearnWrapper(model.predict, feature_names=feature_orderings)

Note

If the sklearn model is trained with access to the feature names (e.g. trained on a pandas DataFrame) it will usually raise a warning, if unnamed feature values are provided (e.g. in the form of a np.ndarray). Since instantiating a pandas DataFrame for each input is computationally more expensive, the specific warning is manually suppressed in this Wrapper.

__call__(x)[source]#
Return type:

Union[dict, List[dict]]