utils.wrappers#
This modul gathers basic wrapper objects to transform common ML model architectures into callable functions.
Note: To decrease the dependency count the required wrappers should be imported directly.
- class utils.wrappers.RiverMetricToLossFunction(river_metric, dict_input_metric=False)[source]#
Bases:
objectWrapper that transforms a river.metrics.base.Metric into a loss function.
This Wrapper turns metrics that expect a single value as predictions (e.g. river.metrics.MAE, or river.metrics.Accuracy) or metrics that expect a dictionary as predictions (e.g. river.metrics.CrossEntropy) into a similar interface.
- class utils.wrappers.RiverWrapper(prediction_function)[source]#
Bases:
WrapperWrapper for river prediction functions.
This wrapper turns any prediction function ouput into an iterable (list or np.ndarray) output.
Examples
Basic usage:
>>> from river.ensemble import AdaptiveRandomForestClassifier >>> model = AdaptiveRandomForestClassifier() >>> model_function = RiverWrapper(model.predict_one)
For classifiers returning probas:
>>> model_function = RiverWrapper(model.predict_proba_one)
- class utils.wrappers.SklearnWrapper(prediction_function, feature_names=None)[source]#
Bases:
WrapperWrapper for sklearn prediction functions.
This wrapper turns any prediction function ouput into an iterable (list or np.ndarray) output. And allows for dict inputs.
Examples
Basic usage:
>>> from sklearn.ensemble import RandomForestClassifier >>> model = RandomForestClassifier() >>> model_function = SklearnWrapper(model.predict)
For classifiers returning probas:
>>> model_function = SklearnWrapper(model.predict_proba)
If the dict-inputs may be in a different orderings
>>> feature_orderings: list = ['feature_1', 'feature_2', 'feature_3'] >>> model_function = SklearnWrapper(model.predict, feature_names=feature_orderings)
Note
If the sklearn model is trained with access to the feature names (e.g. trained on a pandas DataFrame) it will usually raise a warning, if unnamed feature values are provided (e.g. in the form of a np.ndarray). Since instantiating a pandas DataFrame for each input is computationally more expensive, the specific warning is manually suppressed in this Wrapper.
- class utils.wrappers.TorchSupervisedLearningWrapper(model, optimizer, loss_function, task, n_classes=1, class_labels=None)[source]#
Bases:
WrapperBasic wrapper for torch classification models.
Warning: This wrapper entails only very basic functionality. This wrapper is only intend for basic supervised learning tasks solved with torch.
This wrapper turns any prediction function output into an iterable (list or np.ndarray) output.
- class utils.wrappers.TorchWrapper(link_function, feature_names=None, device='cpu')[source]#
Bases:
WrapperWrapper for torch link functions.
This wrapper turns any torch ouput tensor into a dict output and allows for dict inputs.
Examples
Basic usage:
>>> torch_module: torch.nn.Module = torch_module >>> module_device = 'cpu' >>> model_function = TorchWrapper(torch_module, device=module_device)
For classifiers returning class_labels:
>>> def link_function_class(x): >>> return torch.argmax(torch.softmax(torch_module(x), dim=-1), dim=-1) >>> model_function = TorchWrapper(link_function_class)
For classifiers returning probas:
>>> def link_function_probas(x): >>> return torch.softmax(torch_module(x), dim=-1) >>> model_function = TorchWrapper(link_function_probas)
If the dict-inputs may be in a different orderings:
>>> feature_orderings: list = ['feature_1', 'feature_2', 'feature_3'] >>> model_function = TorchWrapper(link_function_probas, feature_names=feature_orderings)
- __call__(x)[source]#
Runs the torch model with the given input dictionary.
- Parameters:
- Returns:
The model output as a dictionary following river conventions.
- Return type:
Examples
Basic usage:
>>> def link_function_probas(x): >>> return torch.softmax(torch_module(x), dim=-1) >>> model_function: typing.Callable = TorchWrapper(link_function_probas) >>> input_dict = {'feature_1': 1, 'feature_2': 0} >>> model_function(input_dict) >>> {0: 0.45, 1: 0.05, 2: 0.5}
Modules
This module contains River Model Wrappers to turn the output of river models into lists or arrays. |
|