utils.wrappers.torch#

Classes

TorchSupervisedLearningWrapper(model, ...[, ...])

Basic wrapper for torch classification models.

TorchWrapper(link_function[, feature_names, ...])

Wrapper for torch link functions.

class utils.wrappers.torch.TorchSupervisedLearningWrapper(model, optimizer, loss_function, task, n_classes=1, class_labels=None)[source]#

Bases: Wrapper

Basic wrapper for torch classification models.

Warning: This wrapper entails only very basic functionality. This wrapper is only intend for basic supervised learning tasks solved with torch.

This wrapper turns any prediction function output into an iterable (list or np.ndarray) output.

__call__(x_i)[source]#
learn_one(x, y)[source]#
predict_one(x_i)[source]#
predict_proba_one(x_i)[source]#
class utils.wrappers.torch.TorchWrapper(link_function, feature_names=None, device='cpu')[source]#

Bases: Wrapper

Wrapper for torch link functions.

This wrapper turns any torch ouput tensor into a dict output and allows for dict inputs.

Examples

Basic usage:

>>> torch_module: torch.nn.Module = torch_module
>>> module_device = 'cpu'
>>> model_function = TorchWrapper(torch_module, device=module_device)

For classifiers returning class_labels:

>>> def link_function_class(x):
>>>     return torch.argmax(torch.softmax(torch_module(x), dim=-1), dim=-1)
>>> model_function = TorchWrapper(link_function_class)

For classifiers returning probas:

>>> def link_function_probas(x):
>>>     return torch.softmax(torch_module(x), dim=-1)
>>> model_function = TorchWrapper(link_function_probas)

If the dict-inputs may be in a different orderings:

>>> feature_orderings: list = ['feature_1', 'feature_2', 'feature_3']
>>> model_function = TorchWrapper(link_function_probas, feature_names=feature_orderings)
__call__(x)[source]#

Runs the torch model with the given input dictionary.

Parameters:
  • x (Union[list[dict], dict]) – Input features in the form of a dict (1d-input) mapping from feature names to

  • dicts. (feature values or a list of such) –

Returns:

The model output as a dictionary following river conventions.

Return type:

(Union[list[dict], dict])

Examples

Basic usage:

>>> def link_function_probas(x):
>>>     return torch.softmax(torch_module(x), dim=-1)
>>> model_function: typing.Callable = TorchWrapper(link_function_probas)
>>> input_dict = {'feature_1': 1, 'feature_2': 0}
>>> model_function(input_dict)
>>> {0: 0.45, 1: 0.05, 2: 0.5}