Source code for explainer.sage.incremental

"""
This module contains the incremental SAGE explainer.
"""
from typing import Callable, Any, Union, Dict, Sequence, Optional

import numpy as np
from river.metrics.base import Metric

from ixai.explainer.base import BaseIncrementalFeatureImportance, _get_mean_model_output
from ixai.imputer import BaseImputer
from ixai.storage.base import BaseStorage


[docs]class IncrementalSage(BaseIncrementalFeatureImportance): """Incremental SAGE Explainer Computes SAGE importance values incrementally by applying exponential smoothing. For each input instance tuple x_i, y_i one update of the explanation procedure is performed. Args: model_function (Callable): The Model function to be explained (e.g. ``model.predict_one`` (river), ``model.predict_proba`` (sklearn)). loss_function (Union[Metric, Callable[[Any, Dict], float]]): The loss function for which the importance values are calculated. This can either be a callable function or a predefined ``river.metric.base.Metric``. river.metric.base.Metric: Any Metric implemented in river (e.g. ``river.metrics.CrossEntropy()`` for classification or ``river.metrics.MSE()`` for regression). callable function: The loss_function needs to follow the signature of loss_function(y_true, y_pred) and handle the output dimensions of the model function. Smaller values are interpreted as being better if not overriden with ``loss_bigger_is_better=True``. ``y_pred`` is passed as a dict. feature_names (Sequence[Union[str, int, float]]): List of feature names to be explained for the model. smoothing_alpha (float, optional): The smoothing parameter for the exponential smoothing of the importance values. Should be in the interval between ]0,1]. Defaults to 0.001. storage (BaseStorage, optional): Optional incremental data storage Mechanism. Defaults to ``GeometricReservoirStorage(size=100)`` for dynamic modelling settings (``dynamic_setting=True``) and ``UniformReservoirStorage(size=100)`` in static modelling settings (``dynamic_setting=False``). imputer (BaseImputer, optional): Incremental imputing strategy to be used. Defaults to ``MarginalImputer(sampling_strategy='joint')``. n_inner_samples (int): Number of model evaluation per feature and explanation step (observation). Defaults to 1. dynamic_setting (bool): Flag to indicate if the modelling setting is dynamic ``True`` (changing model, and adaptive explanation) or a static modelling setting ``False`` (all observations contribute equally to the final importance). Defaults to ``True``. loss_bigger_is_better (bool): Flag that indicates if a smaller loss value indicates a better fit ('True') or not ('False'). This is only used to represent the marginal- and model-loss more sensibly. Attributes: marginal_prediction (dict): The current marginal prediction of the model_function (smoothed over time). n_inner_samples (int): Number of model evaluation per feature and explanation step (observation). """ def __init__( self, model_function: Callable, loss_function: Union[Metric, Callable[[Any, Dict], float]], feature_names: Sequence[Union[str, int, float]], *, smoothing_alpha: Optional[float] = None, storage: Optional[BaseStorage] = None, imputer: Optional[BaseImputer] = None, n_inner_samples: int = 1, dynamic_setting: bool = True, loss_bigger_is_better: bool = False ): super(IncrementalSage, self).__init__( model_function=model_function, loss_function=loss_function, feature_names=feature_names, dynamic_setting=dynamic_setting, smoothing_alpha=smoothing_alpha, storage=storage, imputer=imputer ) self._loss_direction = 1. if loss_bigger_is_better else 0. self.n_inner_samples = n_inner_samples self.marginal_prediction: dict = {} @property def marginal_loss(self): """Marginal loss (loss of the model without any features, default prediction loss) property, which is smoothed over time.""" return self._marginal_loss_tracker.get() + self._loss_direction @property def model_loss(self): """Model loss (loss of model with features) property, which is smoothed over time.""" return self._model_loss_tracker.get() + self._loss_direction @property def explained_loss(self): """Explained loss (difference between the current marginal and model loss.) property.""" return self.marginal_loss - self.model_loss
[docs] def explain_one( self, x_i: dict, y_i: Any, n_inner_samples: Optional[int] = None, update_storage: bool = True ) -> dict: """Explain one observation (x_i, y_i). Args: x_i (dict): The input features of the current observation as a dict of feature names to feature values. y_i (Any): Target label of the current observation. n_inner_samples (int, optional): Number of model evaluation per feature for the current explanation step (observation). Defaults to ``None``. update_storage (bool): Flag if the underlying incremental data storage mechanism is to be updated with the new observation (``True``) or not (``False``). Defaults to ``True``. Returns: (dict): The current SAGE feature importance scores. """ if self.seen_samples >= 1: if n_inner_samples is None: n_inner_samples = self.n_inner_samples permutation_chain = np.random.permutation(self.feature_names) y_i_pred = self._model_function(x_i) model_loss = self._loss_function(y_i, y_i_pred) self._model_loss_tracker.update(model_loss) self._marginal_prediction_tracker.update(y_i_pred) self.marginal_prediction = self._marginal_prediction_tracker.get_normalized() sample_loss = self._loss_function(y_i, self.marginal_prediction) self._marginal_loss_tracker.update(sample_loss) features_not_in_s = set(self.feature_names) marginal_contributions = {} for feature in permutation_chain: features_not_in_s.remove(feature) predictions = self._imputer.impute( feature_subset=features_not_in_s, x_i=x_i, n_samples=n_inner_samples ) y = _get_mean_model_output(predictions) feature_loss = self._loss_function(y_i, y) marginal_contribution = sample_loss - feature_loss sample_loss = feature_loss marginal_contributions[feature] = marginal_contribution self._importance_trackers.update(marginal_contributions) variances = { feature: (marginal_contributions[feature] - self.importance_values[feature])**2 for feature in self.feature_names } self._variance_trackers.update(variances) self.seen_samples += 1 if update_storage: self._storage.update(x_i, y_i) return self.importance_values