utils.wrappers.torch#
Classes
|
Basic wrapper for torch classification models. |
|
Wrapper for torch link functions. |
- class utils.wrappers.torch.TorchSupervisedLearningWrapper(model, optimizer, loss_function, task, n_classes=1, class_labels=None)[source]#
Bases:
WrapperBasic wrapper for torch classification models.
Warning: This wrapper entails only very basic functionality. This wrapper is only intend for basic supervised learning tasks solved with torch.
This wrapper turns any prediction function output into an iterable (list or np.ndarray) output.
- class utils.wrappers.torch.TorchWrapper(link_function, feature_names=None, device='cpu')[source]#
Bases:
WrapperWrapper for torch link functions.
This wrapper turns any torch ouput tensor into a dict output and allows for dict inputs.
Examples
Basic usage:
>>> torch_module: torch.nn.Module = torch_module >>> module_device = 'cpu' >>> model_function = TorchWrapper(torch_module, device=module_device)
For classifiers returning class_labels:
>>> def link_function_class(x): >>> return torch.argmax(torch.softmax(torch_module(x), dim=-1), dim=-1) >>> model_function = TorchWrapper(link_function_class)
For classifiers returning probas:
>>> def link_function_probas(x): >>> return torch.softmax(torch_module(x), dim=-1) >>> model_function = TorchWrapper(link_function_probas)
If the dict-inputs may be in a different orderings:
>>> feature_orderings: list = ['feature_1', 'feature_2', 'feature_3'] >>> model_function = TorchWrapper(link_function_probas, feature_names=feature_orderings)
- __call__(x)[source]#
Runs the torch model with the given input dictionary.
- Parameters:
- Returns:
The model output as a dictionary following river conventions.
- Return type:
Examples
Basic usage:
>>> def link_function_probas(x): >>> return torch.softmax(torch_module(x), dim=-1) >>> model_function: typing.Callable = TorchWrapper(link_function_probas) >>> input_dict = {'feature_1': 1, 'feature_2': 0} >>> model_function(input_dict) >>> {0: 0.45, 1: 0.05, 2: 0.5}