utils.wrappers.sklearn#
Classes
|
Wrapper for sklearn prediction functions. |
- class utils.wrappers.sklearn.SklearnWrapper(prediction_function, feature_names=None)[source]#
Bases:
WrapperWrapper for sklearn prediction functions.
This wrapper turns any prediction function ouput into an iterable (list or np.ndarray) output. And allows for dict inputs.
Examples
Basic usage:
>>> from sklearn.ensemble import RandomForestClassifier >>> model = RandomForestClassifier() >>> model_function = SklearnWrapper(model.predict)
For classifiers returning probas:
>>> model_function = SklearnWrapper(model.predict_proba)
If the dict-inputs may be in a different orderings
>>> feature_orderings: list = ['feature_1', 'feature_2', 'feature_3'] >>> model_function = SklearnWrapper(model.predict, feature_names=feature_orderings)
Note
If the sklearn model is trained with access to the feature names (e.g. trained on a pandas DataFrame) it will usually raise a warning, if unnamed feature values are provided (e.g. in the form of a np.ndarray). Since instantiating a pandas DataFrame for each input is computationally more expensive, the specific warning is manually suppressed in this Wrapper.