Source code for explainer.base

"""
This module gathers base Explanation Methods
"""
import copy
import math
import abc
from typing import Union, Sequence, Dict, List, Callable, Any, Optional

from river.metrics.base import Metric

from ixai.imputer import BaseImputer, MarginalImputer
from ixai.storage import GeometricReservoirStorage, UniformReservoirStorage
from ixai.storage.base import BaseStorage
from ixai.utils.tracker.base import Tracker
from ixai.utils.tracker import MultiValueTracker, WelfordTracker, ExponentialSmoothingTracker
from ixai.utils.validators.loss import validate_loss_function
from ixai.utils.validators.model import validate_model_function


[docs]class BaseIncrementalExplainer(metaclass=abc.ABCMeta): """Base class for incremental explainer algorithms. Warning: This class should not be used directly. Use derived classes instead. Args: model_function (Callable): The Model function to be explained. feature_names (list): List of feature names to be explained for the model. Attributes: feature_names (list[typing.Any]): List of feature names that are explained. number_of_features (int): Number of features that are explained. seen_samples (int): Number of instances observed. """ @abc.abstractmethod def __init__( self, model_function: Callable[[Any], Any], feature_names: Sequence[Union[str, int, float]] ): self._model_function = validate_model_function(model_function) self.feature_names = feature_names self.number_of_features: int = len(feature_names) self.seen_samples: int = 0 def __repr__(self): return f"Explainer for {self.number_of_features} features after {self.seen_samples} samples."
[docs]class BaseIncrementalFeatureImportance(BaseIncrementalExplainer): """Base class for incremental feature importance explainer algorithms. Warning: This class should not be used directly. Use derived classes instead. """ @abc.abstractmethod def __init__( self, model_function: Callable[[Any], Any], loss_function: Union[Metric, Callable[[Any, Dict], float]], feature_names: Sequence[Union[str, int, float]], storage: Optional[BaseStorage] = None, imputer: Optional[BaseImputer] = None, dynamic_setting: bool = False, smoothing_alpha: Optional[float] = None ): super().__init__(model_function, feature_names) self._loss_function = validate_loss_function(loss_function) self._smoothing_alpha = 0.001 if smoothing_alpha is None else smoothing_alpha if dynamic_setting: assert 0. < smoothing_alpha <= 1., f"The smoothing parameter needs to be in the range" \ f" of ']0,1]' and not " \ f"'{self._smoothing_alpha}'." base_tracker = ExponentialSmoothingTracker(alpha=self._smoothing_alpha) else: base_tracker = WelfordTracker() self._marginal_loss_tracker: Tracker = copy.deepcopy(base_tracker) self._model_loss_tracker: Tracker = copy.deepcopy(base_tracker) self._marginal_prediction_tracker: MultiValueTracker = MultiValueTracker(copy.deepcopy(base_tracker)) self._importance_trackers: MultiValueTracker = MultiValueTracker(copy.deepcopy(base_tracker)) self._variance_trackers: MultiValueTracker = MultiValueTracker(copy.deepcopy(base_tracker)) self._storage: BaseStorage = storage if self._storage is None: if dynamic_setting: self._storage = GeometricReservoirStorage(store_targets=False, size=100) else: self._storage = UniformReservoirStorage(store_targets=False, size=100) self._imputer: BaseImputer = imputer if self._imputer is None: self._imputer = MarginalImputer(self._model_function, 'joint', self._storage)
[docs] @abc.abstractmethod def explain_one(self, *args, **kwargs): raise NotImplementedError
@property def importance_values(self): """Incremental Importance Values property.""" return self._importance_trackers.get() @property def variances(self): """Incremental Variances values property.""" return self._variance_trackers.get()
[docs] def get_normalized_importance_values(self, mode: str = 'sum') -> dict: """Normalizes the importance scores. Args: mode (str): The normalization mode to be applied. Possible values are 'sum' and 'delta'. Defaults to 'sum'. - sum: Normalizes the importance scores by division through the sum of importance scores. - delta: Normalizes the importance scores by division through the difference between the max of the importance scores and the min of the importance scores. Returns: (dict): The normalized importance values. """ return self._normalize_importance_values(self.importance_values, mode=mode)
[docs] def get_confidence_bound(self, delta: float): """Calculates Delta-Confidence Bounds. Args: delta (float): The confidence parameter. Must be a value in the interval of ]0,1]. Returns: (dict): The upper confidence bound around the point estimate of the importance values. This value needs to be added to the top and bottom of the point estimate. """ assert 0 < delta <= 1., f"Delta must be float in the interval of ]0,1] and not {delta}." return { feature_name: (1 - self._smoothing_alpha) ** self.seen_samples + (1 / math.sqrt(delta)) * math.sqrt(self.variances[feature_name]) * math.sqrt(self._smoothing_alpha / (2 - self._smoothing_alpha)) for feature_name in self.feature_names}
@staticmethod def _normalize_importance_values(importance_values: dict, mode: str = 'sum') -> dict: importance_values_list = list(importance_values.values()) if mode == 'delta': factor = max(importance_values_list) - min(importance_values_list) elif mode == 'sum': factor = sum(importance_values_list) else: raise NotImplementedError(f"The mode must be either 'sum', or 'delta' not '{mode}'.") try: return {feature: importance_value / factor for feature, importance_value in importance_values.items()} except ZeroDivisionError: return {feature: 0.0 for feature, importance_value in importance_values.items()}
[docs] def update_storage(self, x_i: dict, y_i: Optional[Any] = None): """Manually updates the data storage with the given observation. Args: x_i (dict): The input features of the current observation. y_i (Any, optional): Target label of the current observation. Defaults to `None` """ self._storage.update(x=x_i, y=y_i)
def _get_mean_model_output(model_outputs: List[dict]) -> dict: """Calculates the mean values of a list of dict model outputs. Args: model_outputs (list[dict]): List of model outputs. Returns: (dict) The mean model output, where every label value is the average of all individual label values. """ all_labels = {label for model_output in model_outputs for label in model_output} mean_output = {label: sum([output.get(label, 0) for output in model_outputs]) / len(model_outputs) for label in all_labels} return mean_output