Source code for storage.tree_storage

"""
This module contains the TreeStorage and the MeanVarRegressor leaf classifier.
"""
import typing
from typing import Dict, Any, Optional, List, Union, Tuple
import random

import numpy as np
from river import base
from river.tree import HoeffdingAdaptiveTreeClassifier, HoeffdingAdaptiveTreeRegressor, \
    HoeffdingTreeRegressor, HoeffdingTreeClassifier
from river.metrics import R2, Accuracy
from river.utils import Rolling

from ixai.utils.tracker.welford import WelfordTracker
from .geometric_reservoir_storage import GeometricReservoirStorage
from .base import BaseStorage


NODE_SEPERATOR: str = "|STOP|"


[docs]def get_all_tree_paths(node, walked_path: str = '', paths=None) -> List[str]: if paths is None: paths = [] try: children = node.children for branch_no, child in enumerate(children): child_path = walked_path + "|".join((str(node), str(node.repr_split), str(branch_no))) child_path += NODE_SEPERATOR _ = get_all_tree_paths(child, walked_path=child_path, paths=paths) except AttributeError: walked_path += str(node) + NODE_SEPERATOR paths.append(walked_path) return paths
[docs]def walk_through_tree( node: typing.Union["Branch", "Leaf"], x_i: dict, until_leaf: bool = True ) -> typing.Iterable[typing.Union["Branch", "Leaf"]]: """Traverses a decision tree given a data point, and a starting node. Args: node: Target as float or integer x_i (dict): Data point as Dicts. until_leaf (bool): Flag weather to traverse the tree until a leaf node (``True``) or just the next node (``False``). Yields: The next node in the tree. """ yield node try: yield from walk_through_tree(node.next(x_i), x_i, until_leaf) except KeyError: if until_leaf: children = node.children ratios = [child.total_weight for child in children] node = random.choices(children, weights=ratios, k=1)[0] yield node yield from walk_through_tree(node, x_i, until_leaf) except AttributeError: # we are at a leaf node -> which was already returned pass
[docs]class MeanVarRegressor(base.Regressor): """A simple regressor model intended to be used as a leaf model in Decision Tree Regressors. The Regressor keeps track of the mean and standard deviation of the incoming numerical labels and samples prediction values from a normal distribution according to the current mean and standard deviation. """ def __init__(self): self._stat_object = WelfordTracker()
[docs] def predict_one(self, x=None) -> float: """Predicts a value based on the current summary statistics. Args: x (Any): input features (that are not used for prediction) Returns: float: The predicted value. """ mean = self._stat_object.mean std = self._stat_object.std prediction = np.random.normal(loc=mean, scale=abs(std), size=1)[0] return prediction
[docs] def learn_one(self, x: Any, y: base.typing.RegTarget) -> "base.Regressor": """Updates the summary statistics based on the target labels. Args: x (Any): input features (that are not used for prediction) y (base.typing.RegTarget): A number that is transformable into a float. Returns: base.Regressor: The Regressor itself. """ y_i = float(y) self._stat_object.update(value_i=y_i) return self
[docs]class TreeStorage(BaseStorage): """ A Tree Storage that trains incremental decision trees for each feature. Attributes: feature_names (list[str]): List of features stored. cat_feature_names (list[str]): List of categorical features stored. num_feature_names (list[str]): List of numerical features stored. performances (dict[Any, Union[R2, Accuracy]]): Dictionary of performance metrics per incremental decision tree for each feature stored. data_reservoirs (dict[str, dict]): Dictionary of data reservoirs for each feature and leaf nodes. """ def __init__( self, cat_feature_names: list, num_feature_names: list, max_depth: int = 5, leaf_reservoir_length: int = 10, grace_period: int = 200, seed: Optional[int] = None ): """ A Tree Storage that trains incremental decision trees for each feature. Args: cat_feature_names (list[str]): List of categorical features to be stored. num_feature_names (list[str]): List of numerical features to be stored. max_depth (int): Maximum tree depth for the incremental decision trees. Defaults to 5. leaf_reservoir_length (int): Size of the reservoir stored at each leaf node of each feature's incremental decision tree. Defaults to 10. grace_period (int): Grace period of the underlying river Hoeffding Adaptive Trees. Defaults to 200. seed (int, optional): Random seed of the underlying river Hoeffding Adaptive Trees. Defaults to None. """ self.feature_names = cat_feature_names + num_feature_names self.cat_feature_names = cat_feature_names self.num_feature_names = num_feature_names self._leaf_reservoir_length = leaf_reservoir_length self._seen_samples = 0 self._storage_x = {cat_feature: HoeffdingAdaptiveTreeClassifier( max_depth=max_depth, leaf_prediction='nba', binary_split=True, grace_period=grace_period, seed=seed) for cat_feature in self.cat_feature_names} self._storage_x.update( {num_feature: HoeffdingAdaptiveTreeRegressor( max_depth=max_depth, leaf_prediction='adaptive', binary_split=True, grace_period=grace_period, seed=seed) for num_feature in self.num_feature_names}) self.performances = {num_feature: Rolling(R2(), window_size=1000) for num_feature in self.num_feature_names} self.performances.update( {cat_feature: Rolling(Accuracy(), window_size=1000) for cat_feature in self.cat_feature_names}) self.data_reservoirs = {feature: {} for feature in self.feature_names}
[docs] def update(self, x: Dict, y: Optional[Any] = None): """Given a data point, it updates the storage. Args: x: Features as List of Dicts y: Target as float or integer (not used) Returns: None """ for feature_name in x.keys(): if feature_name in self.feature_names: feature_model = self._storage_x[feature_name] x_i = {**x} y_i = x_i.pop(feature_name) feature_model.learn_one(x_i, y_i) self._update_data_reservoirs(feature_name, x_i, x) pred_i = feature_model.predict_one(x_i) self.performances[feature_name].update(y_i, pred_i) self._seen_samples += 1
[docs] @staticmethod def get_path_through_tree(node, x_i) -> str: """Given a data point and a starting node, traverses the decision tree. Args: node: Root node of the model. x_i: Data point to traverse the tree with. Returns: str: The walked path through the decision tree. """ walked_path = '' path = iter(walk_through_tree(node, x_i, until_leaf=True)) for stop in path: walked_path += str(stop) if hasattr(stop, 'repr_split'): walked_path += "|" + str(stop.repr_split) + "|" + str(stop.branch_no(x_i)) walked_path += NODE_SEPERATOR return walked_path
def _update_data_reservoirs(self, feature_name, x_i, x): """Update the data reservoir at a leaf node with a new sample. Args: feature_name str: The feature for which to update the stored data reservoirs. x_i dict: The data point to find the leaf node with. x dict: The data point to be inserted in the leaf node's reservoir. """ root_node = self._storage_x[feature_name]._root data_reservoir = self.data_reservoirs[feature_name] leaf_id = self.get_path_through_tree(root_node, x_i) if leaf_id not in data_reservoir: data_reservoir[leaf_id] = GeometricReservoirStorage( size=self._leaf_reservoir_length, store_targets=False, constant_probability=1.0) self._delete_outdated_reservoirs(feature_name, root_node) data_reservoir[leaf_id].update(x)
[docs] def __call__(self, feature_name: Any) -> Tuple[Union[HoeffdingTreeRegressor, HoeffdingTreeClassifier], str]: """Given a feature name, returns the associated data reservoirs. Args: feature_name (str): The feature name for which to return the data reservoirs. Returns: (dict, str): Tuple of data reservoir and flag if it is stored as a numerical feature or categorical. Raises: ValueError: If `feature_name` is not stored as a categorical feature nor a numerical feature. """ if feature_name in self.cat_feature_names: return self._storage_x[feature_name], 'cat' elif feature_name in self.num_feature_names: return self._storage_x[feature_name], 'num' else: raise ValueError(f"The {feature_name} is not stored.")
def _delete_outdated_reservoirs(self, feature_name: str, root_node: typing.Union["Branch", "Leaf"]): """Deletes the outdated reservoirs that no longer are part of all paths in the decision trees. Args: feature_name (str): The feature name for which outdated leafs might be deleted. root_node ("Branch", "Leaf"): The root node of the feature decision tree. """ all_leafs = get_all_tree_paths(root_node) reservoirs_labels = list(self.data_reservoirs[feature_name].keys()) for reservoirs_label in reservoirs_labels: if reservoirs_label not in all_leafs: del self.data_reservoirs[feature_name][reservoirs_label] def __len__(self): """Returns the number of samples observed in the storage object.""" return self._seen_samples