Source code for explainer.sage.batch

"""
This module contains the Batch SAGE explainer.
"""

import random
from typing import Callable, Any, Sequence, Union, Dict, Optional, List

import numpy as np
from river.metrics.base import Metric
from tqdm import tqdm

from ixai.explainer.base import _get_mean_model_output
from ixai.imputer import BaseImputer, MarginalImputer
from ixai.storage import BatchStorage
from ixai.storage.base import BaseStorage
from ixai.utils.validators import validate_model_function, validate_loss_function


[docs]class BatchSage: """Batch SAGE Explainer Computes SAGE importance values according to its original definition in https://arxiv.org/abs/2004.00668. A Storage is updated with all observations from a stream and an explanation is created with access to all of these observations at once. This can be computationally challenging for large amounts of observations. Args: model_function (Callable[[Any], Any]): The Model function to be explained (e.g. model.predict_one (river), model.predict_proba (sklearn)). loss_function (Union[Metric, Callable[[Any, Dict], float]]): The loss function for which the importance values are calculated. This can either be a callable function or a predefined river.metric.base.Metric.<br> - river.metric.base.Metric: Any Metric implemented in river (e.g. river.metrics.CrossEntropy() for classification or river.metrics.MSE() for regression).<br> - callable function: The loss_function needs to follow the signature of loss_function(y_true, y_pred) and handle the output dimensions of the model function. Smaller values are interpreted as being better if not overriden with `loss_bigger_is_better=True`. `y_pred` is passed as a dict. feature_names (Sequence[Union[str, int, float]]): List of feature names to be explained for the model. storage (BaseStorage, optional): Optional incremental data storage Mechanism. Defaults to `GeometricReservoirStorage(size=100)` for dynamic modelling settings (`dynamic_setting=True`) and `UniformReservoirStorage(size=100)` in static modelling settings (`dynamic_setting=False`). imputer (BaseImputer, optional): Incremental imputing strategy to be used. Defaults to `MarginalImputer(sampling_strategy='joint')`. n_inner_samples (int): Number of model evaluation per feature and explanation step (observation). Defaults to 1. Attributes: feature_names (Sequence[Union[str, int, float]]): The feature names of the dataset. n_inner_samples (int): Number of model evaluation per feature and explanation step (observation). """ def __init__( self, model_function: Callable[[Any], Any], feature_names: Sequence[Union[str, int, float]], loss_function: Union[Metric, Callable[[Any, Dict], float]], n_inner_samples: int = 1, storage: Optional[BaseStorage] = None, imputer: Optional[BaseImputer] = None, ): self.feature_names = feature_names self.n_inner_samples = n_inner_samples self._model_function = validate_model_function(model_function) self._loss_function = validate_loss_function(loss_function) self._storage: BaseStorage = storage if self._storage is None: self._storage = BatchStorage(store_targets=True) self._imputer: BaseImputer = imputer if self._imputer is None: self._imputer = MarginalImputer( sampling_strategy='joint', model_function=self._model_function, storage_object=self._storage) self.importance_values = {feature_name: 0. for feature_name in self.feature_names}
[docs] def update_storage(self, x_i: dict, y_i: Any): """Updates the data storage with one observation (x_i, y_i). Args: x_i (dict): The input features of the current observation. y_i (Any): Target label of the current observation. """ self._storage.update(x=x_i, y=y_i)
[docs] def explain_one( self, x_i, y_i, n_inner_samples: Optional[int] = None, original_sage: bool = False, verbose: bool = True ) -> dict: """Explain one observation (x_i, y_i) with all data stored. Args: x_i (dict): The input features of the current observation as a dict of feature names to feature values. y_i (Any): Target label of the current observation. n_inner_samples (int, optional): Number of model evaluation per feature for the current explanation step (observation). Defaults to `None`. original_sage (bool): Flag indicating if the original definition of SAGE is used (`True`) or not (`False`). Defaults to `False`. Returns: (dict): The current SAGE feature importance scores. """ self._storage.update(x_i, y_i) x_data, y_data = self._storage.get_data() if original_sage: self.explain_many_original( x_data=x_data, y_data=y_data, n_inner_samples=n_inner_samples, verbose=verbose) else: self.explain_many( x_data=x_data, y_data=y_data, n_inner_samples=n_inner_samples, verbose=verbose) return self.importance_values
[docs] def explain_many( self, x_data: List[dict], y_data: List[Any], n_inner_samples: Optional[int] = None, verbose: bool = True ) -> dict: """Explain one observation (x_i, y_i) with all data stored. Args: x_data (List[dict]): A list of input data to be explained, as dicts mapping from feature names to feature values. y_data (List[Any]): Target label of the current observation. n_inner_samples (int, optional): Number of model evaluation per feature for the current explanation step (observation). Defaults to `None`. verbose (bool): Flag indicating if the explanation should print to console (`True`) or not (`False`). Returns: (dict): The current SAGE feature importance scores. """ if n_inner_samples is None: n_inner_samples = self.n_inner_samples sage_values = {feature: 0. for feature in self.feature_names} n_data = len(x_data) all_predictions = self._model_function(x_data) marginal_prediction = _get_mean_model_output(all_predictions) for n, (x_i, y_i) in tqdm(enumerate(zip(x_data, y_data), start=1), total=n_data, disable=not verbose): permutation_chain = np.random.permutation(self.feature_names) loss_previous = self._loss_function(y_true=y_i, y_prediction=marginal_prediction) features_not_in_s = set(self.feature_names) for feature in permutation_chain: features_not_in_s.remove(feature) predictions = self._imputer.impute( feature_subset=features_not_in_s, x_i=x_i, n_samples=n_inner_samples ) y = _get_mean_model_output(predictions) feature_loss = self._loss_function(y_true=y_i, y_prediction=y) marginal_contribution = loss_previous - feature_loss sage_values[feature] += marginal_contribution loss_previous = feature_loss n_data = n self.importance_values = {feature: sage_value / n_data for feature, sage_value in sage_values.items()} return self.importance_values
[docs] def explain_many_original( self, x_data: List[dict], y_data: List[Any], n_inner_samples: Optional[int] = None, verbose: bool = True ) -> dict: """Explain one observation (x_i, y_i) with all data stored according to the original definition in https://arxiv.org/abs/2004.00668. Args: x_data (List[dict]): A list of input data to be explained, as dicts mapping from feature names to feature values. y_data (List[Any]): Target label of the current observation. n_inner_samples (int, optional): Number of model evaluation per feature for the current explanation step (observation). Defaults to `None`. verbose (bool): Flag indicating if the explanation should print to console (`True`) or not (`False`). Returns: (dict): The current SAGE feature importance scores. """ if n_inner_samples is None: n_inner_samples = self.n_inner_samples sage_values = {feature: 0. for feature in self.feature_names} n_data = len(x_data) all_predictions = self._model_function(x_data) marginal_prediction = _get_mean_model_output(all_predictions) for n, (x_i, y_i) in tqdm(enumerate(zip(x_data, y_data), start=1), total=n_data, disable=not verbose): permutation_chain = np.random.permutation(self.feature_names) x_s = {} loss_previous = self._loss_function(y_true=y_i, y_prediction=marginal_prediction) for feature in permutation_chain: x_s[feature] = x_i[feature] predictions = [] for _ in range(1, n_inner_samples + 1): x_marginal = x_data[random.randint(0, n_data - 1)] x_marginal = {**x_marginal, **x_s} predictions.append(self._model_function(x_marginal)) y = _get_mean_model_output(predictions) feature_loss = self._loss_function(y_true=y_i, y_prediction=y) marginal_contribution = loss_previous - feature_loss sage_values[feature] += marginal_contribution loss_previous = feature_loss n_data = n self.importance_values = {feature: sage_value / n_data for feature, sage_value in sage_values.items()} return self.importance_values