diff --git a/analysis/shap_model.py b/analysis/shap_model.py index 3ced9109c65c7a9a503f8f22daa8070763edf24c..0f6976d972a72b1e93ccbae377cb383f83a26200 100644 --- a/analysis/shap_model.py +++ b/analysis/shap_model.py @@ -1,6 +1,7 @@ -import shap import matplotlib.pyplot as plt +import lib.shap as shap + def shap_calculate(model, x, feature_names): explainer = shap.Explainer(model.predict, x) diff --git a/diagram/__init__.py b/diagram/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/shap/__init__.py b/lib/shap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..622f0eb97d264a6c8f2a634f9ddd2347f36b70f7 --- /dev/null +++ b/lib/shap/__init__.py @@ -0,0 +1,144 @@ +from ._explanation import Cohorts, Explanation + +# explainers +from .explainers import other +from .explainers._additive import AdditiveExplainer +from .explainers._deep import DeepExplainer +from .explainers._exact import ExactExplainer +from .explainers._explainer import Explainer +from .explainers._gpu_tree import GPUTreeExplainer +from .explainers._gradient import GradientExplainer +from .explainers._kernel import KernelExplainer +from .explainers._linear import LinearExplainer +from .explainers._partition import PartitionExplainer +from .explainers._permutation import PermutationExplainer +from .explainers._sampling import SamplingExplainer +from .explainers._tree import TreeExplainer + +try: + # Version from setuptools-scm + from ._version import version as __version__ +except ImportError: + # Expected when running locally without build + __version__ = "0.0.0-not-built" + +_no_matplotlib_warning = "matplotlib is not installed so plotting is not available! Run `pip install matplotlib` " \ + "to fix this." + + +# plotting (only loaded if matplotlib is present) +def unsupported(*args, **kwargs): + raise ImportError(_no_matplotlib_warning) + + +class UnsupportedModule: + def __getattribute__(self, item): + raise ImportError(_no_matplotlib_warning) + + +try: + import matplotlib # noqa: F401 + have_matplotlib = True +except ImportError: + have_matplotlib = False +if have_matplotlib: + from . import plots + from .plots._bar import bar_legacy as bar_plot + from .plots._beeswarm import summary_legacy as summary_plot + from .plots._decision import decision as decision_plot + from .plots._decision import multioutput_decision as multioutput_decision_plot + from .plots._embedding import embedding as embedding_plot + from .plots._force import force as force_plot + from .plots._force import getjs, initjs, save_html + from .plots._group_difference import group_difference as group_difference_plot + from .plots._heatmap import heatmap as heatmap_plot + from .plots._image import image as image_plot + from .plots._monitoring import monitoring as monitoring_plot + from .plots._partial_dependence import partial_dependence as partial_dependence_plot + from .plots._scatter import dependence_legacy as dependence_plot + from .plots._text import text as text_plot + from .plots._violin import violin as violin_plot + from .plots._waterfall import waterfall as waterfall_plot +else: + bar_plot = unsupported + summary_plot = unsupported + decision_plot = unsupported + multioutput_decision_plot = unsupported + embedding_plot = unsupported + force_plot = unsupported + getjs = unsupported + initjs = unsupported + save_html = unsupported + group_difference_plot = unsupported + heatmap_plot = unsupported + image_plot = unsupported + monitoring_plot = unsupported + partial_dependence_plot = unsupported + dependence_plot = unsupported + text_plot = unsupported + violin_plot = unsupported + waterfall_plot = unsupported + # If matplotlib is available, then the plots submodule will be directly available. + # If not, we need to define something that will issue a meaningful warning message + # (rather than ModuleNotFound). + plots = UnsupportedModule() + + +# other stuff :) +from . import datasets, links, utils # noqa: E402 +from .actions._optimizer import ActionOptimizer # noqa: E402 +from .utils import approximate_interactions, sample # noqa: E402 + +#from . import benchmark +from .utils._legacy import kmeans # noqa: E402 + +# Use __all__ to let type checkers know what is part of the public API. +__all__ = [ + "Cohorts", + "Explanation", + + # Explainers + "other", + "AdditiveExplainer", + "DeepExplainer", + "ExactExplainer", + "Explainer", + "GPUTreeExplainer", + "GradientExplainer", + "KernelExplainer", + "LinearExplainer", + "PartitionExplainer", + "PermutationExplainer", + "SamplingExplainer", + "TreeExplainer", + + # Plots + "plots", + "bar_plot", + "summary_plot", + "decision_plot", + "multioutput_decision_plot", + "embedding_plot", + "force_plot", + "getjs", + "initjs", + "save_html", + "group_difference_plot", + "heatmap_plot", + "image_plot", + "monitoring_plot", + "partial_dependence_plot", + "dependence_plot", + "text_plot", + "violin_plot", + "waterfall_plot", + + # Other stuff + "datasets", + "links", + "utils", + "ActionOptimizer", + "approximate_interactions", + "sample", + "kmeans", +] diff --git a/lib/shap/_cext.cp310-win_amd64.pyd b/lib/shap/_cext.cp310-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..5c91703fe2cfb8f7f76402f0f576a9bac07b7e09 Binary files /dev/null and b/lib/shap/_cext.cp310-win_amd64.pyd differ diff --git a/lib/shap/_explanation.py b/lib/shap/_explanation.py new file mode 100644 index 0000000000000000000000000000000000000000..0a67fe4c15acfa3e700560823c435bb5de2563fb --- /dev/null +++ b/lib/shap/_explanation.py @@ -0,0 +1,901 @@ + +import copy +import operator + +import numpy as np +import pandas as pd +import scipy.cluster +import scipy.sparse +import scipy.spatial +import sklearn +from slicer import Alias, Obj, Slicer + +from .utils._exceptions import DimensionError +from .utils._general import OpChain + +op_chain_root = OpChain("shap.Explanation") +class MetaExplanation(type): + """ This metaclass exposes the Explanation object's methods for creating template op chains. + """ + + def __getitem__(cls, item): + return op_chain_root.__getitem__(item) + + @property + def abs(cls): + """ Element-wise absolute value op. + """ + return op_chain_root.abs + + @property + def identity(cls): + """ A no-op. + """ + return op_chain_root.identity + + @property + def argsort(cls): + """ Numpy style argsort. + """ + return op_chain_root.argsort + + @property + def sum(cls): + """ Numpy style sum. + """ + return op_chain_root.sum + + @property + def max(cls): + """ Numpy style max. + """ + return op_chain_root.max + + @property + def min(cls): + """ Numpy style min. + """ + return op_chain_root.min + + @property + def mean(cls): + """ Numpy style mean. + """ + return op_chain_root.mean + + @property + def sample(cls): + """ Numpy style sample. + """ + return op_chain_root.sample + + @property + def hclust(cls): + """ Hierarchical clustering op. + """ + return op_chain_root.hclust + + +class Explanation(metaclass=MetaExplanation): + """ A sliceable set of parallel arrays representing a SHAP explanation. + """ + def __init__( + self, + values, + base_values=None, + data=None, + display_data=None, + instance_names=None, + feature_names=None, + output_names=None, + output_indexes=None, + lower_bounds=None, + upper_bounds=None, + error_std=None, + main_effects=None, + hierarchical_values=None, + clustering=None, + compute_time=None + ): + self.op_history = [] + + self.compute_time = compute_time + + # cloning. TODOsomeday: better cloning :) + if issubclass(type(values), Explanation): + e = values + values = e.values + base_values = e.base_values + data = e.data + + self.output_dims = compute_output_dims(values, base_values, data, output_names) + values_shape = _compute_shape(values) + + if output_names is None and len(self.output_dims) == 1: + output_names = [f"Output {i}" for i in range(values_shape[self.output_dims[0]])] + + if len(_compute_shape(feature_names)) == 1: # TODO: should always be an alias once slicer supports per-row aliases + if len(values_shape) >= 2 and len(feature_names) == values_shape[1]: + feature_names = Alias(list(feature_names), 1) + elif len(values_shape) >= 1 and len(feature_names) == values_shape[0]: + feature_names = Alias(list(feature_names), 0) + + if len(_compute_shape(output_names)) == 1: # TODO: should always be an alias once slicer supports per-row aliases + output_names = Alias(list(output_names), self.output_dims[0]) + # if len(values_shape) >= 1 and len(output_names) == values_shape[0]: + # output_names = Alias(list(output_names), 0) + # elif len(values_shape) >= 2 and len(output_names) == values_shape[1]: + # output_names = Alias(list(output_names), 1) + + if output_names is not None and not isinstance(output_names, Alias): + output_names_order = len(_compute_shape(output_names)) + if output_names_order == 0: + pass + elif output_names_order == 1: + output_names = Obj(output_names, self.output_dims) + elif output_names_order == 2: + output_names = Obj(output_names, [0] + list(self.output_dims)) + else: + raise ValueError("shap.Explanation does not yet support output_names of order greater than 3!") + + if not hasattr(base_values, "__len__") or len(base_values) == 0: + pass + elif len(_compute_shape(base_values)) == len(self.output_dims): + base_values = Obj(base_values, list(self.output_dims)) + else: + base_values = Obj(base_values, [0] + list(self.output_dims)) + + self._s = Slicer( + values=values, + base_values=base_values, + data=list_wrap(data), + display_data=list_wrap(display_data), + instance_names=None if instance_names is None else Alias(instance_names, 0), + feature_names=feature_names, + output_names=output_names, + output_indexes=None if output_indexes is None else (self.output_dims, output_indexes), + lower_bounds=list_wrap(lower_bounds), + upper_bounds=list_wrap(upper_bounds), + error_std=list_wrap(error_std), + main_effects=list_wrap(main_effects), + hierarchical_values=list_wrap(hierarchical_values), + clustering=None if clustering is None else Obj(clustering, [0]) + ) + + @property + def shape(self): + """ Compute the shape over potentially complex data nesting. + """ + return _compute_shape(self._s.values) + + @property + def values(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.values + @values.setter + def values(self, new_values): + self._s.values = new_values + + @property + def base_values(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.base_values + @base_values.setter + def base_values(self, new_base_values): + self._s.base_values = new_base_values + + @property + def data(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.data + @data.setter + def data(self, new_data): + self._s.data = new_data + + @property + def display_data(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.display_data + @display_data.setter + def display_data(self, new_display_data): + if issubclass(type(new_display_data), pd.DataFrame): + new_display_data = new_display_data.values + self._s.display_data = new_display_data + + @property + def instance_names(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.instance_names + + @property + def output_names(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.output_names + @output_names.setter + def output_names(self, new_output_names): + self._s.output_names = new_output_names + + @property + def output_indexes(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.output_indexes + + @property + def feature_names(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.feature_names + @feature_names.setter + def feature_names(self, new_feature_names): + self._s.feature_names = new_feature_names + + @property + def lower_bounds(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.lower_bounds + + @property + def upper_bounds(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.upper_bounds + + @property + def error_std(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.error_std + + @property + def main_effects(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.main_effects + @main_effects.setter + def main_effects(self, new_main_effects): + self._s.main_effects = new_main_effects + + @property + def hierarchical_values(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.hierarchical_values + @hierarchical_values.setter + def hierarchical_values(self, new_hierarchical_values): + self._s.hierarchical_values = new_hierarchical_values + + @property + def clustering(self): + """ Pass-through from the underlying slicer object. + """ + return self._s.clustering + @clustering.setter + def clustering(self, new_clustering): + self._s.clustering = new_clustering + + def cohorts(self, cohorts): + """ Split this explanation into several cohorts. + + Parameters + ---------- + cohorts : int or array + If this is an integer then we auto build that many cohorts using a decision tree. If this is + an array then we treat that as an array of cohort names/ids for each instance. + """ + + if isinstance(cohorts, int): + return _auto_cohorts(self, max_cohorts=cohorts) + if isinstance(cohorts, (list, tuple, np.ndarray)): + cohorts = np.array(cohorts) + return Cohorts(**{name: self[cohorts == name] for name in np.unique(cohorts)}) + raise TypeError("The given set of cohort indicators is not recognized! Please give an array or int.") + + def __repr__(self): + """ Display some basic printable info, but not everything. + """ + out = ".values =\n"+self.values.__repr__() + if self.base_values is not None: + out += "\n\n.base_values =\n"+self.base_values.__repr__() + if self.data is not None: + out += "\n\n.data =\n"+self.data.__repr__() + return out + + def __getitem__(self, item): + """ This adds support for OpChain indexing. + """ + new_self = None + if not isinstance(item, tuple): + item = (item,) + + # convert any OpChains or magic strings + pos = -1 + for t in item: + pos += 1 + + # skip over Ellipsis + if t is Ellipsis: + pos += len(self.shape) - len(item) + continue + + orig_t = t + if issubclass(type(t), OpChain): + t = t.apply(self) + if issubclass(type(t), (np.int64, np.int32)): # because slicer does not like numpy indexes + t = int(t) + elif issubclass(type(t), np.ndarray): + t = [int(v) for v in t] # slicer wants lists not numpy arrays for indexing + elif issubclass(type(t), Explanation): + t = t.values + elif isinstance(t, str): + + # work around for 2D output_names since they are not yet slicer supported + output_names_dims = [] + if "output_names" in self._s._objects: + output_names_dims = self._s._objects["output_names"].dim + elif "output_names" in self._s._aliases: + output_names_dims = self._s._aliases["output_names"].dim + if pos != 0 and pos in output_names_dims: + if len(output_names_dims) == 1: + t = np.argwhere(np.array(self.output_names) == t)[0][0] + elif len(output_names_dims) == 2: + new_values = [] + new_base_values = [] + new_data = [] + new_self = copy.deepcopy(self) + for i, v in enumerate(self.values): + for j, s in enumerate(self.output_names[i]): + if s == t: + new_values.append(np.array(v[:,j])) + new_data.append(np.array(self.data[i])) + new_base_values.append(self.base_values[i][j]) + + new_self = Explanation( + np.array(new_values), + np.array(new_base_values), + np.array(new_data), + self.display_data, + self.instance_names, + np.array(new_data), + t, # output_names + self.output_indexes, + self.lower_bounds, + self.upper_bounds, + self.error_std, + self.main_effects, + self.hierarchical_values, + self.clustering + ) + new_self.op_history = copy.copy(self.op_history) + # new_self = copy.deepcopy(self) + # new_self.values = np.array(new_values) + # new_self.base_values = np.array(new_base_values) + # new_self.data = np.array(new_data) + # new_self.output_names = t + # new_self.feature_names = np.array(new_data) + # new_self.clustering = None + + # work around for 2D feature_names since they are not yet slicer supported + feature_names_dims = [] + if "feature_names" in self._s._objects: + feature_names_dims = self._s._objects["feature_names"].dim + if pos != 0 and pos in feature_names_dims and len(feature_names_dims) == 2: + new_values = [] + new_data = [] + for i, val_i in enumerate(self.values): + for s,v,d in zip(self.feature_names[i], val_i, self.data[i]): + if s == t: + new_values.append(v) + new_data.append(d) + new_self = copy.deepcopy(self) + new_self.values = new_values + new_self.data = new_data + new_self.feature_names = t + new_self.clustering = None + # return new_self + + if issubclass(type(t), (np.int8, np.int16, np.int32, np.int64)): + t = int(t) + + if t is not orig_t: + tmp = list(item) + tmp[pos] = t + item = tuple(tmp) + + # call slicer for the real work + item = tuple(v for v in item) # SML I cut out: `if not isinstance(v, str)` + if len(item) == 0: + return new_self + if new_self is None: + new_self = copy.copy(self) + new_self._s = new_self._s.__getitem__(item) + new_self.op_history.append({ + "name": "__getitem__", + "args": (item,), + "prev_shape": self.shape + }) + + return new_self + + def __len__(self): + return self.shape[0] + + def __copy__(self): + new_exp = Explanation( + self.values, + self.base_values, + self.data, + self.display_data, + self.instance_names, + self.feature_names, + self.output_names, + self.output_indexes, + self.lower_bounds, + self.upper_bounds, + self.error_std, + self.main_effects, + self.hierarchical_values, + self.clustering + ) + new_exp.op_history = copy.copy(self.op_history) + return new_exp + + def _apply_binary_operator(self, other, binary_op, op_name): + new_exp = self.__copy__() + new_exp.op_history = copy.copy(self.op_history) + new_exp.op_history.append({ + "name": op_name, + "args": (other,), + "prev_shape": self.shape + }) + if isinstance(other, Explanation): + new_exp.values = binary_op(new_exp.values, other.values) + if new_exp.data is not None: + new_exp.data = binary_op(new_exp.data, other.data) + if new_exp.base_values is not None: + new_exp.base_values = binary_op(new_exp.base_values, other.base_values) + else: + new_exp.values = binary_op(new_exp.values, other) + if new_exp.data is not None: + new_exp.data = binary_op(new_exp.data, other) + if new_exp.base_values is not None: + new_exp.base_values = binary_op(new_exp.base_values, other) + return new_exp + + def __add__(self, other): + return self._apply_binary_operator(other, operator.add, "__add__") + + def __radd__(self, other): + return self._apply_binary_operator(other, operator.add, "__add__") + + def __sub__(self, other): + return self._apply_binary_operator(other, operator.sub, "__sub__") + + def __rsub__(self, other): + return self._apply_binary_operator(other, operator.sub, "__sub__") + + def __mul__(self, other): + return self._apply_binary_operator(other, operator.mul, "__mul__") + + def __rmul__(self, other): + return self._apply_binary_operator(other, operator.mul, "__mul__") + + def __truediv__(self, other): + return self._apply_binary_operator(other, operator.truediv, "__truediv__") + + # @property + # def abs(self): + # """ Element-size absolute value operator. + # """ + # new_self = copy.copy(self) + # new_self.values = np.abs(new_self.values) + # new_self.op_history.append({ + # "name": "abs", + # "prev_shape": self.shape + # }) + # return new_self + + def _numpy_func(self, fname, **kwargs): + """ Apply a numpy-style function to this Explanation. + """ + new_self = copy.copy(self) + axis = kwargs.get("axis", None) + + # collapse the slicer to right shape + if axis == 0: + new_self = new_self[0] + elif axis == 1: + new_self = new_self[1] + elif axis == 2: + new_self = new_self[2] + if axis in [0,1,2]: + new_self.op_history = new_self.op_history[:-1] # pop off the slicing operation we just used + + if self.feature_names is not None and not is_1d(self.feature_names) and axis == 0: + new_values = self._flatten_feature_names() + new_self.feature_names = np.array(list(new_values.keys())) + new_self.values = np.array([getattr(np, fname)(v,0) for v in new_values.values()]) + new_self.clustering = None + else: + new_self.values = getattr(np, fname)(np.array(self.values), **kwargs) + if new_self.data is not None: + try: + new_self.data = getattr(np, fname)(np.array(self.data), **kwargs) + except Exception: + new_self.data = None + if new_self.base_values is not None and issubclass(type(axis), int) and len(self.base_values.shape) > axis: + new_self.base_values = getattr(np, fname)(self.base_values, **kwargs) + elif issubclass(type(axis), int): + new_self.base_values = None + + if axis == 0 and self.clustering is not None and len(self.clustering.shape) == 3: + if self.clustering.std(0).sum() < 1e-8: + new_self.clustering = self.clustering[0] + else: + new_self.clustering = None + + new_self.op_history.append({ + "name": fname, + "kwargs": kwargs, + "prev_shape": self.shape, + "collapsed_instances": axis == 0 + }) + + return new_self + + def mean(self, axis): + """ Numpy-style mean function. + """ + return self._numpy_func("mean", axis=axis) + + def max(self, axis): + """ Numpy-style mean function. + """ + return self._numpy_func("max", axis=axis) + + def min(self, axis): + """ Numpy-style mean function. + """ + return self._numpy_func("min", axis=axis) + + def sum(self, axis=None, grouping=None): + """ Numpy-style mean function. + """ + if grouping is None: + return self._numpy_func("sum", axis=axis) + elif axis == 1 or len(self.shape) == 1: + return group_features(self, grouping) + else: + raise DimensionError("Only axis = 1 is supported for grouping right now...") + + def hstack(self, other): + """ Stack two explanations column-wise. + """ + assert self.shape[0] == other.shape[0], "Can't hstack explanations with different numbers of rows!" + assert np.max(np.abs(self.base_values - other.base_values)) < 1e-6, "Can't hstack explanations with different base values!" + + new_exp = Explanation( + values=np.hstack([self.values, other.values]), + base_values=self.base_values, + data=self.data, + display_data=self.display_data, + instance_names=self.instance_names, + feature_names=self.feature_names, + output_names=self.output_names, + output_indexes=self.output_indexes, + lower_bounds=self.lower_bounds, + upper_bounds=self.upper_bounds, + error_std=self.error_std, + main_effects=self.main_effects, + hierarchical_values=self.hierarchical_values, + clustering=self.clustering, + ) + return new_exp + + # def reshape(self, *args): + # return self._numpy_func("reshape", newshape=args) + + @property + def abs(self): + return self._numpy_func("abs") + + @property + def identity(self): + return self + + @property + def argsort(self): + return self._numpy_func("argsort") + + @property + def flip(self): + return self._numpy_func("flip") + + + def hclust(self, metric="sqeuclidean", axis=0): + """ Computes an optimal leaf ordering sort order using hclustering. + + hclust(metric="sqeuclidean") + + Parameters + ---------- + metric : string + A metric supported by scipy clustering. + + axis : int + The axis to cluster along. + """ + values = self.values + + if len(values.shape) != 2: + raise DimensionError("The hclust order only supports 2D arrays right now!") + + if axis == 1: + values = values.T + + # compute a hierarchical clustering and return the optimal leaf ordering + D = scipy.spatial.distance.pdist(values, metric) + cluster_matrix = scipy.cluster.hierarchy.complete(D) + inds = scipy.cluster.hierarchy.leaves_list(scipy.cluster.hierarchy.optimal_leaf_ordering(cluster_matrix, D)) + return inds + + def sample(self, max_samples, replace=False, random_state=0): + """ Randomly samples the instances (rows) of the Explanation object. + + Parameters + ---------- + max_samples : int + The number of rows to sample. Note that if replace=False then less than + fewer than max_samples will be drawn if explanation.shape[0] < max_samples. + + replace : bool + Sample with or without replacement. + """ + prev_seed = np.random.seed(random_state) + inds = np.random.choice(self.shape[0], min(max_samples, self.shape[0]), replace=replace) + np.random.seed(prev_seed) + return self[list(inds)] + + def _flatten_feature_names(self): + new_values = {} + for i in range(len(self.values)): + for s,v in zip(self.feature_names[i], self.values[i]): + if s not in new_values: + new_values[s] = [] + new_values[s].append(v) + return new_values + + def _use_data_as_feature_names(self): + new_values = {} + for i in range(len(self.values)): + for s,v in zip(self.data[i], self.values[i]): + if s not in new_values: + new_values[s] = [] + new_values[s].append(v) + return new_values + + def percentile(self, q, axis=None): + new_self = copy.deepcopy(self) + if self.feature_names is not None and not is_1d(self.feature_names) and axis == 0: + new_values = self._flatten_feature_names() + new_self.feature_names = np.array(list(new_values.keys())) + new_self.values = np.array([np.percentile(v, q) for v in new_values.values()]) + new_self.clustering = None + else: + new_self.values = np.percentile(new_self.values, q, axis) + new_self.data = np.percentile(new_self.data, q, axis) + #new_self.data = None + new_self.op_history.append({ + "name": "percentile", + "args": (axis,), + "prev_shape": self.shape, + "collapsed_instances": axis == 0 + }) + return new_self + +def group_features(shap_values, feature_map): + # TODOsomeday: support and deal with clusterings + reverse_map = {} + for name in feature_map: + reverse_map[feature_map[name]] = reverse_map.get(feature_map[name], []) + [name] + + curr_names = shap_values.feature_names + sv_new = copy.deepcopy(shap_values) + found = {} + i = 0 + rank1 = len(shap_values.shape) == 1 + for name in curr_names: + new_name = feature_map.get(name, name) + if new_name in found: + continue + found[new_name] = True + + new_name = feature_map.get(name, name) + cols_to_sum = reverse_map.get(new_name, [new_name]) + old_inds = [curr_names.index(v) for v in cols_to_sum] + + if rank1: + sv_new.values[i] = shap_values.values[old_inds].sum() + sv_new.data[i] = shap_values.data[old_inds].sum() + else: + sv_new.values[:,i] = shap_values.values[:,old_inds].sum(1) + sv_new.data[:,i] = shap_values.data[:,old_inds].sum(1) + sv_new.feature_names[i] = new_name + i += 1 + + return Explanation( + sv_new.values[:i] if rank1 else sv_new.values[:,:i], + base_values = sv_new.base_values, + data = sv_new.data[:i] if rank1 else sv_new.data[:,:i], + display_data = None if sv_new.display_data is None else (sv_new.display_data[:,:i] if rank1 else sv_new.display_data[:,:i]), + instance_names = None, + feature_names = None if sv_new.feature_names is None else sv_new.feature_names[:i], + output_names = None, + output_indexes = None, + lower_bounds = None, + upper_bounds = None, + error_std = None, + main_effects = None, + hierarchical_values = None, + clustering = None + ) + +def compute_output_dims(values, base_values, data, output_names): + """ Uses the passed data to infer which dimensions correspond to the model's output. + """ + values_shape = _compute_shape(values) + + # input shape matches the data shape + if data is not None: + data_shape = _compute_shape(data) + + # if we are not given any data we assume it would be the same shape as the given values + else: + data_shape = values_shape + + # output shape is known from the base values or output names + if output_names is not None: + output_shape = _compute_shape(output_names) + + # if our output_names are per sample then we need to drop the sample dimension here + if values_shape[-len(output_shape):] != output_shape and \ + values_shape[-len(output_shape)+1:] == output_shape[1:] and values_shape[0] == output_shape[0]: + output_shape = output_shape[1:] + + elif base_values is not None: + output_shape = _compute_shape(base_values)[1:] + else: + output_shape = tuple() + + interaction_order = len(values_shape) - len(data_shape) - len(output_shape) + output_dims = range(len(data_shape) + interaction_order, len(values_shape)) + return tuple(output_dims) + +def is_1d(val): + return not (isinstance(val[0], list) or isinstance(val[0], np.ndarray)) + +class Op: + pass + +class Percentile(Op): + def __init__(self, percentile): + self.percentile = percentile + + def add_repr(self, s, verbose=False): + return "percentile("+s+", "+str(self.percentile)+")" + +def _first_item(x): + for item in x: + return item + return None + +def _compute_shape(x): + if not hasattr(x, "__len__") or isinstance(x, str): + return tuple() + elif not scipy.sparse.issparse(x) and len(x) > 0 and isinstance(_first_item(x), str): + return (None,) + else: + if isinstance(x, dict): + return (len(x),) + _compute_shape(x[next(iter(x))]) + + # 2D arrays we just take their shape as-is + if len(getattr(x, "shape", tuple())) > 1: + return x.shape + + # 1D arrays we need to look inside + if len(x) == 0: + return (0,) + elif len(x) == 1: + return (1,) + _compute_shape(_first_item(x)) + else: + first_shape = _compute_shape(_first_item(x)) + if first_shape == tuple(): + return (len(x),) + else: # we have an array of arrays... + matches = np.ones(len(first_shape), dtype=bool) + for i in range(1, len(x)): + shape = _compute_shape(x[i]) + assert len(shape) == len(first_shape), "Arrays in Explanation objects must have consistent inner dimensions!" + for j in range(0, len(shape)): + matches[j] &= shape[j] == first_shape[j] + return (len(x),) + tuple(first_shape[j] if match else None for j, match in enumerate(matches)) + +class Cohorts: + def __init__(self, **kwargs): + self.cohorts = kwargs + for k in self.cohorts: + assert isinstance(self.cohorts[k], Explanation), "All the arguments to a Cohorts set must be Explanation objects!" + + def __getitem__(self, item): + new_cohorts = Cohorts() + for k in self.cohorts: + new_cohorts.cohorts[k] = self.cohorts[k].__getitem__(item) + return new_cohorts + + def __getattr__(self, name): + new_cohorts = Cohorts() + for k in self.cohorts: + new_cohorts.cohorts[k] = getattr(self.cohorts[k], name) + return new_cohorts + + def __call__(self, *args, **kwargs): + new_cohorts = Cohorts() + for k in self.cohorts: + new_cohorts.cohorts[k] = self.cohorts[k].__call__(*args, **kwargs) + return new_cohorts + + def __repr__(self): + return f"" + + +def _auto_cohorts(shap_values, max_cohorts): + """ This uses a DecisionTreeRegressor to build a group of cohorts with similar SHAP values. + """ + + # fit a decision tree that well separates the SHAP values + m = sklearn.tree.DecisionTreeRegressor(max_leaf_nodes=max_cohorts) + m.fit(shap_values.data, shap_values.values) + + # group instances by their decision paths + paths = m.decision_path(shap_values.data).toarray() + path_names = [] + + # mark each instance with a path name + for i in range(shap_values.shape[0]): + name = "" + for j in range(len(paths[i])): + if paths[i,j] > 0: + feature = m.tree_.feature[j] + threshold = m.tree_.threshold[j] + val = shap_values.data[i,feature] + if feature >= 0: + name += str(shap_values.feature_names[feature]) + if val < threshold: + name += " < " + else: + name += " >= " + name += str(threshold) + " & " + path_names.append(name[:-3]) # the -3 strips off the last unneeded ' & ' + path_names = np.array(path_names) + + # split the instances into cohorts by their path names + cohorts = {} + for name in np.unique(path_names): + cohorts[name] = shap_values[path_names == name] + + return Cohorts(**cohorts) + +def list_wrap(x): + """ A helper to patch things since slicer doesn't handle arrays of arrays (it does handle lists of arrays) + """ + if isinstance(x, np.ndarray) and len(x.shape) == 1 and isinstance(x[0], np.ndarray): + return [v for v in x] + else: + return x diff --git a/lib/shap/_serializable.py b/lib/shap/_serializable.py new file mode 100644 index 0000000000000000000000000000000000000000..d9dbde3881d5ad3addebad7f721889192acfa97d --- /dev/null +++ b/lib/shap/_serializable.py @@ -0,0 +1,204 @@ + +import inspect +import logging +import pickle + +import cloudpickle +import numpy as np + +log = logging.getLogger('shap') + +class Serializable: + """ This is the superclass of all serializable objects. + """ + + def save(self, out_file): + """ Save the model to the given file stream. + """ + pickle.dump(type(self), out_file) + + @classmethod + def load(cls, in_file, instantiate=True): + """ This is meant to be overridden by subclasses and called with super. + + We return constructor argument values when not being instantiated. Since there are no + constructor arguments for the Serializable class we just return an empty dictionary. + """ + if instantiate: + return cls._instantiated_load(in_file) + return {} + + @classmethod + def _instantiated_load(cls, in_file, **kwargs): + """ This is meant to be overridden by subclasses and called with super. + + We return constructor argument values (we have no values to load in this abstract class). + """ + obj_type = pickle.load(in_file) + if obj_type is None: + return None + + if not inspect.isclass(obj_type) or (not issubclass(obj_type, cls) and (obj_type is not cls)): + raise Exception(f"Invalid object type loaded from file. {obj_type} is not a subclass of {cls}.") + + # here we call the constructor with all the arguments we have loaded + constructor_args = obj_type.load(in_file, instantiate=False, **kwargs) + used_args = inspect.getfullargspec(obj_type.__init__)[0] + return obj_type(**{k: constructor_args[k] for k in constructor_args if k in used_args}) + + +class Serializer: + """ Save data items to an input stream. + """ + def __init__(self, out_stream, block_name, version): + self.out_stream = out_stream + self.block_name = block_name + self.block_version = version + self.serializer_version = 0 # update this when the serializer changes + + def __enter__(self): + log.debug("serializer_version = %d", self.serializer_version) + pickle.dump(self.serializer_version, self.out_stream) + log.debug("block_name = %s", self.block_name) + pickle.dump(self.block_name, self.out_stream) + log.debug("block_version = %d", self.block_version) + pickle.dump(self.block_version, self.out_stream) + return self + + def __exit__(self, exception_type, exception_value, traceback): + log.debug("END_BLOCK___") + pickle.dump("END_BLOCK___", self.out_stream) + + def save(self, name, value, encoder="auto"): + """ Dump a data item to the current input stream. + """ + log.debug("name = %s", name) + pickle.dump(name, self.out_stream) + if encoder is None or encoder is False: + log.debug("encoder_name = %s", "no_encoder") + pickle.dump("no_encoder", self.out_stream) + elif callable(encoder): + log.debug("encoder_name = %s", "custom_encoder") + pickle.dump("custom_encoder", self.out_stream) + encoder(value, self.out_stream) + elif encoder == ".save" or (isinstance(value, Serializable) and encoder == "auto"): + log.debug("encoder_name = %s", "serializable.save") + pickle.dump("serializable.save", self.out_stream) + if len(inspect.getfullargspec(value.save)[0]) == 3: # backward compat for MLflow, can remove 4/1/2021 + value.save(self.out_stream, value) + else: + value.save(self.out_stream) + elif encoder == "auto": + if isinstance(value, (int, float, str)): + log.debug("encoder_name = %s", "pickle.dump") + pickle.dump("pickle.dump", self.out_stream) + pickle.dump(value, self.out_stream) + else: + log.debug("encoder_name = %s", "cloudpickle.dump") + pickle.dump("cloudpickle.dump", self.out_stream) + cloudpickle.dump(value, self.out_stream) + else: + raise ValueError(f"Unknown encoder type '{encoder}' given for serialization!") + log.debug("value = %s", str(value)) + +class Deserializer: + """ Load data items from an input stream. + """ + + def __init__(self, in_stream, block_name, min_version, max_version): + self.in_stream = in_stream + self.block_name = block_name + self.block_min_version = min_version + self.block_max_version = max_version + + # update these when the serializer changes + self.serializer_min_version = 0 + self.serializer_max_version = 0 + + def __enter__(self): + + # confirm the serializer version + serializer_version = pickle.load(self.in_stream) + log.debug("serializer_version = %d", serializer_version) + if serializer_version < self.serializer_min_version: + raise ValueError( + f"The file being loaded was saved with a serializer version of {serializer_version}, " + \ + f"but the current deserializer in SHAP requires at least version {self.serializer_min_version}." + ) + if serializer_version > self.serializer_max_version: + raise ValueError( + f"The file being loaded was saved with a serializer version of {serializer_version}, " + \ + f"but the current deserializer in SHAP only support up to version {self.serializer_max_version}." + ) + + # confirm the block name + block_name = pickle.load(self.in_stream) + log.debug("block_name = %s", block_name) + if block_name != self.block_name: + raise ValueError( + f"The next data block in the file being loaded was supposed to be {self.block_name}, " + \ + f"but the next block found was {block_name}." + ) + + # confirm the block version + block_version = pickle.load(self.in_stream) + log.debug("block_version = %d", block_version) + if block_version < self.block_min_version: + raise ValueError( + f"The file being loaded was saved with a block version of {block_version}, " + \ + f"but the current deserializer in SHAP requires at least version {self.block_min_version}." + ) + if block_version > self.block_max_version: + raise ValueError( + f"The file being loaded was saved with a block version of {block_version}, " + \ + f"but the current deserializer in SHAP only support up to version {self.block_max_version}." + ) + return self + + def __exit__(self, exception_type, exception_value, traceback): + # confirm the block end token + for _ in range(100): + end_token = pickle.load(self.in_stream) + log.debug("end_token = %s", end_token) + if end_token == "END_BLOCK___": + return + self._load_data_value() + raise ValueError( + f"The data block end token wsa not found for the block {self.block_name}." + ) + + def load(self, name, decoder=None): + """ Load a data item from the current input stream. + """ + # confirm the block name + loaded_name = pickle.load(self.in_stream) + log.debug("loaded_name = %s", loaded_name) + print("loaded_name", loaded_name) + if loaded_name != name: + raise ValueError( + f"The next data item in the file being loaded was supposed to be {name}, " + \ + f"but the next block found was {loaded_name}." + ) # We should eventually add support for skipping over unused data items in old formats... + + value = self._load_data_value(decoder) + log.debug("value = %s", str(value)) + return value + + def _load_data_value(self, decoder=None): + encoder_name = pickle.load(self.in_stream) + log.debug("encoder_name = %s", encoder_name) + if encoder_name == "custom_encoder" or callable(decoder): + assert callable(decoder), "You must provide a callable custom decoder for the data item {name}!" + return decoder(self.in_stream) + if encoder_name == "no_encoder": + return None + if encoder_name == "serializable.save": + return Serializable.load(self.in_stream) + if encoder_name == "numpy.save": + return np.load(self.in_stream) + if encoder_name == "pickle.dump": + return pickle.load(self.in_stream) + if encoder_name == "cloudpickle.dump": + return cloudpickle.load(self.in_stream) + + raise ValueError(f"Unsupported encoder type found: {encoder_name}") diff --git a/lib/shap/_version.py b/lib/shap/_version.py new file mode 100644 index 0000000000000000000000000000000000000000..c1324f87db53df2a3c353bed6a1f3de7f411afff --- /dev/null +++ b/lib/shap/_version.py @@ -0,0 +1,16 @@ +# file generated by setuptools_scm +# don't change, don't track in version control +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple, Union + VERSION_TUPLE = Tuple[Union[int, str], ...] +else: + VERSION_TUPLE = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE + +__version__ = version = '0.44.1' +__version_tuple__ = version_tuple = (0, 44, 1) diff --git a/lib/shap/actions/__init__.py b/lib/shap/actions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b449ca842bab7e94add7dd4d2a3d386644293b --- /dev/null +++ b/lib/shap/actions/__init__.py @@ -0,0 +1,3 @@ +from ._action import Action + +__all__ = ["Action"] diff --git a/lib/shap/actions/_action.py b/lib/shap/actions/_action.py new file mode 100644 index 0000000000000000000000000000000000000000..6339e0c105e607226f62a64748ef2c8c124b77bc --- /dev/null +++ b/lib/shap/actions/_action.py @@ -0,0 +1,8 @@ +class Action: + """ Abstract action class. + """ + def __lt__(self, other_action): + return self.cost < other_action.cost + + def __repr__(self): + return f"" diff --git a/lib/shap/actions/_optimizer.py b/lib/shap/actions/_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..752af3ff286b6717cf5c1d673cec3ea53583281b --- /dev/null +++ b/lib/shap/actions/_optimizer.py @@ -0,0 +1,92 @@ +import copy +import queue +import warnings + +from ..utils._exceptions import ConvergenceError, InvalidAction +from ._action import Action + + +class ActionOptimizer: + def __init__(self, model, actions): + self.model = model + warnings.warn( + "Note that ActionOptimizer is still in an alpha state and is subjust to API changes." + ) + # actions go into mutually exclusive groups + self.action_groups = [] + for group in actions: + + if issubclass(type(group), Action): + group._group_index = len(self.action_groups) + group._grouped_index = 0 + self.action_groups.append([copy.copy(group)]) + elif issubclass(type(group), list): + group = sorted([copy.copy(v) for v in group], key=lambda a: a.cost) + for i, v in enumerate(group): + v._group_index = len(self.action_groups) + v._grouped_index = i + self.action_groups.append(group) + else: + raise InvalidAction( + "A passed action was not an Action or list of actions!" + ) + + def __call__(self, *args, max_evals=10000): + + # init our queue with all the least costly actions + q = queue.PriorityQueue() + for i in range(len(self.action_groups)): + group = self.action_groups[i] + q.put((group[0].cost, [group[0]])) + + nevals = 0 + while not q.empty(): + + # see if we have exceeded our runtime budget + nevals += 1 + if nevals > max_evals: + raise ConvergenceError( + f"Failed to find a solution with max_evals={max_evals}! Try reducing the number of actions or increasing max_evals." + ) + + # get the next cheapest set of actions we can do + cost, actions = q.get() + + # apply those actions + args_tmp = copy.deepcopy(args) + for a in actions: + a(*args_tmp) + + # if the model is now satisfied we are done!! + v = self.model(*args_tmp) + if v: + return actions + + # if not then we add all possible follow-on actions to our queue + else: + for i in range(len(self.action_groups)): + group = self.action_groups[i] + + # look to to see if we already have a action from this group, if so we need to + # move to a more expensive action in the same group + next_ind = 0 + prev_in_group = -1 + for j, a in enumerate(actions): + if a._group_index == i: + next_ind = max(next_ind, a._grouped_index + 1) + prev_in_group = j + + # we are adding a new action type + if prev_in_group == -1: + new_actions = actions + [group[next_ind]] + # we are moving from one action to a more expensive one in the same group + elif next_ind < len(group): + new_actions = copy.copy(actions) + new_actions[prev_in_group] = group[next_ind] + # we don't have a more expensive action left in this group + else: + new_actions = None + + # add the new option to our queue + if new_actions is not None: + q.put((sum([a.cost for a in new_actions]), new_actions)) diff --git a/lib/shap/benchmark/__init__.py b/lib/shap/benchmark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a1669d7d9244e7f91d53c0798a6050f55ff03f4f --- /dev/null +++ b/lib/shap/benchmark/__init__.py @@ -0,0 +1,9 @@ +from ._compute import ComputeTime +from ._explanation_error import ExplanationError +from ._result import BenchmarkResult +from ._sequential import SequentialMasker + +# from . import framework +# from .. import datasets + +__all__ = ["ComputeTime", "ExplanationError", "BenchmarkResult", "SequentialMasker"] diff --git a/lib/shap/benchmark/_compute.py b/lib/shap/benchmark/_compute.py new file mode 100644 index 0000000000000000000000000000000000000000..ab46ca5d195dc833f040d84dae3f1b5daca36ee8 --- /dev/null +++ b/lib/shap/benchmark/_compute.py @@ -0,0 +1,9 @@ +from ._result import BenchmarkResult + + +class ComputeTime: + """ Extracts a runtime benchmark result from the passed Explanation. + """ + + def __call__(self, explanation, name): + return BenchmarkResult("compute time", name, value=explanation.compute_time / explanation.shape[0]) diff --git a/lib/shap/benchmark/_explanation_error.py b/lib/shap/benchmark/_explanation_error.py new file mode 100644 index 0000000000000000000000000000000000000000..d325adcfe5abaea6bd0efa72f8dab23ebab8b157 --- /dev/null +++ b/lib/shap/benchmark/_explanation_error.py @@ -0,0 +1,181 @@ +import time + +import numpy as np +from tqdm.auto import tqdm + +from shap import Explanation, links +from shap.maskers import FixedComposite, Image, Text +from shap.utils import MaskedModel, partition_tree_shuffle +from shap.utils._exceptions import DimensionError + +from ._result import BenchmarkResult + + +class ExplanationError: + """ A measure of the explanation error relative to a model's actual output. + + This benchmark metric measures the discrepancy between the output of the model predicted by an + attribution explanation vs. the actual output of the model. This discrepancy is measured over + many masking patterns drawn from permutations of the input features. + + For explanations (like Shapley values) that explain the difference between one alternative and another + (for example a current sample and typical background feature values) there is possible explanation error + for every pattern of mixing foreground and background, or other words every possible masking pattern. + In this class we compute the standard deviation over these explanation errors where masking patterns + are drawn from prefixes of random feature permutations. This seems natural, and aligns with Shapley value + computations, but of course you could choose to summarize explanation errors in others ways as well. + """ + + def __init__(self, masker, model, *model_args, batch_size=500, num_permutations=10, link=links.identity, linearize_link=True, seed=38923): + """ Build a new explanation error benchmarker with the given masker, model, and model args. + + Parameters + ---------- + masker : function or shap.Masker + The masker defines how we hide features during the perturbation process. + + model : function or shap.Model + The model we want to evaluate explanations against. + + model_args : ... + The list of arguments we will give to the model that we will have explained. When we later call this benchmark + object we should pass explanations that have been computed on this same data. + + batch_size : int + The maximum batch size we should use when calling the model. For some large NLP models this needs to be set + lower (at say 1) to avoid running out of GPU memory. + + num_permutations : int + How many permutations we will use to estimate the average explanation error for each sample. If you are running + this benchmark on a large dataset with many samples then you can reduce this value since the final result is + averaged over samples as well and the averages of both directly combine to reduce variance. So for 10k samples + num_permutations=1 is appropreiate. + + link : function + Allows for a non-linear link function to be used to bringe between the model output space and the explanation + space. + + linearize_link : bool + Non-linear links can destroy additive separation in generalized linear models, so by linearizing the link we can + retain additive separation. See upcoming paper/doc for details. + """ + + self.masker = masker + self.model = model + self.model_args = model_args + self.num_permutations = num_permutations + self.link = link + self.linearize_link = linearize_link + self.model_args = model_args + self.batch_size = batch_size + self.seed = seed + + # user must give valid masker + underlying_masker = masker.masker if isinstance(masker, FixedComposite) else masker + if isinstance(underlying_masker, Text): + self.data_type = "text" + elif isinstance(underlying_masker, Image): + self.data_type = "image" + else: + self.data_type = "tabular" + + def __call__(self, explanation, name, step_fraction=0.01, indices=[], silent=False): + """ Run this benchmark on the given explanation. + """ + + if isinstance(explanation, np.ndarray): + attributions = explanation + elif isinstance(explanation, Explanation): + attributions = explanation.values + else: + raise ValueError("The passed explanation must be either of type numpy.ndarray or shap.Explanation!") + + if len(attributions) != len(self.model_args[0]): + emsg = ( + "The explanation passed must have the same number of rows as " + "the self.model_args that were passed!" + ) + raise DimensionError(emsg) + + # it is important that we choose the same permutations for the different explanations we are comparing + # so as to avoid needless noise + old_seed = np.random.seed() + np.random.seed(self.seed) + + pbar = None + start_time = time.time() + svals = [] + mask_vals = [] + + for i, args in enumerate(zip(*self.model_args)): + + if len(args[0].shape) != len(attributions[i].shape): + raise ValueError("The passed explanation must have the same dim as the model_args and must not have a vector output!") + + feature_size = np.prod(attributions[i].shape) + sample_attributions = attributions[i].flatten() + + # compute any custom clustering for this row + row_clustering = None + if getattr(self.masker, "clustering", None) is not None: + if isinstance(self.masker.clustering, np.ndarray): + row_clustering = self.masker.clustering + elif callable(self.masker.clustering): + row_clustering = self.masker.clustering(*args) + else: + raise NotImplementedError("The masker passed has a .clustering attribute that is not yet supported by the ExplanationError benchmark!") + + masked_model = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *args) + + total_values = None + for _ in range(self.num_permutations): + masks = [] + mask = np.zeros(feature_size, dtype=bool) + masks.append(mask.copy()) + ordered_inds = np.arange(feature_size) + + # shuffle the indexes so we get a random permutation ordering + if row_clustering is not None: + inds_mask = np.ones(feature_size, dtype=bool) + partition_tree_shuffle(ordered_inds, inds_mask, row_clustering) + else: + np.random.shuffle(ordered_inds) + + increment = max(1, int(feature_size * step_fraction)) + for j in range(0, feature_size, increment): + mask[ordered_inds[np.arange(j, min(feature_size, j+increment))]] = True + masks.append(mask.copy()) + mask_vals.append(masks) + + values = [] + masks_arr = np.array(masks) + for j in range(0, len(masks_arr), self.batch_size): + values.append(masked_model(masks_arr[j:j + self.batch_size])) + values = np.concatenate(values) + base_value = values[0] + for j, v in enumerate(values): + values[j] = (v - (base_value + np.sum(sample_attributions[masks_arr[j]])))**2 + + if total_values is None: + total_values = values + else: + total_values += values + total_values /= self.num_permutations + + svals.append(total_values) + + if pbar is None and time.time() - start_time > 5: + pbar = tqdm(total=len(self.model_args[0]), disable=silent, leave=False, desc=f"ExplanationError for {name}") + pbar.update(i+1) + if pbar is not None: + pbar.update(1) + + if pbar is not None: + pbar.close() + + svals = np.array(svals) + + # reset the random seed so we don't mess up the caller + np.random.seed(old_seed) + + return BenchmarkResult("explanation error", name, value=np.sqrt(np.sum(total_values)/len(total_values))) diff --git a/lib/shap/benchmark/_result.py b/lib/shap/benchmark/_result.py new file mode 100644 index 0000000000000000000000000000000000000000..0c31e0f0fbd2fca82a56486ee707cc222daf487f --- /dev/null +++ b/lib/shap/benchmark/_result.py @@ -0,0 +1,34 @@ +import numpy as np +import sklearn + +sign_defaults = { + "keep positive": 1, + "keep negative": -1, + "remove positive": -1, + "remove negative": 1, + "compute time": -1, + "keep absolute": -1, # the absolute signs are defaults that make sense when scoring losses + "remove absolute": 1, + "explanation error": -1 +} + +class BenchmarkResult: + """ The result of a benchmark run. + """ + + def __init__(self, metric, method, value=None, curve_x=None, curve_y=None, curve_y_std=None, value_sign=None): + self.metric = metric + self.method = method + self.value = value + self.curve_x = curve_x + self.curve_y = curve_y + self.curve_y_std = curve_y_std + self.value_sign = value_sign + if self.value_sign is None and self.metric in sign_defaults: + self.value_sign = sign_defaults[self.metric] + if self.value is None: + self.value = sklearn.metrics.auc(curve_x, (np.array(curve_y) - curve_y[0])) + + @property + def full_name(self): + return self.method + " " + self.metric diff --git a/lib/shap/benchmark/_sequential.py b/lib/shap/benchmark/_sequential.py new file mode 100644 index 0000000000000000000000000000000000000000..4c5eca2233518fceafc500575f5dbf74bc320304 --- /dev/null +++ b/lib/shap/benchmark/_sequential.py @@ -0,0 +1,332 @@ +import time + +import matplotlib.pyplot as pl +import numpy as np +import pandas as pd +import sklearn +from tqdm.auto import tqdm + +from shap import Explanation, links +from shap.maskers import FixedComposite, Image, Text +from shap.utils import MaskedModel + +from ._result import BenchmarkResult + + +class SequentialMasker: + def __init__(self, mask_type, sort_order, masker, model, *model_args, batch_size=500): + + for arg in model_args: + if isinstance(arg, pd.DataFrame): + raise TypeError("DataFrame arguments dont iterate correctly, pass numpy arrays instead!") + + # convert any DataFrames to numpy arrays + # self.model_arg_cols = [] + # self.model_args = [] + # self.has_df = False + # for arg in model_args: + # if isinstance(arg, pd.DataFrame): + # self.model_arg_cols.append(arg.columns) + # self.model_args.append(arg.values) + # self.has_df = True + # else: + # self.model_arg_cols.append(None) + # self.model_args.append(arg) + + # if self.has_df: + # given_model = model + # def new_model(*args): + # df_args = [] + # for i, arg in enumerate(args): + # if self.model_arg_cols[i] is not None: + # df_args.append(pd.DataFrame(arg, columns=self.model_arg_cols[i])) + # else: + # df_args.append(arg) + # return given_model(*df_args) + # model = new_model + + self.inner = SequentialPerturbation( + model, masker, sort_order, mask_type + ) + self.model_args = model_args + self.batch_size = batch_size + + def __call__(self, explanation, name, **kwargs): + return self.inner(name, explanation, *self.model_args, batch_size=self.batch_size, **kwargs) + +class SequentialPerturbation: + def __init__(self, model, masker, sort_order, perturbation, linearize_link=False): + # self.f = lambda masked, x, index: model.predict(masked) + self.model = model if callable(model) else model.predict + self.masker = masker + self.sort_order = sort_order + self.perturbation = perturbation + self.linearize_link = linearize_link + + # define our sort order + if self.sort_order == "positive": + self.sort_order_map = lambda x: np.argsort(-x) + elif self.sort_order == "negative": + self.sort_order_map = lambda x: np.argsort(x) + elif self.sort_order == "absolute": + self.sort_order_map = lambda x: np.argsort(-abs(x)) + else: + raise ValueError("sort_order must be either \"positive\", \"negative\", or \"absolute\"!") + + # user must give valid masker + underlying_masker = masker.masker if isinstance(masker, FixedComposite) else masker + if isinstance(underlying_masker, Text): + self.data_type = "text" + elif isinstance(underlying_masker, Image): + self.data_type = "image" + else: + self.data_type = "tabular" + #raise ValueError("masker must be for \"tabular\", \"text\", or \"image\"!") + + self.score_values = [] + self.score_aucs = [] + self.labels = [] + + def __call__(self, name, explanation, *model_args, percent=0.01, indices=[], y=None, label=None, silent=False, debug_mode=False, batch_size=10): + # if explainer is already the attributions + if isinstance(explanation, np.ndarray): + attributions = explanation + elif isinstance(explanation, Explanation): + attributions = explanation.values + else: + raise ValueError("The passed explanation must be either of type numpy.ndarray or shap.Explanation!") + + assert len(attributions) == len(model_args[0]), "The explanation passed must have the same number of rows as the model_args that were passed!" + + if label is None: + label = "Score %d" % len(self.score_values) + + # convert dataframes + # if isinstance(X, (pd.Series, pd.DataFrame)): + # X = X.values + + # convert all single-sample vectors to matrices + # if not hasattr(attributions[0], "__len__"): + # attributions = np.array([attributions]) + # if not hasattr(X[0], "__len__") and self.data_type == "tabular": + # X = np.array([X]) + + pbar = None + start_time = time.time() + svals = [] + mask_vals = [] + + for i, args in enumerate(zip(*model_args)): + # if self.data_type == "image": + # x_shape, y_shape = attributions[i].shape[0], attributions[i].shape[1] + # feature_size = np.prod([x_shape, y_shape]) + # sample_attributions = attributions[i].mean(2).reshape(feature_size, -1) + # data = X[i].flatten() + # mask_shape = X[i].shape + # else: + feature_size = np.prod(attributions[i].shape) + sample_attributions = attributions[i].flatten() + # data = X[i] + # mask_shape = feature_size + + self.masked_model = MaskedModel(self.model, self.masker, links.identity, self.linearize_link, *args) + + masks = [] + + mask = np.ones(feature_size, dtype=bool) * (self.perturbation == "remove") + masks.append(mask.copy()) + + ordered_inds = self.sort_order_map(sample_attributions) + increment = max(1,int(feature_size*percent)) + for j in range(0, feature_size, increment): + oind_list = [ordered_inds[t] for t in range(j, min(feature_size, j+increment))] + + for oind in oind_list: + if not ((self.sort_order == "positive" and sample_attributions[oind] <= 0) or \ + (self.sort_order == "negative" and sample_attributions[oind] >= 0)): + mask[oind] = self.perturbation == "keep" + + masks.append(mask.copy()) + + mask_vals.append(masks) + + # mask_size = len(range(0, feature_size, increment)) + 1 + values = [] + masks_arr = np.array(masks) + for j in range(0, len(masks_arr), batch_size): + values.append(self.masked_model(masks_arr[j:j + batch_size])) + values = np.concatenate(values) + + svals.append(values) + + if pbar is None and time.time() - start_time > 5: + pbar = tqdm(total=len(model_args[0]), disable=silent, leave=False, desc="SequentialMasker") + pbar.update(i+1) + if pbar is not None: + pbar.update(1) + + if pbar is not None: + pbar.close() + + self.score_values.append(np.array(svals)) + + # if self.sort_order == "negative": + # curve_sign = -1 + # else: + curve_sign = 1 + + self.labels.append(label) + + xs = np.linspace(0, 1, 100) + curves = np.zeros((len(self.score_values[-1]), len(xs))) + for j in range(len(self.score_values[-1])): + xp = np.linspace(0, 1, len(self.score_values[-1][j])) + yp = self.score_values[-1][j] + curves[j,:] = np.interp(xs, xp, yp) + ys = curves.mean(0) + std = curves.std(0) / np.sqrt(curves.shape[0]) + auc = sklearn.metrics.auc(np.linspace(0, 1, len(ys)), curve_sign*(ys-ys[0])) + + if not debug_mode: + return BenchmarkResult(self.perturbation + " " + self.sort_order, name, curve_x=xs, curve_y=ys, curve_y_std=std) + else: + aucs = [] + for j in range(len(self.score_values[-1])): + curve = curves[j,:] + auc = sklearn.metrics.auc(np.linspace(0, 1, len(curve)), curve_sign*(curve-curve[0])) + aucs.append(auc) + return mask_vals, curves, aucs + + def score(self, explanation, X, percent=0.01, y=None, label=None, silent=False, debug_mode=False): + ''' + Will be deprecated once MaskedModel is in complete support + ''' + # if explainer is already the attributions + if isinstance(explanation, np.ndarray): + attributions = explanation + elif isinstance(explanation, Explanation): + attributions = explanation.values + + if label is None: + label = "Score %d" % len(self.score_values) + + # convert dataframes + if isinstance(X, (pd.Series, pd.DataFrame)): + X = X.values + + # convert all single-sample vectors to matrices + if not hasattr(attributions[0], "__len__"): + attributions = np.array([attributions]) + if not hasattr(X[0], "__len__") and self.data_type == "tabular": + X = np.array([X]) + + pbar = None + start_time = time.time() + svals = [] + mask_vals = [] + + for i in range(len(X)): + if self.data_type == "image": + x_shape, y_shape = attributions[i].shape[0], attributions[i].shape[1] + feature_size = np.prod([x_shape, y_shape]) + sample_attributions = attributions[i].mean(2).reshape(feature_size, -1) + else: + feature_size = attributions[i].shape[0] + sample_attributions = attributions[i] + + if len(attributions[i].shape) == 1 or self.data_type == "tabular": + output_size = 1 + else: + output_size = attributions[i].shape[-1] + + for k in range(output_size): + if self.data_type == "image": + mask_shape = X[i].shape + else: + mask_shape = feature_size + + mask = np.ones(mask_shape, dtype=bool) * (self.perturbation == "remove") + masks = [mask.copy()] + + values = np.zeros(feature_size+1) + # masked, data = self.masker(mask, X[i]) + masked = self.masker(mask, X[i]) + data = None + curr_val = self.f(masked, data, k).mean(0) + + values[0] = curr_val + + if output_size != 1: + test_attributions = sample_attributions[:,k] + else: + test_attributions = sample_attributions + + ordered_inds = self.sort_order_map(test_attributions) + increment = max(1,int(feature_size*percent)) + for j in range(0, feature_size, increment): + oind_list = [ordered_inds[t] for t in range(j, min(feature_size, j+increment))] + + for oind in oind_list: + if not ((self.sort_order == "positive" and test_attributions[oind] <= 0) or \ + (self.sort_order == "negative" and test_attributions[oind] >= 0)): + if self.data_type == "image": + xoind, yoind = oind // attributions[i].shape[1], oind % attributions[i].shape[1] + mask[xoind][yoind] = self.perturbation == "keep" + else: + mask[oind] = self.perturbation == "keep" + + masks.append(mask.copy()) + # masked, data = self.masker(mask, X[i]) + masked = self.masker(mask, X[i]) + curr_val = self.f(masked, data, k).mean(0) + + for t in range(j, min(feature_size, j+increment)): + values[t+1] = curr_val + + svals.append(values) + mask_vals.append(masks) + + if pbar is None and time.time() - start_time > 5: + pbar = tqdm(total=len(X), disable=silent, leave=False) + pbar.update(i+1) + if pbar is not None: + pbar.update(1) + + if pbar is not None: + pbar.close() + + self.score_values.append(np.array(svals)) + + if self.sort_order == "negative": + curve_sign = -1 + else: + curve_sign = 1 + + self.labels.append(label) + + xs = np.linspace(0, 1, 100) + curves = np.zeros((len(self.score_values[-1]), len(xs))) + for j in range(len(self.score_values[-1])): + xp = np.linspace(0, 1, len(self.score_values[-1][j])) + yp = self.score_values[-1][j] + curves[j,:] = np.interp(xs, xp, yp) + ys = curves.mean(0) + + if debug_mode: + aucs = [] + for j in range(len(self.score_values[-1])): + curve = curves[j,:] + auc = sklearn.metrics.auc(np.linspace(0, 1, len(curve)), curve_sign*(curve-curve[0])) + aucs.append(auc) + return mask_vals, curves, aucs + else: + auc = sklearn.metrics.auc(np.linspace(0, 1, len(ys)), curve_sign*(ys-ys[0])) + return xs, ys, auc + + def plot(self, xs, ys, auc): + pl.plot(xs, ys, label="AUC %0.4f" % auc) + pl.legend() + xlabel = "Percent Unmasked" if self.perturbation == "keep" else "Percent Masked" + pl.xlabel(xlabel) + pl.ylabel("Model Output") + pl.show() diff --git a/lib/shap/benchmark/experiments.py b/lib/shap/benchmark/experiments.py new file mode 100644 index 0000000000000000000000000000000000000000..42d2527673596208246ca19420c9f0cf79c3b04f --- /dev/null +++ b/lib/shap/benchmark/experiments.py @@ -0,0 +1,414 @@ +import copy +import itertools +import os +import pickle +import random +import subprocess +import sys +import time +from multiprocessing import Pool + +from .. import __version__, datasets +from . import metrics, models + +try: + from queue import Queue +except ImportError: + from Queue import Queue +from threading import Lock, Thread + +regression_metrics = [ + "local_accuracy", + "consistency_guarantees", + "keep_positive_mask", + "keep_positive_resample", + #"keep_positive_impute", + "keep_negative_mask", + "keep_negative_resample", + #"keep_negative_impute", + "keep_absolute_mask__r2", + "keep_absolute_resample__r2", + #"keep_absolute_impute__r2", + "remove_positive_mask", + "remove_positive_resample", + #"remove_positive_impute", + "remove_negative_mask", + "remove_negative_resample", + #"remove_negative_impute", + "remove_absolute_mask__r2", + "remove_absolute_resample__r2", + #"remove_absolute_impute__r2" + "runtime", +] + +binary_classification_metrics = [ + "local_accuracy", + "consistency_guarantees", + "keep_positive_mask", + "keep_positive_resample", + #"keep_positive_impute", + "keep_negative_mask", + "keep_negative_resample", + #"keep_negative_impute", + "keep_absolute_mask__roc_auc", + "keep_absolute_resample__roc_auc", + #"keep_absolute_impute__roc_auc", + "remove_positive_mask", + "remove_positive_resample", + #"remove_positive_impute", + "remove_negative_mask", + "remove_negative_resample", + #"remove_negative_impute", + "remove_absolute_mask__roc_auc", + "remove_absolute_resample__roc_auc", + #"remove_absolute_impute__roc_auc" + "runtime", +] + +human_metrics = [ + "human_and_00", + "human_and_01", + "human_and_11", + "human_or_00", + "human_or_01", + "human_or_11", + "human_xor_00", + "human_xor_01", + "human_xor_11", + "human_sum_00", + "human_sum_01", + "human_sum_11" +] + +linear_regress_methods = [ + "linear_shap_corr", + "linear_shap_ind", + "coef", + "random", + "kernel_shap_1000_meanref", + #"kernel_shap_100_meanref", + #"sampling_shap_10000", + "sampling_shap_1000", + "lime_tabular_regression_1000" + #"sampling_shap_100" +] + +linear_classify_methods = [ + # NEED LIME + "linear_shap_corr", + "linear_shap_ind", + "coef", + "random", + "kernel_shap_1000_meanref", + #"kernel_shap_100_meanref", + #"sampling_shap_10000", + "sampling_shap_1000", + #"lime_tabular_regression_1000" + #"sampling_shap_100" +] + +tree_regress_methods = [ + # NEED tree_shap_ind + # NEED split_count? + "tree_shap_tree_path_dependent", + "tree_shap_independent_200", + "saabas", + "random", + "tree_gain", + "kernel_shap_1000_meanref", + "mean_abs_tree_shap", + #"kernel_shap_100_meanref", + #"sampling_shap_10000", + "sampling_shap_1000", + "lime_tabular_regression_1000", + "maple" + #"sampling_shap_100" +] + +rf_regress_methods = [ # methods that only support random forest models + "tree_maple" +] + +tree_classify_methods = [ + # NEED tree_shap_ind + # NEED split_count? + "tree_shap_tree_path_dependent", + "tree_shap_independent_200", + "saabas", + "random", + "tree_gain", + "kernel_shap_1000_meanref", + "mean_abs_tree_shap", + #"kernel_shap_100_meanref", + #"sampling_shap_10000", + "sampling_shap_1000", + "lime_tabular_classification_1000", + "maple" + #"sampling_shap_100" +] + +deep_regress_methods = [ + "deep_shap", + "expected_gradients", + "random", + "kernel_shap_1000_meanref", + "sampling_shap_1000", + #"lime_tabular_regression_1000" +] + +deep_classify_methods = [ + "deep_shap", + "expected_gradients", + "random", + "kernel_shap_1000_meanref", + "sampling_shap_1000", + #"lime_tabular_regression_1000" +] + +_experiments = [] +_experiments += [["corrgroups60", "lasso", m, s] for s in regression_metrics for m in linear_regress_methods] +_experiments += [["corrgroups60", "ridge", m, s] for s in regression_metrics for m in linear_regress_methods] +_experiments += [["corrgroups60", "decision_tree", m, s] for s in regression_metrics for m in tree_regress_methods] +_experiments += [["corrgroups60", "random_forest", m, s] for s in regression_metrics for m in (tree_regress_methods + rf_regress_methods)] +_experiments += [["corrgroups60", "gbm", m, s] for s in regression_metrics for m in tree_regress_methods] +_experiments += [["corrgroups60", "ffnn", m, s] for s in regression_metrics for m in deep_regress_methods] + +_experiments += [["independentlinear60", "lasso", m, s] for s in regression_metrics for m in linear_regress_methods] +_experiments += [["independentlinear60", "ridge", m, s] for s in regression_metrics for m in linear_regress_methods] +_experiments += [["independentlinear60", "decision_tree", m, s] for s in regression_metrics for m in tree_regress_methods] +_experiments += [["independentlinear60", "random_forest", m, s] for s in regression_metrics for m in (tree_regress_methods + rf_regress_methods)] +_experiments += [["independentlinear60", "gbm", m, s] for s in regression_metrics for m in tree_regress_methods] +_experiments += [["independentlinear60", "ffnn", m, s] for s in regression_metrics for m in deep_regress_methods] + +_experiments += [["cric", "lasso", m, s] for s in binary_classification_metrics for m in linear_classify_methods] +_experiments += [["cric", "ridge", m, s] for s in binary_classification_metrics for m in linear_classify_methods] +_experiments += [["cric", "decision_tree", m, s] for s in binary_classification_metrics for m in tree_classify_methods] +_experiments += [["cric", "random_forest", m, s] for s in binary_classification_metrics for m in tree_classify_methods] +_experiments += [["cric", "gbm", m, s] for s in binary_classification_metrics for m in tree_classify_methods] +_experiments += [["cric", "ffnn", m, s] for s in binary_classification_metrics for m in deep_classify_methods] + +_experiments += [["human", "decision_tree", m, s] for s in human_metrics for m in tree_regress_methods] + + +def experiments(dataset=None, model=None, method=None, metric=None): + for experiment in _experiments: + if dataset is not None and dataset != experiment[0]: + continue + if model is not None and model != experiment[1]: + continue + if method is not None and method != experiment[2]: + continue + if metric is not None and metric != experiment[3]: + continue + yield experiment + +def run_experiment(experiment, use_cache=True, cache_dir="/tmp"): + dataset_name, model_name, method_name, metric_name = experiment + + # see if we have a cached version + cache_id = __gen_cache_id(experiment) + cache_file = os.path.join(cache_dir, cache_id + ".pickle") + if use_cache and os.path.isfile(cache_file): + with open(cache_file, "rb") as f: + #print(cache_id.replace("__", " ") + " ...loaded from cache.") + return pickle.load(f) + + # compute the scores + print(cache_id.replace("__", " ", 4) + " ...") + sys.stdout.flush() + start = time.time() + X,y = getattr(datasets, dataset_name)() + score = getattr(metrics, metric_name)( + X, y, + getattr(models, dataset_name+"__"+model_name), + method_name + ) + print("...took %f seconds.\n" % (time.time() - start)) + + # cache the scores + with open(cache_file, "wb") as f: + pickle.dump(score, f) + + return score + + +def run_experiments_helper(args): + experiment, cache_dir = args + return run_experiment(experiment, cache_dir=cache_dir) + +def run_experiments(dataset=None, model=None, method=None, metric=None, cache_dir="/tmp", nworkers=1): + experiments_arr = list(experiments(dataset=dataset, model=model, method=method, metric=metric)) + if nworkers == 1: + out = list(map(run_experiments_helper, zip(experiments_arr, itertools.repeat(cache_dir)))) + else: + with Pool(nworkers) as pool: + out = pool.map(run_experiments_helper, zip(experiments_arr, itertools.repeat(cache_dir))) + return list(zip(experiments_arr, out)) + + +nexperiments = 0 +total_sent = 0 +total_done = 0 +total_failed = 0 +host_records = {} +worker_lock = Lock() +ssh_conn_per_min_limit = 0 # set as an argument to run_remote_experiments +def __thread_worker(q, host): + global total_sent, total_done + hostname, python_binary = host.split(":") + while True: + + # make sure we are not sending too many ssh connections to the host + # (if we send too many connections ssh thottling will lock us out) + while True: + all_clear = False + + worker_lock.acquire() + try: + if hostname not in host_records: + host_records[hostname] = [] + + if len(host_records[hostname]) < ssh_conn_per_min_limit: + all_clear = True + elif time.time() - host_records[hostname][-ssh_conn_per_min_limit] > 61: + all_clear = True + finally: + worker_lock.release() + + # if we are clear to send a new ssh connection then break + if all_clear: + break + + # if we are not clear then we sleep and try again + time.sleep(5) + + experiment = q.get() + + # if we are not loading from the cache then we note that we have called the host + cache_dir = "/tmp" + cache_file = os.path.join(cache_dir, __gen_cache_id(experiment) + ".pickle") + if not os.path.isfile(cache_file): + worker_lock.acquire() + try: + host_records[hostname].append(time.time()) + finally: + worker_lock.release() + + # record how many we have sent off for execution + worker_lock.acquire() + try: + total_sent += 1 + __print_status() + finally: + worker_lock.release() + + __run_remote_experiment(experiment, hostname, cache_dir=cache_dir, python_binary=python_binary) + + # record how many are finished + worker_lock.acquire() + try: + total_done += 1 + __print_status() + finally: + worker_lock.release() + + q.task_done() + +def __print_status(): + print("Benchmark task %d of %d done (%d failed, %d running)" % (total_done, nexperiments, total_failed, total_sent - total_done), end="\r") + sys.stdout.flush() + + +def run_remote_experiments(experiments, thread_hosts, rate_limit=10): + """ Use ssh to run the experiments on remote machines in parallel. + + Parameters + ---------- + experiments : iterable + Output of shap.benchmark.experiments(...). + + thread_hosts : list of strings + Each host has the format "host_name:path_to_python_binary" and can appear multiple times + in the list (one for each parallel execution you want on that machine). + + rate_limit : int + How many ssh connections we make per minute to each host (to avoid throttling issues). + """ + + global ssh_conn_per_min_limit + ssh_conn_per_min_limit = rate_limit + + # first we kill any remaining workers from previous runs + # note we don't check_call because pkill kills our ssh call as well + thread_hosts = copy.copy(thread_hosts) + random.shuffle(thread_hosts) + for host in set(thread_hosts): + hostname,_ = host.split(":") + try: + subprocess.run(["ssh", hostname, "pkill -f shap.benchmark.run_experiment"], timeout=15) + except subprocess.TimeoutExpired: + print("Failed to connect to", hostname, "after 15 seconds! Exiting.") + return + + experiments = copy.copy(list(experiments)) + random.shuffle(experiments) # this way all the hard experiments don't get put on one machine + global nexperiments, total_sent, total_done, total_failed, host_records + nexperiments = len(experiments) + total_sent = 0 + total_done = 0 + total_failed = 0 + host_records = {} + + q = Queue() + + for host in thread_hosts: + worker = Thread(target=__thread_worker, args=(q, host)) + worker.setDaemon(True) + worker.start() + + for experiment in experiments: + q.put(experiment) + + q.join() + +def __run_remote_experiment(experiment, remote, cache_dir="/tmp", python_binary="python"): + global total_failed + dataset_name, model_name, method_name, metric_name = experiment + + # see if we have a cached version + cache_id = __gen_cache_id(experiment) + cache_file = os.path.join(cache_dir, cache_id + ".pickle") + if os.path.isfile(cache_file): + with open(cache_file, "rb") as f: + return pickle.load(f) + + # this is just so we don't dump everything at once on a machine + time.sleep(random.uniform(0,5)) + + # run the benchmark on the remote machine + #start = time.time() + cmd = "CUDA_VISIBLE_DEVICES=\"\" "+python_binary+" -c \"import shap; shap.benchmark.run_experiment(['{}', '{}', '{}', '{}'], cache_dir='{}')\" &> {}/{}.output".format( + dataset_name, model_name, method_name, metric_name, cache_dir, cache_dir, cache_id + ) + try: + subprocess.check_output(["ssh", remote, cmd]) + except subprocess.CalledProcessError as e: + print("The following command failed on %s:" % remote, file=sys.stderr) + print(cmd, file=sys.stderr) + total_failed += 1 + print(e) + return + + # copy the results back + subprocess.check_output(["scp", remote+":"+cache_file, cache_file]) + + if os.path.isfile(cache_file): + with open(cache_file, "rb") as f: + #print(cache_id.replace("__", " ") + " ...loaded from remote after %f seconds" % (time.time() - start)) + return pickle.load(f) + else: + raise FileNotFoundError("Remote benchmark call finished but no local file was found!") + +def __gen_cache_id(experiment): + dataset_name, model_name, method_name, metric_name = experiment + return "v" + "__".join([__version__, dataset_name, model_name, method_name, metric_name]) diff --git a/lib/shap/benchmark/framework.py b/lib/shap/benchmark/framework.py new file mode 100644 index 0000000000000000000000000000000000000000..3fe1497514aaf10592228ad879f9a3241b5faf0d --- /dev/null +++ b/lib/shap/benchmark/framework.py @@ -0,0 +1,113 @@ +import itertools as it + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from . import perturbation + + +def update(model, attributions, X, y, masker, sort_order, perturbation_method, scores): + metric = perturbation_method + ' ' + sort_order + sp = perturbation.SequentialPerturbation(model, masker, sort_order, perturbation_method) + xs, ys, auc = sp.model_score(attributions, X, y=y) + scores['metrics'].append(metric) + scores['values'][metric] = [xs, ys, auc] + +def get_benchmark(model, attributions, X, y, masker, metrics): + # convert dataframes + if isinstance(X, (pd.Series, pd.DataFrame)): + X = X.values + if isinstance(masker, (pd.Series, pd.DataFrame)): + masker = masker.values + + # record scores per metric + scores = {'metrics': list(), 'values': dict()} + for sort_order, perturbation_method in list(it.product(metrics['sort_order'], metrics['perturbation'])): + update(model, attributions, X, y, masker, sort_order, perturbation_method, scores) + + return scores + +def get_metrics(benchmarks, selection): + # select metrics to plot using selection function + explainer_metrics = set() + for explainer in benchmarks: + scores = benchmarks[explainer] + if len(explainer_metrics) == 0: + explainer_metrics = set(scores['metrics']) + else: + explainer_metrics = selection(explainer_metrics, set(scores['metrics'])) + + return list(explainer_metrics) + +def trend_plot(benchmarks): + explainer_metrics = get_metrics(benchmarks, lambda x, y: x.union(y)) + + # plot all curves if metric exists + for metric in explainer_metrics: + plt.clf() + + for explainer in benchmarks: + scores = benchmarks[explainer] + if metric in scores['values']: + x, y, auc = scores['values'][metric] + plt.plot(x, y, label=f'{round(auc, 3)} - {explainer}') + + if 'keep' in metric: + xlabel = 'Percent Unmasked' + if 'remove' in metric: + xlabel = 'Percent Masked' + + plt.ylabel('Model Output') + plt.xlabel(xlabel) + plt.title(metric) + plt.legend() + plt.show() + +def compare_plot(benchmarks): + explainer_metrics = get_metrics(benchmarks, lambda x, y: x.intersection(y)) + explainers = list(benchmarks.keys()) + num_explainers = len(explainers) + num_metrics = len(explainer_metrics) + + # dummy start to evenly distribute explainers on the left + # can later be replaced by boolean metrics + aucs = dict() + for i in range(num_explainers): + explainer = explainers[i] + aucs[explainer] = [i/(num_explainers-1)] + + # normalize per metric + for metric in explainer_metrics: + max_auc, min_auc = -float('inf'), float('inf') + + for explainer in explainers: + scores = benchmarks[explainer] + _, _, auc = scores['values'][metric] + min_auc = min(auc, min_auc) + max_auc = max(auc, max_auc) + + for explainer in explainers: + scores = benchmarks[explainer] + _, _, auc = scores['values'][metric] + aucs[explainer].append((auc-min_auc)/(max_auc-min_auc)) + + # plot common curves + ax = plt.gca() + for explainer in explainers: + plt.plot(np.linspace(0, 1, len(explainer_metrics)+1), aucs[explainer], '--o') + + ax.tick_params(which='major', axis='both', labelsize=8) + + ax.set_yticks([i/(num_explainers-1) for i in range(0, num_explainers)]) + ax.set_yticklabels(explainers, rotation=0) + + ax.set_xticks(np.linspace(0, 1, num_metrics+1)) + ax.set_xticklabels([' '] + explainer_metrics, rotation=45, ha='right') + + plt.grid(which='major', axis='x', linestyle='--') + plt.tight_layout() + plt.ylabel('Relative Performance of Each Explanation Method') + plt.xlabel('Evaluation Metrics') + plt.title('Explanation Method Performance Across Metrics') + plt.show() diff --git a/lib/shap/benchmark/measures.py b/lib/shap/benchmark/measures.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a0fe291efd5c1cfab4fb9d5542a670e6ef103a --- /dev/null +++ b/lib/shap/benchmark/measures.py @@ -0,0 +1,424 @@ +import warnings + +import numpy as np +import pandas as pd +import sklearn.utils +from tqdm.auto import tqdm + +_remove_cache = {} +def remove_retrain(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ The model is retrained for each test sample with the important features set to a constant. + + If you want to know how important a set of features is you can ask how the model would be + different if those features had never existed. To determine this we can mask those features + across the entire training and test datasets, then retrain the model. If we apply compare the + output of this retrained model to the original model we can see the effect produced by knowning + the features we masked. Since for individualized explanation methods each test sample has a + different set of most important features we need to retrain the model for every test sample + to get the change in model performance when a specified fraction of the most important features + are withheld. + """ + + warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") + + # see if we match the last cached call + global _remove_cache + args = (X_train, y_train, X_test, y_test, model_generator, metric) + cache_match = False + if "args" in _remove_cache: + if all(a is b for a,b in zip(_remove_cache["args"], args)) and np.all(_remove_cache["attr_test"] == attr_test): + cache_match = True + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # this is the model we will retrain many times + model_masked = model_generator() + + # mask nmask top features and re-train the model for each test explanation + X_train_tmp = np.zeros(X_train.shape) + X_test_tmp = np.zeros(X_test.shape) + yp_masked_test = np.zeros(y_test.shape) + tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6 + last_nmask = _remove_cache.get("nmask", None) + last_yp_masked_test = _remove_cache.get("yp_masked_test", None) + for i in tqdm(range(len(y_test)), "Retraining for the 'remove' metric"): + if cache_match and last_nmask[i] == nmask[i]: + yp_masked_test[i] = last_yp_masked_test[i] + elif nmask[i] == 0: + yp_masked_test[i] = trained_model.predict(X_test[i:i+1])[0] + else: + # mask out the most important features for this test instance + X_train_tmp[:] = X_train + X_test_tmp[:] = X_test + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + X_train_tmp[:,ordering[:nmask[i]]] = X_train[:,ordering[:nmask[i]]].mean() + X_test_tmp[i,ordering[:nmask[i]]] = X_train[:,ordering[:nmask[i]]].mean() + + # retrain the model and make a prediction + model_masked.fit(X_train_tmp, y_train) + yp_masked_test[i] = model_masked.predict(X_test_tmp[i:i+1])[0] + + # save our results so the next call to us can be faster when there is redundancy + _remove_cache["nmask"] = nmask + _remove_cache["yp_masked_test"] = yp_masked_test + _remove_cache["attr_test"] = attr_test + _remove_cache["args"] = args + + return metric(y_test, yp_masked_test) + +def remove_mask(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ Each test sample is masked by setting the important features to a constant. + """ + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # mask nmask top features for each test explanation + X_test_tmp = X_test.copy() + tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6 + mean_vals = X_train.mean(0) + for i in range(len(y_test)): + if nmask[i] > 0: + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + X_test_tmp[i,ordering[:nmask[i]]] = mean_vals[ordering[:nmask[i]]] + + yp_masked_test = trained_model.predict(X_test_tmp) + + return metric(y_test, yp_masked_test) + +def remove_impute(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ The model is reevaluated for each test sample with the important features set to an imputed value. + + Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on + being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should + be significantly bigger than X_train.shape[1]. + """ + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # keep nkeep top features for each test explanation + C = np.cov(X_train.T) + C += np.eye(C.shape[0]) * 1e-6 + X_test_tmp = X_test.copy() + yp_masked_test = np.zeros(y_test.shape) + tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6 + mean_vals = X_train.mean(0) + for i in range(len(y_test)): + if nmask[i] > 0: + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + observe_inds = ordering[nmask[i]:] + impute_inds = ordering[:nmask[i]] + + # impute missing data assuming it follows a multivariate normal distribution + Coo_inv = np.linalg.inv(C[observe_inds,:][:,observe_inds]) + Cio = C[impute_inds,:][:,observe_inds] + impute = mean_vals[impute_inds] + Cio @ Coo_inv @ (X_test[i, observe_inds] - mean_vals[observe_inds]) + + X_test_tmp[i, impute_inds] = impute + + yp_masked_test = trained_model.predict(X_test_tmp) + + return metric(y_test, yp_masked_test) + +def remove_resample(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ The model is reevaluated for each test sample with the important features set to resample background values. + """ + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # how many samples to take + nsamples = 100 + + # keep nkeep top features for each test explanation + N,M = X_test.shape + X_test_tmp = np.tile(X_test, [1, nsamples]).reshape(nsamples * N, M) + tie_breaking_noise = const_rand(M) * 1e-6 + inds = sklearn.utils.resample(np.arange(N), n_samples=nsamples, random_state=random_state) + for i in range(N): + if nmask[i] > 0: + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + X_test_tmp[i*nsamples:(i+1)*nsamples, ordering[:nmask[i]]] = X_train[inds, :][:, ordering[:nmask[i]]] + + yp_masked_test = trained_model.predict(X_test_tmp) + yp_masked_test = np.reshape(yp_masked_test, (N, nsamples)).mean(1) # take the mean output over all samples + + return metric(y_test, yp_masked_test) + +def batch_remove_retrain(nmask_train, nmask_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric): + """ An approximation of holdout that only retraines the model once. + + This is also called ROAR (RemOve And Retrain) in work by Google. It is much more computationally + efficient that the holdout method because it masks the most important features in every sample + and then retrains the model once, instead of retraining the model for every test sample like + the holdout metric. + """ + + warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # mask nmask top features for each explanation + X_train_tmp = X_train.copy() + X_train_mean = X_train.mean(0) + tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6 + for i in range(len(y_train)): + if nmask_train[i] > 0: + ordering = np.argsort(-attr_train[i, :] + tie_breaking_noise) + X_train_tmp[i, ordering[:nmask_train[i]]] = X_train_mean[ordering[:nmask_train[i]]] + X_test_tmp = X_test.copy() + for i in range(len(y_test)): + if nmask_test[i] > 0: + ordering = np.argsort(-attr_test[i, :] + tie_breaking_noise) + X_test_tmp[i, ordering[:nmask_test[i]]] = X_train_mean[ordering[:nmask_test[i]]] + + # train the model with all the given features masked + model_masked = model_generator() + model_masked.fit(X_train_tmp, y_train) + yp_test_masked = model_masked.predict(X_test_tmp) + + return metric(y_test, yp_test_masked) + +_keep_cache = {} +def keep_retrain(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ The model is retrained for each test sample with the non-important features set to a constant. + + If you want to know how important a set of features is you can ask how the model would be + different if only those features had existed. To determine this we can mask the other features + across the entire training and test datasets, then retrain the model. If we apply compare the + output of this retrained model to the original model we can see the effect produced by only + knowning the important features. Since for individualized explanation methods each test sample + has a different set of most important features we need to retrain the model for every test sample + to get the change in model performance when a specified fraction of the most important features + are retained. + """ + + warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") + + # see if we match the last cached call + global _keep_cache + args = (X_train, y_train, X_test, y_test, model_generator, metric) + cache_match = False + if "args" in _keep_cache: + if all(a is b for a,b in zip(_keep_cache["args"], args)) and np.all(_keep_cache["attr_test"] == attr_test): + cache_match = True + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # this is the model we will retrain many times + model_masked = model_generator() + + # keep nkeep top features and re-train the model for each test explanation + X_train_tmp = np.zeros(X_train.shape) + X_test_tmp = np.zeros(X_test.shape) + yp_masked_test = np.zeros(y_test.shape) + tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6 + last_nkeep = _keep_cache.get("nkeep", None) + last_yp_masked_test = _keep_cache.get("yp_masked_test", None) + for i in tqdm(range(len(y_test)), "Retraining for the 'keep' metric"): + if cache_match and last_nkeep[i] == nkeep[i]: + yp_masked_test[i] = last_yp_masked_test[i] + elif nkeep[i] == attr_test.shape[1]: + yp_masked_test[i] = trained_model.predict(X_test[i:i+1])[0] + else: + + # mask out the most important features for this test instance + X_train_tmp[:] = X_train + X_test_tmp[:] = X_test + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + X_train_tmp[:,ordering[nkeep[i]:]] = X_train[:,ordering[nkeep[i]:]].mean() + X_test_tmp[i,ordering[nkeep[i]:]] = X_train[:,ordering[nkeep[i]:]].mean() + + # retrain the model and make a prediction + model_masked.fit(X_train_tmp, y_train) + yp_masked_test[i] = model_masked.predict(X_test_tmp[i:i+1])[0] + + # save our results so the next call to us can be faster when there is redundancy + _keep_cache["nkeep"] = nkeep + _keep_cache["yp_masked_test"] = yp_masked_test + _keep_cache["attr_test"] = attr_test + _keep_cache["args"] = args + + return metric(y_test, yp_masked_test) + +def keep_mask(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ The model is reevaluated for each test sample with the non-important features set to their mean. + """ + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # keep nkeep top features for each test explanation + X_test_tmp = X_test.copy() + yp_masked_test = np.zeros(y_test.shape) + tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6 + mean_vals = X_train.mean(0) + for i in range(len(y_test)): + if nkeep[i] < X_test.shape[1]: + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + X_test_tmp[i,ordering[nkeep[i]:]] = mean_vals[ordering[nkeep[i]:]] + + yp_masked_test = trained_model.predict(X_test_tmp) + + return metric(y_test, yp_masked_test) + +def keep_impute(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ The model is reevaluated for each test sample with the non-important features set to an imputed value. + + Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on + being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should + be significantly bigger than X_train.shape[1]. + """ + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # keep nkeep top features for each test explanation + C = np.cov(X_train.T) + C += np.eye(C.shape[0]) * 1e-6 + X_test_tmp = X_test.copy() + yp_masked_test = np.zeros(y_test.shape) + tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6 + mean_vals = X_train.mean(0) + for i in range(len(y_test)): + if nkeep[i] < X_test.shape[1]: + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + observe_inds = ordering[:nkeep[i]] + impute_inds = ordering[nkeep[i]:] + + # impute missing data assuming it follows a multivariate normal distribution + Coo_inv = np.linalg.inv(C[observe_inds,:][:,observe_inds]) + Cio = C[impute_inds,:][:,observe_inds] + impute = mean_vals[impute_inds] + Cio @ Coo_inv @ (X_test[i, observe_inds] - mean_vals[observe_inds]) + + X_test_tmp[i, impute_inds] = impute + + yp_masked_test = trained_model.predict(X_test_tmp) + + return metric(y_test, yp_masked_test) + +def keep_resample(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state): + """ The model is reevaluated for each test sample with the non-important features set to resample background values. + """ # why broken? overwriting? + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # how many samples to take + nsamples = 100 + + # keep nkeep top features for each test explanation + N,M = X_test.shape + X_test_tmp = np.tile(X_test, [1, nsamples]).reshape(nsamples * N, M) + tie_breaking_noise = const_rand(M) * 1e-6 + inds = sklearn.utils.resample(np.arange(N), n_samples=nsamples, random_state=random_state) + for i in range(N): + if nkeep[i] < M: + ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise) + X_test_tmp[i*nsamples:(i+1)*nsamples, ordering[nkeep[i]:]] = X_train[inds, :][:, ordering[nkeep[i]:]] + + yp_masked_test = trained_model.predict(X_test_tmp) + yp_masked_test = np.reshape(yp_masked_test, (N, nsamples)).mean(1) # take the mean output over all samples + + return metric(y_test, yp_masked_test) + +def batch_keep_retrain(nkeep_train, nkeep_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric): + """ An approximation of keep that only retraines the model once. + + This is also called KAR (Keep And Retrain) in work by Google. It is much more computationally + efficient that the keep method because it masks the unimportant features in every sample + and then retrains the model once, instead of retraining the model for every test sample like + the keep metric. + """ + + warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!") + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # mask nkeep top features for each explanation + X_train_tmp = X_train.copy() + X_train_mean = X_train.mean(0) + tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6 + for i in range(len(y_train)): + if nkeep_train[i] < X_train.shape[1]: + ordering = np.argsort(-attr_train[i, :] + tie_breaking_noise) + X_train_tmp[i, ordering[nkeep_train[i]:]] = X_train_mean[ordering[nkeep_train[i]:]] + X_test_tmp = X_test.copy() + for i in range(len(y_test)): + if nkeep_test[i] < X_test.shape[1]: + ordering = np.argsort(-attr_test[i, :] + tie_breaking_noise) + X_test_tmp[i, ordering[nkeep_test[i]:]] = X_train_mean[ordering[nkeep_test[i]:]] + + # train the model with all the features not given masked + model_masked = model_generator() + model_masked.fit(X_train_tmp, y_train) + yp_test_masked = model_masked.predict(X_test_tmp) + + return metric(y_test, yp_test_masked) + +def local_accuracy(X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model): + """ The how well do the features plus a constant base rate sum up to the model output. + """ + + X_train, X_test = to_array(X_train, X_test) + + # how many features to mask + assert X_train.shape[1] == X_test.shape[1] + + # keep nkeep top features and re-train the model for each test explanation + yp_test = trained_model.predict(X_test) + + return metric(yp_test, strip_list(attr_test).sum(1)) + +def to_array(*args): + return [a.values if isinstance(a, pd.DataFrame) else a for a in args] + +def const_rand(size, seed=23980): + """ Generate a random array with a fixed seed. + """ + old_seed = np.random.seed() + np.random.seed(seed) + out = np.random.rand(size) + np.random.seed(old_seed) + return out + +def const_shuffle(arr, seed=23980): + """ Shuffle an array in-place with a fixed seed. + """ + old_seed = np.random.seed() + np.random.seed(seed) + np.random.shuffle(arr) + np.random.seed(old_seed) + +def strip_list(attrs): + """ This assumes that if you have a list of outputs you just want the second one (the second class is the '1' class). + """ + if isinstance(attrs, list): + return attrs[1] + else: + return attrs diff --git a/lib/shap/benchmark/methods.py b/lib/shap/benchmark/methods.py new file mode 100644 index 0000000000000000000000000000000000000000..f52bd3fa9c6e171d21c277ee9e90d394802c90c6 --- /dev/null +++ b/lib/shap/benchmark/methods.py @@ -0,0 +1,148 @@ +import numpy as np +import sklearn + +from .. import ( + DeepExplainer, + GradientExplainer, + KernelExplainer, + LinearExplainer, + SamplingExplainer, + TreeExplainer, + kmeans, +) +from ..explainers import other +from .models import KerasWrap + + +def linear_shap_corr(model, data): + """ Linear SHAP (corr 1000) + """ + return LinearExplainer(model, data, feature_dependence="correlation", nsamples=1000).shap_values + +def linear_shap_ind(model, data): + """ Linear SHAP (ind) + """ + return LinearExplainer(model, data, feature_dependence="independent").shap_values + +def coef(model, data): + """ Coefficients + """ + return other.CoefficentExplainer(model).attributions + +def random(model, data): + """ Random + color = #777777 + linestyle = solid + """ + return other.RandomExplainer().attributions + +def kernel_shap_1000_meanref(model, data): + """ Kernel SHAP 1000 mean ref. + color = red_blue_circle(0.5) + linestyle = solid + """ + return lambda X: KernelExplainer(model.predict, kmeans(data, 1)).shap_values(X, nsamples=1000, l1_reg=0) + +def sampling_shap_1000(model, data): + """ IME 1000 + color = red_blue_circle(0.5) + linestyle = dashed + """ + return lambda X: SamplingExplainer(model.predict, data).shap_values(X, nsamples=1000) + +def tree_shap_tree_path_dependent(model, data): + """ TreeExplainer + color = red_blue_circle(0) + linestyle = solid + """ + return TreeExplainer(model, feature_dependence="tree_path_dependent").shap_values + +def tree_shap_independent_200(model, data): + """ TreeExplainer (independent) + color = red_blue_circle(0) + linestyle = dashed + """ + data_subsample = sklearn.utils.resample(data, replace=False, n_samples=min(200, data.shape[0]), random_state=0) + return TreeExplainer(model, data_subsample, feature_dependence="independent").shap_values + +def mean_abs_tree_shap(model, data): + """ mean(|TreeExplainer|) + color = red_blue_circle(0.25) + linestyle = solid + """ + def f(X): + v = TreeExplainer(model).shap_values(X) + if isinstance(v, list): + return [np.tile(np.abs(sv).mean(0), (X.shape[0], 1)) for sv in v] + else: + return np.tile(np.abs(v).mean(0), (X.shape[0], 1)) + return f + +def saabas(model, data): + """ Saabas + color = red_blue_circle(0) + linestyle = dotted + """ + return lambda X: TreeExplainer(model).shap_values(X, approximate=True) + +def tree_gain(model, data): + """ Gain/Gini Importance + color = red_blue_circle(0.25) + linestyle = dotted + """ + return other.TreeGainExplainer(model).attributions + +def lime_tabular_regression_1000(model, data): + """ LIME Tabular 1000 + color = red_blue_circle(0.75) + """ + return lambda X: other.LimeTabularExplainer(model.predict, data, mode="regression").attributions(X, nsamples=1000) + +def lime_tabular_classification_1000(model, data): + """ LIME Tabular 1000 + color = red_blue_circle(0.75) + """ + return lambda X: other.LimeTabularExplainer(model.predict_proba, data, mode="classification").attributions(X, nsamples=1000)[1] + +def maple(model, data): + """ MAPLE + color = red_blue_circle(0.6) + """ + return lambda X: other.MapleExplainer(model.predict, data).attributions(X, multiply_by_input=False) + +def tree_maple(model, data): + """ Tree MAPLE + color = red_blue_circle(0.6) + linestyle = dashed + """ + return lambda X: other.TreeMapleExplainer(model, data).attributions(X, multiply_by_input=False) + +def deep_shap(model, data): + """ Deep SHAP (DeepLIFT) + """ + if isinstance(model, KerasWrap): + model = model.model + explainer = DeepExplainer(model, kmeans(data, 1).data) + def f(X): + phi = explainer.shap_values(X) + if isinstance(phi, list) and len(phi) == 1: + return phi[0] + else: + return phi + + return f + +def expected_gradients(model, data): + """ Expected Gradients + """ + if isinstance(model, KerasWrap): + model = model.model + explainer = GradientExplainer(model, data) + def f(X): + phi = explainer.shap_values(X) + if isinstance(phi, list) and len(phi) == 1: + return phi[0] + else: + return phi + + return f diff --git a/lib/shap/benchmark/metrics.py b/lib/shap/benchmark/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..1ff8db7a31d236c4b62e2d90ad067fa70b484b09 --- /dev/null +++ b/lib/shap/benchmark/metrics.py @@ -0,0 +1,824 @@ +import hashlib +import os +import time + +import numpy as np +import sklearn + +from .. import __version__ +from . import measures, methods + +try: + import dill as pickle +except Exception: + pass + +try: + from sklearn.model_selection import train_test_split +except Exception: + from sklearn.cross_validation import train_test_split + + +def runtime(X, y, model_generator, method_name): + """ Runtime (sec / 1k samples) + transform = "negate_log" + sort_order = 2 + """ + + old_seed = np.random.seed() + np.random.seed(3293) + + # average the method scores over several train/test splits + method_reps = [] + for i in range(3): + X_train, X_test, y_train, _ = train_test_split(__toarray(X), y, test_size=100, random_state=i) + + # define the model we are going to explain + model = model_generator() + model.fit(X_train, y_train) + + # evaluate each method + start = time.time() + explainer = getattr(methods, method_name)(model, X_train) + build_time = time.time() - start + + start = time.time() + explainer(X_test) + explain_time = time.time() - start + + # we always normalize the explain time as though we were explaining 1000 samples + # even if to reduce the runtime of the benchmark we do less (like just 100) + method_reps.append(build_time + explain_time * 1000.0 / X_test.shape[0]) + np.random.seed(old_seed) + + return None, np.mean(method_reps) + +def local_accuracy(X, y, model_generator, method_name): + """ Local Accuracy + transform = "identity" + sort_order = 0 + """ + + def score_map(true, pred): + """ Computes local accuracy as the normalized standard deviation of numerical scores. + """ + return np.std(pred - true) / (np.std(true) + 1e-6) + + def score_function(X_train, X_test, y_train, y_test, attr_function, trained_model, random_state): + return measures.local_accuracy( + X_train, y_train, X_test, y_test, attr_function(X_test), + model_generator, score_map, trained_model + ) + return None, __score_method(X, y, None, model_generator, score_function, method_name) + +def consistency_guarantees(X, y, model_generator, method_name): + """ Consistency Guarantees + transform = "identity" + sort_order = 1 + """ + + # 1.0 - perfect consistency + # 0.8 - guarantees depend on sampling + # 0.6 - guarantees depend on approximation + # 0.0 - no garuntees + guarantees = { + "linear_shap_corr": 1.0, + "linear_shap_ind": 1.0, + "coef": 0.0, + "kernel_shap_1000_meanref": 0.8, + "sampling_shap_1000": 0.8, + "random": 0.0, + "saabas": 0.0, + "tree_gain": 0.0, + "tree_shap_tree_path_dependent": 1.0, + "tree_shap_independent_200": 1.0, + "mean_abs_tree_shap": 1.0, + "lime_tabular_regression_1000": 0.8, + "lime_tabular_classification_1000": 0.8, + "maple": 0.8, + "tree_maple": 0.8, + "deep_shap": 0.6, + "expected_gradients": 0.6 + } + + return None, guarantees[method_name] + +def __mean_pred(true, pred): + """ A trivial metric that is just is the output of the model. + """ + return np.mean(pred) + +def keep_positive_mask(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Positive (mask) + xlabel = "Max fraction of features kept" + ylabel = "Mean model output" + transform = "identity" + sort_order = 4 + """ + return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def keep_negative_mask(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Negative (mask) + xlabel = "Max fraction of features kept" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 5 + """ + return __run_measure(measures.keep_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def keep_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Absolute (mask) + xlabel = "Max fraction of features kept" + ylabel = "R^2" + transform = "identity" + sort_order = 6 + """ + return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) + +def keep_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Absolute (mask) + xlabel = "Max fraction of features kept" + ylabel = "ROC AUC" + transform = "identity" + sort_order = 6 + """ + return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) + +def remove_positive_mask(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Positive (mask) + xlabel = "Max fraction of features removed" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 7 + """ + return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def remove_negative_mask(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Negative (mask) + xlabel = "Max fraction of features removed" + ylabel = "Mean model output" + transform = "identity" + sort_order = 8 + """ + return __run_measure(measures.remove_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def remove_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Absolute (mask) + xlabel = "Max fraction of features removed" + ylabel = "1 - R^2" + transform = "one_minus" + sort_order = 9 + """ + return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) + +def remove_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Absolute (mask) + xlabel = "Max fraction of features removed" + ylabel = "1 - ROC AUC" + transform = "one_minus" + sort_order = 9 + """ + return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) + +def keep_positive_resample(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Positive (resample) + xlabel = "Max fraction of features kept" + ylabel = "Mean model output" + transform = "identity" + sort_order = 10 + """ + return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def keep_negative_resample(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Negative (resample) + xlabel = "Max fraction of features kept" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 11 + """ + return __run_measure(measures.keep_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def keep_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Absolute (resample) + xlabel = "Max fraction of features kept" + ylabel = "R^2" + transform = "identity" + sort_order = 12 + """ + return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) + +def keep_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Absolute (resample) + xlabel = "Max fraction of features kept" + ylabel = "ROC AUC" + transform = "identity" + sort_order = 12 + """ + return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) + +def remove_positive_resample(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Positive (resample) + xlabel = "Max fraction of features removed" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 13 + """ + return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def remove_negative_resample(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Negative (resample) + xlabel = "Max fraction of features removed" + ylabel = "Mean model output" + transform = "identity" + sort_order = 14 + """ + return __run_measure(measures.remove_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def remove_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Absolute (resample) + xlabel = "Max fraction of features removed" + ylabel = "1 - R^2" + transform = "one_minus" + sort_order = 15 + """ + return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) + +def remove_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Absolute (resample) + xlabel = "Max fraction of features removed" + ylabel = "1 - ROC AUC" + transform = "one_minus" + sort_order = 15 + """ + return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) + +def keep_positive_impute(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Positive (impute) + xlabel = "Max fraction of features kept" + ylabel = "Mean model output" + transform = "identity" + sort_order = 16 + """ + return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def keep_negative_impute(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Negative (impute) + xlabel = "Max fraction of features kept" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 17 + """ + return __run_measure(measures.keep_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def keep_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Absolute (impute) + xlabel = "Max fraction of features kept" + ylabel = "R^2" + transform = "identity" + sort_order = 18 + """ + return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) + +def keep_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Absolute (impute) + xlabel = "Max fraction of features kept" + ylabel = "ROC AUC" + transform = "identity" + sort_order = 19 + """ + return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) + +def remove_positive_impute(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Positive (impute) + xlabel = "Max fraction of features removed" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 7 + """ + return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def remove_negative_impute(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Negative (impute) + xlabel = "Max fraction of features removed" + ylabel = "Mean model output" + transform = "identity" + sort_order = 8 + """ + return __run_measure(measures.remove_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def remove_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Absolute (impute) + xlabel = "Max fraction of features removed" + ylabel = "1 - R^2" + transform = "one_minus" + sort_order = 9 + """ + return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score) + +def remove_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Absolute (impute) + xlabel = "Max fraction of features removed" + ylabel = "1 - ROC AUC" + transform = "one_minus" + sort_order = 9 + """ + return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score) + +def keep_positive_retrain(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Positive (retrain) + xlabel = "Max fraction of features kept" + ylabel = "Mean model output" + transform = "identity" + sort_order = 6 + """ + return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def keep_negative_retrain(X, y, model_generator, method_name, num_fcounts=11): + """ Keep Negative (retrain) + xlabel = "Max fraction of features kept" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 7 + """ + return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def remove_positive_retrain(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Positive (retrain) + xlabel = "Max fraction of features removed" + ylabel = "Negative mean model output" + transform = "negate" + sort_order = 11 + """ + return __run_measure(measures.remove_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred) + +def remove_negative_retrain(X, y, model_generator, method_name, num_fcounts=11): + """ Remove Negative (retrain) + xlabel = "Max fraction of features removed" + ylabel = "Mean model output" + transform = "identity" + sort_order = 12 + """ + return __run_measure(measures.remove_retrain, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred) + +def __run_measure(measure, X, y, model_generator, method_name, attribution_sign, num_fcounts, summary_function): + + def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trained_model, random_state): + if attribution_sign == 0: + A = np.abs(__strip_list(attr_function(X_test))) + else: + A = attribution_sign * __strip_list(attr_function(X_test)) + nmask = np.ones(len(y_test)) * fcount + nmask = np.minimum(nmask, np.array(A >= 0).sum(1)).astype(int) + return measure( + nmask, X_train, y_train, X_test, y_test, A, + model_generator, summary_function, trained_model, random_state + ) + fcounts = __intlogspace(0, X.shape[1], num_fcounts) + return fcounts, __score_method(X, y, fcounts, model_generator, score_function, method_name) + +def batch_remove_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Batch Remove Absolute (retrain) + xlabel = "Fraction of features removed" + ylabel = "1 - R^2" + transform = "one_minus" + sort_order = 13 + """ + return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts) + +def batch_keep_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11): + """ Batch Keep Absolute (retrain) + xlabel = "Fraction of features kept" + ylabel = "R^2" + transform = "identity" + sort_order = 13 + """ + return __run_batch_abs_metric(measures.batch_keep_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts) + +def batch_remove_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Batch Remove Absolute (retrain) + xlabel = "Fraction of features removed" + ylabel = "1 - ROC AUC" + transform = "one_minus" + sort_order = 13 + """ + return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.roc_auc_score, num_fcounts) + +def batch_keep_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11): + """ Batch Keep Absolute (retrain) + xlabel = "Fraction of features kept" + ylabel = "ROC AUC" + transform = "identity" + sort_order = 13 + """ + return __run_batch_abs_metric(measures.batch_keep_retrain, X, y, model_generator, method_name, sklearn.metrics.roc_auc_score, num_fcounts) + +def __run_batch_abs_metric(metric, X, y, model_generator, method_name, loss, num_fcounts): + def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trained_model): + A_train = np.abs(__strip_list(attr_function(X_train))) + nkeep_train = (np.ones(len(y_train)) * fcount).astype(int) + #nkeep_train = np.minimum(nkeep_train, np.array(A_train > 0).sum(1)).astype(int) + A_test = np.abs(__strip_list(attr_function(X_test))) + nkeep_test = (np.ones(len(y_test)) * fcount).astype(int) + #nkeep_test = np.minimum(nkeep_test, np.array(A_test >= 0).sum(1)).astype(int) + return metric( + nkeep_train, nkeep_test, X_train, y_train, X_test, y_test, A_train, A_test, + model_generator, loss + ) + fcounts = __intlogspace(0, X.shape[1], num_fcounts) + return fcounts, __score_method(X, y, fcounts, model_generator, score_function, method_name) + +_attribution_cache = {} +def __score_method(X, y, fcounts, model_generator, score_function, method_name, nreps=10, test_size=100, cache_dir="/tmp"): + """ Test an explanation method. + """ + + try: + pickle + except NameError: + raise ImportError("The 'dill' package could not be loaded and is needed for the benchmark!") + + old_seed = np.random.seed() + np.random.seed(3293) + + # average the method scores over several train/test splits + method_reps = [] + + data_hash = hashlib.sha256(__toarray(X).flatten()).hexdigest() + hashlib.sha256(__toarray(y)).hexdigest() + for i in range(nreps): + X_train, X_test, y_train, y_test = train_test_split(__toarray(X), y, test_size=test_size, random_state=i) + + # define the model we are going to explain, caching so we onlu build it once + model_id = "model_cache__v" + "__".join([__version__, data_hash, model_generator.__name__])+".pickle" + cache_file = os.path.join(cache_dir, model_id + ".pickle") + if os.path.isfile(cache_file): + with open(cache_file, "rb") as f: + model = pickle.load(f) + else: + model = model_generator() + model.fit(X_train, y_train) + with open(cache_file, "wb") as f: + pickle.dump(model, f) + + attr_key = "_".join([model_generator.__name__, method_name, str(test_size), str(nreps), str(i), data_hash]) + def score(attr_function): + def cached_attr_function(X_inner): + if attr_key not in _attribution_cache: + _attribution_cache[attr_key] = attr_function(X_inner) + return _attribution_cache[attr_key] + + #cached_attr_function = lambda X: __check_cache(attr_function, X) + if fcounts is None: + return score_function(X_train, X_test, y_train, y_test, cached_attr_function, model, i) + else: + scores = [] + for f in fcounts: + scores.append(score_function(f, X_train, X_test, y_train, y_test, cached_attr_function, model, i)) + return np.array(scores) + + # evaluate the method (only building the attribution function if we need to) + if attr_key not in _attribution_cache: + method_reps.append(score(getattr(methods, method_name)(model, X_train))) + else: + method_reps.append(score(None)) + + np.random.seed(old_seed) + return np.array(method_reps).mean(0) + + +# used to memoize explainer functions so we don't waste time re-explaining the same object +__cache0 = None +__cache_X0 = None +__cache_f0 = None +__cache1 = None +__cache_X1 = None +__cache_f1 = None +def __check_cache(f, X): + global __cache0, __cache_X0, __cache_f0 + global __cache1, __cache_X1, __cache_f1 + if X is __cache_X0 and f is __cache_f0: + return __cache0 + elif X is __cache_X1 and f is __cache_f1: + return __cache1 + else: + __cache_f1 = __cache_f0 + __cache_X1 = __cache_X0 + __cache1 = __cache0 + __cache_f0 = f + __cache_X0 = X + __cache0 = f(X) + return __cache0 + +def __intlogspace(start, end, count): + return np.unique(np.round(start + (end-start) * (np.logspace(0, 1, count, endpoint=True) - 1) / 9).astype(int)) + +def __toarray(X): + """ Converts DataFrames to numpy arrays. + """ + if hasattr(X, "values"): + X = X.values + return X + +def __strip_list(attrs): + """ This assumes that if you have a list of outputs you just want the second one (the second class). + """ + if isinstance(attrs, list): + return attrs[1] + else: + return attrs + +def _fit_human(model_generator, val00, val01, val11): + # force the model to fit a function with almost entirely zero background + N = 1000000 + M = 3 + X = np.zeros((N,M)) + X.shape + y = np.ones(N) * val00 + X[0:1000, 0] = 1 + y[0:1000] = val01 + for i in range(0,1000000,1000): + X[i, 1] = 1 + y[i] = val01 + y[0] = val11 + model = model_generator() + model.fit(X, y) + return model + +def _human_and(X, model_generator, method_name, fever, cough): + assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!" + + # these are from the sickness_score mturk user study experiment + X_test = np.zeros((100,3)) + if not fever and not cough: + human_consensus = np.array([0., 0., 0.]) + X_test[0,:] = np.array([[0., 0., 1.]]) + elif not fever and cough: + human_consensus = np.array([0., 2., 0.]) + X_test[0,:] = np.array([[0., 1., 1.]]) + elif fever and cough: + human_consensus = np.array([5., 5., 0.]) + X_test[0,:] = np.array([[1., 1., 1.]]) + + # force the model to fit an XOR function with almost entirely zero background + model = _fit_human(model_generator, 0, 2, 10) + + attr_function = getattr(methods, method_name)(model, X) + methods_attrs = attr_function(X_test) + return "human", (human_consensus, methods_attrs[0,:]) + +def human_and_00(X, y, model_generator, method_name): + """ AND (false/false) + + This tests how well a feature attribution method agrees with human intuition + for an AND operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever and cough: +6 points + + transform = "identity" + sort_order = 0 + """ + return _human_and(X, model_generator, method_name, False, False) + +def human_and_01(X, y, model_generator, method_name): + """ AND (false/true) + + This tests how well a feature attribution method agrees with human intuition + for an AND operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever and cough: +6 points + + transform = "identity" + sort_order = 1 + """ + return _human_and(X, model_generator, method_name, False, True) + +def human_and_11(X, y, model_generator, method_name): + """ AND (true/true) + + This tests how well a feature attribution method agrees with human intuition + for an AND operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever and cough: +6 points + + transform = "identity" + sort_order = 2 + """ + return _human_and(X, model_generator, method_name, True, True) + + +def _human_or(X, model_generator, method_name, fever, cough): + assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!" + + # these are from the sickness_score mturk user study experiment + X_test = np.zeros((100,3)) + if not fever and not cough: + human_consensus = np.array([0., 0., 0.]) + X_test[0,:] = np.array([[0., 0., 1.]]) + elif not fever and cough: + human_consensus = np.array([0., 8., 0.]) + X_test[0,:] = np.array([[0., 1., 1.]]) + elif fever and cough: + human_consensus = np.array([5., 5., 0.]) + X_test[0,:] = np.array([[1., 1., 1.]]) + + # force the model to fit an XOR function with almost entirely zero background + model = _fit_human(model_generator, 0, 8, 10) + + attr_function = getattr(methods, method_name)(model, X) + methods_attrs = attr_function(X_test) + return "human", (human_consensus, methods_attrs[0,:]) + +def human_or_00(X, y, model_generator, method_name): + """ OR (false/false) + + This tests how well a feature attribution method agrees with human intuition + for an OR operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever or cough: +6 points + + transform = "identity" + sort_order = 0 + """ + return _human_or(X, model_generator, method_name, False, False) + +def human_or_01(X, y, model_generator, method_name): + """ OR (false/true) + + This tests how well a feature attribution method agrees with human intuition + for an OR operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever or cough: +6 points + + transform = "identity" + sort_order = 1 + """ + return _human_or(X, model_generator, method_name, False, True) + +def human_or_11(X, y, model_generator, method_name): + """ OR (true/true) + + This tests how well a feature attribution method agrees with human intuition + for an OR operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever or cough: +6 points + + transform = "identity" + sort_order = 2 + """ + return _human_or(X, model_generator, method_name, True, True) + + +def _human_xor(X, model_generator, method_name, fever, cough): + assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!" + + # these are from the sickness_score mturk user study experiment + X_test = np.zeros((100,3)) + if not fever and not cough: + human_consensus = np.array([0., 0., 0.]) + X_test[0,:] = np.array([[0., 0., 1.]]) + elif not fever and cough: + human_consensus = np.array([0., 8., 0.]) + X_test[0,:] = np.array([[0., 1., 1.]]) + elif fever and cough: + human_consensus = np.array([2., 2., 0.]) + X_test[0,:] = np.array([[1., 1., 1.]]) + + # force the model to fit an XOR function with almost entirely zero background + model = _fit_human(model_generator, 0, 8, 4) + + attr_function = getattr(methods, method_name)(model, X) + methods_attrs = attr_function(X_test) + return "human", (human_consensus, methods_attrs[0,:]) + +def human_xor_00(X, y, model_generator, method_name): + """ XOR (false/false) + + This tests how well a feature attribution method agrees with human intuition + for an eXclusive OR operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever or cough but not both: +6 points + + transform = "identity" + sort_order = 3 + """ + return _human_xor(X, model_generator, method_name, False, False) + +def human_xor_01(X, y, model_generator, method_name): + """ XOR (false/true) + + This tests how well a feature attribution method agrees with human intuition + for an eXclusive OR operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever or cough but not both: +6 points + + transform = "identity" + sort_order = 4 + """ + return _human_xor(X, model_generator, method_name, False, True) + +def human_xor_11(X, y, model_generator, method_name): + """ XOR (true/true) + + This tests how well a feature attribution method agrees with human intuition + for an eXclusive OR operation combined with linear effects. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + if fever or cough but not both: +6 points + + transform = "identity" + sort_order = 5 + """ + return _human_xor(X, model_generator, method_name, True, True) + + +def _human_sum(X, model_generator, method_name, fever, cough): + assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!" + + # these are from the sickness_score mturk user study experiment + X_test = np.zeros((100,3)) + if not fever and not cough: + human_consensus = np.array([0., 0., 0.]) + X_test[0,:] = np.array([[0., 0., 1.]]) + elif not fever and cough: + human_consensus = np.array([0., 2., 0.]) + X_test[0,:] = np.array([[0., 1., 1.]]) + elif fever and cough: + human_consensus = np.array([2., 2., 0.]) + X_test[0,:] = np.array([[1., 1., 1.]]) + + # force the model to fit an XOR function with almost entirely zero background + model = _fit_human(model_generator, 0, 2, 4) + + attr_function = getattr(methods, method_name)(model, X) + methods_attrs = attr_function(X_test) + return "human", (human_consensus, methods_attrs[0,:]) + +def human_sum_00(X, y, model_generator, method_name): + """ SUM (false/false) + + This tests how well a feature attribution method agrees with human intuition + for a SUM operation. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + + transform = "identity" + sort_order = 0 + """ + return _human_sum(X, model_generator, method_name, False, False) + +def human_sum_01(X, y, model_generator, method_name): + """ SUM (false/true) + + This tests how well a feature attribution method agrees with human intuition + for a SUM operation. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + + transform = "identity" + sort_order = 1 + """ + return _human_sum(X, model_generator, method_name, False, True) + +def human_sum_11(X, y, model_generator, method_name): + """ SUM (true/true) + + This tests how well a feature attribution method agrees with human intuition + for a SUM operation. This metric deals + specifically with the question of credit allocation for the following function + when all three inputs are true: + if fever: +2 points + if cough: +2 points + + transform = "identity" + sort_order = 2 + """ + return _human_sum(X, model_generator, method_name, True, True) diff --git a/lib/shap/benchmark/models.py b/lib/shap/benchmark/models.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0a886f49712b9ffcb5120d170a5a568d238999 --- /dev/null +++ b/lib/shap/benchmark/models.py @@ -0,0 +1,230 @@ +import numpy as np +import sklearn +import sklearn.ensemble +from sklearn.preprocessing import StandardScaler + + +class KerasWrap: + """ A wrapper that allows us to set parameters in the constructor and do a reset before fitting. + """ + def __init__(self, model, epochs, flatten_output=False): + self.model = model + self.epochs = epochs + self.flatten_output = flatten_output + self.init_weights = None + self.scaler = StandardScaler() + + def fit(self, X, y, verbose=0): + if self.init_weights is None: + self.init_weights = self.model.get_weights() + else: + self.model.set_weights(self.init_weights) + self.scaler.fit(X) + return self.model.fit(X, y, epochs=self.epochs, verbose=verbose) + + def predict(self, X): + X = self.scaler.transform(X) + if self.flatten_output: + return self.model.predict(X).flatten() + else: + return self.model.predict(X) + + +# This models are all tuned for the corrgroups60 dataset + +def corrgroups60__lasso(): + """ Lasso Regression + """ + return sklearn.linear_model.Lasso(alpha=0.1) + +def corrgroups60__ridge(): + """ Ridge Regression + """ + return sklearn.linear_model.Ridge(alpha=1.0) + +def corrgroups60__decision_tree(): + """ Decision Tree + """ + + # max_depth was chosen to minimise test error + return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=6) + +def corrgroups60__random_forest(): + """ Random Forest + """ + return sklearn.ensemble.RandomForestRegressor(100, random_state=0) + +def corrgroups60__gbm(): + """ Gradient Boosted Trees + """ + import xgboost + + # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split + return xgboost.XGBRegressor(max_depth=6, n_estimators=50, learning_rate=0.1, n_jobs=8, random_state=0) + +def corrgroups60__ffnn(): + """ 4-Layer Neural Network + """ + from tensorflow.keras.layers import Dense + from tensorflow.keras.models import Sequential + + model = Sequential() + model.add(Dense(32, activation='relu', input_dim=60)) + model.add(Dense(20, activation='relu')) + model.add(Dense(20, activation='relu')) + model.add(Dense(1)) + + model.compile(optimizer='adam', + loss='mean_squared_error', + metrics=['mean_squared_error']) + + return KerasWrap(model, 30, flatten_output=True) + + +def independentlinear60__lasso(): + """ Lasso Regression + """ + return sklearn.linear_model.Lasso(alpha=0.1) + +def independentlinear60__ridge(): + """ Ridge Regression + """ + return sklearn.linear_model.Ridge(alpha=1.0) + +def independentlinear60__decision_tree(): + """ Decision Tree + """ + + # max_depth was chosen to minimise test error + return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=4) + +def independentlinear60__random_forest(): + """ Random Forest + """ + return sklearn.ensemble.RandomForestRegressor(100, random_state=0) + +def independentlinear60__gbm(): + """ Gradient Boosted Trees + """ + import xgboost + + # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split + return xgboost.XGBRegressor(max_depth=6, n_estimators=100, learning_rate=0.1, n_jobs=8, random_state=0) + +def independentlinear60__ffnn(): + """ 4-Layer Neural Network + """ + from tensorflow.keras.layers import Dense + from tensorflow.keras.models import Sequential + + model = Sequential() + model.add(Dense(32, activation='relu', input_dim=60)) + model.add(Dense(20, activation='relu')) + model.add(Dense(20, activation='relu')) + model.add(Dense(1)) + + model.compile(optimizer='adam', + loss='mean_squared_error', + metrics=['mean_squared_error']) + + return KerasWrap(model, 30, flatten_output=True) + + +def cric__lasso(): + """ Lasso Regression + """ + model = sklearn.linear_model.LogisticRegression(penalty="l1", C=0.002) + + # we want to explain the raw probability outputs of the trees + model.predict = lambda X: model.predict_proba(X)[:,1] + + return model + +def cric__ridge(): + """ Ridge Regression + """ + model = sklearn.linear_model.LogisticRegression(penalty="l2") + + # we want to explain the raw probability outputs of the trees + model.predict = lambda X: model.predict_proba(X)[:,1] + + return model + +def cric__decision_tree(): + """ Decision Tree + """ + model = sklearn.tree.DecisionTreeClassifier(random_state=0, max_depth=4) + + # we want to explain the raw probability outputs of the trees + model.predict = lambda X: model.predict_proba(X)[:,1] + + return model + +def cric__random_forest(): + """ Random Forest + """ + model = sklearn.ensemble.RandomForestClassifier(100, random_state=0) + + # we want to explain the raw probability outputs of the trees + model.predict = lambda X: model.predict_proba(X)[:,1] + + return model + +def cric__gbm(): + """ Gradient Boosted Trees + """ + import xgboost + + # max_depth and subsample match the params used for the full cric data in the paper + # learning_rate was set a bit higher to allow for faster runtimes + # n_estimators was chosen based on a train/test split of the data + model = xgboost.XGBClassifier(max_depth=5, n_estimators=400, learning_rate=0.01, subsample=0.2, n_jobs=8, random_state=0) + + # we want to explain the margin, not the transformed probability outputs + model.__orig_predict = model.predict + model.predict = lambda X: model.__orig_predict(X, output_margin=True) + + return model + +def cric__ffnn(): + """ 4-Layer Neural Network + """ + from tensorflow.keras.layers import Dense, Dropout + from tensorflow.keras.models import Sequential + + model = Sequential() + model.add(Dense(10, activation='relu', input_dim=336)) + model.add(Dropout(0.5)) + model.add(Dense(10, activation='relu')) + model.add(Dropout(0.5)) + model.add(Dense(1, activation='sigmoid')) + + model.compile(optimizer='adam', + loss='binary_crossentropy', + metrics=['accuracy']) + + return KerasWrap(model, 30, flatten_output=True) + + +def human__decision_tree(): + """ Decision Tree + """ + + # build data + N = 1000000 + M = 3 + X = np.zeros((N,M)) + X.shape + y = np.zeros(N) + X[0, 0] = 1 + y[0] = 8 + X[1, 1] = 1 + y[1] = 8 + X[2, 0:2] = 1 + y[2] = 4 + + # fit model + xor_model = sklearn.tree.DecisionTreeRegressor(max_depth=2) + xor_model.fit(X, y) + + return xor_model diff --git a/lib/shap/benchmark/plots.py b/lib/shap/benchmark/plots.py new file mode 100644 index 0000000000000000000000000000000000000000..56bb204b756f8d978b708eea53a5899ea4de52e4 --- /dev/null +++ b/lib/shap/benchmark/plots.py @@ -0,0 +1,566 @@ +import base64 +import io +import os + +import numpy as np +import sklearn +from matplotlib.colors import LinearSegmentedColormap + +from .. import __version__ +from ..plots import colors +from . import methods, metrics, models +from .experiments import run_experiments + +try: + import matplotlib + import matplotlib.pyplot as pl + from IPython.display import HTML +except ImportError: + pass + + +metadata = { + # "runtime": { + # "title": "Runtime", + # "sort_order": 1 + # }, + # "local_accuracy": { + # "title": "Local Accuracy", + # "sort_order": 2 + # }, + # "consistency_guarantees": { + # "title": "Consistency Guarantees", + # "sort_order": 3 + # }, + # "keep_positive_mask": { + # "title": "Keep Positive (mask)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "Mean model output", + # "sort_order": 4 + # }, + # "keep_negative_mask": { + # "title": "Keep Negative (mask)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "Negative mean model output", + # "sort_order": 5 + # }, + # "keep_absolute_mask__r2": { + # "title": "Keep Absolute (mask)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "R^2", + # "sort_order": 6 + # }, + # "keep_absolute_mask__roc_auc": { + # "title": "Keep Absolute (mask)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "ROC AUC", + # "sort_order": 6 + # }, + # "remove_positive_mask": { + # "title": "Remove Positive (mask)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "Negative mean model output", + # "sort_order": 7 + # }, + # "remove_negative_mask": { + # "title": "Remove Negative (mask)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "Mean model output", + # "sort_order": 8 + # }, + # "remove_absolute_mask__r2": { + # "title": "Remove Absolute (mask)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "1 - R^2", + # "sort_order": 9 + # }, + # "remove_absolute_mask__roc_auc": { + # "title": "Remove Absolute (mask)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "1 - ROC AUC", + # "sort_order": 9 + # }, + # "keep_positive_resample": { + # "title": "Keep Positive (resample)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "Mean model output", + # "sort_order": 10 + # }, + # "keep_negative_resample": { + # "title": "Keep Negative (resample)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "Negative mean model output", + # "sort_order": 11 + # }, + # "keep_absolute_resample__r2": { + # "title": "Keep Absolute (resample)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "R^2", + # "sort_order": 12 + # }, + # "keep_absolute_resample__roc_auc": { + # "title": "Keep Absolute (resample)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "ROC AUC", + # "sort_order": 12 + # }, + # "remove_positive_resample": { + # "title": "Remove Positive (resample)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "Negative mean model output", + # "sort_order": 13 + # }, + # "remove_negative_resample": { + # "title": "Remove Negative (resample)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "Mean model output", + # "sort_order": 14 + # }, + # "remove_absolute_resample__r2": { + # "title": "Remove Absolute (resample)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "1 - R^2", + # "sort_order": 15 + # }, + # "remove_absolute_resample__roc_auc": { + # "title": "Remove Absolute (resample)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "1 - ROC AUC", + # "sort_order": 15 + # }, + # "remove_positive_retrain": { + # "title": "Remove Positive (retrain)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "Negative mean model output", + # "sort_order": 11 + # }, + # "remove_negative_retrain": { + # "title": "Remove Negative (retrain)", + # "xlabel": "Max fraction of features removed", + # "ylabel": "Mean model output", + # "sort_order": 12 + # }, + # "keep_positive_retrain": { + # "title": "Keep Positive (retrain)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "Mean model output", + # "sort_order": 6 + # }, + # "keep_negative_retrain": { + # "title": "Keep Negative (retrain)", + # "xlabel": "Max fraction of features kept", + # "ylabel": "Negative mean model output", + # "sort_order": 7 + # }, + # "batch_remove_absolute__r2": { + # "title": "Batch Remove Absolute", + # "xlabel": "Fraction of features removed", + # "ylabel": "1 - R^2", + # "sort_order": 13 + # }, + # "batch_keep_absolute__r2": { + # "title": "Batch Keep Absolute", + # "xlabel": "Fraction of features kept", + # "ylabel": "R^2", + # "sort_order": 8 + # }, + # "batch_remove_absolute__roc_auc": { + # "title": "Batch Remove Absolute", + # "xlabel": "Fraction of features removed", + # "ylabel": "1 - ROC AUC", + # "sort_order": 13 + # }, + # "batch_keep_absolute__roc_auc": { + # "title": "Batch Keep Absolute", + # "xlabel": "Fraction of features kept", + # "ylabel": "ROC AUC", + # "sort_order": 8 + # }, + + # "linear_shap_corr": { + # "title": "Linear SHAP (corr)" + # }, + # "linear_shap_ind": { + # "title": "Linear SHAP (ind)" + # }, + # "coef": { + # "title": "Coefficients" + # }, + # "random": { + # "title": "Random" + # }, + # "kernel_shap_1000_meanref": { + # "title": "Kernel SHAP 1000 mean ref." + # }, + # "sampling_shap_1000": { + # "title": "Sampling SHAP 1000" + # }, + # "tree_shap_tree_path_dependent": { + # "title": "Tree SHAP" + # }, + # "saabas": { + # "title": "Saabas" + # }, + # "tree_gain": { + # "title": "Gain/Gini Importance" + # }, + # "mean_abs_tree_shap": { + # "title": "mean(|Tree SHAP|)" + # }, + # "lasso_regression": { + # "title": "Lasso Regression" + # }, + # "ridge_regression": { + # "title": "Ridge Regression" + # }, + # "gbm_regression": { + # "title": "Gradient Boosting Regression" + # } +} + +benchmark_color_map = { + "tree_shap": "#1E88E5", + "deep_shap": "#1E88E5", + "linear_shap_corr": "#1E88E5", + "linear_shap_ind": "#ff0d57", + "coef": "#13B755", + "random": "#999999", + "const_random": "#666666", + "kernel_shap_1000_meanref": "#7C52FF" +} + +# negated_metrics = [ +# "runtime", +# "remove_positive_retrain", +# "remove_positive_mask", +# "remove_positive_resample", +# "keep_negative_retrain", +# "keep_negative_mask", +# "keep_negative_resample" +# ] + +# one_minus_metrics = [ +# "remove_absolute_mask__r2", +# "remove_absolute_mask__roc_auc", +# "remove_absolute_resample__r2", +# "remove_absolute_resample__roc_auc" +# ] + +def get_method_color(method): + for line in getattr(methods, method).__doc__.split("\n"): + line = line.strip() + if line.startswith("color = "): + v = line.split("=")[1].strip() + if v.startswith("red_blue_circle("): + return colors.red_blue_circle(float(v[16:-1])) + else: + return v + return "#000000" + +def get_method_linestyle(method): + for line in getattr(methods, method).__doc__.split("\n"): + line = line.strip() + if line.startswith("linestyle = "): + return line.split("=")[1].strip() + return "solid" + +def get_metric_attr(metric, attr): + for line in getattr(metrics, metric).__doc__.split("\n"): + line = line.strip() + + # string + prefix = attr+" = \"" + suffix = "\"" + if line.startswith(prefix) and line.endswith(suffix): + return line[len(prefix):-len(suffix)] + + # number + prefix = attr+" = " + if line.startswith(prefix): + return float(line[len(prefix):]) + return "" + +def plot_curve(dataset, model, metric, cmap=benchmark_color_map): + experiments = run_experiments(dataset=dataset, model=model, metric=metric) + pl.figure() + method_arr = [] + for (name,(fcounts,scores)) in experiments: + _,_,method,_ = name + transform = get_metric_attr(metric, "transform") + if transform == "negate": + scores = -scores + elif transform == "one_minus": + scores = 1 - scores + auc = sklearn.metrics.auc(fcounts, scores) / fcounts[-1] + method_arr.append((auc, method, scores)) + for (auc,method,scores) in sorted(method_arr): + method_title = getattr(methods, method).__doc__.split("\n")[0].strip() + label = f"{auc:6.3f} - " + method_title + pl.plot( + fcounts / fcounts[-1], scores, label=label, + color=get_method_color(method), linewidth=2, + linestyle=get_method_linestyle(method) + ) + metric_title = getattr(metrics, metric).__doc__.split("\n")[0].strip() + pl.xlabel(get_metric_attr(metric, "xlabel")) + pl.ylabel(get_metric_attr(metric, "ylabel")) + model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip() + pl.title(metric_title + " - " + model_title) + pl.gca().xaxis.set_ticks_position('bottom') + pl.gca().yaxis.set_ticks_position('left') + pl.gca().spines['right'].set_visible(False) + pl.gca().spines['top'].set_visible(False) + ahandles, alabels = pl.gca().get_legend_handles_labels() + pl.legend(reversed(ahandles), reversed(alabels)) + return pl.gcf() + +def plot_human(dataset, model, metric, cmap=benchmark_color_map): + experiments = run_experiments(dataset=dataset, model=model, metric=metric) + pl.figure() + method_arr = [] + for (name,(fcounts,scores)) in experiments: + _,_,method,_ = name + diff_sum = np.sum(np.abs(scores[1] - scores[0])) + method_arr.append((diff_sum, method, scores[0], scores[1])) + + inds = np.arange(3) # the x locations for the groups + inc_width = (1.0 / len(method_arr)) * 0.8 + width = inc_width * 0.9 + pl.bar(inds, method_arr[0][2], width, label="Human Consensus", color="black", edgecolor="white") + i = 1 + line_style_to_hatch = { + "dashed": "///", + "dotted": "..." + } + for (diff_sum, method, _, methods_attrs) in sorted(method_arr): + method_title = getattr(methods, method).__doc__.split("\n")[0].strip() + label = f"{diff_sum:.2f} - " + method_title + pl.bar( + inds + inc_width * i, methods_attrs.flatten(), width, label=label, edgecolor="white", + color=get_method_color(method), hatch=line_style_to_hatch.get(get_method_linestyle(method), None) + ) + i += 1 + metric_title = getattr(metrics, metric).__doc__.split("\n")[0].strip() + pl.xlabel("Features in the model") + pl.ylabel("Feature attribution value") + model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip() + pl.title(metric_title + " - " + model_title) + pl.gca().xaxis.set_ticks_position('bottom') + pl.gca().yaxis.set_ticks_position('left') + pl.gca().spines['right'].set_visible(False) + pl.gca().spines['top'].set_visible(False) + ahandles, alabels = pl.gca().get_legend_handles_labels() + #pl.legend(ahandles, alabels) + pl.xticks(np.array([0, 1, 2, 3]) - (inc_width + width)/2, ["", "", "", ""]) + + pl.gca().xaxis.set_minor_locator(matplotlib.ticker.FixedLocator([0.4, 1.4, 2.4])) + pl.gca().xaxis.set_minor_formatter(matplotlib.ticker.FixedFormatter(["Fever", "Cough", "Headache"])) + pl.gca().tick_params(which='minor', length=0) + + pl.axhline(0, color="#aaaaaa", linewidth=0.5) + + box = pl.gca().get_position() + pl.gca().set_position([ + box.x0, box.y0 + box.height * 0.3, + box.width, box.height * 0.7 + ]) + + # Put a legend below current axis + pl.gca().legend(ahandles, alabels, loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=2) + + return pl.gcf() + +def _human_score_map(human_consensus, methods_attrs): + """ Converts human agreement differences to numerical scores for coloring. + """ + + v = 1 - min(np.sum(np.abs(methods_attrs - human_consensus)) / (np.abs(human_consensus).sum() + 1), 1.0) + return v + +def make_grid(scores, dataset, model, normalize=True, transform=True): + color_vals = {} + metric_sort_order = {} + for (_,_,method,metric),(fcounts,score) in filter(lambda x: x[0][0] == dataset and x[0][1] == model, scores): + metric_sort_order[metric] = get_metric_attr(metric, "sort_order") + if metric not in color_vals: + color_vals[metric] = {} + + if transform: + transform_type = get_metric_attr(metric, "transform") + if transform_type == "negate": + score = -score + elif transform_type == "one_minus": + score = 1 - score + elif transform_type == "negate_log": + score = -np.log10(score) + + if fcounts is None: + color_vals[metric][method] = score + elif fcounts == "human": + color_vals[metric][method] = _human_score_map(*score) + else: + auc = sklearn.metrics.auc(fcounts, score) / fcounts[-1] + color_vals[metric][method] = auc + # print(metric_sort_order) + # col_keys = sorted(list(color_vals.keys()), key=lambda v: metric_sort_order[v]) + # print(col_keys) + col_keys = list(color_vals.keys()) + row_keys = list({v for k in col_keys for v in color_vals[k].keys()}) + + data = -28567 * np.ones((len(row_keys), len(col_keys))) + + for i in range(len(row_keys)): + for j in range(len(col_keys)): + data[i,j] = color_vals[col_keys[j]][row_keys[i]] + + assert np.sum(data == -28567) == 0, "There are missing data values!" + + if normalize: + data = (data - data.min(0)) / (data.max(0) - data.min(0) + 1e-8) + + # sort by performans + inds = np.argsort(-data.mean(1)) + row_keys = [row_keys[i] for i in inds] + data = data[inds,:] + + return row_keys, col_keys, data + + + +red_blue_solid = LinearSegmentedColormap('red_blue_solid', { + 'red': ((0.0, 198./255, 198./255), + (1.0, 5./255, 5./255)), + + 'green': ((0.0, 34./255, 34./255), + (1.0, 198./255, 198./255)), + + 'blue': ((0.0, 5./255, 5./255), + (1.0, 24./255, 24./255)), + + 'alpha': ((0.0, 1, 1), + (1.0, 1, 1)) +}) +def plot_grids(dataset, model_names, out_dir=None): + + if out_dir is not None: + os.mkdir(out_dir) + + scores = [] + for model in model_names: + scores.extend(run_experiments(dataset=dataset, model=model)) + + prefix = "" + out = "" # background: rgb(30, 136, 229) + + # out += "
SHAP Benchmark
\n" + # out += "
\n" + #out += "
" + + out += "
\n" # box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2), 0 6px 20px 0 rgba(0, 0, 0, 0.19); + out += "
\n" + for ind,model in enumerate(model_names): + row_keys, col_keys, data = make_grid(scores, dataset, model) +# print(data) +# print(colors.red_blue_solid(0.)) +# print(colors.red_blue_solid(1.)) +# return + for metric in col_keys: + save_plot = False + if metric.startswith("human_"): + plot_human(dataset, model, metric) + save_plot = True + elif metric not in ["local_accuracy", "runtime", "consistency_guarantees"]: + plot_curve(dataset, model, metric) + save_plot = True + + if save_plot: + buf = io.BytesIO() + pl.gcf().set_size_inches(1200.0/175,1000.0/175) + pl.savefig(buf, format='png', dpi=175) + if out_dir is not None: + pl.savefig(f"{out_dir}/plot_{dataset}_{model}_{metric}.pdf", format='pdf') + pl.close() + buf.seek(0) + data_uri = base64.b64encode(buf.read()).decode('utf-8').replace('\n', '') + plot_id = "plot__"+dataset+"__"+model+"__"+metric + prefix += f"" + + model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip() + + if ind == 0: + out += "" + for j in range(data.shape[1]): + metric_title = getattr(metrics, col_keys[j]).__doc__.split("\n")[0].strip() + out += "" + out += "\n" + out += "
" + metric_title + "
\n" + out += "\n" + out += "\n" % (data.shape[1], model_title) + for i in range(data.shape[0]): + out += "" +# if i == 0: +# out += "" % (data.shape[0], model_name) + method_title = getattr(methods, row_keys[i]).__doc__.split("\n")[0].strip() + out += "\n" + for j in range(data.shape[1]): + plot_id = "plot__"+dataset+"__"+model+"__"+col_keys[j] + out += "\n" + out += "\n" # + + out += "" % (data.shape[1] + 1) + out += "
%s
%s
" + method_title + "" % plot_id + #out += "
" + out += "
" + #out += "
" + out += "
" + + out += "
\n" + out += "
SHAP Benchmark v"+__version__+"
\n" +# select { +# margin: 50px; +# width: 150px; +# padding: 5px 35px 5px 5px; +# font-size: 16px; +# border: 1px solid #ccc; +# height: 34px; +# -webkit-appearance: none; +# -moz-appearance: none; +# appearance: none; +# background: url(http://www.stackoverflow.com/favicon.ico) 96% / 15% no-repeat #eee; +# } + #out += "
Dataset:
\n" + + out += "\n" + #out += "" + #out += "
CRIC
\n" + out += "
\n" + + # output the legend + out += "\n" + out += "\n" + legend_size = 21 + for i in range(legend_size-9): + out += "" + out += "" + out += "\n" # + out += "\n" + out += "
Higher score
" + val = (legend_size-i-1) / (legend_size-1) + out += "
" + out += "
Lower score
\n" + + if out_dir is not None: + with open(out_dir + "/index.html", "w") as f: + f.write("
") + f.write(prefix) + f.write(out) + f.write("
") + else: + return HTML(prefix + out) diff --git a/lib/shap/cext/_cext.cc b/lib/shap/cext/_cext.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d8cf379bd1e2d8bf500eaf6544ae1264160586a --- /dev/null +++ b/lib/shap/cext/_cext.cc @@ -0,0 +1,560 @@ +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include +#include "tree_shap.h" +#include + +static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args); +static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args); +static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args); +static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args); +static PyObject *_cext_compute_expectations(PyObject *self, PyObject *args); + +static PyMethodDef module_methods[] = { + {"dense_tree_shap", _cext_dense_tree_shap, METH_VARARGS, "C implementation of Tree SHAP for dense."}, + {"dense_tree_predict", _cext_dense_tree_predict, METH_VARARGS, "C implementation of tree predictions."}, + {"dense_tree_update_weights", _cext_dense_tree_update_weights, METH_VARARGS, "C implementation of tree node weight compuatations."}, + {"dense_tree_saabas", _cext_dense_tree_saabas, METH_VARARGS, "C implementation of Saabas (rough fast approximation to Tree SHAP)."}, + {"compute_expectations", _cext_compute_expectations, METH_VARARGS, "Compute expectations of internal nodes."}, + {NULL, NULL, 0, NULL} +}; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "_cext", + "This module provides an interface for a fast Tree SHAP implementation.", + -1, + module_methods, + NULL, + NULL, + NULL, + NULL +}; +#endif + +#if PY_MAJOR_VERSION >= 3 +PyMODINIT_FUNC PyInit__cext(void) +#else +PyMODINIT_FUNC init_cext(void) +#endif +{ + #if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&moduledef); + if (!module) return NULL; + #else + PyObject *module = Py_InitModule("_cext", module_methods); + if (!module) return; + #endif + + /* Load `numpy` functionality. */ + import_array(); + + #if PY_MAJOR_VERSION >= 3 + return module; + #endif +} + +static PyObject *_cext_compute_expectations(PyObject *self, PyObject *args) +{ + PyObject *children_left_obj; + PyObject *children_right_obj; + PyObject *node_sample_weight_obj; + PyObject *values_obj; + + /* Parse the input tuple */ + if (!PyArg_ParseTuple( + args, "OOOO", &children_left_obj, &children_right_obj, &node_sample_weight_obj, &values_obj + )) return NULL; + + /* Interpret the input objects as numpy arrays. */ + PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *node_sample_weight_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weight_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY); + + /* If that didn't work, throw an exception. */ + if (children_left_array == NULL || children_right_array == NULL || + values_array == NULL || node_sample_weight_array == NULL) { + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + //PyArray_ResolveWritebackIfCopy(values_array); + Py_XDECREF(values_array); + Py_XDECREF(node_sample_weight_array); + return NULL; + } + + TreeEnsemble tree; + + // number of outputs + tree.num_outputs = PyArray_DIM(values_array, 1); + + /* Get pointers to the data as C-types. */ + tree.children_left = (int*)PyArray_DATA(children_left_array); + tree.children_right = (int*)PyArray_DATA(children_right_array); + tree.values = (tfloat*)PyArray_DATA(values_array); + tree.node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weight_array); + + const int max_depth = compute_expectations(tree); + + // clean up the created python objects + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + //PyArray_ResolveWritebackIfCopy(values_array); + Py_XDECREF(values_array); + Py_XDECREF(node_sample_weight_array); + + PyObject *ret = Py_BuildValue("i", max_depth); + return ret; +} + + +static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args) +{ + PyObject *children_left_obj; + PyObject *children_right_obj; + PyObject *children_default_obj; + PyObject *features_obj; + PyObject *thresholds_obj; + PyObject *values_obj; + PyObject *node_sample_weights_obj; + int max_depth; + PyObject *X_obj; + PyObject *X_missing_obj; + PyObject *y_obj; + PyObject *R_obj; + PyObject *R_missing_obj; + int tree_limit; + PyObject *out_contribs_obj; + int feature_dependence; + int model_output; + PyObject *base_offset_obj; + bool interactions; + + /* Parse the input tuple */ + if (!PyArg_ParseTuple( + args, "OOOOOOOiOOOOOiOOiib", &children_left_obj, &children_right_obj, &children_default_obj, + &features_obj, &thresholds_obj, &values_obj, &node_sample_weights_obj, + &max_depth, &X_obj, &X_missing_obj, &y_obj, &R_obj, &R_missing_obj, &tree_limit, &base_offset_obj, + &out_contribs_obj, &feature_dependence, &model_output, &interactions + )) return NULL; + + /* Interpret the input objects as numpy arrays. */ + PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *node_sample_weights_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weights_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY); + PyArrayObject *y_array = NULL; + if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *R_array = NULL; + if (R_obj != Py_None) R_array = (PyArrayObject*)PyArray_FROM_OTF(R_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *R_missing_array = NULL; + if (R_missing_obj != Py_None) R_missing_array = (PyArrayObject*)PyArray_FROM_OTF(R_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY); + PyArrayObject *out_contribs_array = (PyArrayObject*)PyArray_FROM_OTF(out_contribs_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY); + PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY); + + /* If that didn't work, throw an exception. Note that R and y are optional. */ + if (children_left_array == NULL || children_right_array == NULL || + children_default_array == NULL || features_array == NULL || thresholds_array == NULL || + values_array == NULL || node_sample_weights_array == NULL || X_array == NULL || + X_missing_array == NULL || out_contribs_array == NULL) { + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(node_sample_weights_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + if (R_array != NULL) Py_XDECREF(R_array); + if (R_missing_array != NULL) Py_XDECREF(R_missing_array); + //PyArray_ResolveWritebackIfCopy(out_contribs_array); + Py_XDECREF(out_contribs_array); + Py_XDECREF(base_offset_array); + return NULL; + } + + const unsigned num_X = PyArray_DIM(X_array, 0); + const unsigned M = PyArray_DIM(X_array, 1); + const unsigned max_nodes = PyArray_DIM(values_array, 1); + const unsigned num_outputs = PyArray_DIM(values_array, 2); + unsigned num_R = 0; + if (R_array != NULL) num_R = PyArray_DIM(R_array, 0); + + // Get pointers to the data as C-types + int *children_left = (int*)PyArray_DATA(children_left_array); + int *children_right = (int*)PyArray_DATA(children_right_array); + int *children_default = (int*)PyArray_DATA(children_default_array); + int *features = (int*)PyArray_DATA(features_array); + tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array); + tfloat *values = (tfloat*)PyArray_DATA(values_array); + tfloat *node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weights_array); + tfloat *X = (tfloat*)PyArray_DATA(X_array); + bool *X_missing = (bool*)PyArray_DATA(X_missing_array); + tfloat *y = NULL; + if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array); + tfloat *R = NULL; + if (R_array != NULL) R = (tfloat*)PyArray_DATA(R_array); + bool *R_missing = NULL; + if (R_missing_array != NULL) R_missing = (bool*)PyArray_DATA(R_missing_array); + tfloat *out_contribs = (tfloat*)PyArray_DATA(out_contribs_array); + tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array); + + // these are just a wrapper objects for all the pointers and numbers associated with + // the ensemble tree model and the dataset we are explaining + TreeEnsemble trees = TreeEnsemble( + children_left, children_right, children_default, features, thresholds, values, + node_sample_weights, max_depth, tree_limit, base_offset, + max_nodes, num_outputs + ); + ExplanationDataset data = ExplanationDataset(X, X_missing, y, R, R_missing, num_X, M, num_R); + + dense_tree_shap(trees, data, out_contribs, feature_dependence, model_output, interactions); + + // retrieve return value before python cleanup of objects + tfloat ret_value = (double)values[0]; + + // clean up the created python objects + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(node_sample_weights_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + if (R_array != NULL) Py_XDECREF(R_array); + if (R_missing_array != NULL) Py_XDECREF(R_missing_array); + //PyArray_ResolveWritebackIfCopy(out_contribs_array); + Py_XDECREF(out_contribs_array); + Py_XDECREF(base_offset_array); + + /* Build the output tuple */ + PyObject *ret = Py_BuildValue("d", ret_value); + return ret; +} + + +static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args) +{ + PyObject *children_left_obj; + PyObject *children_right_obj; + PyObject *children_default_obj; + PyObject *features_obj; + PyObject *thresholds_obj; + PyObject *values_obj; + int max_depth; + int tree_limit; + PyObject *base_offset_obj; + int model_output; + PyObject *X_obj; + PyObject *X_missing_obj; + PyObject *y_obj; + PyObject *out_pred_obj; + + /* Parse the input tuple */ + if (!PyArg_ParseTuple( + args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj, + &features_obj, &thresholds_obj, &values_obj, &max_depth, &tree_limit, &base_offset_obj, &model_output, + &X_obj, &X_missing_obj, &y_obj, &out_pred_obj + )) return NULL; + + /* Interpret the input objects as numpy arrays. */ + PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY); + PyArrayObject *y_array = NULL; + if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *out_pred_array = (PyArrayObject*)PyArray_FROM_OTF(out_pred_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY); + + /* If that didn't work, throw an exception. Note that R and y are optional. */ + if (children_left_array == NULL || children_right_array == NULL || + children_default_array == NULL || features_array == NULL || thresholds_array == NULL || + values_array == NULL || X_array == NULL || + X_missing_array == NULL || out_pred_array == NULL) { + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(base_offset_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + //PyArray_ResolveWritebackIfCopy(out_pred_array); + Py_XDECREF(out_pred_array); + return NULL; + } + + const unsigned num_X = PyArray_DIM(X_array, 0); + const unsigned M = PyArray_DIM(X_array, 1); + const unsigned max_nodes = PyArray_DIM(values_array, 1); + const unsigned num_outputs = PyArray_DIM(values_array, 2); + + const unsigned num_offsets = PyArray_DIM(base_offset_array, 0); + if (num_offsets != num_outputs) { + std::cerr << "The passed base_offset array does that have the same number of outputs as the values array: " << num_offsets << " vs. " << num_outputs << std::endl; + return NULL; + } + + // Get pointers to the data as C-types + int *children_left = (int*)PyArray_DATA(children_left_array); + int *children_right = (int*)PyArray_DATA(children_right_array); + int *children_default = (int*)PyArray_DATA(children_default_array); + int *features = (int*)PyArray_DATA(features_array); + tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array); + tfloat *values = (tfloat*)PyArray_DATA(values_array); + tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array); + tfloat *X = (tfloat*)PyArray_DATA(X_array); + bool *X_missing = (bool*)PyArray_DATA(X_missing_array); + tfloat *y = NULL; + if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array); + tfloat *out_pred = (tfloat*)PyArray_DATA(out_pred_array); + + // these are just wrapper objects for all the pointers and numbers associated with + // the ensemble tree model and the dataset we are explaining + TreeEnsemble trees = TreeEnsemble( + children_left, children_right, children_default, features, thresholds, values, + NULL, max_depth, tree_limit, base_offset, + max_nodes, num_outputs + ); + ExplanationDataset data = ExplanationDataset(X, X_missing, y, NULL, NULL, num_X, M, 0); + + dense_tree_predict(out_pred, trees, data, model_output); + + // clean up the created python objects + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(base_offset_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + //PyArray_ResolveWritebackIfCopy(out_pred_array); + Py_XDECREF(out_pred_array); + + /* Build the output tuple */ + PyObject *ret = Py_BuildValue("d", (double)values[0]); + return ret; +} + + +static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args) +{ + PyObject *children_left_obj; + PyObject *children_right_obj; + PyObject *children_default_obj; + PyObject *features_obj; + PyObject *thresholds_obj; + PyObject *values_obj; + int tree_limit; + PyObject *node_sample_weight_obj; + PyObject *X_obj; + PyObject *X_missing_obj; + + /* Parse the input tuple */ + if (!PyArg_ParseTuple( + args, "OOOOOOiOOO", &children_left_obj, &children_right_obj, &children_default_obj, + &features_obj, &thresholds_obj, &values_obj, &tree_limit, &node_sample_weight_obj, &X_obj, &X_missing_obj + )) return NULL; + + /* Interpret the input objects as numpy arrays. */ + PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *node_sample_weight_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weight_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY); + PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY); + + /* If that didn't work, throw an exception. */ + if (children_left_array == NULL || children_right_array == NULL || + children_default_array == NULL || features_array == NULL || thresholds_array == NULL || + values_array == NULL || node_sample_weight_array == NULL || X_array == NULL || + X_missing_array == NULL) { + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + //PyArray_ResolveWritebackIfCopy(node_sample_weight_array); + Py_XDECREF(node_sample_weight_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + std::cerr << "Found a NULL input array in _cext_dense_tree_update_weights!\n"; + return NULL; + } + + const unsigned num_X = PyArray_DIM(X_array, 0); + const unsigned M = PyArray_DIM(X_array, 1); + const unsigned max_nodes = PyArray_DIM(values_array, 1); + + // Get pointers to the data as C-types + int *children_left = (int*)PyArray_DATA(children_left_array); + int *children_right = (int*)PyArray_DATA(children_right_array); + int *children_default = (int*)PyArray_DATA(children_default_array); + int *features = (int*)PyArray_DATA(features_array); + tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array); + tfloat *values = (tfloat*)PyArray_DATA(values_array); + tfloat *node_sample_weight = (tfloat*)PyArray_DATA(node_sample_weight_array); + tfloat *X = (tfloat*)PyArray_DATA(X_array); + bool *X_missing = (bool*)PyArray_DATA(X_missing_array); + + // these are just wrapper objects for all the pointers and numbers associated with + // the ensemble tree model and the dataset we are explaining + TreeEnsemble trees = TreeEnsemble( + children_left, children_right, children_default, features, thresholds, values, + node_sample_weight, 0, tree_limit, 0, max_nodes, 0 + ); + ExplanationDataset data = ExplanationDataset(X, X_missing, NULL, NULL, NULL, num_X, M, 0); + + dense_tree_update_weights(trees, data); + + // clean up the created python objects + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + // PyArray_ResolveWritebackIfCopy(node_sample_weight_array); + Py_XDECREF(node_sample_weight_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + + /* Build the output tuple */ + PyObject *ret = Py_BuildValue("d", 1); + return ret; +} + + +static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args) +{ + PyObject *children_left_obj; + PyObject *children_right_obj; + PyObject *children_default_obj; + PyObject *features_obj; + PyObject *thresholds_obj; + PyObject *values_obj; + int max_depth; + int tree_limit; + PyObject *base_offset_obj; + int model_output; + PyObject *X_obj; + PyObject *X_missing_obj; + PyObject *y_obj; + PyObject *out_pred_obj; + + + /* Parse the input tuple */ + if (!PyArg_ParseTuple( + args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj, + &features_obj, &thresholds_obj, &values_obj, &max_depth, &tree_limit, &base_offset_obj, &model_output, + &X_obj, &X_missing_obj, &y_obj, &out_pred_obj + )) return NULL; + + /* Interpret the input objects as numpy arrays. */ + PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY); + PyArrayObject *y_array = NULL; + if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *out_pred_array = (PyArrayObject*)PyArray_FROM_OTF(out_pred_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + + /* If that didn't work, throw an exception. Note that R and y are optional. */ + if (children_left_array == NULL || children_right_array == NULL || + children_default_array == NULL || features_array == NULL || thresholds_array == NULL || + values_array == NULL || X_array == NULL || + X_missing_array == NULL || out_pred_array == NULL) { + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(base_offset_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + //PyArray_ResolveWritebackIfCopy(out_pred_array); + Py_XDECREF(out_pred_array); + return NULL; + } + + const unsigned num_X = PyArray_DIM(X_array, 0); + const unsigned M = PyArray_DIM(X_array, 1); + const unsigned max_nodes = PyArray_DIM(values_array, 1); + const unsigned num_outputs = PyArray_DIM(values_array, 2); + + // Get pointers to the data as C-types + int *children_left = (int*)PyArray_DATA(children_left_array); + int *children_right = (int*)PyArray_DATA(children_right_array); + int *children_default = (int*)PyArray_DATA(children_default_array); + int *features = (int*)PyArray_DATA(features_array); + tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array); + tfloat *values = (tfloat*)PyArray_DATA(values_array); + tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array); + tfloat *X = (tfloat*)PyArray_DATA(X_array); + bool *X_missing = (bool*)PyArray_DATA(X_missing_array); + tfloat *y = NULL; + if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array); + tfloat *out_pred = (tfloat*)PyArray_DATA(out_pred_array); + + // these are just wrapper objects for all the pointers and numbers associated with + // the ensemble tree model and the dataset we are explaining + TreeEnsemble trees = TreeEnsemble( + children_left, children_right, children_default, features, thresholds, values, + NULL, max_depth, tree_limit, base_offset, + max_nodes, num_outputs + ); + ExplanationDataset data = ExplanationDataset(X, X_missing, y, NULL, NULL, num_X, M, 0); + + dense_tree_saabas(out_pred, trees, data); + + // clean up the created python objects + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(base_offset_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + //PyArray_ResolveWritebackIfCopy(out_pred_array); + Py_XDECREF(out_pred_array); + + /* Build the output tuple */ + PyObject *ret = Py_BuildValue("d", (double)values[0]); + return ret; +} diff --git a/lib/shap/cext/_cext_gpu.cc b/lib/shap/cext/_cext_gpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..94eec5c24862d03938af3125ce96dd5746de64c6 --- /dev/null +++ b/lib/shap/cext/_cext_gpu.cc @@ -0,0 +1,187 @@ +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION + +#include +#include +#include "tree_shap.h" +#include + +static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args); + +static PyMethodDef module_methods[] = { + {"dense_tree_shap", _cext_dense_tree_shap, METH_VARARGS, "C implementation of Tree SHAP for dense."}, + {NULL, NULL, 0, NULL} +}; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "_cext_gpu", + "This module provides an interface for a fast Tree SHAP implementation.", + -1, + module_methods, + NULL, + NULL, + NULL, + NULL +}; +#endif + +#if PY_MAJOR_VERSION >= 3 +PyMODINIT_FUNC PyInit__cext_gpu(void) +#else +PyMODINIT_FUNC init_cext(void) +#endif +{ + #if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&moduledef); + if (!module) return NULL; + #else + PyObject *module = Py_InitModule("_cext", module_methods); + if (!module) return; + #endif + + /* Load `numpy` functionality. */ + import_array(); + + #if PY_MAJOR_VERSION >= 3 + return module; + #endif +} + +void dense_tree_shap_gpu(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs, + const int feature_dependence, unsigned model_transform, bool interactions); + +static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args) +{ + PyObject *children_left_obj; + PyObject *children_right_obj; + PyObject *children_default_obj; + PyObject *features_obj; + PyObject *thresholds_obj; + PyObject *values_obj; + PyObject *node_sample_weights_obj; + int max_depth; + PyObject *X_obj; + PyObject *X_missing_obj; + PyObject *y_obj; + PyObject *R_obj; + PyObject *R_missing_obj; + int tree_limit; + PyObject *out_contribs_obj; + int feature_dependence; + int model_output; + PyObject *base_offset_obj; + bool interactions; + + /* Parse the input tuple */ + if (!PyArg_ParseTuple( + args, "OOOOOOOiOOOOOiOOiib", &children_left_obj, &children_right_obj, &children_default_obj, + &features_obj, &thresholds_obj, &values_obj, &node_sample_weights_obj, + &max_depth, &X_obj, &X_missing_obj, &y_obj, &R_obj, &R_missing_obj, &tree_limit, &base_offset_obj, + &out_contribs_obj, &feature_dependence, &model_output, &interactions + )) return NULL; + + /* Interpret the input objects as numpy arrays. */ + PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY); + PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *node_sample_weights_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weights_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY); + PyArrayObject *y_array = NULL; + if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *R_array = NULL; + if (R_obj != Py_None) R_array = (PyArrayObject*)PyArray_FROM_OTF(R_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY); + PyArrayObject *R_missing_array = NULL; + if (R_missing_obj != Py_None) R_missing_array = (PyArrayObject*)PyArray_FROM_OTF(R_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY); + PyArrayObject *out_contribs_array = (PyArrayObject*)PyArray_FROM_OTF(out_contribs_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY); + PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY); + + /* If that didn't work, throw an exception. Note that R and y are optional. */ + if (children_left_array == NULL || children_right_array == NULL || + children_default_array == NULL || features_array == NULL || thresholds_array == NULL || + values_array == NULL || node_sample_weights_array == NULL || X_array == NULL || + X_missing_array == NULL || out_contribs_array == NULL) { + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(node_sample_weights_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + if (R_array != NULL) Py_XDECREF(R_array); + if (R_missing_array != NULL) Py_XDECREF(R_missing_array); + //PyArray_ResolveWritebackIfCopy(out_contribs_array); + Py_XDECREF(out_contribs_array); + Py_XDECREF(base_offset_array); + return NULL; + } + + const unsigned num_X = PyArray_DIM(X_array, 0); + const unsigned M = PyArray_DIM(X_array, 1); + const unsigned max_nodes = PyArray_DIM(values_array, 1); + const unsigned num_outputs = PyArray_DIM(values_array, 2); + unsigned num_R = 0; + if (R_array != NULL) num_R = PyArray_DIM(R_array, 0); + + // Get pointers to the data as C-types + int *children_left = (int*)PyArray_DATA(children_left_array); + int *children_right = (int*)PyArray_DATA(children_right_array); + int *children_default = (int*)PyArray_DATA(children_default_array); + int *features = (int*)PyArray_DATA(features_array); + tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array); + tfloat *values = (tfloat*)PyArray_DATA(values_array); + tfloat *node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weights_array); + tfloat *X = (tfloat*)PyArray_DATA(X_array); + bool *X_missing = (bool*)PyArray_DATA(X_missing_array); + tfloat *y = NULL; + if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array); + tfloat *R = NULL; + if (R_array != NULL) R = (tfloat*)PyArray_DATA(R_array); + bool *R_missing = NULL; + if (R_missing_array != NULL) R_missing = (bool*)PyArray_DATA(R_missing_array); + tfloat *out_contribs = (tfloat*)PyArray_DATA(out_contribs_array); + tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array); + + // these are just a wrapper objects for all the pointers and numbers associated with + // the ensemble tree model and the dataset we are explaining + TreeEnsemble trees = TreeEnsemble( + children_left, children_right, children_default, features, thresholds, values, + node_sample_weights, max_depth, tree_limit, base_offset, + max_nodes, num_outputs + ); + ExplanationDataset data = ExplanationDataset(X, X_missing, y, R, R_missing, num_X, M, num_R); + + dense_tree_shap_gpu(trees, data, out_contribs, feature_dependence, model_output, interactions); + + + // retrieve return value before python cleanup of objects + tfloat ret_value = (double)values[0]; + + // clean up the created python objects + Py_XDECREF(children_left_array); + Py_XDECREF(children_right_array); + Py_XDECREF(children_default_array); + Py_XDECREF(features_array); + Py_XDECREF(thresholds_array); + Py_XDECREF(values_array); + Py_XDECREF(node_sample_weights_array); + Py_XDECREF(X_array); + Py_XDECREF(X_missing_array); + if (y_array != NULL) Py_XDECREF(y_array); + if (R_array != NULL) Py_XDECREF(R_array); + if (R_missing_array != NULL) Py_XDECREF(R_missing_array); + //PyArray_ResolveWritebackIfCopy(out_contribs_array); + Py_XDECREF(out_contribs_array); + Py_XDECREF(base_offset_array); + + /* Build the output tuple */ + PyObject *ret = Py_BuildValue("d", ret_value); + return ret; +} diff --git a/lib/shap/cext/_cext_gpu.cu b/lib/shap/cext/_cext_gpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..bdfe49f269b92a92912cafe34c5a667a061279a8 --- /dev/null +++ b/lib/shap/cext/_cext_gpu.cu @@ -0,0 +1,353 @@ +#include + +#include "gpu_treeshap.h" +#include "tree_shap.h" + +const float inf = std::numeric_limits::infinity(); + +struct ShapSplitCondition { + ShapSplitCondition() = default; + ShapSplitCondition(tfloat feature_lower_bound, tfloat feature_upper_bound, + bool is_missing_branch) + : feature_lower_bound(feature_lower_bound), + feature_upper_bound(feature_upper_bound), + is_missing_branch(is_missing_branch) { + assert(feature_lower_bound <= feature_upper_bound); + } + + /*! Feature values >= lower and < upper flow down this path. */ + tfloat feature_lower_bound; + tfloat feature_upper_bound; + /*! Do missing values flow down this path? */ + bool is_missing_branch; + + // Does this instance flow down this path? + __host__ __device__ bool EvaluateSplit(float x) const { + // is nan + if (isnan(x)) { + return is_missing_branch; + } + return x > feature_lower_bound && x <= feature_upper_bound; + } + + // Combine two split conditions on the same feature + __host__ __device__ void + Merge(const ShapSplitCondition &other) { // Combine duplicate features + feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); + feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); + is_missing_branch = is_missing_branch && other.is_missing_branch; + } +}; + + +// Inspired by: https://en.cppreference.com/w/cpp/iterator/size +// Limited implementation of std::size fo arrays +template +constexpr size_t array_size(const T (&array)[N]) noexcept +{ + return N; +} + +void RecurseTree( + unsigned pos, const TreeEnsemble &tree, + std::vector> *tmp_path, + std::vector> *paths, + size_t *path_idx, int num_outputs) { + if (tree.is_leaf(pos)) { + for (auto j = 0ull; j < num_outputs; j++) { + auto v = tree.values[pos * num_outputs + j]; + if (v == 0.0) { + // The tree has no output for this class, don't bother adding the path + continue; + } + // Go back over path, setting v, path_idx + for (auto &e : *tmp_path) { + e.v = v; + e.group = j; + e.path_idx = *path_idx; + } + + paths->insert(paths->end(), tmp_path->begin(), tmp_path->end()); + // Increment path index + (*path_idx)++; + } + return; + } + + // Add left split to the path + unsigned left_child = tree.children_left[pos]; + double left_zero_fraction = + tree.node_sample_weights[left_child] / tree.node_sample_weights[pos]; + // Encode the range of feature values that flow down this path + tmp_path->emplace_back(0, tree.features[pos], 0, + ShapSplitCondition{-inf, tree.thresholds[pos], false}, + left_zero_fraction, 0.0f); + + RecurseTree(left_child, tree, tmp_path, paths, path_idx, num_outputs); + + // Add left split to the path + tmp_path->back() = gpu_treeshap::PathElement( + 0, tree.features[pos], 0, + ShapSplitCondition{tree.thresholds[pos], inf, false}, + 1.0 - left_zero_fraction, 0.0f); + + RecurseTree(tree.children_right[pos], tree, tmp_path, paths, path_idx, + num_outputs); + + tmp_path->pop_back(); +} + +std::vector> +ExtractPaths(const TreeEnsemble &trees) { + std::vector> paths; + size_t path_idx = 0; + for (auto i = 0; i < trees.tree_limit; i++) { + TreeEnsemble tree; + trees.get_tree(tree, i); + std::vector> tmp_path; + tmp_path.reserve(tree.max_depth); + tmp_path.emplace_back(0, -1, 0, ShapSplitCondition{-inf, inf, false}, 1.0, + 0.0f); + RecurseTree(0, tree, &tmp_path, &paths, &path_idx, tree.num_outputs); + } + return paths; +} + +class DeviceExplanationDataset { + thrust::device_vector data; + thrust::device_vector missing; + size_t num_features; + size_t num_rows; + + public: + DeviceExplanationDataset(const ExplanationDataset &host_data, + bool background_dataset = false) { + num_features = host_data.M; + if (background_dataset) { + num_rows = host_data.num_R; + data = thrust::device_vector( + host_data.R, host_data.R + host_data.num_R * host_data.M); + missing = thrust::device_vector(host_data.R_missing, + host_data.R_missing + + host_data.num_R * host_data.M); + + } else { + num_rows = host_data.num_X; + data = thrust::device_vector( + host_data.X, host_data.X + host_data.num_X * host_data.M); + missing = thrust::device_vector(host_data.X_missing, + host_data.X_missing + + host_data.num_X * host_data.M); + } + } + + class DenseDatasetWrapper { + const tfloat *data; + const bool *missing; + int num_rows; + int num_cols; + + public: + DenseDatasetWrapper() = default; + DenseDatasetWrapper(const tfloat *data, const bool *missing, int num_rows, + int num_cols) + : data(data), missing(missing), num_rows(num_rows), num_cols(num_cols) { + } + __device__ tfloat GetElement(size_t row_idx, size_t col_idx) const { + auto idx = row_idx * num_cols + col_idx; + if (missing[idx]) { + return std::numeric_limits::quiet_NaN(); + } + return data[idx]; + } + __host__ __device__ size_t NumRows() const { return num_rows; } + __host__ __device__ size_t NumCols() const { return num_cols; } + }; + + DenseDatasetWrapper GetDeviceAccessor() { + return DenseDatasetWrapper(data.data().get(), missing.data().get(), + num_rows, num_features); + } +}; + +inline void dense_tree_path_dependent_gpu( + const TreeEnsemble &trees, const ExplanationDataset &data, + tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) { + auto paths = ExtractPaths(trees); + DeviceExplanationDataset device_data(data); + DeviceExplanationDataset::DenseDatasetWrapper X = + device_data.GetDeviceAccessor(); + + thrust::device_vector phis((X.NumCols() + 1) * X.NumRows() * + trees.num_outputs); + gpu_treeshap::GPUTreeShap(X, paths.begin(), paths.end(), trees.num_outputs, + phis.begin(), phis.end()); + // Add the base offset term to bias + thrust::device_vector base_offset( + trees.base_offset, trees.base_offset + trees.num_outputs); + auto counting = thrust::make_counting_iterator(size_t(0)); + auto d_phis = phis.data().get(); + auto d_base_offset = base_offset.data().get(); + size_t num_groups = trees.num_outputs; + thrust::for_each(counting, counting + X.NumRows() * trees.num_outputs, + [=] __device__(size_t idx) { + size_t row_idx = idx / num_groups; + size_t group = idx % num_groups; + auto phi_idx = gpu_treeshap::IndexPhi( + row_idx, num_groups, group, X.NumCols(), X.NumCols()); + d_phis[phi_idx] += d_base_offset[group]; + }); + + // Shap uses a slightly different layout for multiclass + thrust::device_vector transposed_phis(phis.size()); + auto d_transposed_phis = transposed_phis.data(); + thrust::for_each( + counting, counting + phis.size(), [=] __device__(size_t idx) { + size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1)}; + size_t old_idx[array_size(old_shape)]; + gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx); + // Define new tensor format, switch num_groups axis to end + size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), num_groups}; + size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[1]}; + size_t transposed_idx = + gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx); + d_transposed_phis[transposed_idx] = d_phis[idx]; + }); + thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs); +} + +inline void +dense_tree_independent_gpu(const TreeEnsemble &trees, + const ExplanationDataset &data, tfloat *out_contribs, + tfloat transform(const tfloat, const tfloat)) { + auto paths = ExtractPaths(trees); + DeviceExplanationDataset device_data(data); + DeviceExplanationDataset::DenseDatasetWrapper X = + device_data.GetDeviceAccessor(); + DeviceExplanationDataset background_device_data(data, true); + DeviceExplanationDataset::DenseDatasetWrapper R = + background_device_data.GetDeviceAccessor(); + + thrust::device_vector phis((X.NumCols() + 1) * X.NumRows() * + trees.num_outputs); + gpu_treeshap::GPUTreeShapInterventional(X, R, paths.begin(), paths.end(), + trees.num_outputs, phis.begin(), + phis.end()); + // Add the base offset term to bias + thrust::device_vector base_offset( + trees.base_offset, trees.base_offset + trees.num_outputs); + auto counting = thrust::make_counting_iterator(size_t(0)); + auto d_phis = phis.data().get(); + auto d_base_offset = base_offset.data().get(); + size_t num_groups = trees.num_outputs; + thrust::for_each(counting, counting + X.NumRows() * trees.num_outputs, + [=] __device__(size_t idx) { + size_t row_idx = idx / num_groups; + size_t group = idx % num_groups; + auto phi_idx = gpu_treeshap::IndexPhi( + row_idx, num_groups, group, X.NumCols(), X.NumCols()); + d_phis[phi_idx] += d_base_offset[group]; + }); + + // Shap uses a slightly different layout for multiclass + thrust::device_vector transposed_phis(phis.size()); + auto d_transposed_phis = transposed_phis.data(); + thrust::for_each( + counting, counting + phis.size(), [=] __device__(size_t idx) { + size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1)}; + size_t old_idx[array_size(old_shape)]; + gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx); + // Define new tensor format, switch num_groups axis to end + size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), num_groups}; + size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[1]}; + size_t transposed_idx = + gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx); + d_transposed_phis[transposed_idx] = d_phis[idx]; + }); + thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs); +} + +inline void dense_tree_path_dependent_interactions_gpu( + const TreeEnsemble &trees, const ExplanationDataset &data, + tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) { + auto paths = ExtractPaths(trees); + DeviceExplanationDataset device_data(data); + DeviceExplanationDataset::DenseDatasetWrapper X = + device_data.GetDeviceAccessor(); + + thrust::device_vector phis((X.NumCols() + 1) * (X.NumCols() + 1) * + X.NumRows() * trees.num_outputs); + gpu_treeshap::GPUTreeShapInteractions(X, paths.begin(), paths.end(), + trees.num_outputs, phis.begin(), + phis.end()); + // Add the base offset term to bias + thrust::device_vector base_offset( + trees.base_offset, trees.base_offset + trees.num_outputs); + auto counting = thrust::make_counting_iterator(size_t(0)); + auto d_phis = phis.data().get(); + auto d_base_offset = base_offset.data().get(); + size_t num_groups = trees.num_outputs; + thrust::for_each(counting, counting + X.NumRows() * num_groups, + [=] __device__(size_t idx) { + size_t row_idx = idx / num_groups; + size_t group = idx % num_groups; + auto phi_idx = gpu_treeshap::IndexPhiInteractions( + row_idx, num_groups, group, X.NumCols(), X.NumCols(), + X.NumCols()); + d_phis[phi_idx] += d_base_offset[group]; + }); + // Shap uses a slightly different layout for multiclass + thrust::device_vector transposed_phis(phis.size()); + auto d_transposed_phis = transposed_phis.data(); + thrust::for_each( + counting, counting + phis.size(), [=] __device__(size_t idx) { + size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1), + (X.NumCols() + 1)}; + size_t old_idx[array_size(old_shape)]; + gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx); + // Define new tensor format, switch num_groups axis to end + size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), (X.NumCols() + 1), + num_groups}; + size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[3], old_idx[1]}; + size_t transposed_idx = + gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx); + d_transposed_phis[transposed_idx] = d_phis[idx]; + }); + thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs); +} + +void dense_tree_shap_gpu(const TreeEnsemble &trees, + const ExplanationDataset &data, tfloat *out_contribs, + const int feature_dependence, unsigned model_transform, + bool interactions) { + // see what transform (if any) we have + transform_f transform = get_transform(model_transform); + + // dispatch to the correct algorithm handler + switch (feature_dependence) { + case FEATURE_DEPENDENCE::independent: + if (interactions) { + std::cerr << "FEATURE_DEPENDENCE::independent with interactions not yet " + "supported\n"; + } else { + dense_tree_independent_gpu(trees, data, out_contribs, transform); + } + return; + + case FEATURE_DEPENDENCE::tree_path_dependent: + if (interactions) { + dense_tree_path_dependent_interactions_gpu(trees, data, out_contribs, + transform); + } else { + dense_tree_path_dependent_gpu(trees, data, out_contribs, transform); + } + return; + + case FEATURE_DEPENDENCE::global_path_dependent: + std::cerr << "FEATURE_DEPENDENCE::global_path_dependent not supported\n"; + return; + default: + std::cerr << "Unknown feature dependence option\n"; + return; + } +} diff --git a/lib/shap/cext/gpu_treeshap.h b/lib/shap/cext/gpu_treeshap.h new file mode 100644 index 0000000000000000000000000000000000000000..1666f153f0380762bb2bc84dc8eff59e5a8fae2b --- /dev/null +++ b/lib/shap/cext/gpu_treeshap.h @@ -0,0 +1,1535 @@ +/* + * Copyright (c) 2020, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include +#include +#if (CUDART_VERSION >= 11000) +#include +#else +// Hack to get cub device reduce on older toolkits +#include +using namespace thrust::cuda_cub; +#endif +#include +#include +#include +#include +#include +#include + +namespace gpu_treeshap { + +struct XgboostSplitCondition { + XgboostSplitCondition() = default; + XgboostSplitCondition(float feature_lower_bound, float feature_upper_bound, + bool is_missing_branch) + : feature_lower_bound(feature_lower_bound), + feature_upper_bound(feature_upper_bound), + is_missing_branch(is_missing_branch) { + assert(feature_lower_bound <= feature_upper_bound); + } + + /*! Feature values >= lower and < upper flow down this path. */ + float feature_lower_bound; + float feature_upper_bound; + /*! Do missing values flow down this path? */ + bool is_missing_branch; + + // Does this instance flow down this path? + __host__ __device__ bool EvaluateSplit(float x) const { + // is nan + if (isnan(x)) { + return is_missing_branch; + } + return x >= feature_lower_bound && x < feature_upper_bound; + } + + // Combine two split conditions on the same feature + __host__ __device__ void Merge( + const XgboostSplitCondition& other) { // Combine duplicate features + feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound); + feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound); + is_missing_branch = is_missing_branch && other.is_missing_branch; + } +}; + +/*! + * An element of a unique path through a decision tree. Can implement various + * types of splits via the templated SplitConditionT. Some decision tree + * implementations may wish to use double precision or single precision, some + * may use < or <= as the threshold, missing values can be handled differently, + * categoricals may be supported. + * + * \tparam SplitConditionT A split condition implementing the methods + * EvaluateSplit and Merge. + */ +template +struct PathElement { + using split_type = SplitConditionT; + __host__ __device__ PathElement(size_t path_idx, int64_t feature_idx, + int group, SplitConditionT split_condition, + double zero_fraction, float v) + : path_idx(path_idx), + feature_idx(feature_idx), + group(group), + split_condition(split_condition), + zero_fraction(zero_fraction), + v(v) {} + + PathElement() = default; + __host__ __device__ bool IsRoot() const { return feature_idx == -1; } + + template + __host__ __device__ bool EvaluateSplit(DatasetT X, size_t row_idx) const { + if (this->IsRoot()) { + return 1.0; + } + return split_condition.EvaluateSplit(X.GetElement(row_idx, feature_idx)); + } + + /*! Unique path index. */ + size_t path_idx; + /*! Feature of this split, -1 indicates bias term. */ + int64_t feature_idx; + /*! Indicates class for multiclass problems. */ + int group; + SplitConditionT split_condition; + /*! Probability of following this path when feature_idx is not in the active + * set. */ + double zero_fraction; + float v; // Leaf weight at the end of the path +}; + +// Helper function that accepts an index into a flat contiguous array and the +// dimensions of a tensor and returns the indices with respect to the tensor +template +__device__ void FlatIdxToTensorIdx(T flat_idx, const T (&shape)[N], + T (&out_idx)[N]) { + T current_size = shape[0]; + for (auto i = 1ull; i < N; i++) { + current_size *= shape[i]; + } + for (auto i = 0ull; i < N; i++) { + current_size /= shape[i]; + out_idx[i] = flat_idx / current_size; + flat_idx -= current_size * out_idx[i]; + } +} + +// Given a shape and coordinates into a tensor, return the index into the +// backing storage one-dimensional array +template +__device__ T TensorIdxToFlatIdx(const T (&shape)[N], const T (&tensor_idx)[N]) { + T current_size = shape[0]; + for (auto i = 1ull; i < N; i++) { + current_size *= shape[i]; + } + T idx = 0; + for (auto i = 0ull; i < N; i++) { + current_size /= shape[i]; + idx += tensor_idx[i] * current_size; + } + return idx; +} + +// Maps values to the phi array according to row, group and column +__host__ __device__ inline size_t IndexPhi(size_t row_idx, size_t num_groups, + size_t group, size_t num_columns, + size_t column_idx) { + return (row_idx * num_groups + group) * (num_columns + 1) + column_idx; +} + +__host__ __device__ inline size_t IndexPhiInteractions(size_t row_idx, + size_t num_groups, + size_t group, + size_t num_columns, + size_t i, size_t j) { + size_t matrix_size = (num_columns + 1) * (num_columns + 1); + size_t matrix_offset = (row_idx * num_groups + group) * matrix_size; + return matrix_offset + i * (num_columns + 1) + j; +} + +namespace detail { + +// Shorthand for creating a device vector with an appropriate allocator type +template +using RebindVector = + thrust::device_vector::other>; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__) +__device__ __forceinline__ double atomicAddDouble(double* address, double val) { + return atomicAdd(address, val); +} +#else // In device code and CUDA < 600 +__device__ __forceinline__ double atomicAddDouble(double* address, + double val) { // NOLINT + unsigned long long int* address_as_ull = // NOLINT + (unsigned long long int*)address; // NOLINT + unsigned long long int old = *address_as_ull, assumed; // NOLINT + + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + + // Note: uses integer comparison to avoid hang in case of NaN (since NaN != + // NaN) + } while (assumed != old); + + return __longlong_as_double(old); +} +#endif + +__forceinline__ __device__ unsigned int lanemask32_lt() { + unsigned int lanemask32_lt; + asm volatile("mov.u32 %0, %%lanemask_lt;" : "=r"(lanemask32_lt)); + return (lanemask32_lt); +} + +// Like a coalesced group, except we can make the assumption that all threads in +// a group are next to each other. This makes shuffle operations much cheaper. +class ContiguousGroup { + public: + __device__ ContiguousGroup(uint32_t mask) : mask_(mask) {} + + __device__ uint32_t size() const { return __popc(mask_); } + __device__ uint32_t thread_rank() const { + return __popc(mask_ & lanemask32_lt()); + } + template + __device__ T shfl(T val, uint32_t src) const { + return __shfl_sync(mask_, val, src + __ffs(mask_) - 1); + } + template + __device__ T shfl_up(T val, uint32_t delta) const { + return __shfl_up_sync(mask_, val, delta); + } + __device__ uint32_t ballot(int predicate) const { + return __ballot_sync(mask_, predicate) >> (__ffs(mask_) - 1); + } + + template + __device__ T reduce(T val, OpT op) { + for (int i = 1; i < this->size(); i *= 2) { + T shfl = shfl_up(val, i); + if (static_cast(thread_rank()) - i >= 0) { + val = op(val, shfl); + } + } + return shfl(val, size() - 1); + } + uint32_t mask_; +}; + +// Separate the active threads by labels +// This functionality is available in cuda 11.0 on cc >=7.0 +// We reimplement for backwards compatibility +// Assumes partitions are contiguous +inline __device__ ContiguousGroup active_labeled_partition(uint32_t mask, + int label) { +#if __CUDA_ARCH__ >= 700 + uint32_t subgroup_mask = __match_any_sync(mask, label); +#else + uint32_t subgroup_mask = 0; + for (int i = 0; i < 32;) { + int current_label = __shfl_sync(mask, label, i); + uint32_t ballot = __ballot_sync(mask, label == current_label); + if (label == current_label) { + subgroup_mask = ballot; + } + uint32_t completed_mask = + (1 << (32 - __clz(ballot))) - 1; // Threads that have finished + // Find the start of the next group, mask off completed threads from active + // threads Then use ffs - 1 to find the position of the next group + int next_i = __ffs(mask & ~completed_mask) - 1; + if (next_i == -1) break; // -1 indicates all finished + assert(next_i > i); // Prevent infinite loops when the constraints not met + i = next_i; + } +#endif + return ContiguousGroup(subgroup_mask); +} + +// Group of threads where each thread holds a path element +class GroupPath { + protected: + const ContiguousGroup& g_; + // These are combined so we can communicate them in a single 64 bit shuffle + // instruction + float zero_one_fraction_[2]; + float pweight_; + int unique_depth_; + + public: + __device__ GroupPath(const ContiguousGroup& g, float zero_fraction, + float one_fraction) + : g_(g), + zero_one_fraction_{zero_fraction, one_fraction}, + pweight_(g.thread_rank() == 0 ? 1.0f : 0.0f), + unique_depth_(0) {} + + // Cooperatively extend the path with a group of threads + // Each thread maintains pweight for its path element in register + __device__ void Extend() { + unique_depth_++; + + // Broadcast the zero and one fraction from the newly added path element + // Combine 2 shuffle operations into 64 bit word + const size_t rank = g_.thread_rank(); + const float inv_unique_depth = + __fdividef(1.0f, static_cast(unique_depth_ + 1)); + uint64_t res = g_.shfl(*reinterpret_cast(&zero_one_fraction_), + unique_depth_); + const float new_zero_fraction = reinterpret_cast(&res)[0]; + const float new_one_fraction = reinterpret_cast(&res)[1]; + float left_pweight = g_.shfl_up(pweight_, 1); + + // pweight of threads with rank < unique_depth_ is 0 + // We use max(x,0) to avoid using a branch + // pweight_ *= + // new_zero_fraction * max(unique_depth_ - rank, 0llu) * inv_unique_depth; + pweight_ = __fmul_rn( + __fmul_rn(pweight_, new_zero_fraction), + __fmul_rn(max(unique_depth_ - rank, size_t(0)), inv_unique_depth)); + + // pweight_ += new_one_fraction * left_pweight * rank * inv_unique_depth; + pweight_ = __fmaf_rn(__fmul_rn(new_one_fraction, left_pweight), + __fmul_rn(rank, inv_unique_depth), pweight_); + } + + // Each thread unwinds the path for its feature and returns the sum + __device__ float UnwoundPathSum() { + float next_one_portion = g_.shfl(pweight_, unique_depth_); + float total = 0.0f; + const float zero_frac_div_unique_depth = __fdividef( + zero_one_fraction_[0], static_cast(unique_depth_ + 1)); + for (int i = unique_depth_ - 1; i >= 0; i--) { + float ith_pweight = g_.shfl(pweight_, i); + float precomputed = + __fmul_rn((unique_depth_ - i), zero_frac_div_unique_depth); + const float tmp = + __fdividef(__fmul_rn(next_one_portion, unique_depth_ + 1), i + 1); + total = __fmaf_rn(tmp, zero_one_fraction_[1], total); + next_one_portion = __fmaf_rn(-tmp, precomputed, ith_pweight); + float numerator = + __fmul_rn(__fsub_rn(1.0f, zero_one_fraction_[1]), ith_pweight); + if (precomputed > 0.0f) { + total += __fdividef(numerator, precomputed); + } + } + + return total; + } +}; + +// Has different permutation weightings to the above +// Used in Taylor Shapley interaction index +class TaylorGroupPath : GroupPath { + public: + __device__ TaylorGroupPath(const ContiguousGroup& g, float zero_fraction, + float one_fraction) + : GroupPath(g, zero_fraction, one_fraction) {} + + // Extend the path is normal, all reweighting can happen in UnwoundPathSum + __device__ void Extend() { GroupPath::Extend(); } + + // Each thread unwinds the path for its feature and returns the sum + // We use a different permutation weighting for Taylor interactions + // As if the total number of features was one larger + __device__ float UnwoundPathSum() { + float one_fraction = zero_one_fraction_[1]; + float zero_fraction = zero_one_fraction_[0]; + float next_one_portion = g_.shfl(pweight_, unique_depth_) / + static_cast(unique_depth_ + 2); + + float total = 0.0f; + for (int i = unique_depth_ - 1; i >= 0; i--) { + float ith_pweight = + g_.shfl(pweight_, i) * (static_cast(unique_depth_ - i + 1) / + static_cast(unique_depth_ + 2)); + if (one_fraction > 0.0f) { + const float tmp = + next_one_portion * (unique_depth_ + 2) / ((i + 1) * one_fraction); + + total += tmp; + next_one_portion = + ith_pweight - tmp * zero_fraction * + ((unique_depth_ - i + 1) / + static_cast(unique_depth_ + 2)); + } else if (zero_fraction > 0.0f) { + total += + (ith_pweight / zero_fraction) / + ((unique_depth_ - i + 1) / static_cast(unique_depth_ + 2)); + } + } + + return 2 * total; + } +}; + +template +__device__ float ComputePhi(const PathElement& e, + size_t row_idx, const DatasetT& X, + const ContiguousGroup& group, float zero_fraction) { + float one_fraction = + e.EvaluateSplit(X, row_idx); + GroupPath path(group, zero_fraction, one_fraction); + size_t unique_path_length = group.size(); + + // Extend the path + for (auto unique_depth = 1ull; unique_depth < unique_path_length; + unique_depth++) { + path.Extend(); + } + + float sum = path.UnwoundPathSum(); + return sum * (one_fraction - zero_fraction) * e.v; +} + +inline __host__ __device__ size_t DivRoundUp(size_t a, size_t b) { + return (a + b - 1) / b; +} + +template +void __device__ +ConfigureThread(const DatasetT& X, const size_t bins_per_row, + const PathElement* path_elements, + const size_t* bin_segments, size_t* start_row, size_t* end_row, + PathElement* e, bool* thread_active) { + // Partition work + // Each warp processes a set of training instances applied to a path + size_t tid = kBlockSize * blockIdx.x + threadIdx.x; + const size_t warp_size = 32; + size_t warp_rank = tid / warp_size; + if (warp_rank >= bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp)) { + *thread_active = false; + return; + } + size_t bin_idx = warp_rank % bins_per_row; + size_t bank = warp_rank / bins_per_row; + size_t path_start = bin_segments[bin_idx]; + size_t path_end = bin_segments[bin_idx + 1]; + uint32_t thread_rank = threadIdx.x % warp_size; + if (thread_rank >= path_end - path_start) { + *thread_active = false; + } else { + *e = path_elements[path_start + thread_rank]; + *start_row = bank * kRowsPerWarp; + *end_row = min((bank + 1) * kRowsPerWarp, X.NumRows()); + *thread_active = true; + } +} + +#define GPUTREESHAP_MAX_THREADS_PER_BLOCK 256 +#define FULL_MASK 0xffffffff + +template +__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK) + ShapKernel(DatasetT X, size_t bins_per_row, + const PathElement* path_elements, + const size_t* bin_segments, size_t num_groups, double* phis) { + // Use shared memory for structs, otherwise nvcc puts in local memory + __shared__ DatasetT s_X; + s_X = X; + __shared__ PathElement s_elements[kBlockSize]; + PathElement& e = s_elements[threadIdx.x]; + + size_t start_row, end_row; + bool thread_active; + ConfigureThread( + s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e, + &thread_active); + uint32_t mask = __ballot_sync(FULL_MASK, thread_active); + if (!thread_active) return; + + float zero_fraction = e.zero_fraction; + auto labelled_group = active_labeled_partition(mask, e.path_idx); + + for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) { + float phi = ComputePhi(e, row_idx, X, labelled_group, zero_fraction); + + if (!e.IsRoot()) { + atomicAddDouble(&phis[IndexPhi(row_idx, num_groups, e.group, X.NumCols(), + e.feature_idx)], + phi); + } + } +} + +template +void ComputeShap( + DatasetT X, + const thrust::device_vector& bin_segments, + const thrust::device_vector, PathAllocatorT>& + path_elements, + size_t num_groups, double* phis) { + size_t bins_per_row = bin_segments.size() - 1; + const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK; + const int warps_per_block = kBlockThreads / 32; + const int kRowsPerWarp = 1024; + size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp); + + const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block); + + ShapKernel + <<>>( + X, bins_per_row, path_elements.data().get(), + bin_segments.data().get(), num_groups, phis); +} + +template +__device__ float ComputePhiCondition(const PathElement& e, + size_t row_idx, const DatasetT& X, + const ContiguousGroup& group, + int64_t condition_feature) { + float one_fraction = e.EvaluateSplit(X, row_idx); + PathT path(group, e.zero_fraction, one_fraction); + size_t unique_path_length = group.size(); + float condition_on_fraction = 1.0f; + float condition_off_fraction = 1.0f; + + // Extend the path + for (auto i = 1ull; i < unique_path_length; i++) { + bool is_condition_feature = + group.shfl(e.feature_idx, i) == condition_feature; + float o_i = group.shfl(one_fraction, i); + float z_i = group.shfl(e.zero_fraction, i); + + if (is_condition_feature) { + condition_on_fraction = o_i; + condition_off_fraction = z_i; + } else { + path.Extend(); + } + } + float sum = path.UnwoundPathSum(); + if (e.feature_idx == condition_feature) { + return 0.0f; + } + float phi = sum * (one_fraction - e.zero_fraction) * e.v; + return phi * (condition_on_fraction - condition_off_fraction) * 0.5f; +} + +// If there is a feature in the path we are conditioning on, swap it to the end +// of the path +template +inline __device__ void SwapConditionedElement( + PathElement** e, PathElement* s_elements, + uint32_t condition_rank, const ContiguousGroup& group) { + auto last_rank = group.size() - 1; + auto this_rank = group.thread_rank(); + if (this_rank == last_rank) { + *e = &s_elements[(threadIdx.x - this_rank) + condition_rank]; + } else if (this_rank == condition_rank) { + *e = &s_elements[(threadIdx.x - this_rank) + last_rank]; + } +} + +template +__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK) + ShapInteractionsKernel(DatasetT X, size_t bins_per_row, + const PathElement* path_elements, + const size_t* bin_segments, size_t num_groups, + double* phis_interactions) { + // Use shared memory for structs, otherwise nvcc puts in local memory + __shared__ DatasetT s_X; + s_X = X; + __shared__ PathElement s_elements[kBlockSize]; + PathElement* e = &s_elements[threadIdx.x]; + + size_t start_row, end_row; + bool thread_active; + ConfigureThread( + s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e, + &thread_active); + uint32_t mask = __ballot_sync(FULL_MASK, thread_active); + if (!thread_active) return; + + auto labelled_group = active_labeled_partition(mask, e->path_idx); + + for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) { + float phi = ComputePhi(*e, row_idx, X, labelled_group, e->zero_fraction); + if (!e->IsRoot()) { + auto phi_offset = + IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(), + e->feature_idx, e->feature_idx); + atomicAddDouble(phis_interactions + phi_offset, phi); + } + + for (auto condition_rank = 1ull; condition_rank < labelled_group.size(); + condition_rank++) { + e = &s_elements[threadIdx.x]; + int64_t condition_feature = + labelled_group.shfl(e->feature_idx, condition_rank); + SwapConditionedElement(&e, s_elements, condition_rank, labelled_group); + float x = ComputePhiCondition(*e, row_idx, X, labelled_group, + condition_feature); + if (!e->IsRoot()) { + auto phi_offset = + IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(), + e->feature_idx, condition_feature); + atomicAddDouble(phis_interactions + phi_offset, x); + // Subtract effect from diagonal + auto phi_diag = + IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(), + e->feature_idx, e->feature_idx); + atomicAddDouble(phis_interactions + phi_diag, -x); + } + } + } +} + +template +void ComputeShapInteractions( + DatasetT X, + const thrust::device_vector& bin_segments, + const thrust::device_vector, PathAllocatorT>& + path_elements, + size_t num_groups, double* phis) { + size_t bins_per_row = bin_segments.size() - 1; + const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK; + const int warps_per_block = kBlockThreads / 32; + const int kRowsPerWarp = 100; + size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp); + + const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block); + + ShapInteractionsKernel + <<>>( + X, bins_per_row, path_elements.data().get(), + bin_segments.data().get(), num_groups, phis); +} + +template +__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK) + ShapTaylorInteractionsKernel( + DatasetT X, size_t bins_per_row, + const PathElement* path_elements, + const size_t* bin_segments, size_t num_groups, + double* phis_interactions) { + // Use shared memory for structs, otherwise nvcc puts in local memory + __shared__ DatasetT s_X; + if (threadIdx.x == 0) { + s_X = X; + } + __syncthreads(); + __shared__ PathElement s_elements[kBlockSize]; + PathElement* e = &s_elements[threadIdx.x]; + + size_t start_row, end_row; + bool thread_active; + ConfigureThread( + s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e, + &thread_active); + uint32_t mask = __ballot_sync(FULL_MASK, thread_active); + if (!thread_active) return; + + auto labelled_group = active_labeled_partition(mask, e->path_idx); + + for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) { + for (auto condition_rank = 1ull; condition_rank < labelled_group.size(); + condition_rank++) { + e = &s_elements[threadIdx.x]; + // Compute the diagonal terms + // TODO(Rory): this can be more efficient + float reduce_input = + e->IsRoot() || labelled_group.thread_rank() == condition_rank + ? 1.0f + : e->zero_fraction; + float reduce = + labelled_group.reduce(reduce_input, thrust::multiplies()); + if (labelled_group.thread_rank() == condition_rank) { + float one_fraction = e->split_condition.EvaluateSplit( + X.GetElement(row_idx, e->feature_idx)); + auto phi_offset = + IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(), + e->feature_idx, e->feature_idx); + atomicAddDouble(phis_interactions + phi_offset, + reduce * (one_fraction - e->zero_fraction) * e->v); + } + + int64_t condition_feature = + labelled_group.shfl(e->feature_idx, condition_rank); + + SwapConditionedElement(&e, s_elements, condition_rank, labelled_group); + + float x = ComputePhiCondition( + *e, row_idx, X, labelled_group, condition_feature); + if (!e->IsRoot()) { + auto phi_offset = + IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(), + e->feature_idx, condition_feature); + atomicAddDouble(phis_interactions + phi_offset, x); + } + } + } +} + +template +void ComputeShapTaylorInteractions( + DatasetT X, + const thrust::device_vector& bin_segments, + const thrust::device_vector, PathAllocatorT>& + path_elements, + size_t num_groups, double* phis) { + size_t bins_per_row = bin_segments.size() - 1; + const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK; + const int warps_per_block = kBlockThreads / 32; + const int kRowsPerWarp = 100; + size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp); + + const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block); + + ShapTaylorInteractionsKernel + <<>>( + X, bins_per_row, path_elements.data().get(), + bin_segments.data().get(), num_groups, phis); +} + + +inline __host__ __device__ int64_t Factorial(int64_t x) { + int64_t y = 1; + for (auto i = 2; i <= x; i++) { + y *= i; + } + return y; +} + +// Compute factorials in log space using lgamma to avoid overflow +inline __host__ __device__ double W(double s, double n) { + assert(n - s - 1 >= 0); + return exp(lgamma(s + 1) - lgamma(n + 1) + lgamma(n - s)); +} + +template +__global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK) + ShapInterventionalKernel(DatasetT X, DatasetT R, size_t bins_per_row, + const PathElement* path_elements, + const size_t* bin_segments, size_t num_groups, + double* phis) { + // Cache W coefficients + __shared__ float s_W[33][33]; + for (int i = threadIdx.x; i < 33 * 33; i += kBlockSize) { + auto s = i % 33; + auto n = i / 33; + if (n - s - 1 >= 0) { + s_W[s][n] = W(s, n); + } else { + s_W[s][n] = 0.0; + } + } + + __syncthreads(); + + __shared__ PathElement s_elements[kBlockSize]; + PathElement& e = s_elements[threadIdx.x]; + + size_t start_row, end_row; + bool thread_active; + ConfigureThread( + X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e, + &thread_active); + + uint32_t mask = __ballot_sync(FULL_MASK, thread_active); + if (!thread_active) return; + + auto labelled_group = active_labeled_partition(mask, e.path_idx); + + for (int64_t x_idx = start_row; x_idx < end_row; x_idx++) { + float result = 0.0f; + bool x_cond = e.EvaluateSplit(X, x_idx); + uint32_t x_ballot = labelled_group.ballot(x_cond); + for (int64_t r_idx = 0; r_idx < R.NumRows(); r_idx++) { + bool r_cond = e.EvaluateSplit(R, r_idx); + uint32_t r_ballot = labelled_group.ballot(r_cond); + assert(!e.IsRoot() || + (x_cond == r_cond)); // These should be the same for the root + uint32_t s = __popc(x_ballot & ~r_ballot); + uint32_t n = __popc(x_ballot ^ r_ballot); + float tmp = 0.0f; + // Theorem 1 + if (x_cond && !r_cond) { + tmp += s_W[s - 1][n]; + } + tmp -= s_W[s][n] * (r_cond && !x_cond); + + // No foreground samples make it to this leaf, increment bias + if (e.IsRoot() && s == 0) { + tmp += 1.0f; + } + // If neither foreground or background go down this path, ignore this path + bool reached_leaf = !labelled_group.ballot(!x_cond && !r_cond); + tmp *= reached_leaf; + result += tmp; + } + + if (result != 0.0) { + result /= R.NumRows(); + + // Root writes bias + auto feature = e.IsRoot() ? X.NumCols() : e.feature_idx; + atomicAddDouble( + &phis[IndexPhi(x_idx, num_groups, e.group, X.NumCols(), feature)], + result * e.v); + } + } +} + +template +void ComputeShapInterventional( + DatasetT X, DatasetT R, + const thrust::device_vector& bin_segments, + const thrust::device_vector, PathAllocatorT>& + path_elements, + size_t num_groups, double* phis) { + size_t bins_per_row = bin_segments.size() - 1; + const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK; + const int warps_per_block = kBlockThreads / 32; + const int kRowsPerWarp = 100; + size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp); + + const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block); + + ShapInterventionalKernel + <<>>( + X, R, bins_per_row, path_elements.data().get(), + bin_segments.data().get(), num_groups, phis); +} + +template +void GetBinSegments(const PathVectorT& paths, const SizeVectorT& bin_map, + SizeVectorT* bin_segments) { + DeviceAllocatorT alloc; + size_t num_bins = + thrust::reduce(thrust::cuda::par(alloc), bin_map.begin(), bin_map.end(), + size_t(0), thrust::maximum()) + + 1; + bin_segments->resize(num_bins + 1, 0); + auto counting = thrust::make_counting_iterator(0llu); + auto d_paths = paths.data().get(); + auto d_bin_segments = bin_segments->data().get(); + auto d_bin_map = bin_map.data(); + thrust::for_each_n(counting, paths.size(), [=] __device__(size_t idx) { + auto path_idx = d_paths[idx].path_idx; + atomicAdd(reinterpret_cast(d_bin_segments) + // NOLINT + d_bin_map[path_idx], + 1); + }); + thrust::exclusive_scan(thrust::cuda::par(alloc), bin_segments->begin(), + bin_segments->end(), bin_segments->begin()); +} + +struct DeduplicateKeyTransformOp { + template + __device__ thrust::pair operator()( + const PathElement& e) { + return {e.path_idx, e.feature_idx}; + } +}; + +inline void CheckCuda(cudaError_t err) { + if (err != cudaSuccess) { + throw thrust::system_error(err, thrust::cuda_category()); + } +} + +template +class DiscardOverload : public thrust::discard_iterator { + public: + using value_type = Return; // NOLINT +}; + +template +void DeduplicatePaths(PathVectorT* device_paths, + PathVectorT* deduplicated_paths) { + DeviceAllocatorT alloc; + // Sort by feature + thrust::sort(thrust::cuda::par(alloc), device_paths->begin(), + device_paths->end(), + [=] __device__(const PathElement& a, + const PathElement& b) { + if (a.path_idx < b.path_idx) return true; + if (b.path_idx < a.path_idx) return false; + + if (a.feature_idx < b.feature_idx) return true; + if (b.feature_idx < a.feature_idx) return false; + return false; + }); + + deduplicated_paths->resize(device_paths->size()); + + using Pair = thrust::pair; + auto key_transform = thrust::make_transform_iterator( + device_paths->begin(), DeduplicateKeyTransformOp()); + + thrust::device_vector d_num_runs_out(1); + size_t* h_num_runs_out; + CheckCuda(cudaMallocHost(&h_num_runs_out, sizeof(size_t))); + + auto combine = [] __device__(PathElement a, + PathElement b) { + // Combine duplicate features + a.split_condition.Merge(b.split_condition); + a.zero_fraction *= b.zero_fraction; + return a; + }; // NOLINT + size_t temp_size = 0; + CheckCuda(cub::DeviceReduce::ReduceByKey( + nullptr, temp_size, key_transform, DiscardOverload(), + device_paths->begin(), deduplicated_paths->begin(), + d_num_runs_out.begin(), combine, device_paths->size())); + using TempAlloc = RebindVector; + TempAlloc tmp(temp_size); + CheckCuda(cub::DeviceReduce::ReduceByKey( + tmp.data().get(), temp_size, key_transform, DiscardOverload(), + device_paths->begin(), deduplicated_paths->begin(), + d_num_runs_out.begin(), combine, device_paths->size())); + + CheckCuda(cudaMemcpy(h_num_runs_out, d_num_runs_out.data().get(), + sizeof(size_t), cudaMemcpyDeviceToHost)); + deduplicated_paths->resize(*h_num_runs_out); + CheckCuda(cudaFreeHost(h_num_runs_out)); +} + +template +void SortPaths(PathVectorT* paths, const SizeVectorT& bin_map) { + auto d_bin_map = bin_map.data(); + DeviceAllocatorT alloc; + thrust::sort(thrust::cuda::par(alloc), paths->begin(), paths->end(), + [=] __device__(const PathElement& a, + const PathElement& b) { + size_t a_bin = d_bin_map[a.path_idx]; + size_t b_bin = d_bin_map[b.path_idx]; + if (a_bin < b_bin) return true; + if (b_bin < a_bin) return false; + + if (a.path_idx < b.path_idx) return true; + if (b.path_idx < a.path_idx) return false; + + if (a.feature_idx < b.feature_idx) return true; + if (b.feature_idx < a.feature_idx) return false; + return false; + }); +} + +using kv = std::pair; + +struct BFDCompare { + bool operator()(const kv& lhs, const kv& rhs) const { + if (lhs.second == rhs.second) { + return lhs.first < rhs.first; + } + return lhs.second < rhs.second; + } +}; + +// Best Fit Decreasing bin packing +// Efficient O(nlogn) implementation with balanced tree using std::set +template +std::vector BFDBinPacking(const IntVectorT& counts, + int bin_limit = 32) { + thrust::host_vector counts_host(counts); + std::vector path_lengths(counts_host.size()); + for (auto i = 0ull; i < counts_host.size(); i++) { + path_lengths[i] = {i, counts_host[i]}; + } + + std::sort(path_lengths.begin(), path_lengths.end(), + [&](const kv& a, const kv& b) { + std::greater<> op; + return op(a.second, b.second); + }); + + // map unique_id -> bin + std::vector bin_map(counts_host.size()); + std::set bin_capacities; + bin_capacities.insert({bin_capacities.size(), bin_limit}); + for (auto pair : path_lengths) { + int new_size = pair.second; + auto itr = bin_capacities.lower_bound({0, new_size}); + // Does not fit in any bin + if (itr == bin_capacities.end()) { + size_t new_bin_idx = bin_capacities.size(); + bin_capacities.insert({new_bin_idx, bin_limit - new_size}); + bin_map[pair.first] = new_bin_idx; + } else { + kv entry = *itr; + entry.second -= new_size; + bin_map[pair.first] = entry.first; + bin_capacities.erase(itr); + bin_capacities.insert(entry); + } + } + + return bin_map; +} + +// First Fit Decreasing bin packing +// Inefficient O(n^2) implementation +template +std::vector FFDBinPacking(const IntVectorT& counts, + int bin_limit = 32) { + thrust::host_vector counts_host(counts); + std::vector path_lengths(counts_host.size()); + for (auto i = 0ull; i < counts_host.size(); i++) { + path_lengths[i] = {i, counts_host[i]}; + } + std::sort(path_lengths.begin(), path_lengths.end(), + [&](const kv& a, const kv& b) { + std::greater<> op; + return op(a.second, b.second); + }); + + // map unique_id -> bin + std::vector bin_map(counts_host.size()); + std::vector bin_capacities(path_lengths.size(), bin_limit); + for (auto pair : path_lengths) { + int new_size = pair.second; + for (auto j = 0ull; j < bin_capacities.size(); j++) { + int& capacity = bin_capacities[j]; + + if (capacity >= new_size) { + capacity -= new_size; + bin_map[pair.first] = j; + break; + } + } + } + + return bin_map; +} + +// Next Fit bin packing +// O(n) implementation +template +std::vector NFBinPacking(const IntVectorT& counts, int bin_limit = 32) { + thrust::host_vector counts_host(counts); + std::vector bin_map(counts_host.size()); + size_t current_bin = 0; + int current_capacity = bin_limit; + for (auto i = 0ull; i < counts_host.size(); i++) { + int new_size = counts_host[i]; + size_t path_idx = i; + if (new_size <= current_capacity) { + current_capacity -= new_size; + bin_map[path_idx] = current_bin; + } else { + current_capacity = bin_limit - new_size; + bin_map[path_idx] = ++current_bin; + } + } + return bin_map; +} + +template +void GetPathLengths(const PathVectorT& device_paths, + LengthVectorT* path_lengths) { + path_lengths->resize( + static_cast>(device_paths.back()).path_idx + + 1, + 0); + auto counting = thrust::make_counting_iterator(0llu); + auto d_paths = device_paths.data().get(); + auto d_lengths = path_lengths->data().get(); + thrust::for_each_n(counting, device_paths.size(), [=] __device__(size_t idx) { + auto path_idx = d_paths[idx].path_idx; + atomicAdd(d_lengths + path_idx, 1ull); + }); +} + +struct PathTooLongOp { + __device__ size_t operator()(size_t length) { return length > 32; } +}; + +template +struct IncorrectVOp { + const PathElement* paths; + __device__ size_t operator()(size_t idx) { + auto a = paths[idx - 1]; + auto b = paths[idx]; + return a.path_idx == b.path_idx && a.v != b.v; + } +}; + +template +void ValidatePaths(const PathVectorT& device_paths, + const LengthVectorT& path_lengths) { + DeviceAllocatorT alloc; + PathTooLongOp too_long_op; + auto invalid_length = + thrust::any_of(thrust::cuda::par(alloc), path_lengths.begin(), + path_lengths.end(), too_long_op); + + if (invalid_length) { + throw std::invalid_argument("Tree depth must be < 32"); + } + + IncorrectVOp incorrect_v_op{device_paths.data().get()}; + auto counting = thrust::counting_iterator(0); + auto incorrect_v = + thrust::any_of(thrust::cuda::par(alloc), counting + 1, + counting + device_paths.size(), incorrect_v_op); + + if (incorrect_v) { + throw std::invalid_argument( + "Leaf value v should be the same across a single path"); + } +} + +template +void PreprocessPaths(PathVectorT* device_paths, PathVectorT* deduplicated_paths, + SizeVectorT* bin_segments) { + // Sort paths by length and feature + detail::DeduplicatePaths( + device_paths, deduplicated_paths); + using int_vector = RebindVector; + int_vector path_lengths; + detail::GetPathLengths(*deduplicated_paths, + &path_lengths); + SizeVectorT device_bin_map = detail::BFDBinPacking(path_lengths); + ValidatePaths(*deduplicated_paths, + path_lengths); + detail::SortPaths(deduplicated_paths, device_bin_map); + detail::GetBinSegments( + *deduplicated_paths, device_bin_map, bin_segments); +} + +struct PathIdxTransformOp { + template + __device__ size_t operator()(const PathElement& e) { + return e.path_idx; + } +}; + +struct GroupIdxTransformOp { + template + __device__ size_t operator()(const PathElement& e) { + return e.group; + } +}; + +struct BiasTransformOp { + template + __device__ double operator()(const PathElement& e) { + return e.zero_fraction * e.v; + } +}; + +// While it is possible to compute bias in the primary kernel, we do it here +// using double precision to avoid numerical stability issues +template +void ComputeBias(const PathVectorT& device_paths, DoubleVectorT* bias) { + using double_vector = thrust::device_vector< + double, typename DeviceAllocatorT::template rebind::other>; + PathVectorT sorted_paths(device_paths); + DeviceAllocatorT alloc; + // Make sure groups are contiguous + thrust::sort(thrust::cuda::par(alloc), sorted_paths.begin(), + sorted_paths.end(), + [=] __device__(const PathElement& a, + const PathElement& b) { + if (a.group < b.group) return true; + if (b.group < a.group) return false; + + if (a.path_idx < b.path_idx) return true; + if (b.path_idx < a.path_idx) return false; + + return false; + }); + // Combine zero fraction for all paths + auto path_key = thrust::make_transform_iterator(sorted_paths.begin(), + PathIdxTransformOp()); + PathVectorT combined(sorted_paths.size()); + auto combined_out = thrust::reduce_by_key( + thrust::cuda ::par(alloc), path_key, path_key + sorted_paths.size(), + sorted_paths.begin(), thrust::make_discard_iterator(), combined.begin(), + thrust::equal_to(), + [=] __device__(PathElement a, + const PathElement& b) { + a.zero_fraction *= b.zero_fraction; + return a; + }); + size_t num_paths = combined_out.second - combined.begin(); + // Combine bias for each path, over each group + using size_vector = thrust::device_vector< + size_t, typename DeviceAllocatorT::template rebind::other>; + size_vector keys_out(num_paths); + double_vector values_out(num_paths); + auto group_key = + thrust::make_transform_iterator(combined.begin(), GroupIdxTransformOp()); + auto values = + thrust::make_transform_iterator(combined.begin(), BiasTransformOp()); + + auto out_itr = thrust::reduce_by_key(thrust::cuda::par(alloc), group_key, + group_key + num_paths, values, + keys_out.begin(), values_out.begin()); + + // Write result + size_t n = out_itr.first - keys_out.begin(); + auto counting = thrust::make_counting_iterator(0llu); + auto d_keys_out = keys_out.data().get(); + auto d_values_out = values_out.data().get(); + auto d_bias = bias->data().get(); + thrust::for_each_n(counting, n, [=] __device__(size_t idx) { + d_bias[d_keys_out[idx]] = d_values_out[idx]; + }); +} + +}; // namespace detail + +/*! + * Compute feature contributions on the GPU given a set of unique paths through + * a tree ensemble and a dataset. Uses device memory proportional to the tree + * ensemble size. + * + * \exception std::invalid_argument Thrown when an invalid argument error + * condition occurs. \tparam PathIteratorT Thrust type iterator, may be + * thrust::device_ptr for device memory, or stl iterator/raw pointer for host + * memory. \tparam PhiIteratorT Thrust type iterator, may be + * thrust::device_ptr for device memory, or stl iterator/raw pointer for host + * memory. Value type must be floating point. \tparam DatasetT User-specified + * dataset container. \tparam DeviceAllocatorT Optional thrust style + * allocator. + * + * \param X Thin wrapper over a dataset allocated in device memory. X + * should be trivially copyable as a kernel parameter (i.e. contain only + * pointers to actual data) and must implement the methods + * NumRows()/NumCols()/GetElement(size_t row_idx, size_t col_idx) as __device__ + * functions. GetElement may return NaN where the feature value is missing. + * \param begin Iterator to paths, where separate paths are delineated by + * PathElement.path_idx. Each unique path should contain 1 + * root with feature_idx = -1 and zero_fraction = 1.0. The ordering of path + * elements inside a unique path does not matter - the result will be the same. + * Paths may contain duplicate features. See the PathElement class for more + * information. \param end Path end iterator. \param num_groups Number + * of output groups. In multiclass classification the algorithm outputs feature + * contributions per output class. \param phis_begin Begin iterator for output + * phis. \param phis_end End iterator for output phis. + */ +template , + typename DatasetT, typename PathIteratorT, typename PhiIteratorT> +void GPUTreeShap(DatasetT X, PathIteratorT begin, PathIteratorT end, + size_t num_groups, PhiIteratorT phis_begin, + PhiIteratorT phis_end) { + if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return; + + if (size_t(phis_end - phis_begin) < + X.NumRows() * (X.NumCols() + 1) * num_groups) { + throw std::invalid_argument( + "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * " + "num_groups"); + } + + using size_vector = detail::RebindVector; + using double_vector = detail::RebindVector; + using path_vector = detail::RebindVector< + typename std::iterator_traits::value_type, + DeviceAllocatorT>; + using split_condition = + typename std::iterator_traits::value_type::split_type; + + // Compute the global bias + double_vector temp_phi(phis_end - phis_begin, 0.0); + path_vector device_paths(begin, end); + double_vector bias(num_groups, 0.0); + detail::ComputeBias(device_paths, &bias); + auto d_bias = bias.data().get(); + auto d_temp_phi = temp_phi.data().get(); + thrust::for_each_n(thrust::make_counting_iterator(0llu), + X.NumRows() * num_groups, [=] __device__(size_t idx) { + size_t group = idx % num_groups; + size_t row_idx = idx / num_groups; + d_temp_phi[IndexPhi(row_idx, num_groups, group, + X.NumCols(), X.NumCols())] += + d_bias[group]; + }); + + path_vector deduplicated_paths; + size_vector device_bin_segments; + detail::PreprocessPaths( + &device_paths, &deduplicated_paths, &device_bin_segments); + + detail::ComputeShap(X, device_bin_segments, deduplicated_paths, num_groups, + temp_phi.data().get()); + thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin); +} + +/*! + * Compute feature interaction contributions on the GPU given a set of unique + * paths through a tree ensemble and a dataset. Uses device memory + * proportional to the tree ensemble size. + * + * \exception std::invalid_argument Thrown when an invalid argument error + * condition occurs. + * \tparam DeviceAllocatorT Optional thrust style allocator. + * \tparam DatasetT User-specified dataset container. + * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr + * for device memory, or stl iterator/raw pointer for + * host memory. + * \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr + * for device memory, or stl iterator/raw pointer for + * host memory. Value type must be floating point. + * + * \param X Thin wrapper over a dataset allocated in device memory. X + * should be trivially copyable as a kernel parameter (i.e. + * contain only pointers to actual data) and must implement + * the methods NumRows()/NumCols()/GetElement(size_t row_idx, + * size_t col_idx) as __device__ functions. GetElement may + * return NaN where the feature value is missing. + * \param begin Iterator to paths, where separate paths are delineated by + * PathElement.path_idx. Each unique path should contain 1 + * root with feature_idx = -1 and zero_fraction = 1.0. The + * ordering of path elements inside a unique path does not + * matter - the result will be the same. Paths may contain + * duplicate features. See the PathElement class for more + * information. + * \param end Path end iterator. + * \param num_groups Number of output groups. In multiclass classification the + * algorithm outputs feature contributions per output class. + * \param phis_begin Begin iterator for output phis. + * \param phis_end End iterator for output phis. + */ +template , + typename DatasetT, typename PathIteratorT, typename PhiIteratorT> +void GPUTreeShapInteractions(DatasetT X, PathIteratorT begin, PathIteratorT end, + size_t num_groups, PhiIteratorT phis_begin, + PhiIteratorT phis_end) { + if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return; + if (size_t(phis_end - phis_begin) < + X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) { + throw std::invalid_argument( + "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * " + "(X.NumCols() + 1) * " + "num_groups"); + } + + using size_vector = detail::RebindVector; + using double_vector = detail::RebindVector; + using path_vector = detail::RebindVector< + typename std::iterator_traits::value_type, + DeviceAllocatorT>; + using split_condition = + typename std::iterator_traits::value_type::split_type; + + // Compute the global bias + double_vector temp_phi(phis_end - phis_begin, 0.0); + path_vector device_paths(begin, end); + double_vector bias(num_groups, 0.0); + detail::ComputeBias(device_paths, &bias); + auto d_bias = bias.data().get(); + auto d_temp_phi = temp_phi.data().get(); + thrust::for_each_n( + thrust::make_counting_iterator(0llu), X.NumRows() * num_groups, + [=] __device__(size_t idx) { + size_t group = idx % num_groups; + size_t row_idx = idx / num_groups; + d_temp_phi[IndexPhiInteractions(row_idx, num_groups, group, X.NumCols(), + X.NumCols(), X.NumCols())] += + d_bias[group]; + }); + + path_vector deduplicated_paths; + size_vector device_bin_segments; + detail::PreprocessPaths( + &device_paths, &deduplicated_paths, &device_bin_segments); + + detail::ComputeShapInteractions(X, device_bin_segments, deduplicated_paths, + num_groups, temp_phi.data().get()); + thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin); +} + +/*! + * Compute feature interaction contributions using the Shapley Taylor index on + * the GPU, given a set of unique paths through a tree ensemble and a dataset. + * Uses device memory proportional to the tree ensemble size. + * + * \exception std::invalid_argument Thrown when an invalid argument error + * condition occurs. + * \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr + * for device memory, or stl iterator/raw pointer for + * host memory. Value type must be floating point. + * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr + * for device memory, or stl iterator/raw pointer for + * host memory. + * \tparam DatasetT User-specified dataset container. + * \tparam DeviceAllocatorT Optional thrust style allocator. + * + * \param X Thin wrapper over a dataset allocated in device memory. X + * should be trivially copyable as a kernel parameter (i.e. + * contain only pointers to actual data) and must implement + * the methods NumRows()/NumCols()/GetElement(size_t row_idx, + * size_t col_idx) as __device__ functions. GetElement may + * return NaN where the feature value is missing. + * \param begin Iterator to paths, where separate paths are delineated by + * PathElement.path_idx. Each unique path should contain 1 + * root with feature_idx = -1 and zero_fraction = 1.0. The + * ordering of path elements inside a unique path does not + * matter - the result will be the same. Paths may contain + * duplicate features. See the PathElement class for more + * information. + * \param end Path end iterator. + * \param num_groups Number of output groups. In multiclass classification the + * algorithm outputs feature contributions per output class. + * \param phis_begin Begin iterator for output phis. + * \param phis_end End iterator for output phis. + */ +template , + typename DatasetT, typename PathIteratorT, typename PhiIteratorT> +void GPUTreeShapTaylorInteractions(DatasetT X, PathIteratorT begin, + PathIteratorT end, size_t num_groups, + PhiIteratorT phis_begin, + PhiIteratorT phis_end) { + using phis_type = typename std::iterator_traits::value_type; + static_assert(std::is_floating_point::value, + "Phis type must be floating point"); + + if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return; + + if (size_t(phis_end - phis_begin) < + X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) { + throw std::invalid_argument( + "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * " + "(X.NumCols() + 1) * " + "num_groups"); + } + + using size_vector = detail::RebindVector; + using double_vector = detail::RebindVector; + using path_vector = detail::RebindVector< + typename std::iterator_traits::value_type, + DeviceAllocatorT>; + using split_condition = + typename std::iterator_traits::value_type::split_type; + + // Compute the global bias + double_vector temp_phi(phis_end - phis_begin, 0.0); + path_vector device_paths(begin, end); + double_vector bias(num_groups, 0.0); + detail::ComputeBias(device_paths, &bias); + auto d_bias = bias.data().get(); + auto d_temp_phi = temp_phi.data().get(); + thrust::for_each_n( + thrust::make_counting_iterator(0llu), X.NumRows() * num_groups, + [=] __device__(size_t idx) { + size_t group = idx % num_groups; + size_t row_idx = idx / num_groups; + d_temp_phi[IndexPhiInteractions(row_idx, num_groups, group, X.NumCols(), + X.NumCols(), X.NumCols())] += + d_bias[group]; + }); + + path_vector deduplicated_paths; + size_vector device_bin_segments; + detail::PreprocessPaths( + &device_paths, &deduplicated_paths, &device_bin_segments); + + detail::ComputeShapTaylorInteractions(X, device_bin_segments, + deduplicated_paths, num_groups, + temp_phi.data().get()); + thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin); +} + +/*! + * Compute feature contributions on the GPU given a set of unique paths through a tree ensemble + * and a dataset. Uses device memory proportional to the tree ensemble size. This variant + * implements the interventional tree shap algorithm described here: + * https://drafts.distill.pub/HughChen/its_blog/ + * + * It requires a background dataset R. + * + * \exception std::invalid_argument Thrown when an invalid argument error condition occurs. + * \tparam DeviceAllocatorT Optional thrust style allocator. + * \tparam DatasetT User-specified dataset container. + * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or + * stl iterator/raw pointer for host memory. + * + * \param X Thin wrapper over a dataset allocated in device memory. X should be trivially + * copyable as a kernel parameter (i.e. contain only pointers to actual data) and + * must implement the methods NumRows()/NumCols()/GetElement(size_t row_idx, + * size_t col_idx) as __device__ functions. GetElement may return NaN where the + * feature value is missing. + * \param R Background dataset. + * \param begin Iterator to paths, where separate paths are delineated by + * PathElement.path_idx. Each unique path should contain 1 root with feature_idx = + * -1 and zero_fraction = 1.0. The ordering of path elements inside a unique path + * does not matter - the result will be the same. Paths may contain duplicate + * features. See the PathElement class for more information. + * \param end Path end iterator. + * \param num_groups Number of output groups. In multiclass classification the algorithm outputs + * feature contributions per output class. + * \param phis_begin Begin iterator for output phis. + * \param phis_end End iterator for output phis. + */ +template , + typename DatasetT, typename PathIteratorT, typename PhiIteratorT> +void GPUTreeShapInterventional(DatasetT X, DatasetT R, PathIteratorT begin, + PathIteratorT end, size_t num_groups, + PhiIteratorT phis_begin, PhiIteratorT phis_end) { + if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return; + + if (size_t(phis_end - phis_begin) < + X.NumRows() * (X.NumCols() + 1) * num_groups) { + throw std::invalid_argument( + "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * " + "num_groups"); + } + + using size_vector = detail::RebindVector; + using double_vector = detail::RebindVector; + using path_vector = detail::RebindVector< + typename std::iterator_traits::value_type, + DeviceAllocatorT>; + using split_condition = + typename std::iterator_traits::value_type::split_type; + + double_vector temp_phi(phis_end - phis_begin, 0.0); + path_vector device_paths(begin, end); + + path_vector deduplicated_paths; + size_vector device_bin_segments; + detail::PreprocessPaths( + &device_paths, &deduplicated_paths, &device_bin_segments); + detail::ComputeShapInterventional(X, R, device_bin_segments, + deduplicated_paths, num_groups, + temp_phi.data().get()); + thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin); +} +} // namespace gpu_treeshap diff --git a/lib/shap/cext/tree_shap.h b/lib/shap/cext/tree_shap.h new file mode 100644 index 0000000000000000000000000000000000000000..eb5eef3c567f36397f48a75b16136010910d2d76 --- /dev/null +++ b/lib/shap/cext/tree_shap.h @@ -0,0 +1,1460 @@ +/** + * Fast recursive computation of SHAP values in trees. + * See https://arxiv.org/abs/1802.03888 for details. + * + * Scott Lundberg, 2018 (independent algorithm courtesy of Hugh Chen 2018) + */ + +#include +#include +#include +#include +#include +#include +#if defined(_WIN32) || defined(WIN32) + #include +#elif defined(__MVS__) + #include +#else + #include +#endif +using namespace std; + +typedef double tfloat; +typedef tfloat (* transform_f)(const tfloat margin, const tfloat y); + +namespace FEATURE_DEPENDENCE { + const unsigned independent = 0; + const unsigned tree_path_dependent = 1; + const unsigned global_path_dependent = 2; +} + +struct TreeEnsemble { + int *children_left; + int *children_right; + int *children_default; + int *features; + tfloat *thresholds; + tfloat *values; + tfloat *node_sample_weights; + unsigned max_depth; + unsigned tree_limit; + tfloat *base_offset; + unsigned max_nodes; + unsigned num_outputs; + + TreeEnsemble() {} + TreeEnsemble(int *children_left, int *children_right, int *children_default, int *features, + tfloat *thresholds, tfloat *values, tfloat *node_sample_weights, + unsigned max_depth, unsigned tree_limit, tfloat *base_offset, + unsigned max_nodes, unsigned num_outputs) : + children_left(children_left), children_right(children_right), + children_default(children_default), features(features), thresholds(thresholds), + values(values), node_sample_weights(node_sample_weights), + max_depth(max_depth), tree_limit(tree_limit), + base_offset(base_offset), max_nodes(max_nodes), num_outputs(num_outputs) {} + + void get_tree(TreeEnsemble &tree, const unsigned i) const { + const unsigned d = i * max_nodes; + + tree.children_left = children_left + d; + tree.children_right = children_right + d; + tree.children_default = children_default + d; + tree.features = features + d; + tree.thresholds = thresholds + d; + tree.values = values + d * num_outputs; + tree.node_sample_weights = node_sample_weights + d; + tree.max_depth = max_depth; + tree.tree_limit = 1; + tree.base_offset = base_offset; + tree.max_nodes = max_nodes; + tree.num_outputs = num_outputs; + } + + bool is_leaf(unsigned pos)const { + return children_left[pos] < 0; + } + + void allocate(unsigned tree_limit_in, unsigned max_nodes_in, unsigned num_outputs_in) { + tree_limit = tree_limit_in; + max_nodes = max_nodes_in; + num_outputs = num_outputs_in; + children_left = new int[tree_limit * max_nodes]; + children_right = new int[tree_limit * max_nodes]; + children_default = new int[tree_limit * max_nodes]; + features = new int[tree_limit * max_nodes]; + thresholds = new tfloat[tree_limit * max_nodes]; + values = new tfloat[tree_limit * max_nodes * num_outputs]; + node_sample_weights = new tfloat[tree_limit * max_nodes]; + } + + void free() { + delete[] children_left; + delete[] children_right; + delete[] children_default; + delete[] features; + delete[] thresholds; + delete[] values; + delete[] node_sample_weights; + } +}; + +struct ExplanationDataset { + tfloat *X; + bool *X_missing; + tfloat *y; + tfloat *R; + bool *R_missing; + unsigned num_X; + unsigned M; + unsigned num_R; + + ExplanationDataset() {} + ExplanationDataset(tfloat *X, bool *X_missing, tfloat *y, tfloat *R, bool *R_missing, unsigned num_X, + unsigned M, unsigned num_R) : + X(X), X_missing(X_missing), y(y), R(R), R_missing(R_missing), num_X(num_X), M(M), num_R(num_R) {} + + void get_x_instance(ExplanationDataset &instance, const unsigned i) const { + instance.M = M; + instance.X = X + i * M; + instance.X_missing = X_missing + i * M; + instance.num_X = 1; + } +}; + + +// data we keep about our decision path +// note that pweight is included for convenience and is not tied with the other attributes +// the pweight of the i'th path element is the permutation weight of paths with i-1 ones in them +struct PathElement { + int feature_index; + tfloat zero_fraction; + tfloat one_fraction; + tfloat pweight; + PathElement() {} + PathElement(int i, tfloat z, tfloat o, tfloat w) : + feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {} +}; + +inline tfloat logistic_transform(const tfloat margin, const tfloat y) { + return 1 / (1 + exp(-margin)); +} + +inline tfloat logistic_nlogloss_transform(const tfloat margin, const tfloat y) { + return log(1 + exp(margin)) - y * margin; // y is in {0, 1} +} + +inline tfloat squared_loss_transform(const tfloat margin, const tfloat y) { + return (margin - y) * (margin - y); +} + +namespace MODEL_TRANSFORM { + const unsigned identity = 0; + const unsigned logistic = 1; + const unsigned logistic_nlogloss = 2; + const unsigned squared_loss = 3; +} + +inline transform_f get_transform(unsigned model_transform) { + transform_f transform = NULL; + switch (model_transform) { + case MODEL_TRANSFORM::logistic: + transform = logistic_transform; + break; + + case MODEL_TRANSFORM::logistic_nlogloss: + transform = logistic_nlogloss_transform; + break; + + case MODEL_TRANSFORM::squared_loss: + transform = squared_loss_transform; + break; + } + + return transform; +} + +inline tfloat *tree_predict(unsigned i, const TreeEnsemble &trees, const tfloat *x, const bool *x_missing) { + const unsigned offset = i * trees.max_nodes; + unsigned node = 0; + while (true) { + const unsigned pos = offset + node; + const unsigned feature = trees.features[pos]; + + // we hit a leaf so return a pointer to the values + if (trees.is_leaf(pos)) { + return trees.values + pos * trees.num_outputs; + } + + // otherwise we are at an internal node and need to recurse + if (x_missing[feature]) { + node = trees.children_default[pos]; + } else if (x[feature] <= trees.thresholds[pos]) { + node = trees.children_left[pos]; + } else { + node = trees.children_right[pos]; + } + } +} + +inline void dense_tree_predict(tfloat *out, const TreeEnsemble &trees, const ExplanationDataset &data, unsigned model_transform) { + tfloat *row_out = out; + const tfloat *x = data.X; + const bool *x_missing = data.X_missing; + + // see what transform (if any) we have + transform_f transform = get_transform(model_transform); + + for (unsigned i = 0; i < data.num_X; ++i) { + + // add the base offset + for (unsigned k = 0; k < trees.num_outputs; ++k) { + row_out[k] += trees.base_offset[k]; + } + + // add the leaf values from each tree + for (unsigned j = 0; j < trees.tree_limit; ++j) { + const tfloat *leaf_value = tree_predict(j, trees, x, x_missing); + + for (unsigned k = 0; k < trees.num_outputs; ++k) { + row_out[k] += leaf_value[k]; + } + } + + // apply any needed transform + if (transform != NULL) { + const tfloat y_i = data.y == NULL ? 0 : data.y[i]; + for (unsigned k = 0; k < trees.num_outputs; ++k) { + row_out[k] = transform(row_out[k], y_i); + } + } + + x += data.M; + x_missing += data.M; + row_out += trees.num_outputs; + } +} + +inline void tree_update_weights(unsigned i, TreeEnsemble &trees, const tfloat *x, const bool *x_missing) { + const unsigned offset = i * trees.max_nodes; + unsigned node = 0; + while (true) { + const unsigned pos = offset + node; + const unsigned feature = trees.features[pos]; + + // Record that a sample passed through this node + trees.node_sample_weights[pos] += 1.0; + + // we hit a leaf so return a pointer to the values + if (trees.children_left[pos] < 0) break; + + // otherwise we are at an internal node and need to recurse + if (x_missing[feature]) { + node = trees.children_default[pos]; + } else if (x[feature] <= trees.thresholds[pos]) { + node = trees.children_left[pos]; + } else { + node = trees.children_right[pos]; + } + } +} + +inline void dense_tree_update_weights(TreeEnsemble &trees, const ExplanationDataset &data) { + const tfloat *x = data.X; + const bool *x_missing = data.X_missing; + + for (unsigned i = 0; i < data.num_X; ++i) { + + // add the leaf values from each tree + for (unsigned j = 0; j < trees.tree_limit; ++j) { + tree_update_weights(j, trees, x, x_missing); + } + + x += data.M; + x_missing += data.M; + } +} + +inline void tree_saabas(tfloat *out, const TreeEnsemble &tree, const ExplanationDataset &data) { + unsigned curr_node = 0; + unsigned next_node = 0; + while (true) { + + // we hit a leaf and are done + if (tree.children_left[curr_node] < 0) return; + + // otherwise we are at an internal node and need to recurse + const unsigned feature = tree.features[curr_node]; + if (data.X_missing[feature]) { + next_node = tree.children_default[curr_node]; + } else if (data.X[feature] <= tree.thresholds[curr_node]) { + next_node = tree.children_left[curr_node]; + } else { + next_node = tree.children_right[curr_node]; + } + + // assign credit to this feature as the difference in values at the current node vs. the next node + for (unsigned i = 0; i < tree.num_outputs; ++i) { + out[feature * tree.num_outputs + i] += tree.values[next_node * tree.num_outputs + i] - tree.values[curr_node * tree.num_outputs + i]; + } + + curr_node = next_node; + } +} + +/** + * This runs Tree SHAP with a per tree path conditional dependence assumption. + */ +inline void dense_tree_saabas(tfloat *out_contribs, const TreeEnsemble& trees, const ExplanationDataset &data) { + tfloat *instance_out_contribs; + TreeEnsemble tree; + ExplanationDataset instance; + + // build explanation for each sample + for (unsigned i = 0; i < data.num_X; ++i) { + instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs; + data.get_x_instance(instance, i); + + // aggregate the effect of explaining each tree + // (this works because of the linearity property of Shapley values) + for (unsigned j = 0; j < trees.tree_limit; ++j) { + trees.get_tree(tree, j); + tree_saabas(instance_out_contribs, tree, instance); + } + + // apply the base offset to the bias term + for (unsigned j = 0; j < trees.num_outputs; ++j) { + instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j]; + } + } +} + + +// extend our decision path with a fraction of one and zero extensions +inline void extend_path(PathElement *unique_path, unsigned unique_depth, + tfloat zero_fraction, tfloat one_fraction, int feature_index) { + unique_path[unique_depth].feature_index = feature_index; + unique_path[unique_depth].zero_fraction = zero_fraction; + unique_path[unique_depth].one_fraction = one_fraction; + unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f); + for (int i = unique_depth - 1; i >= 0; i--) { + unique_path[i + 1].pweight += one_fraction * unique_path[i].pweight * (i + 1) + / static_cast(unique_depth + 1); + unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i) + / static_cast(unique_depth + 1); + } +} + +// undo a previous extension of the decision path +inline void unwind_path(PathElement *unique_path, unsigned unique_depth, unsigned path_index) { + const tfloat one_fraction = unique_path[path_index].one_fraction; + const tfloat zero_fraction = unique_path[path_index].zero_fraction; + tfloat next_one_portion = unique_path[unique_depth].pweight; + + for (int i = unique_depth - 1; i >= 0; --i) { + if (one_fraction != 0) { + const tfloat tmp = unique_path[i].pweight; + unique_path[i].pweight = next_one_portion * (unique_depth + 1) + / static_cast((i + 1) * one_fraction); + next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i) + / static_cast(unique_depth + 1); + } else { + unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1)) + / static_cast(zero_fraction * (unique_depth - i)); + } + } + + for (unsigned i = path_index; i < unique_depth; ++i) { + unique_path[i].feature_index = unique_path[i+1].feature_index; + unique_path[i].zero_fraction = unique_path[i+1].zero_fraction; + unique_path[i].one_fraction = unique_path[i+1].one_fraction; + } +} + +// determine what the total permutation weight would be if +// we unwound a previous extension in the decision path +inline tfloat unwound_path_sum(const PathElement *unique_path, unsigned unique_depth, + unsigned path_index) { + const tfloat one_fraction = unique_path[path_index].one_fraction; + const tfloat zero_fraction = unique_path[path_index].zero_fraction; + tfloat next_one_portion = unique_path[unique_depth].pweight; + tfloat total = 0; + + if (one_fraction != 0) { + for (int i = unique_depth - 1; i >= 0; --i) { + const tfloat tmp = next_one_portion / static_cast((i + 1) * one_fraction); + total += tmp; + next_one_portion = unique_path[i].pweight - tmp * zero_fraction * (unique_depth - i); + } + } else { + for (int i = unique_depth - 1; i >= 0; --i) { + total += unique_path[i].pweight / (zero_fraction * (unique_depth - i)); + } + } + return total * (unique_depth + 1); +} + +// recursive computation of SHAP values for a decision tree +inline void tree_shap_recursive(const unsigned num_outputs, const int *children_left, + const int *children_right, + const int *children_default, const int *features, + const tfloat *thresholds, const tfloat *values, + const tfloat *node_sample_weight, + const tfloat *x, const bool *x_missing, tfloat *phi, + unsigned node_index, unsigned unique_depth, + PathElement *parent_unique_path, tfloat parent_zero_fraction, + tfloat parent_one_fraction, int parent_feature_index, + int condition, unsigned condition_feature, + tfloat condition_fraction) { + + // stop if we have no weight coming down to us + if (condition_fraction == 0) return; + + // extend the unique path + PathElement *unique_path = parent_unique_path + unique_depth + 1; + std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path); + + if (condition == 0 || condition_feature != static_cast(parent_feature_index)) { + extend_path(unique_path, unique_depth, parent_zero_fraction, + parent_one_fraction, parent_feature_index); + } + const unsigned split_index = features[node_index]; + + // leaf node + if (children_right[node_index] < 0) { + for (unsigned i = 1; i <= unique_depth; ++i) { + const tfloat w = unwound_path_sum(unique_path, unique_depth, i); + const PathElement &el = unique_path[i]; + const unsigned phi_offset = el.feature_index * num_outputs; + const unsigned values_offset = node_index * num_outputs; + const tfloat scale = w * (el.one_fraction - el.zero_fraction) * condition_fraction; + for (unsigned j = 0; j < num_outputs; ++j) { + phi[phi_offset + j] += scale * values[values_offset + j]; + } + } + + // internal node + } else { + // find which branch is "hot" (meaning x would follow it) + unsigned hot_index = 0; + if (x_missing[split_index]) { + hot_index = children_default[node_index]; + } else if (x[split_index] <= thresholds[node_index]) { + hot_index = children_left[node_index]; + } else { + hot_index = children_right[node_index]; + } + const unsigned cold_index = (static_cast(hot_index) == children_left[node_index] ? + children_right[node_index] : children_left[node_index]); + const tfloat w = node_sample_weight[node_index]; + const tfloat hot_zero_fraction = node_sample_weight[hot_index] / w; + const tfloat cold_zero_fraction = node_sample_weight[cold_index] / w; + tfloat incoming_zero_fraction = 1; + tfloat incoming_one_fraction = 1; + + // see if we have already split on this feature, + // if so we undo that split so we can redo it for this node + unsigned path_index = 0; + for (; path_index <= unique_depth; ++path_index) { + if (static_cast(unique_path[path_index].feature_index) == split_index) break; + } + if (path_index != unique_depth + 1) { + incoming_zero_fraction = unique_path[path_index].zero_fraction; + incoming_one_fraction = unique_path[path_index].one_fraction; + unwind_path(unique_path, unique_depth, path_index); + unique_depth -= 1; + } + + // divide up the condition_fraction among the recursive calls + tfloat hot_condition_fraction = condition_fraction; + tfloat cold_condition_fraction = condition_fraction; + if (condition > 0 && split_index == condition_feature) { + cold_condition_fraction = 0; + unique_depth -= 1; + } else if (condition < 0 && split_index == condition_feature) { + hot_condition_fraction *= hot_zero_fraction; + cold_condition_fraction *= cold_zero_fraction; + unique_depth -= 1; + } + + tree_shap_recursive( + num_outputs, children_left, children_right, children_default, features, thresholds, values, + node_sample_weight, x, x_missing, phi, hot_index, unique_depth + 1, unique_path, + hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction, + split_index, condition, condition_feature, hot_condition_fraction + ); + + tree_shap_recursive( + num_outputs, children_left, children_right, children_default, features, thresholds, values, + node_sample_weight, x, x_missing, phi, cold_index, unique_depth + 1, unique_path, + cold_zero_fraction * incoming_zero_fraction, 0, + split_index, condition, condition_feature, cold_condition_fraction + ); + } +} + +inline int compute_expectations(TreeEnsemble &tree, int i = 0, int depth = 0) { + unsigned max_depth = 0; + + if (tree.children_right[i] >= 0) { + const unsigned li = tree.children_left[i]; + const unsigned ri = tree.children_right[i]; + const unsigned depth_left = compute_expectations(tree, li, depth + 1); + const unsigned depth_right = compute_expectations(tree, ri, depth + 1); + const tfloat left_weight = tree.node_sample_weights[li]; + const tfloat right_weight = tree.node_sample_weights[ri]; + const unsigned li_offset = li * tree.num_outputs; + const unsigned ri_offset = ri * tree.num_outputs; + const unsigned i_offset = i * tree.num_outputs; + for (unsigned j = 0; j < tree.num_outputs; ++j) { + if ((left_weight == 0) && (right_weight == 0)) { + tree.values[i_offset + j] = 0.0; + } else { + const tfloat v = (left_weight * tree.values[li_offset + j] + right_weight * tree.values[ri_offset + j]) / (left_weight + right_weight); + tree.values[i_offset + j] = v; + } + } + max_depth = std::max(depth_left, depth_right) + 1; + } + + if (depth == 0) tree.max_depth = max_depth; + + return max_depth; +} + +inline void tree_shap(const TreeEnsemble& tree, const ExplanationDataset &data, + tfloat *out_contribs, int condition, unsigned condition_feature) { + + // update the reference value with the expected value of the tree's predictions + if (condition == 0) { + for (unsigned j = 0; j < tree.num_outputs; ++j) { + out_contribs[data.M * tree.num_outputs + j] += tree.values[j]; + } + } + + // Pre-allocate space for the unique path data + const unsigned maxd = tree.max_depth + 2; // need a bit more space than the max depth + PathElement *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2]; + + tree_shap_recursive( + tree.num_outputs, tree.children_left, tree.children_right, tree.children_default, + tree.features, tree.thresholds, tree.values, tree.node_sample_weights, data.X, + data.X_missing, out_contribs, 0, 0, unique_path_data, 1, 1, -1, condition, + condition_feature, 1 + ); + + delete[] unique_path_data; +} + + +inline unsigned build_merged_tree_recursive(TreeEnsemble &out_tree, const TreeEnsemble &trees, + const tfloat *data, const bool *data_missing, int *data_inds, + const unsigned num_background_data_inds, unsigned num_data_inds, + unsigned M, unsigned row = 0, unsigned i = 0, unsigned pos = 0, + tfloat *leaf_value = NULL) { + //tfloat new_leaf_value[trees.num_outputs]; + tfloat *new_leaf_value = (tfloat *) alloca(sizeof(tfloat) * trees.num_outputs); // allocate on the stack + unsigned row_offset = row * trees.max_nodes; + + // we have hit a terminal leaf!!! + if (trees.children_left[row_offset + i] < 0 && row + 1 == trees.tree_limit) { + + // create the leaf node + const tfloat *vals = trees.values + (row * trees.max_nodes + i) * trees.num_outputs; + if (leaf_value == NULL) { + for (unsigned j = 0; j < trees.num_outputs; ++j) { + out_tree.values[pos * trees.num_outputs + j] = vals[j]; + } + } else { + for (unsigned j = 0; j < trees.num_outputs; ++j) { + out_tree.values[pos * trees.num_outputs + j] = leaf_value[j] + vals[j]; + } + } + out_tree.children_left[pos] = -1; + out_tree.children_right[pos] = -1; + out_tree.children_default[pos] = -1; + out_tree.features[pos] = -1; + out_tree.thresholds[pos] = 0; + out_tree.node_sample_weights[pos] = num_background_data_inds; + + return pos; + } + + // we hit an intermediate leaf (so just add the value to our accumulator and move to the next tree) + if (trees.children_left[row_offset + i] < 0) { + + // accumulate the value of this original leaf so it will land on all eventual terminal leaves + const tfloat *vals = trees.values + (row * trees.max_nodes + i) * trees.num_outputs; + if (leaf_value == NULL) { + for (unsigned j = 0; j < trees.num_outputs; ++j) { + new_leaf_value[j] = vals[j]; + } + } else { + for (unsigned j = 0; j < trees.num_outputs; ++j) { + new_leaf_value[j] = leaf_value[j] + vals[j]; + } + } + leaf_value = new_leaf_value; + + // move forward to the next tree + row += 1; + row_offset += trees.max_nodes; + i = 0; + } + + // split the data inds by this node's threshold + const tfloat t = trees.thresholds[row_offset + i]; + const int f = trees.features[row_offset + i]; + const bool right_default = trees.children_default[row_offset + i] == trees.children_right[row_offset + i]; + int low_ptr = 0; + int high_ptr = num_data_inds - 1; + unsigned num_left_background_data_inds = 0; + int low_data_ind; + while (low_ptr <= high_ptr) { + low_data_ind = data_inds[low_ptr]; + const int data_ind = std::abs(low_data_ind) * M + f; + const bool is_missing = data_missing[data_ind]; + if ((!is_missing && data[data_ind] > t) || (right_default && is_missing)) { + data_inds[low_ptr] = data_inds[high_ptr]; + data_inds[high_ptr] = low_data_ind; + high_ptr -= 1; + } else { + if (low_data_ind >= 0) ++num_left_background_data_inds; // negative data_inds are not background samples + low_ptr += 1; + } + } + int *left_data_inds = data_inds; + const unsigned num_left_data_inds = low_ptr; + int *right_data_inds = data_inds + low_ptr; + const unsigned num_right_data_inds = num_data_inds - num_left_data_inds; + const unsigned num_right_background_data_inds = num_background_data_inds - num_left_background_data_inds; + + // all the data went right, so we skip creating this node and just recurse right + if (num_left_data_inds == 0) { + return build_merged_tree_recursive( + out_tree, trees, data, data_missing, data_inds, + num_background_data_inds, num_data_inds, M, row, + trees.children_right[row_offset + i], pos, leaf_value + ); + + // all the data went left, so we skip creating this node and just recurse left + } else if (num_right_data_inds == 0) { + return build_merged_tree_recursive( + out_tree, trees, data, data_missing, data_inds, + num_background_data_inds, num_data_inds, M, row, + trees.children_left[row_offset + i], pos, leaf_value + ); + + // data went both ways so we create this node and recurse down both paths + } else { + + // build the left subtree + const unsigned new_pos = build_merged_tree_recursive( + out_tree, trees, data, data_missing, left_data_inds, + num_left_background_data_inds, num_left_data_inds, M, row, + trees.children_left[row_offset + i], pos + 1, leaf_value + ); + + // fill in the data for this node + out_tree.children_left[pos] = pos + 1; + out_tree.children_right[pos] = new_pos + 1; + if (trees.children_left[row_offset + i] == trees.children_default[row_offset + i]) { + out_tree.children_default[pos] = pos + 1; + } else { + out_tree.children_default[pos] = new_pos + 1; + } + + out_tree.features[pos] = trees.features[row_offset + i]; + out_tree.thresholds[pos] = trees.thresholds[row_offset + i]; + out_tree.node_sample_weights[pos] = num_background_data_inds; + + // build the right subtree + return build_merged_tree_recursive( + out_tree, trees, data, data_missing, right_data_inds, + num_right_background_data_inds, num_right_data_inds, M, row, + trees.children_right[row_offset + i], new_pos + 1, leaf_value + ); + } +} + + +inline void build_merged_tree(TreeEnsemble &out_tree, const ExplanationDataset &data, const TreeEnsemble &trees) { + + // create a joint data matrix from both X and R matrices + tfloat *joined_data = new tfloat[(data.num_X + data.num_R) * data.M]; + std::copy(data.X, data.X + data.num_X * data.M, joined_data); + std::copy(data.R, data.R + data.num_R * data.M, joined_data + data.num_X * data.M); + bool *joined_data_missing = new bool[(data.num_X + data.num_R) * data.M]; + std::copy(data.X_missing, data.X_missing + data.num_X * data.M, joined_data_missing); + std::copy(data.R_missing, data.R_missing + data.num_R * data.M, joined_data_missing + data.num_X * data.M); + + // create an starting array of data indexes we will recursively sort + int *data_inds = new int[data.num_X + data.num_R]; + for (unsigned i = 0; i < data.num_X; ++i) data_inds[i] = i; + for (unsigned i = data.num_X; i < data.num_X + data.num_R; ++i) { + data_inds[i] = -i; // a negative index means it won't be recorded as a background sample + } + + build_merged_tree_recursive( + out_tree, trees, joined_data, joined_data_missing, data_inds, data.num_R, + data.num_X + data.num_R, data.M + ); + + delete[] joined_data; + delete[] joined_data_missing; + delete[] data_inds; +} + + +// Independent Tree SHAP functions below here +// ------------------------------------------ +struct Node { + short cl, cr, cd, pnode, feat, pfeat; // uint_16 + float thres, value; + char from_flag; +}; + +#define FROM_NEITHER 0 +#define FROM_X_NOT_R 1 +#define FROM_R_NOT_X 2 + +// https://www.geeksforgeeks.org/space-and-time-efficient-binomial-coefficient/ +inline int bin_coeff(int n, int k) { + int res = 1; + if (k > n - k) + k = n - k; + for (int i = 0; i < k; ++i) { + res *= (n - i); + res /= (i + 1); + } + return res; +} + +// note this only handles single output models, so multi-output models get explained using multiple passes +inline void tree_shap_indep(const unsigned max_depth, const unsigned num_feats, + const unsigned num_nodes, const tfloat *x, + const bool *x_missing, const tfloat *r, + const bool *r_missing, tfloat *out_contribs, + float *pos_lst, float *neg_lst, signed short *feat_hist, + float *memoized_weights, int *node_stack, Node *mytree) { + +// const bool DEBUG = true; +// ofstream myfile; +// if (DEBUG) { +// myfile.open ("/homes/gws/hughchen/shap/out.txt",fstream::app); +// myfile << "Entering tree_shap_indep\n"; +// } + int ns_ctr = 0; + std::fill_n(feat_hist, num_feats, 0); + short node = 0, feat, cl, cr, cd, pnode, pfeat = -1; + short next_xnode = -1, next_rnode = -1; + short next_node = -1, from_child = -1; + float thres, pos_x = 0, neg_x = 0, pos_r = 0, neg_r = 0; + char from_flag; + unsigned M = 0, N = 0; + + Node curr_node = mytree[node]; + feat = curr_node.feat; + thres = curr_node.thres; + cl = curr_node.cl; + cr = curr_node.cr; + cd = curr_node.cd; + + // short circuit when this is a stump tree (with no splits) + if (cl < 0) { + out_contribs[num_feats] += curr_node.value; + return; + } + +// if (DEBUG) { +// myfile << "\nNode: " << node << "\n"; +// myfile << "x[feat]: " << x[feat] << ", r[feat]: " << r[feat] << "\n"; +// myfile << "thres: " << thres << "\n"; +// } + + if (x_missing[feat]) { + next_xnode = cd; + } else if (x[feat] > thres) { + next_xnode = cr; + } else if (x[feat] <= thres) { + next_xnode = cl; + } + + if (r_missing[feat]) { + next_rnode = cd; + } else if (r[feat] > thres) { + next_rnode = cr; + } else if (r[feat] <= thres) { + next_rnode = cl; + } + + if (next_xnode != next_rnode) { + mytree[next_xnode].from_flag = FROM_X_NOT_R; + mytree[next_rnode].from_flag = FROM_R_NOT_X; + } else { + mytree[next_xnode].from_flag = FROM_NEITHER; + } + + // Check if x and r go the same way + if (next_xnode == next_rnode) { + next_node = next_xnode; + } + + // If not, go left + if (next_node < 0) { + next_node = cl; + if (next_rnode == next_node) { // rpath + N = N+1; + feat_hist[feat] -= 1; + } else if (next_xnode == next_node) { // xpath + M = M+1; + N = N+1; + feat_hist[feat] += 1; + } + } + node_stack[ns_ctr] = node; + ns_ctr += 1; + while (true) { + node = next_node; + curr_node = mytree[node]; + feat = curr_node.feat; + thres = curr_node.thres; + cl = curr_node.cl; + cr = curr_node.cr; + cd = curr_node.cd; + pnode = curr_node.pnode; + pfeat = curr_node.pfeat; + from_flag = curr_node.from_flag; + + + +// if (DEBUG) { +// myfile << "\nNode: " << node << "\n"; +// myfile << "N: " << N << ", M: " << M << "\n"; +// myfile << "from_flag==FROM_X_NOT_R: " << (from_flag==FROM_X_NOT_R) << "\n"; +// myfile << "from_flag==FROM_R_NOT_X: " << (from_flag==FROM_R_NOT_X) << "\n"; +// myfile << "from_flag==FROM_NEITHER: " << (from_flag==FROM_NEITHER) << "\n"; +// myfile << "feat_hist[feat]: " << feat_hist[feat] << "\n"; +// } + + // At a leaf + if (cl < 0) { + // if (DEBUG) { + // myfile << "At a leaf\n"; + // } + + if (M == 0) { + out_contribs[num_feats] += mytree[node].value; + } + + // Currently assuming a single output + if (N != 0) { + if (M != 0) { + pos_lst[node] = mytree[node].value * memoized_weights[N + max_depth * (M-1)]; + } + if (M != N) { + neg_lst[node] = -mytree[node].value * memoized_weights[N + max_depth * M]; + } + } +// if (DEBUG) { +// myfile << "pos_lst[node]: " << pos_lst[node] << "\n"; +// myfile << "neg_lst[node]: " << neg_lst[node] << "\n"; +// } + // Pop from node_stack + ns_ctr -= 1; + next_node = node_stack[ns_ctr]; + from_child = node; + // Unwind + if (feat_hist[pfeat] > 0) { + feat_hist[pfeat] -= 1; + } else if (feat_hist[pfeat] < 0) { + feat_hist[pfeat] += 1; + } + if (feat_hist[pfeat] == 0) { + if (from_flag == FROM_X_NOT_R) { + N = N-1; + M = M-1; + } else if (from_flag == FROM_R_NOT_X) { + N = N-1; + } + } + continue; + } + + const bool x_right = x[feat] > thres; + const bool r_right = r[feat] > thres; + + if (x_missing[feat]) { + next_xnode = cd; + } else if (x_right) { + next_xnode = cr; + } else if (!x_right) { + next_xnode = cl; + } + + if (r_missing[feat]) { + next_rnode = cd; + } else if (r_right) { + next_rnode = cr; + } else if (!r_right) { + next_rnode = cl; + } + + if (next_xnode >= 0) { + if (next_xnode != next_rnode) { + mytree[next_xnode].from_flag = FROM_X_NOT_R; + mytree[next_rnode].from_flag = FROM_R_NOT_X; + } else { + mytree[next_xnode].from_flag = FROM_NEITHER; + } + } + + // Arriving at node from parent + if (from_child == -1) { + // if (DEBUG) { + // myfile << "Arriving at node from parent\n"; + // } + node_stack[ns_ctr] = node; + ns_ctr += 1; + next_node = -1; + + // if (DEBUG) { + // myfile << "feat_hist[feat]" << feat_hist[feat] << "\n"; + // } + // Feature is set upstream + if (feat_hist[feat] > 0) { + next_node = next_xnode; + feat_hist[feat] += 1; + } else if (feat_hist[feat] < 0) { + next_node = next_rnode; + feat_hist[feat] -= 1; + } + + // x and r go the same way + if (next_node < 0) { + if (next_xnode == next_rnode) { + next_node = next_xnode; + } + } + + // Go down one path + if (next_node >= 0) { + continue; + } + + // Go down both paths, but go left first + next_node = cl; + if (next_rnode == next_node) { + N = N+1; + feat_hist[feat] -= 1; + } else if (next_xnode == next_node) { + M = M+1; + N = N+1; + feat_hist[feat] += 1; + } + from_child = -1; + continue; + } + + // Arriving at node from child + if (from_child != -1) { +// if (DEBUG) { +// myfile << "Arriving at node from child\n"; +// } + next_node = -1; + // Check if we should unroll immediately + if ((next_rnode == next_xnode) || (feat_hist[feat] != 0)) { + next_node = pnode; + } + + // Came from a single path, so unroll + if (next_node >= 0) { +// if (DEBUG) { +// myfile << "Came from a single path, so unroll\n"; +// } + // At the root node + if (node == 0) { + break; + } + // Update and unroll + pos_lst[node] = pos_lst[from_child]; + neg_lst[node] = neg_lst[from_child]; + +// if (DEBUG) { +// myfile << "pos_lst[node]: " << pos_lst[node] << "\n"; +// myfile << "neg_lst[node]: " << neg_lst[node] << "\n"; +// } + from_child = node; + ns_ctr -= 1; + + // Unwind + if (feat_hist[pfeat] > 0) { + feat_hist[pfeat] -= 1; + } else if (feat_hist[pfeat] < 0) { + feat_hist[pfeat] += 1; + } + if (feat_hist[pfeat] == 0) { + if (from_flag == FROM_X_NOT_R) { + N = N-1; + M = M-1; + } else if (from_flag == FROM_R_NOT_X) { + N = N-1; + } + } + continue; + // Go right - Arriving from the left child + } else if (from_child == cl) { +// if (DEBUG) { +// myfile << "Go right - Arriving from the left child\n"; +// } + node_stack[ns_ctr] = node; + ns_ctr += 1; + next_node = cr; + if (next_xnode == next_node) { + M = M+1; + N = N+1; + feat_hist[feat] += 1; + } else if (next_rnode == next_node) { + N = N+1; + feat_hist[feat] -= 1; + } + from_child = -1; + continue; + // Compute stuff and unroll - Arriving from the right child + } else if (from_child == cr) { +// if (DEBUG) { +// myfile << "Compute stuff and unroll - Arriving from the right child\n"; +// } + pos_x = 0; + neg_x = 0; + pos_r = 0; + neg_r = 0; + if ((next_xnode == cr) && (next_rnode == cl)) { + pos_x = pos_lst[cr]; + neg_x = neg_lst[cr]; + pos_r = pos_lst[cl]; + neg_r = neg_lst[cl]; + } else if ((next_xnode == cl) && (next_rnode == cr)) { + pos_x = pos_lst[cl]; + neg_x = neg_lst[cl]; + pos_r = pos_lst[cr]; + neg_r = neg_lst[cr]; + } + // out_contribs needs to have been initialized as all zeros + // if (pos_x + neg_r != 0) { + // std::cout << "val " << pos_x + neg_r << "\n"; + // } + out_contribs[feat] += pos_x + neg_r; + pos_lst[node] = pos_x + pos_r; + neg_lst[node] = neg_x + neg_r; + +// if (DEBUG) { +// myfile << "out_contribs[feat]: " << out_contribs[feat] << "\n"; +// myfile << "pos_lst[node]: " << pos_lst[node] << "\n"; +// myfile << "neg_lst[node]: " << neg_lst[node] << "\n"; +// } + + // Check if at root + if (node == 0) { + break; + } + + // Pop + ns_ctr -= 1; + next_node = node_stack[ns_ctr]; + from_child = node; + + // Unwind + if (feat_hist[pfeat] > 0) { + feat_hist[pfeat] -= 1; + } else if (feat_hist[pfeat] < 0) { + feat_hist[pfeat] += 1; + } + if (feat_hist[pfeat] == 0) { + if (from_flag == FROM_X_NOT_R) { + N = N-1; + M = M-1; + } else if (from_flag == FROM_R_NOT_X) { + N = N-1; + } + } + continue; + } + } + } + // if (DEBUG) { + // myfile.close(); + // } +} + + +inline void print_progress_bar(tfloat &last_print, tfloat start_time, unsigned i, unsigned total_count) { + const tfloat elapsed_seconds = difftime(time(NULL), start_time); + + if (elapsed_seconds > 10 && elapsed_seconds - last_print > 0.5) { + const tfloat fraction = static_cast(i) / total_count; + const double total_seconds = elapsed_seconds / fraction; + last_print = elapsed_seconds; + + PySys_WriteStderr( + "\r%3.0f%%|%.*s%.*s| %d/%d [%02d:%02d<%02d:%02d] ", + fraction * 100, int(0.5 + fraction*20), "===================", + 20-int(0.5 + fraction*20), " ", + i, total_count, + int(elapsed_seconds/60), int(elapsed_seconds) % 60, + int((total_seconds - elapsed_seconds)/60), int(total_seconds - elapsed_seconds) % 60 + ); + + // Get handle to python stderr file and flush it (https://mail.python.org/pipermail/python-list/2004-November/294912.html) + PyObject *pyStderr = PySys_GetObject("stderr"); + if (pyStderr) { + PyObject *result = PyObject_CallMethod(pyStderr, "flush", NULL); + Py_XDECREF(result); + } + } +} + +/** + * Runs Tree SHAP with feature independence assumptions on dense data. + */ +inline void dense_independent(const TreeEnsemble& trees, const ExplanationDataset &data, + tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) { + + // reformat the trees for faster access + Node *node_trees = new Node[trees.tree_limit * trees.max_nodes]; + for (unsigned i = 0; i < trees.tree_limit; ++i) { + Node *node_tree = node_trees + i * trees.max_nodes; + for (unsigned j = 0; j < trees.max_nodes; ++j) { + const unsigned en_ind = i * trees.max_nodes + j; + node_tree[j].cl = trees.children_left[en_ind]; + node_tree[j].cr = trees.children_right[en_ind]; + node_tree[j].cd = trees.children_default[en_ind]; + if (j == 0) { + node_tree[j].pnode = 0; + } + if (trees.children_left[en_ind] >= 0) { // relies on all unused entries having negative values in them + node_tree[trees.children_left[en_ind]].pnode = j; + node_tree[trees.children_left[en_ind]].pfeat = trees.features[en_ind]; + } + if (trees.children_right[en_ind] >= 0) { // relies on all unused entries having negative values in them + node_tree[trees.children_right[en_ind]].pnode = j; + node_tree[trees.children_right[en_ind]].pfeat = trees.features[en_ind]; + } + + node_tree[j].thres = trees.thresholds[en_ind]; + node_tree[j].feat = trees.features[en_ind]; + } + } + + // preallocate arrays needed by the algorithm + float *pos_lst = new float[trees.max_nodes]; + float *neg_lst = new float[trees.max_nodes]; + int *node_stack = new int[(unsigned) trees.max_depth]; + signed short *feat_hist = new signed short[data.M]; + tfloat *tmp_out_contribs = new tfloat[(data.M + 1)]; + + // precompute all the weight coefficients + float *memoized_weights = new float[(trees.max_depth+1) * (trees.max_depth+1)]; + for (unsigned n = 0; n <= trees.max_depth; ++n) { + for (unsigned m = 0; m <= trees.max_depth; ++m) { + memoized_weights[n + trees.max_depth * m] = 1.0 / (n * bin_coeff(n-1, m)); + } + } + + // compute the explanations for each sample + tfloat *instance_out_contribs; + tfloat rescale_factor = 1.0; + tfloat margin_x = 0; + tfloat margin_r = 0; + time_t start_time = time(NULL); + tfloat last_print = 0; + for (unsigned oind = 0; oind < trees.num_outputs; ++oind) { + // set the values in the reformatted tree to the current output index + for (unsigned i = 0; i < trees.tree_limit; ++i) { + Node *node_tree = node_trees + i * trees.max_nodes; + for (unsigned j = 0; j < trees.max_nodes; ++j) { + const unsigned en_ind = i * trees.max_nodes + j; + node_tree[j].value = trees.values[en_ind * trees.num_outputs + oind]; + } + } + + // loop over all the samples + for (unsigned i = 0; i < data.num_X; ++i) { + const tfloat *x = data.X + i * data.M; + const bool *x_missing = data.X_missing + i * data.M; + instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs; + const tfloat y_i = data.y == NULL ? 0 : data.y[i]; + + print_progress_bar(last_print, start_time, oind * data.num_X + i, data.num_X * trees.num_outputs); + + // compute the model's margin output for x + if (transform != NULL) { + margin_x = trees.base_offset[oind]; + for (unsigned k = 0; k < trees.tree_limit; ++k) { + margin_x += tree_predict(k, trees, x, x_missing)[oind]; + } + } + + for (unsigned j = 0; j < data.num_R; ++j) { + const tfloat *r = data.R + j * data.M; + const bool *r_missing = data.R_missing + j * data.M; + std::fill_n(tmp_out_contribs, (data.M + 1), 0); + + // compute the model's margin output for r + if (transform != NULL) { + margin_r = trees.base_offset[oind]; + for (unsigned k = 0; k < trees.tree_limit; ++k) { + margin_r += tree_predict(k, trees, r, r_missing)[oind]; + } + } + + for (unsigned k = 0; k < trees.tree_limit; ++k) { + tree_shap_indep( + trees.max_depth, data.M, trees.max_nodes, x, x_missing, r, r_missing, + tmp_out_contribs, pos_lst, neg_lst, feat_hist, memoized_weights, + node_stack, node_trees + k * trees.max_nodes + ); + } + + // compute the rescale factor + if (transform != NULL) { + if (margin_x == margin_r) { + rescale_factor = 1.0; + } else { + rescale_factor = (*transform)(margin_x, y_i) - (*transform)(margin_r, y_i); + rescale_factor /= margin_x - margin_r; + } + } + + // add the effect of the current reference to our running total + // this is where we can do per reference scaling for non-linear transformations + for (unsigned k = 0; k < data.M; ++k) { + instance_out_contribs[k * trees.num_outputs + oind] += tmp_out_contribs[k] * rescale_factor; + } + + // Add the base offset + if (transform != NULL) { + instance_out_contribs[data.M * trees.num_outputs + oind] += (*transform)(trees.base_offset[oind] + tmp_out_contribs[data.M], 0); + } else { + instance_out_contribs[data.M * trees.num_outputs + oind] += trees.base_offset[oind] + tmp_out_contribs[data.M]; + } + } + + // average the results over all the references. + for (unsigned j = 0; j < (data.M + 1); ++j) { + instance_out_contribs[j * trees.num_outputs + oind] /= data.num_R; + } + + // apply the base offset to the bias term + // for (unsigned j = 0; j < trees.num_outputs; ++j) { + // instance_out_contribs[data.M * trees.num_outputs + j] += (*transform)(trees.base_offset[j], 0); + // } + } + } + + delete[] tmp_out_contribs; + delete[] node_trees; + delete[] pos_lst; + delete[] neg_lst; + delete[] node_stack; + delete[] feat_hist; + delete[] memoized_weights; +} + + +/** + * This runs Tree SHAP with a per tree path conditional dependence assumption. + */ +inline void dense_tree_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data, + tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) { + tfloat *instance_out_contribs; + TreeEnsemble tree; + ExplanationDataset instance; + + // build explanation for each sample + for (unsigned i = 0; i < data.num_X; ++i) { + instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs; + data.get_x_instance(instance, i); + + // aggregate the effect of explaining each tree + // (this works because of the linearity property of Shapley values) + for (unsigned j = 0; j < trees.tree_limit; ++j) { + trees.get_tree(tree, j); + tree_shap(tree, instance, instance_out_contribs, 0, 0); + } + + // apply the base offset to the bias term + for (unsigned j = 0; j < trees.num_outputs; ++j) { + instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j]; + } + } +} + +// phi = np.zeros((self._current_X.shape[1] + 1, self._current_X.shape[1] + 1, self.n_outputs)) +// phi_diag = np.zeros((self._current_X.shape[1] + 1, self.n_outputs)) +// for t in range(self.tree_limit): +// self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_diag) +// for j in self.trees[t].unique_features: +// phi_on = np.zeros((self._current_X.shape[1] + 1, self.n_outputs)) +// phi_off = np.zeros((self._current_X.shape[1] + 1, self.n_outputs)) +// self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_on, 1, j) +// self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_off, -1, j) +// phi[j] += np.true_divide(np.subtract(phi_on,phi_off),2.0) +// phi_diag[j] -= np.sum(np.true_divide(np.subtract(phi_on,phi_off),2.0)) +// for j in range(self._current_X.shape[1]+1): +// phi[j][j] = phi_diag[j] +// phi /= self.tree_limit +// return phi + +inline void dense_tree_interactions_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data, + tfloat *out_contribs, + tfloat transform(const tfloat, const tfloat)) { + + // build a list of all the unique features in each tree + int amount_of_unique_features = min(data.M, trees.max_nodes); + int *unique_features = new int[trees.tree_limit * amount_of_unique_features]; + std::fill(unique_features, unique_features + trees.tree_limit * amount_of_unique_features, -1); + for (unsigned j = 0; j < trees.tree_limit; ++j) { + const int *features_row = trees.features + j * trees.max_nodes; + int *unique_features_row = unique_features + j * amount_of_unique_features; + for (unsigned k = 0; k < trees.max_nodes; ++k) { + for (unsigned l = 0; l < amount_of_unique_features; ++l) { + if (features_row[k] == unique_features_row[l]) break; + if (unique_features_row[l] < 0) { + unique_features_row[l] = features_row[k]; + break; + } + } + } + } + + // build an interaction explanation for each sample + tfloat *instance_out_contribs; + TreeEnsemble tree; + ExplanationDataset instance; + const unsigned contrib_row_size = (data.M + 1) * trees.num_outputs; + tfloat *diag_contribs = new tfloat[contrib_row_size]; + tfloat *on_contribs = new tfloat[contrib_row_size]; + tfloat *off_contribs = new tfloat[contrib_row_size]; + for (unsigned i = 0; i < data.num_X; ++i) { + instance_out_contribs = out_contribs + i * (data.M + 1) * contrib_row_size; + data.get_x_instance(instance, i); + + // aggregate the effect of explaining each tree + // (this works because of the linearity property of Shapley values) + std::fill(diag_contribs, diag_contribs + contrib_row_size, 0); + for (unsigned j = 0; j < trees.tree_limit; ++j) { + trees.get_tree(tree, j); + tree_shap(tree, instance, diag_contribs, 0, 0); + + const int *unique_features_row = unique_features + j * amount_of_unique_features; + for (unsigned k = 0; k < amount_of_unique_features; ++k) { + const int ind = unique_features_row[k]; + if (ind < 0) break; // < 0 means we have seen all the features for this tree + + // compute the shap value with this feature held on and off + std::fill(on_contribs, on_contribs + contrib_row_size, 0); + std::fill(off_contribs, off_contribs + contrib_row_size, 0); + tree_shap(tree, instance, on_contribs, 1, ind); + tree_shap(tree, instance, off_contribs, -1, ind); + + // save the difference between on and off as the interaction value + for (unsigned l = 0; l < contrib_row_size; ++l) { + const tfloat val = (on_contribs[l] - off_contribs[l]) / 2; + instance_out_contribs[ind * contrib_row_size + l] += val; + diag_contribs[l] -= val; + } + } + } + + // set the diagonal + for (unsigned j = 0; j < data.M + 1; ++j) { + const unsigned offset = j * contrib_row_size + j * trees.num_outputs; + for (unsigned k = 0; k < trees.num_outputs; ++k) { + instance_out_contribs[offset + k] = diag_contribs[j * trees.num_outputs + k]; + } + } + + // apply the base offset to the bias term + const unsigned last_ind = (data.M * (data.M + 1) + data.M) * trees.num_outputs; + for (unsigned j = 0; j < trees.num_outputs; ++j) { + instance_out_contribs[last_ind + j] += trees.base_offset[j]; + } + } + + delete[] diag_contribs; + delete[] on_contribs; + delete[] off_contribs; + delete[] unique_features; +} + +/** + * This runs Tree SHAP with a global path conditional dependence assumption. + * + * By first merging all the trees in a tree ensemble into an equivalent single tree + * this method allows arbitrary marginal transformations and also ensures that all the + * evaluations of the model are consistent with some training data point. + */ +inline void dense_global_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data, + tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) { + + // allocate space for our new merged tree (we save enough room to totally split all samples if need be) + TreeEnsemble merged_tree; + merged_tree.allocate(1, (data.num_X + data.num_R) * 2, trees.num_outputs); + + // collapse the ensemble of trees into a single tree that has the same behavior + // for all the X and R samples in the dataset + build_merged_tree(merged_tree, data, trees); + + // compute the expected value and depth of the new merged tree + compute_expectations(merged_tree); + + // explain each sample using our new merged tree + ExplanationDataset instance; + tfloat *instance_out_contribs; + for (unsigned i = 0; i < data.num_X; ++i) { + instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs; + data.get_x_instance(instance, i); + + // since we now just have a single merged tree we can just use the tree_path_dependent algorithm + tree_shap(merged_tree, instance, instance_out_contribs, 0, 0); + + // apply the base offset to the bias term + for (unsigned j = 0; j < trees.num_outputs; ++j) { + instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j]; + } + } + + merged_tree.free(); +} + + +/** + * The main method for computing Tree SHAP on models using dense data. + */ +inline void dense_tree_shap(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs, + const int feature_dependence, unsigned model_transform, bool interactions) { + + // see what transform (if any) we have + transform_f transform = get_transform(model_transform); + + // dispatch to the correct algorithm handler + switch (feature_dependence) { + case FEATURE_DEPENDENCE::independent: + if (interactions) { + std::cerr << "FEATURE_DEPENDENCE::independent does not support interactions!\n"; + } else dense_independent(trees, data, out_contribs, transform); + return; + + case FEATURE_DEPENDENCE::tree_path_dependent: + if (interactions) dense_tree_interactions_path_dependent(trees, data, out_contribs, transform); + else dense_tree_path_dependent(trees, data, out_contribs, transform); + return; + + case FEATURE_DEPENDENCE::global_path_dependent: + if (interactions) { + std::cerr << "FEATURE_DEPENDENCE::global_path_dependent does not support interactions!\n"; + } else dense_global_path_dependent(trees, data, out_contribs, transform); + return; + } +} diff --git a/lib/shap/datasets.py b/lib/shap/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e33bce1f22ed58e8d26c233fea1fbf270b17cbbf --- /dev/null +++ b/lib/shap/datasets.py @@ -0,0 +1,309 @@ +import os +from urllib.request import urlretrieve + +import numpy as np +import pandas as pd +import sklearn.datasets + +import shap + +github_data_url = "https://github.com/shap/shap/raw/master/data/" + + +def imagenet50(display=False, resolution=224, n_points=None): + """ This is a set of 50 images representative of ImageNet images. + + This dataset was collected by randomly finding a working ImageNet link and then pasting the + original ImageNet image into Google image search restricted to images licensed for reuse. A + similar image (now with rights to reuse) was downloaded as a rough replacement for the original + ImageNet image. The point is to have a random sample of ImageNet for use as a background + distribution for explaining models trained on ImageNet data. + + Note that because the images are only rough replacements the labels might no longer be correct. + """ + + prefix = github_data_url + "imagenet50_" + X = np.load(cache(f"{prefix}{resolution}x{resolution}.npy")).astype(np.float32) + y = np.loadtxt(cache(f"{prefix}labels.csv")) + + if n_points is not None: + X = shap.utils.sample(X, n_points, random_state=0) + y = shap.utils.sample(y, n_points, random_state=0) + + return X, y + + +def california(display=False, n_points=None): + """ Return the california housing data in a nice package. """ + + d = sklearn.datasets.fetch_california_housing() + df = pd.DataFrame(data=d.data, columns=d.feature_names) + target = d.target + + if n_points is not None: + df = shap.utils.sample(df, n_points, random_state=0) + target = shap.utils.sample(target, n_points, random_state=0) + + return df, target + + +def linnerud(display=False, n_points=None): + """ Return the linnerud data in a nice package (multi-target regression). """ + + d = sklearn.datasets.load_linnerud() + X = pd.DataFrame(d.data, columns=d.feature_names) + y = pd.DataFrame(d.target, columns=d.target_names) + + if n_points is not None: + X = shap.utils.sample(X, n_points, random_state=0) + y = shap.utils.sample(y, n_points, random_state=0) + + return X, y + + +def imdb(display=False, n_points=None): + """ Return the classic IMDB sentiment analysis training data in a nice package. + + Full data is at: http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz + Paper to cite when using the data is: http://www.aclweb.org/anthology/P11-1015 + """ + + with open(cache(github_data_url + "imdb_train.txt"), encoding="utf-8") as f: + data = f.readlines() + y = np.ones(25000, dtype=bool) + y[:12500] = 0 + + if n_points is not None: + data = shap.utils.sample(data, n_points, random_state=0) + y = shap.utils.sample(y, n_points, random_state=0) + + return data, y + + +def communitiesandcrime(display=False, n_points=None): + """ Predict total number of non-violent crimes per 100K popuation. + + This dataset is from the classic UCI Machine Learning repository: + https://archive.ics.uci.edu/ml/datasets/Communities+and+Crime+Unnormalized + """ + + raw_data = pd.read_csv( + cache(github_data_url + "CommViolPredUnnormalizedData.txt"), + na_values="?" + ) + + # find the indices where the total violent crimes are known + valid_inds = np.where(np.invert(np.isnan(raw_data.iloc[:,-2])))[0] + + if n_points is not None: + valid_inds = shap.utils.sample(valid_inds, n_points, random_state=0) + + y = np.array(raw_data.iloc[valid_inds,-2], dtype=float) + + # extract the predictive features and remove columns with missing values + X = raw_data.iloc[valid_inds,5:-18] + valid_cols = np.where(np.isnan(X.values).sum(0) == 0)[0] + X = X.iloc[:,valid_cols] + + return X, y + + +def diabetes(display=False, n_points=None): + """ Return the diabetes data in a nice package. """ + + d = sklearn.datasets.load_diabetes() + df = pd.DataFrame(data=d.data, columns=d.feature_names) + target = d.target + + if n_points is not None: + df = shap.utils.sample(df, n_points, random_state=0) + target = shap.utils.sample(target, n_points, random_state=0) + + return df, target + + +def iris(display=False, n_points=None): + """ Return the classic iris data in a nice package. """ + + d = sklearn.datasets.load_iris() + df = pd.DataFrame(data=d.data, columns=d.feature_names) + target = d.target + + if n_points is not None: + df = shap.utils.sample(df, n_points, random_state=0) + target = shap.utils.sample(target, n_points, random_state=0) + + if display: + return df, [d.target_names[v] for v in target] + return df, target + + +def adult(display=False, n_points=None): + """ Return the Adult census data in a nice package. """ + dtypes = [ + ("Age", "float32"), ("Workclass", "category"), ("fnlwgt", "float32"), + ("Education", "category"), ("Education-Num", "float32"), ("Marital Status", "category"), + ("Occupation", "category"), ("Relationship", "category"), ("Race", "category"), + ("Sex", "category"), ("Capital Gain", "float32"), ("Capital Loss", "float32"), + ("Hours per week", "float32"), ("Country", "category"), ("Target", "category") + ] + raw_data = pd.read_csv( + cache(github_data_url + "adult.data"), + names=[d[0] for d in dtypes], + na_values="?", + dtype=dict(dtypes) + ) + + if n_points is not None: + raw_data = shap.utils.sample(raw_data, n_points, random_state=0) + + data = raw_data.drop(["Education"], axis=1) # redundant with Education-Num + filt_dtypes = list(filter(lambda x: x[0] not in ["Target", "Education"], dtypes)) + data["Target"] = data["Target"] == " >50K" + rcode = { + "Not-in-family": 0, + "Unmarried": 1, + "Other-relative": 2, + "Own-child": 3, + "Husband": 4, + "Wife": 5 + } + for k, dtype in filt_dtypes: + if dtype == "category": + if k == "Relationship": + data[k] = np.array([rcode[v.strip()] for v in data[k]]) + else: + data[k] = data[k].cat.codes + + if display: + return raw_data.drop(["Education", "Target", "fnlwgt"], axis=1), data["Target"].values + return data.drop(["Target", "fnlwgt"], axis=1), data["Target"].values + + +def nhanesi(display=False, n_points=None): + """ A nicely packaged version of NHANES I data with surivival times as labels. + """ + X = pd.read_csv(cache(github_data_url + "NHANESI_X.csv"), index_col=0) + y = pd.read_csv(cache(github_data_url + "NHANESI_y.csv"), index_col=0)["y"] + + if n_points is not None: + X = shap.utils.sample(X, n_points, random_state=0) + y = shap.utils.sample(y, n_points, random_state=0) + + if display: + X_display = X.copy() + # X_display["sex_isFemale"] = ["Female" if v else "Male" for v in X["sex_isFemale"]] + return X_display, np.array(y) + return X, np.array(y) + + +def corrgroups60(display=False, n_points=1_000): + """ Correlated Groups 60 + + A simulated dataset with tight correlations among distinct groups of features. + """ + + # set a constant seed + old_seed = np.random.seed() + np.random.seed(0) + + # generate dataset with known correlation + N, M = n_points, 60 + + # set one coefficient from each group of 3 to 1 + beta = np.zeros(M) + beta[0:30:3] = 1 + + # build a correlation matrix with groups of 3 tightly correlated features + C = np.eye(M) + for i in range(0,30,3): + C[i,i+1] = C[i+1,i] = 0.99 + C[i,i+2] = C[i+2,i] = 0.99 + C[i+1,i+2] = C[i+2,i+1] = 0.99 + def f(X): + return np.matmul(X, beta) + + # Make sure the sample correlation is a perfect match + X_start = np.random.randn(N, M) + X_centered = X_start - X_start.mean(0) + Sigma = np.matmul(X_centered.T, X_centered) / X_centered.shape[0] + W = np.linalg.cholesky(np.linalg.inv(Sigma)).T + X_white = np.matmul(X_centered, W.T) + assert np.linalg.norm(np.corrcoef(np.matmul(X_centered, W.T).T) - np.eye(M)) < 1e-6 # ensure this decorrelates the data + + # create the final data + X_final = np.matmul(X_white, np.linalg.cholesky(C).T) + X = X_final + y = f(X) + np.random.randn(N) * 1e-2 + + # restore the previous numpy random seed + np.random.seed(old_seed) + + return pd.DataFrame(X), y + + +def independentlinear60(display=False, n_points=1_000): + """ A simulated dataset with tight correlations among distinct groups of features. + """ + + # set a constant seed + old_seed = np.random.seed() + np.random.seed(0) + + # generate dataset with known correlation + N, M = n_points, 60 + + # set one coefficient from each group of 3 to 1 + beta = np.zeros(M) + beta[0:30:3] = 1 + def f(X): + return np.matmul(X, beta) + + # Make sure the sample correlation is a perfect match + X_start = np.random.randn(N, M) + X = X_start - X_start.mean(0) + y = f(X) + np.random.randn(N) * 1e-2 + + # restore the previous numpy random seed + np.random.seed(old_seed) + + return pd.DataFrame(X), y + + +def a1a(n_points=None): + """ A sparse dataset in scipy csr matrix format. + """ + data, target = sklearn.datasets.load_svmlight_file(cache(github_data_url + 'a1a.svmlight')) + + if n_points is not None: + data = shap.utils.sample(data, n_points, random_state=0) + target = shap.utils.sample(target, n_points, random_state=0) + + return data, target + + +def rank(): + """ Ranking datasets from lightgbm repository. + """ + rank_data_url = 'https://raw.githubusercontent.com/Microsoft/LightGBM/master/examples/lambdarank/' + x_train, y_train = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.train')) + x_test, y_test = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.test')) + q_train = np.loadtxt(cache(rank_data_url + 'rank.train.query')) + q_test = np.loadtxt(cache(rank_data_url + 'rank.test.query')) + + return x_train, y_train, x_test, y_test, q_train, q_test + + +def cache(url, file_name=None): + """ Loads a file from the URL and caches it locally. + """ + if file_name is None: + file_name = os.path.basename(url) + data_dir = os.path.join(os.path.dirname(__file__), "cached_data") + os.makedirs(data_dir, exist_ok=True) + + file_path = os.path.join(data_dir, file_name) + if not os.path.isfile(file_path): + urlretrieve(url, file_path) + + return file_path diff --git a/lib/shap/explainers/__init__.py b/lib/shap/explainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..de3a0edd4fbc26753e622730bd043f49d592ad23 --- /dev/null +++ b/lib/shap/explainers/__init__.py @@ -0,0 +1,38 @@ +from ._additive import AdditiveExplainer +from ._deep import DeepExplainer +from ._exact import ExactExplainer +from ._gpu_tree import GPUTreeExplainer +from ._gradient import GradientExplainer +from ._kernel import KernelExplainer +from ._linear import LinearExplainer +from ._partition import PartitionExplainer +from ._permutation import PermutationExplainer +from ._sampling import SamplingExplainer +from ._tree import TreeExplainer + +# Alternative legacy "short-form" aliases, which are kept here for backwards-compatibility +Additive = AdditiveExplainer +Deep = DeepExplainer +Exact = ExactExplainer +GPUTree = GPUTreeExplainer +Gradient = GradientExplainer +Kernel = KernelExplainer +Linear = LinearExplainer +Partition = PartitionExplainer +Permutation = PermutationExplainer +Sampling = SamplingExplainer +Tree = TreeExplainer + +__all__ = [ + "AdditiveExplainer", + "DeepExplainer", + "ExactExplainer", + "GPUTreeExplainer", + "GradientExplainer", + "KernelExplainer", + "LinearExplainer", + "PartitionExplainer", + "PermutationExplainer", + "SamplingExplainer", + "TreeExplainer", +] diff --git a/lib/shap/explainers/_additive.py b/lib/shap/explainers/_additive.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb5dc47e9c45cd2e0b731ae53a56006f413634a --- /dev/null +++ b/lib/shap/explainers/_additive.py @@ -0,0 +1,187 @@ +import numpy as np + +from ..utils import MaskedModel, safe_isinstance +from ._explainer import Explainer + + +class AdditiveExplainer(Explainer): + """ Computes SHAP values for generalized additive models. + + This assumes that the model only has first-order effects. Extending this to + second- and third-order effects is future work (if you apply this to those models right now + you will get incorrect answers that fail additivity). + """ + + def __init__(self, model, masker, link=None, feature_names=None, linearize_link=True): + """ Build an Additive explainer for the given model using the given masker object. + + Parameters + ---------- + model : function + A callable python object that executes the model given a set of input data samples. + + masker : function or numpy.array or pandas.DataFrame + A callable python object used to "mask" out hidden features of the form `masker(mask, *fargs)`. + It takes a single a binary mask and an input sample and returns a matrix of masked samples. These + masked samples are evaluated using the model function and the outputs are then averaged. + As a shortcut for the standard masking used by SHAP you can pass a background data matrix + instead of a function and that matrix will be used for masking. To use a clustering + game structure you can pass a shap.maskers.Tabular(data, hclustering=\"correlation\") object, but + note that this structure information has no effect on the explanations of additive models. + """ + super().__init__(model, masker, feature_names=feature_names, linearize_link=linearize_link) + + if safe_isinstance(model, "interpret.glassbox.ExplainableBoostingClassifier"): + self.model = model.decision_function + + if self.masker is None: + self._expected_value = model.intercept_ + # num_features = len(model.additive_terms_) + + # fm = MaskedModel(self.model, self.masker, self.link, np.zeros(num_features)) + # masks = np.ones((1, num_features), dtype=bool) + # outputs = fm(masks) + # self.model(np.zeros(num_features)) + # self._zero_offset = self.model(np.zeros(num_features))#model.intercept_#outputs[0] + # self._input_offsets = np.zeros(num_features) #* self._zero_offset + raise NotImplementedError("Masker not given and we don't yet support pulling the distribution centering directly from the EBM model!") + return + + # here we need to compute the offsets ourselves because we can't pull them directly from a model we know about + assert safe_isinstance(self.masker, "shap.maskers.Independent"), "The Additive explainer only supports the Tabular masker at the moment!" + + # pre-compute per-feature offsets + fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, np.zeros(self.masker.shape[1])) + masks = np.ones((self.masker.shape[1]+1, self.masker.shape[1]), dtype=bool) + for i in range(1, self.masker.shape[1]+1): + masks[i,i-1] = False + outputs = fm(masks) + self._zero_offset = outputs[0] + self._input_offsets = np.zeros(masker.shape[1]) + for i in range(1, self.masker.shape[1]+1): + self._input_offsets[i-1] = outputs[i] - self._zero_offset + + self._expected_value = self._input_offsets.sum() + self._zero_offset + + def __call__(self, *args, max_evals=None, silent=False): + """ Explains the output of model(*args), where args represents one or more parallel iterable args. + """ + + # we entirely rely on the general call implementation, we override just to remove **kwargs + # from the function signature + return super().__call__(*args, max_evals=max_evals, silent=silent) + + @staticmethod + def supports_model_with_masker(model, masker): + """ Determines if this explainer can handle the given model. + + This is an abstract static method meant to be implemented by each subclass. + """ + if safe_isinstance(model, "interpret.glassbox.ExplainableBoostingClassifier"): + if model.interactions != 0: + raise NotImplementedError("Need to add support for interaction effects!") + return True + + return False + + def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent): + """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes). + """ + + x = row_args[0] + inputs = np.zeros((len(x), len(x))) + for i in range(len(x)): + inputs[i,i] = x[i] + + phi = self.model(inputs) - self._zero_offset - self._input_offsets + + return { + "values": phi, + "expected_values": self._expected_value, + "mask_shapes": [a.shape for a in row_args], + "main_effects": phi, + "clustering": getattr(self.masker, "clustering", None) + } + +# class AdditiveExplainer(Explainer): +# """ Computes SHAP values for generalized additive models. + +# This assumes that the model only has first order effects. Extending this to +# 2nd and third order effects is future work (if you apply this to those models right now +# you will get incorrect answers that fail additivity). + +# Parameters +# ---------- +# model : function or ExplainableBoostingRegressor +# User supplied additive model either as either a function or a model object. + +# data : numpy.array, pandas.DataFrame +# The background dataset to use for computing conditional expectations. +# feature_perturbation : "interventional" +# Only the standard interventional SHAP values are supported by AdditiveExplainer right now. +# """ + +# def __init__(self, model, data, feature_perturbation="interventional"): +# if feature_perturbation != "interventional": +# raise Exception("Unsupported type of feature_perturbation provided: " + feature_perturbation) + +# if safe_isinstance(model, "interpret.glassbox.ebm.ebm.ExplainableBoostingRegressor"): +# self.f = model.predict +# elif callable(model): +# self.f = model +# else: +# raise ValueError("The passed model must be a recognized object or a function!") + +# # convert dataframes +# if isinstance(data, (pd.Series, pd.DataFrame)): +# data = data.values +# self.data = data + +# # compute the expected value of the model output +# self.expected_value = self.f(data).mean() + +# # pre-compute per-feature offsets +# tmp = np.zeros(data.shape) +# self._zero_offset = self.f(tmp).mean() +# self._feature_offset = np.zeros(data.shape[1]) +# for i in range(data.shape[1]): +# tmp[:,i] = data[:,i] +# self._feature_offset[i] = self.f(tmp).mean() - self._zero_offset +# tmp[:,i] = 0 + + +# def shap_values(self, X): +# """ Estimate the SHAP values for a set of samples. + +# Parameters +# ---------- +# X : numpy.array, pandas.DataFrame or scipy.csr_matrix +# A matrix of samples (# samples x # features) on which to explain the model's output. + +# Returns +# ------- +# For models with a single output this returns a matrix of SHAP values +# (# samples x # features). Each row sums to the difference between the model output for that +# sample and the expected value of the model output (which is stored as expected_value +# attribute of the explainer). +# """ + +# # convert dataframes +# if isinstance(X, (pd.Series, pd.DataFrame)): +# X = X.values + +# # assert isinstance(X, np.ndarray), "Unknown instance type: " + str(type(X)) +# assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!" + +# # convert dataframes +# if isinstance(X, (pd.Series, pd.DataFrame)): +# X = X.values + +# phi = np.zeros(X.shape) +# tmp = np.zeros(X.shape) +# for i in range(X.shape[1]): +# tmp[:,i] = X[:,i] +# phi[:,i] = self.f(tmp) - self._zero_offset - self._feature_offset[i] +# tmp[:,i] = 0 + +# return phi diff --git a/lib/shap/explainers/_deep/__init__.py b/lib/shap/explainers/_deep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef84549b87e36d13dc65b68282de31b3661163fa --- /dev/null +++ b/lib/shap/explainers/_deep/__init__.py @@ -0,0 +1,125 @@ +from .._explainer import Explainer +from .deep_pytorch import PyTorchDeep +from .deep_tf import TFDeep + + +class DeepExplainer(Explainer): + """ Meant to approximate SHAP values for deep learning models. + + This is an enhanced version of the DeepLIFT algorithm (Deep SHAP) where, similar to Kernel SHAP, we + approximate the conditional expectations of SHAP values using a selection of background samples. + Lundberg and Lee, NIPS 2017 showed that the per node attribution rules in DeepLIFT (Shrikumar, + Greenside, and Kundaje, arXiv 2017) can be chosen to approximate Shapley values. By integrating + over many background samples Deep estimates approximate SHAP values such that they sum + up to the difference between the expected model output on the passed background samples and the + current model output (f(x) - E[f(x)]). + + Examples + -------- + See :ref:`Deep Explainer Examples ` + """ + + def __init__(self, model, data, session=None, learning_phase_flags=None): + """ An explainer object for a differentiable model using a given background dataset. + + Note that the complexity of the method scales linearly with the number of background data + samples. Passing the entire training dataset as `data` will give very accurate expected + values, but be unreasonably expensive. The variance of the expectation estimates scale by + roughly 1/sqrt(N) for N background data samples. So 100 samples will give a good estimate, + and 1000 samples a very good estimate of the expected values. + + Parameters + ---------- + model : if framework == 'tensorflow', (input : [tf.Tensor], output : tf.Tensor) + A pair of TensorFlow tensors (or a list and a tensor) that specifies the input and + output of the model to be explained. Note that SHAP values are specific to a single + output value, so the output tf.Tensor should be a single dimensional output (,1). + + if framework == 'pytorch', an nn.Module object (model), or a tuple (model, layer), + where both are nn.Module objects + The model is an nn.Module object which takes as input a tensor (or list of tensors) of + shape data, and returns a single dimensional output. + If the input is a tuple, the returned shap values will be for the input of the + layer argument. layer must be a layer in the model, i.e. model.conv2 + + data : + if framework == 'tensorflow': [numpy.array] or [pandas.DataFrame] + if framework == 'pytorch': [torch.tensor] + The background dataset to use for integrating out features. Deep integrates + over these samples. The data passed here must match the input tensors given in the + first argument. Note that since these samples are integrated over for each sample you + should only something like 100 or 1000 random background samples, not the whole training + dataset. + + if framework == 'tensorflow': + + session : None or tensorflow.Session + The TensorFlow session that has the model we are explaining. If None is passed then + we do our best to find the right session, first looking for a keras session, then + falling back to the default TensorFlow session. + + learning_phase_flags : None or list of tensors + If you have your own custom learning phase flags pass them here. When explaining a prediction + we need to ensure we are not in training mode, since this changes the behavior of ops like + batch norm or dropout. If None is passed then we look for tensors in the graph that look like + learning phase flags (this works for Keras models). Note that we assume all the flags should + have a value of False during predictions (and hence explanations). + """ + # first, we need to find the framework + if type(model) is tuple: + a, b = model + try: + a.named_parameters() + framework = 'pytorch' + except Exception: + framework = 'tensorflow' + else: + try: + model.named_parameters() + framework = 'pytorch' + except Exception: + framework = 'tensorflow' + + if framework == 'tensorflow': + self.explainer = TFDeep(model, data, session, learning_phase_flags) + elif framework == 'pytorch': + self.explainer = PyTorchDeep(model, data) + + self.expected_value = self.explainer.expected_value + self.explainer.framework = framework + + def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True): + """ Return approximate SHAP values for the model applied to the data given by X. + + Parameters + ---------- + X : list, + if framework == 'tensorflow': numpy.array, or pandas.DataFrame + if framework == 'pytorch': torch.tensor + A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to + explain the model's output. + + ranked_outputs : None or int + If ranked_outputs is None then we explain all the outputs in a multi-output model. If + ranked_outputs is a positive integer then we only explain that many of the top model + outputs (where "top" is determined by output_rank_order). Note that this causes a pair + of values to be returned (shap_values, indexes), where shap_values is a list of numpy + arrays for each of the output ranks, and indexes is a matrix that indicates for each sample + which output indexes were choses as "top". + + output_rank_order : "max", "min", or "max_abs" + How to order the model outputs when using ranked_outputs, either by maximum, minimum, or + maximum absolute value. + + Returns + ------- + array or list + For a models with a single output this returns a tensor of SHAP values with the same shape + as X. For a model with multiple outputs this returns a list of SHAP value tensors, each of + which are the same shape as X. If ranked_outputs is None then this list of tensors matches + the number of model outputs. If ranked_outputs is a positive integer a pair is returned + (shap_values, indexes), where shap_values is a list of tensors with a length of + ranked_outputs, and indexes is a matrix that indicates for each sample which output indexes + were chosen as "top". + """ + return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity) diff --git a/lib/shap/explainers/_deep/deep_pytorch.py b/lib/shap/explainers/_deep/deep_pytorch.py new file mode 100644 index 0000000000000000000000000000000000000000..670e4ddb64ffde2e88cdefa7c391b1c63673f146 --- /dev/null +++ b/lib/shap/explainers/_deep/deep_pytorch.py @@ -0,0 +1,386 @@ +import warnings + +import numpy as np +from packaging import version + +from .._explainer import Explainer +from .deep_utils import _check_additivity + +torch = None + + +class PyTorchDeep(Explainer): + + def __init__(self, model, data): + # try and import pytorch + global torch + if torch is None: + import torch + if version.parse(torch.__version__) < version.parse("0.4"): + warnings.warn("Your PyTorch version is older than 0.4 and not supported.") + + # check if we have multiple inputs + self.multi_input = False + if isinstance(data, list): + self.multi_input = True + if not isinstance(data, list): + data = [data] + self.data = data + self.layer = None + self.input_handle = None + self.interim = False + self.interim_inputs_shape = None + self.expected_value = None # to keep the DeepExplainer base happy + if type(model) == tuple: + self.interim = True + model, layer = model + model = model.eval() + self.layer = layer + self.add_target_handle(self.layer) + + # if we are taking an interim layer, the 'data' is going to be the input + # of the interim layer; we will capture this using a forward hook + with torch.no_grad(): + _ = model(*data) + interim_inputs = self.layer.target_input + if type(interim_inputs) is tuple: + # this should always be true, but just to be safe + self.interim_inputs_shape = [i.shape for i in interim_inputs] + else: + self.interim_inputs_shape = [interim_inputs.shape] + self.target_handle.remove() + del self.layer.target_input + self.model = model.eval() + + self.multi_output = False + self.num_outputs = 1 + with torch.no_grad(): + outputs = model(*data) + + # also get the device everything is running on + self.device = outputs.device + if outputs.shape[1] > 1: + self.multi_output = True + self.num_outputs = outputs.shape[1] + self.expected_value = outputs.mean(0).cpu().numpy() + + def add_target_handle(self, layer): + input_handle = layer.register_forward_hook(get_target_input) + self.target_handle = input_handle + + def add_handles(self, model, forward_handle, backward_handle): + """ + Add handles to all non-container layers in the model. + Recursively for non-container layers + """ + handles_list = [] + model_children = list(model.children()) + if model_children: + for child in model_children: + handles_list.extend(self.add_handles(child, forward_handle, backward_handle)) + else: # leaves + handles_list.append(model.register_forward_hook(forward_handle)) + handles_list.append(model.register_full_backward_hook(backward_handle)) + return handles_list + + def remove_attributes(self, model): + """ + Removes the x and y attributes which were added by the forward handles + Recursively searches for non-container layers + """ + for child in model.children(): + if 'nn.modules.container' in str(type(child)): + self.remove_attributes(child) + else: + try: + del child.x + except AttributeError: + pass + try: + del child.y + except AttributeError: + pass + + def gradient(self, idx, inputs): + self.model.zero_grad() + X = [x.requires_grad_() for x in inputs] + outputs = self.model(*X) + selected = [val for val in outputs[:, idx]] + grads = [] + if self.interim: + interim_inputs = self.layer.target_input + for idx, input in enumerate(interim_inputs): + grad = torch.autograd.grad(selected, input, + retain_graph=True if idx + 1 < len(interim_inputs) else None, + allow_unused=True)[0] + if grad is not None: + grad = grad.cpu().numpy() + else: + grad = torch.zeros_like(X[idx]).cpu().numpy() + grads.append(grad) + del self.layer.target_input + return grads, [i.detach().cpu().numpy() for i in interim_inputs] + else: + for idx, x in enumerate(X): + grad = torch.autograd.grad(selected, x, + retain_graph=True if idx + 1 < len(X) else None, + allow_unused=True)[0] + if grad is not None: + grad = grad.cpu().numpy() + else: + grad = torch.zeros_like(X[idx]).cpu().numpy() + grads.append(grad) + return grads + + def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_additivity=True): + # X ~ self.model_input + # X_data ~ self.data + + # check if we have multiple inputs + if not self.multi_input: + assert not isinstance(X, list), "Expected a single tensor model input!" + X = [X] + else: + assert isinstance(X, list), "Expected a list of model inputs!" + + X = [x.detach().to(self.device) for x in X] + + model_output_values = None + + if ranked_outputs is not None and self.multi_output: + with torch.no_grad(): + model_output_values = self.model(*X) + # rank and determine the model outputs that we will explain + if output_rank_order == "max": + _, model_output_ranks = torch.sort(model_output_values, descending=True) + elif output_rank_order == "min": + _, model_output_ranks = torch.sort(model_output_values, descending=False) + elif output_rank_order == "max_abs": + _, model_output_ranks = torch.sort(torch.abs(model_output_values), descending=True) + else: + emsg = "output_rank_order must be max, min, or max_abs!" + raise ValueError(emsg) + model_output_ranks = model_output_ranks[:, :ranked_outputs] + else: + model_output_ranks = (torch.ones((X[0].shape[0], self.num_outputs)).int() * + torch.arange(0, self.num_outputs).int()) + + # add the gradient handles + handles = self.add_handles(self.model, add_interim_values, deeplift_grad) + if self.interim: + self.add_target_handle(self.layer) + + # compute the attributions + output_phis = [] + for i in range(model_output_ranks.shape[1]): + phis = [] + if self.interim: + for k in range(len(self.interim_inputs_shape)): + phis.append(np.zeros((X[0].shape[0], ) + self.interim_inputs_shape[k][1: ])) + else: + for k in range(len(X)): + phis.append(np.zeros(X[k].shape)) + for j in range(X[0].shape[0]): + # tile the inputs to line up with the background data samples + tiled_X = [X[t][j:j + 1].repeat( + (self.data[t].shape[0],) + tuple([1 for k in range(len(X[t].shape) - 1)])) for t + in range(len(X))] + joint_x = [torch.cat((tiled_X[t], self.data[t]), dim=0) for t in range(len(X))] + # run attribution computation graph + feature_ind = model_output_ranks[j, i] + sample_phis = self.gradient(feature_ind, joint_x) + # assign the attributions to the right part of the output arrays + if self.interim: + sample_phis, output = sample_phis + x, data = [], [] + for k in range(len(output)): + x_temp, data_temp = np.split(output[k], 2) + x.append(x_temp) + data.append(data_temp) + for t in range(len(self.interim_inputs_shape)): + phis[t][j] = (sample_phis[t][self.data[t].shape[0]:] * (x[t] - data[t])).mean(0) + else: + for t in range(len(X)): + phis[t][j] = (torch.from_numpy(sample_phis[t][self.data[t].shape[0]:]).to(self.device) * (X[t][j: j + 1] - self.data[t])).cpu().detach().numpy().mean(0) + output_phis.append(phis[0] if not self.multi_input else phis) + # cleanup; remove all gradient handles + for handle in handles: + handle.remove() + self.remove_attributes(self.model) + if self.interim: + self.target_handle.remove() + + # check that the SHAP values sum up to the model output + if check_additivity: + if model_output_values is None: + with torch.no_grad(): + model_output_values = self.model(*X) + + _check_additivity(self, model_output_values.cpu(), output_phis) + + if not self.multi_output: + return output_phis[0] + elif ranked_outputs is not None: + return output_phis, model_output_ranks + else: + return output_phis + +# Module hooks + + +def deeplift_grad(module, grad_input, grad_output): + """The backward hook which computes the deeplift + gradient for an nn.Module + """ + # first, get the module type + module_type = module.__class__.__name__ + # first, check the module is supported + if module_type in op_handler: + if op_handler[module_type].__name__ not in ['passthrough', 'linear_1d']: + return op_handler[module_type](module, grad_input, grad_output) + else: + warnings.warn(f'unrecognized nn.Module: {module_type}') + return grad_input + + +def add_interim_values(module, input, output): + """The forward hook used to save interim tensors, detached + from the graph. Used to calculate the multipliers + """ + try: + del module.x + except AttributeError: + pass + try: + del module.y + except AttributeError: + pass + module_type = module.__class__.__name__ + if module_type in op_handler: + func_name = op_handler[module_type].__name__ + # First, check for cases where we don't need to save the x and y tensors + if func_name == 'passthrough': + pass + else: + # check only the 0th input varies + for i in range(len(input)): + if i != 0 and type(output) is tuple: + assert input[i] == output[i], "Only the 0th input may vary!" + # if a new method is added, it must be added here too. This ensures tensors + # are only saved if necessary + if func_name in ['maxpool', 'nonlinear_1d']: + # only save tensors if necessary + if type(input) is tuple: + setattr(module, 'x', torch.nn.Parameter(input[0].detach())) + else: + setattr(module, 'x', torch.nn.Parameter(input.detach())) + if type(output) is tuple: + setattr(module, 'y', torch.nn.Parameter(output[0].detach())) + else: + setattr(module, 'y', torch.nn.Parameter(output.detach())) + + +def get_target_input(module, input, output): + """A forward hook which saves the tensor - attached to its graph. + Used if we want to explain the interim outputs of a model + """ + try: + del module.target_input + except AttributeError: + pass + setattr(module, 'target_input', input) + + +def passthrough(module, grad_input, grad_output): + """No change made to gradients""" + return None + + +def maxpool(module, grad_input, grad_output): + pool_to_unpool = { + 'MaxPool1d': torch.nn.functional.max_unpool1d, + 'MaxPool2d': torch.nn.functional.max_unpool2d, + 'MaxPool3d': torch.nn.functional.max_unpool3d + } + pool_to_function = { + 'MaxPool1d': torch.nn.functional.max_pool1d, + 'MaxPool2d': torch.nn.functional.max_pool2d, + 'MaxPool3d': torch.nn.functional.max_pool3d + } + delta_in = module.x[: int(module.x.shape[0] / 2)] - module.x[int(module.x.shape[0] / 2):] + dup0 = [2] + [1 for i in delta_in.shape[1:]] + # we also need to check if the output is a tuple + y, ref_output = torch.chunk(module.y, 2) + cross_max = torch.max(y, ref_output) + diffs = torch.cat([cross_max - ref_output, y - cross_max], 0) + + # all of this just to unpool the outputs + with torch.no_grad(): + _, indices = pool_to_function[module.__class__.__name__]( + module.x, module.kernel_size, module.stride, module.padding, + module.dilation, module.ceil_mode, True) + xmax_pos, rmax_pos = torch.chunk(pool_to_unpool[module.__class__.__name__]( + grad_output[0] * diffs, indices, module.kernel_size, module.stride, + module.padding, list(module.x.shape)), 2) + + grad_input = [None for _ in grad_input] + grad_input[0] = torch.where(torch.abs(delta_in) < 1e-7, torch.zeros_like(delta_in), + (xmax_pos + rmax_pos) / delta_in).repeat(dup0) + + return tuple(grad_input) + + +def linear_1d(module, grad_input, grad_output): + """No change made to gradients.""" + return None + + +def nonlinear_1d(module, grad_input, grad_output): + delta_out = module.y[: int(module.y.shape[0] / 2)] - module.y[int(module.y.shape[0] / 2):] + + delta_in = module.x[: int(module.x.shape[0] / 2)] - module.x[int(module.x.shape[0] / 2):] + dup0 = [2] + [1 for i in delta_in.shape[1:]] + # handles numerical instabilities where delta_in is very small by + # just taking the gradient in those cases + grads = [None for _ in grad_input] + grads[0] = torch.where(torch.abs(delta_in.repeat(dup0)) < 1e-6, grad_input[0], + grad_output[0] * (delta_out / delta_in).repeat(dup0)) + return tuple(grads) + + +op_handler = {} + +# passthrough ops, where we make no change to the gradient +op_handler['Dropout3d'] = passthrough +op_handler['Dropout2d'] = passthrough +op_handler['Dropout'] = passthrough +op_handler['AlphaDropout'] = passthrough + +op_handler['Conv1d'] = linear_1d +op_handler['Conv2d'] = linear_1d +op_handler['Conv3d'] = linear_1d +op_handler['ConvTranspose1d'] = linear_1d +op_handler['ConvTranspose2d'] = linear_1d +op_handler['ConvTranspose3d'] = linear_1d +op_handler['Linear'] = linear_1d +op_handler['AvgPool1d'] = linear_1d +op_handler['AvgPool2d'] = linear_1d +op_handler['AvgPool3d'] = linear_1d +op_handler['AdaptiveAvgPool1d'] = linear_1d +op_handler['AdaptiveAvgPool2d'] = linear_1d +op_handler['AdaptiveAvgPool3d'] = linear_1d +op_handler['BatchNorm1d'] = linear_1d +op_handler['BatchNorm2d'] = linear_1d +op_handler['BatchNorm3d'] = linear_1d + +op_handler['LeakyReLU'] = nonlinear_1d +op_handler['ReLU'] = nonlinear_1d +op_handler['ELU'] = nonlinear_1d +op_handler['Sigmoid'] = nonlinear_1d +op_handler["Tanh"] = nonlinear_1d +op_handler["Softplus"] = nonlinear_1d +op_handler['Softmax'] = nonlinear_1d + +op_handler['MaxPool1d'] = maxpool +op_handler['MaxPool2d'] = maxpool +op_handler['MaxPool3d'] = maxpool diff --git a/lib/shap/explainers/_deep/deep_tf.py b/lib/shap/explainers/_deep/deep_tf.py new file mode 100644 index 0000000000000000000000000000000000000000..55c3d7db282e078d8a6d4d57d4f0d1aeccca992d --- /dev/null +++ b/lib/shap/explainers/_deep/deep_tf.py @@ -0,0 +1,763 @@ +import warnings + +import numpy as np +from packaging import version + +from ...utils._exceptions import DimensionError +from .._explainer import Explainer +from ..tf_utils import _get_graph, _get_model_inputs, _get_model_output, _get_session +from .deep_utils import _check_additivity + +tf = None +tf_ops = None +tf_backprop = None +tf_execute = None +tf_gradients_impl = None + +def custom_record_gradient(op_name, inputs, attrs, results): + """ This overrides tensorflow.python.eager.backprop._record_gradient. + + We need to override _record_gradient in order to get gradient backprop to + get called for ResourceGather operations. In order to make this work we + temporarily "lie" about the input type to prevent the node from getting + pruned from the gradient backprop process. We then reset the type directly + afterwards back to what it was (an integer type). + """ + reset_input = False + if op_name == "ResourceGather" and inputs[1].dtype == tf.int32: + inputs[1].__dict__["_dtype"] = tf.float32 + reset_input = True + try: + out = tf_backprop._record_gradient("shap_"+op_name, inputs, attrs, results) + except AttributeError: + out = tf_backprop.record_gradient("shap_"+op_name, inputs, attrs, results) + + if reset_input: + inputs[1].__dict__["_dtype"] = tf.int32 + + return out + +class TFDeep(Explainer): + """ + Using tf.gradients to implement the backpropagation was + inspired by the gradient-based implementation approach proposed by Ancona et al, ICLR 2018. Note + that this package does not currently use the reveal-cancel rule for ReLu units proposed in DeepLIFT. + """ + + def __init__(self, model, data, session=None, learning_phase_flags=None): + """ An explainer object for a deep model using a given background dataset. + + Note that the complexity of the method scales linearly with the number of background data + samples. Passing the entire training dataset as `data` will give very accurate expected + values, but will be computationally expensive. The variance of the expectation estimates scales by + roughly 1/sqrt(N) for N background data samples. So 100 samples will give a good estimate, + and 1000 samples a very good estimate of the expected values. + + Parameters + ---------- + model : tf.keras.Model or (input : [tf.Operation], output : tf.Operation) + A keras model object or a pair of TensorFlow operations (or a list and an op) that + specifies the input and output of the model to be explained. Note that SHAP values + are specific to a single output value, so you get an explanation for each element of + the output tensor (which must be a flat rank one vector). + + data : [numpy.array] or [pandas.DataFrame] or function + The background dataset to use for integrating out features. DeepExplainer integrates + over all these samples for each explanation. The data passed here must match the input + operations given to the model. If a function is supplied, it must be a function that + takes a particular input example and generates the background dataset for that example + session : None or tensorflow.Session + The TensorFlow session that has the model we are explaining. If None is passed then + we do our best to find the right session, first looking for a keras session, then + falling back to the default TensorFlow session. + + learning_phase_flags : None or list of tensors + If you have your own custom learning phase flags pass them here. When explaining a prediction + we need to ensure we are not in training mode, since this changes the behavior of ops like + batch norm or dropout. If None is passed then we look for tensors in the graph that look like + learning phase flags (this works for Keras models). Note that we assume all the flags should + have a value of False during predictions (and hence explanations). + + """ + # try to import tensorflow + global tf, tf_ops, tf_backprop, tf_execute, tf_gradients_impl + if tf is None: + from tensorflow.python.eager import backprop as tf_backprop + from tensorflow.python.eager import execute as tf_execute + from tensorflow.python.framework import ( + ops as tf_ops, + ) + from tensorflow.python.ops import ( + gradients_impl as tf_gradients_impl, + ) + if not hasattr(tf_gradients_impl, "_IsBackpropagatable"): + from tensorflow.python.ops import gradients_util as tf_gradients_impl + import tensorflow as tf + if version.parse(tf.__version__) < version.parse("1.4.0"): + warnings.warn("Your TensorFlow version is older than 1.4.0 and not supported.") + + if version.parse(tf.__version__) >= version.parse("2.4.0"): + warnings.warn("Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.") + + # determine the model inputs and outputs + self.model_inputs = _get_model_inputs(model) + self.model_output = _get_model_output(model) + assert not isinstance(self.model_output, list), "The model output to be explained must be a single tensor!" + assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!" + self.multi_output = True + if len(self.model_output.shape) == 1: + self.multi_output = False + + if tf.executing_eagerly(): + if isinstance(model, tuple) or isinstance(model, list): + assert len(model) == 2, "When a tuple is passed it must be of the form (inputs, outputs)" + from tensorflow.keras import Model + self.model = Model(model[0], model[1]) + else: + self.model = model + + # check if we have multiple inputs + self.multi_input = True + if not isinstance(self.model_inputs, list) or len(self.model_inputs) == 1: + self.multi_input = False + if not isinstance(self.model_inputs, list): + self.model_inputs = [self.model_inputs] + if not isinstance(data, list) and (hasattr(data, "__call__") is False): + data = [data] + self.data = data + + self._vinputs = {} # used to track what op inputs depends on the model inputs + self.orig_grads = {} + + if not tf.executing_eagerly(): + self.session = _get_session(session) + + self.graph = _get_graph(self) + + # if no learning phase flags were given we go looking for them + # ...this will catch the one that keras uses + # we need to find them since we want to make sure learning phase flags are set to False + if learning_phase_flags is None: + self.learning_phase_ops = [] + for op in self.graph.get_operations(): + if 'learning_phase' in op.name and op.type == "Const" and len(op.outputs[0].shape) == 0: + if op.outputs[0].dtype == tf.bool: + self.learning_phase_ops.append(op) + self.learning_phase_flags = [op.outputs[0] for op in self.learning_phase_ops] + else: + self.learning_phase_ops = [t.op for t in learning_phase_flags] + + # save the expected output of the model + # if self.data is a function, set self.expected_value to None + if (hasattr(self.data, '__call__')): + self.expected_value = None + else: + if self.data[0].shape[0] > 5000: + warnings.warn("You have provided over 5k background samples! For better performance consider using smaller random sample.") + if not tf.executing_eagerly(): + self.expected_value = self.run(self.model_output, self.model_inputs, self.data).mean(0) + else: + #if type(self.model)is tuple: + # self.fModel(cnn.inputs, cnn.get_layer(theNameYouWant).outputs) + self.expected_value = tf.reduce_mean(self.model(self.data), 0) + + if not tf.executing_eagerly(): + self._init_between_tensors(self.model_output.op, self.model_inputs) + + # make a blank array that will get lazily filled in with the SHAP value computation + # graphs for each output. Lazy is important since if there are 1000 outputs and we + # only explain the top 5 it would be a waste to build graphs for the other 995 + if not self.multi_output: + self.phi_symbolics = [None] + else: + noutputs = self.model_output.shape.as_list()[1] + if noutputs is not None: + self.phi_symbolics = [None for i in range(noutputs)] + else: + raise DimensionError("The model output tensor to be explained cannot have a static shape in dim 1 of None!") + + def _get_model_output(self, model): + if len(model.layers[-1]._inbound_nodes) == 0: + if len(model.outputs) > 1: + warnings.warn("Only one model output supported.") + return model.outputs[0] + else: + return model.layers[-1].output + + def _init_between_tensors(self, out_op, model_inputs): + # find all the operations in the graph between our inputs and outputs + tensor_blacklist = tensors_blocked_by_false(self.learning_phase_ops) # don't follow learning phase branches + dependence_breakers = [k for k in op_handlers if op_handlers[k] == break_dependence] + back_ops = backward_walk_ops( + [out_op], tensor_blacklist, + dependence_breakers + ) + start_ops = [] + for minput in model_inputs: + for op in minput.consumers(): + start_ops.append(op) + self.between_ops = forward_walk_ops( + start_ops, + tensor_blacklist, dependence_breakers, + within_ops=back_ops + ) + + # note all the tensors that are on the path between the inputs and the output + self.between_tensors = {} + for op in self.between_ops: + for t in op.outputs: + self.between_tensors[t.name] = True + for t in model_inputs: + self.between_tensors[t.name] = True + + # save what types are being used + self.used_types = {} + for op in self.between_ops: + self.used_types[op.type] = True + + def _variable_inputs(self, op): + """ Return which inputs of this operation are variable (i.e. depend on the model inputs). + """ + if op not in self._vinputs: + out = np.zeros(len(op.inputs), dtype=bool) + for i,t in enumerate(op.inputs): + out[i] = t.name in self.between_tensors + self._vinputs[op] = out + return self._vinputs[op] + + def phi_symbolic(self, i): + """ Get the SHAP value computation graph for a given model output. + """ + if self.phi_symbolics[i] is None: + + if not tf.executing_eagerly(): + def anon(): + out = self.model_output[:,i] if self.multi_output else self.model_output + return tf.gradients(out, self.model_inputs) + + self.phi_symbolics[i] = self.execute_with_overridden_gradients(anon) + else: + @tf.function + def grad_graph(shap_rAnD): + phase = tf.keras.backend.learning_phase() + tf.keras.backend.set_learning_phase(0) + + with tf.GradientTape(watch_accessed_variables=False) as tape: + tape.watch(shap_rAnD) + out = self.model(shap_rAnD) + if self.multi_output: + out = out[:,i] + + self._init_between_tensors(out.op, shap_rAnD) + x_grad = tape.gradient(out, shap_rAnD) + tf.keras.backend.set_learning_phase(phase) + return x_grad + + self.phi_symbolics[i] = grad_graph + + return self.phi_symbolics[i] + + def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_additivity=True): + # check if we have multiple inputs + if not self.multi_input: + if isinstance(X, list) and len(X) != 1: + raise ValueError("Expected a single tensor as model input!") + elif not isinstance(X, list): + X = [X] + else: + assert isinstance(X, list), "Expected a list of model inputs!" + assert len(self.model_inputs) == len(X), "Number of model inputs (%d) does not match the number given (%d)!" % (len(self.model_inputs), len(X)) + + # rank and determine the model outputs that we will explain + if ranked_outputs is not None and self.multi_output: + if not tf.executing_eagerly(): + model_output_values = self.run(self.model_output, self.model_inputs, X) + else: + model_output_values = self.model(X) + + if output_rank_order == "max": + model_output_ranks = np.argsort(-model_output_values) + elif output_rank_order == "min": + model_output_ranks = np.argsort(model_output_values) + elif output_rank_order == "max_abs": + model_output_ranks = np.argsort(np.abs(model_output_values)) + else: + emsg = "output_rank_order must be max, min, or max_abs!" + raise ValueError(emsg) + model_output_ranks = model_output_ranks[:,:ranked_outputs] + else: + model_output_ranks = np.tile(np.arange(len(self.phi_symbolics)), (X[0].shape[0], 1)) + + # compute the attributions + output_phis = [] + for i in range(model_output_ranks.shape[1]): + phis = [] + for k in range(len(X)): + phis.append(np.zeros(X[k].shape)) + for j in range(X[0].shape[0]): + if (hasattr(self.data, '__call__')): + bg_data = self.data([X[t][j] for t in range(len(X))]) + if not isinstance(bg_data, list): + bg_data = [bg_data] + else: + bg_data = self.data + + # tile the inputs to line up with the background data samples + tiled_X = [np.tile(X[t][j:j+1], (bg_data[t].shape[0],) + tuple([1 for k in range(len(X[t].shape)-1)])) for t in range(len(X))] + + # we use the first sample for the current sample and the rest for the references + joint_input = [np.concatenate([tiled_X[t], bg_data[t]], 0) for t in range(len(X))] + + # run attribution computation graph + feature_ind = model_output_ranks[j,i] + sample_phis = self.run(self.phi_symbolic(feature_ind), self.model_inputs, joint_input) + + # assign the attributions to the right part of the output arrays + for t in range(len(X)): + phis[t][j] = (sample_phis[t][bg_data[t].shape[0]:] * (X[t][j] - bg_data[t])).mean(0) + + output_phis.append(phis[0] if not self.multi_input else phis) + + # check that the SHAP values sum up to the model output + if check_additivity: + if not tf.executing_eagerly(): + model_output = self.run(self.model_output, self.model_inputs, X) + else: + model_output = self.model(X) + + _check_additivity(self, model_output, output_phis) + + if not self.multi_output: + return output_phis[0] + elif ranked_outputs is not None: + return output_phis, model_output_ranks + else: + return output_phis + + def run(self, out, model_inputs, X): + """ Runs the model while also setting the learning phase flags to False. + """ + if not tf.executing_eagerly(): + feed_dict = dict(zip(model_inputs, X)) + for t in self.learning_phase_flags: + feed_dict[t] = False + return self.session.run(out, feed_dict) + else: + def anon(): + tf_execute.record_gradient = custom_record_gradient + + # build inputs that are correctly shaped, typed, and tf-wrapped + inputs = [] + for i in range(len(X)): + shape = list(self.model_inputs[i].shape) + shape[0] = -1 + data = X[i].reshape(shape) + v = tf.constant(data, dtype=self.model_inputs[i].dtype) + inputs.append(v) + final_out = out(inputs) + try: + tf_execute.record_gradient = tf_backprop._record_gradient + except AttributeError: + tf_execute.record_gradient = tf_backprop.record_gradient + + return final_out + return self.execute_with_overridden_gradients(anon) + + def custom_grad(self, op, *grads): + """ Passes a gradient op creation request to the correct handler. + """ + type_name = op.type[5:] if op.type.startswith("shap_") else op.type + out = op_handlers[type_name](self, op, *grads) # we cut off the shap_ prefix before the lookup + return out + + def execute_with_overridden_gradients(self, f): + # replace the gradients for all the non-linear activations + # we do this by hacking our way into the registry (TODO: find a public API for this if it exists) + reg = tf_ops._gradient_registry._registry + ops_not_in_registry = ['TensorListReserve'] + # NOTE: location_tag taken from tensorflow source for None type ops + location_tag = ("UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN") + # TODO: unclear why some ops are not in the registry with TF 2.0 like TensorListReserve + for non_reg_ops in ops_not_in_registry: + reg[non_reg_ops] = {'type': None, 'location': location_tag} + for n in op_handlers: + if n in reg: + self.orig_grads[n] = reg[n]["type"] + reg["shap_"+n] = { + "type": self.custom_grad, + "location": reg[n]["location"] + } + reg[n]["type"] = self.custom_grad + + # In TensorFlow 1.10 they started pruning out nodes that they think can't be backpropped + # unfortunately that includes the index of embedding layers so we disable that check here + if hasattr(tf_gradients_impl, "_IsBackpropagatable"): + orig_IsBackpropagatable = tf_gradients_impl._IsBackpropagatable + tf_gradients_impl._IsBackpropagatable = lambda tensor: True + + # define the computation graph for the attribution values using a custom gradient-like computation + try: + out = f() + finally: + # reinstate the backpropagatable check + if hasattr(tf_gradients_impl, "_IsBackpropagatable"): + tf_gradients_impl._IsBackpropagatable = orig_IsBackpropagatable + + # restore the original gradient definitions + for n in op_handlers: + if n in reg: + del reg["shap_"+n] + reg[n]["type"] = self.orig_grads[n] + for non_reg_ops in ops_not_in_registry: + del reg[non_reg_ops] + if not tf.executing_eagerly(): + return out + else: + return [v.numpy() for v in out] + +def tensors_blocked_by_false(ops): + """ Follows a set of ops assuming their value is False and find blocked Switch paths. + + This is used to prune away parts of the model graph that are only used during the training + phase (like dropout, batch norm, etc.). + """ + blocked = [] + def recurse(op): + if op.type == "Switch": + blocked.append(op.outputs[1]) # the true path is blocked since we assume the ops we trace are False + else: + for out in op.outputs: + for c in out.consumers(): + recurse(c) + for op in ops: + recurse(op) + + return blocked + +def backward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist): + found_ops = [] + op_stack = [op for op in start_ops] + while len(op_stack) > 0: + op = op_stack.pop() + if op.type not in op_type_blacklist and op not in found_ops: + found_ops.append(op) + for input in op.inputs: + if input not in tensor_blacklist: + op_stack.append(input.op) + return found_ops + +def forward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist, within_ops): + found_ops = [] + op_stack = [op for op in start_ops] + while len(op_stack) > 0: + op = op_stack.pop() + if op.type not in op_type_blacklist and op in within_ops and op not in found_ops: + found_ops.append(op) + for out in op.outputs: + if out not in tensor_blacklist: + for c in out.consumers(): + op_stack.append(c) + return found_ops + + +def softmax(explainer, op, *grads): + """ Just decompose softmax into its components and recurse, we can handle all of them :) + + We assume the 'axis' is the last dimension because the TF codebase swaps the 'axis' to + the last dimension before the softmax op if 'axis' is not already the last dimension. + We also don't subtract the max before tf.exp for numerical stability since that might + mess up the attributions and it seems like TensorFlow doesn't define softmax that way + (according to the docs) + """ + in0 = op.inputs[0] + in0_max = tf.reduce_max(in0, axis=-1, keepdims=True, name="in0_max") + in0_centered = in0 - in0_max + evals = tf.exp(in0_centered, name="custom_exp") + rsum = tf.reduce_sum(evals, axis=-1, keepdims=True) + div = evals / rsum + + # mark these as in-between the inputs and outputs + for op in [evals.op, rsum.op, div.op, in0_centered.op]: + for t in op.outputs: + if t.name not in explainer.between_tensors: + explainer.between_tensors[t.name] = False + + out = tf.gradients(div, in0_centered, grad_ys=grads[0])[0] + + # remove the names we just added + for op in [evals.op, rsum.op, div.op, in0_centered.op]: + for t in op.outputs: + if explainer.between_tensors[t.name] is False: + del explainer.between_tensors[t.name] + + # rescale to account for our shift by in0_max (which we did for numerical stability) + xin0,rin0 = tf.split(in0, 2) + xin0_centered,rin0_centered = tf.split(in0_centered, 2) + delta_in0 = xin0 - rin0 + dup0 = [2] + [1 for i in delta_in0.shape[1:]] + return tf.where( + tf.tile(tf.abs(delta_in0), dup0) < 1e-6, + out, + out * tf.tile((xin0_centered - rin0_centered) / delta_in0, dup0) + ) + +def maxpool(explainer, op, *grads): + xin0,rin0 = tf.split(op.inputs[0], 2) + xout,rout = tf.split(op.outputs[0], 2) + delta_in0 = xin0 - rin0 + dup0 = [2] + [1 for i in delta_in0.shape[1:]] + cross_max = tf.maximum(xout, rout) + diffs = tf.concat([cross_max - rout, xout - cross_max], 0) + if op.type.startswith("shap_"): + op.type = op.type[5:] + xmax_pos,rmax_pos = tf.split(explainer.orig_grads[op.type](op, grads[0] * diffs), 2) + return tf.tile(tf.where( + tf.abs(delta_in0) < 1e-7, + tf.zeros_like(delta_in0), + (xmax_pos + rmax_pos) / delta_in0 + ), dup0) + +def gather(explainer, op, *grads): + #params = op.inputs[0] + indices = op.inputs[1] + #axis = op.inputs[2] + var = explainer._variable_inputs(op) + if var[1] and not var[0]: + assert len(indices.shape) == 2, "Only scalar indices supported right now in GatherV2!" + + xin1,rin1 = tf.split(tf.cast(op.inputs[1], tf.float32), 2) + xout,rout = tf.split(op.outputs[0], 2) + dup_in1 = [2] + [1 for i in xin1.shape[1:]] + dup_out = [2] + [1 for i in xout.shape[1:]] + delta_in1_t = tf.tile(xin1 - rin1, dup_in1) + out_sum = tf.reduce_sum(grads[0] * tf.tile(xout - rout, dup_out), list(range(len(indices.shape), len(grads[0].shape)))) + if op.type == "ResourceGather": + return [None, tf.where( + tf.abs(delta_in1_t) < 1e-6, + tf.zeros_like(delta_in1_t), + out_sum / delta_in1_t + )] + return [None, tf.where( + tf.abs(delta_in1_t) < 1e-6, + tf.zeros_like(delta_in1_t), + out_sum / delta_in1_t + ), None] + elif var[0] and not var[1]: + if op.type.startswith("shap_"): + op.type = op.type[5:] + return [explainer.orig_grads[op.type](op, grads[0]), None] # linear in this case + else: + raise ValueError("Axis not yet supported to be varying for gather op!") + + +def linearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func): + def handler(explainer, op, *grads): + var = explainer._variable_inputs(op) + if var[input_ind0] and not var[input_ind1]: + return linearity_1d_handler(input_ind0, explainer, op, *grads) + elif var[input_ind1] and not var[input_ind0]: + return linearity_1d_handler(input_ind1, explainer, op, *grads) + elif var[input_ind0] and var[input_ind1]: + return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads) + else: + return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function + return handler + +def nonlinearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func): + def handler(explainer, op, *grads): + var = explainer._variable_inputs(op) + if var[input_ind0] and not var[input_ind1]: + return nonlinearity_1d_handler(input_ind0, explainer, op, *grads) + elif var[input_ind1] and not var[input_ind0]: + return nonlinearity_1d_handler(input_ind1, explainer, op, *grads) + elif var[input_ind0] and var[input_ind1]: + return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads) + else: + return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function + return handler + +def nonlinearity_1d(input_ind): + def handler(explainer, op, *grads): + return nonlinearity_1d_handler(input_ind, explainer, op, *grads) + return handler + +def nonlinearity_1d_handler(input_ind, explainer, op, *grads): + # make sure only the given input varies + op_inputs = op.inputs + if op_inputs is None: + op_inputs = op.outputs[0].op.inputs + + for i in range(len(op_inputs)): + if i != input_ind: + assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!" + + xin0, rin0 = tf.split(op_inputs[input_ind], 2) + xout, rout = tf.split(op.outputs[input_ind], 2) + delta_in0 = xin0 - rin0 + if delta_in0.shape is None: + dup0 = [2, 1] + else: + dup0 = [2] + [1 for i in delta_in0.shape[1:]] + out = [None for _ in op_inputs] + if op.type.startswith("shap_"): + op.type = op.type[5:] + orig_grad = explainer.orig_grads[op.type](op, grads[0]) + out[input_ind] = tf.where( + tf.tile(tf.abs(delta_in0), dup0) < 1e-6, + orig_grad[input_ind] if len(op_inputs) > 1 else orig_grad, + grads[0] * tf.tile((xout - rout) / delta_in0, dup0) + ) + return out + +def nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads): + if not (input_ind0 == 0 and input_ind1 == 1): + emsg = "TODO: Can't yet handle double inputs that are not first!" + raise Exception(emsg) + xout,rout = tf.split(op.outputs[0], 2) + in0 = op.inputs[input_ind0] + in1 = op.inputs[input_ind1] + xin0,rin0 = tf.split(in0, 2) + xin1,rin1 = tf.split(in1, 2) + delta_in0 = xin0 - rin0 + delta_in1 = xin1 - rin1 + dup0 = [2] + [1 for i in delta_in0.shape[1:]] + out10 = op_func(xin0, rin1) + out01 = op_func(rin0, xin1) + out11,out00 = xout,rout + out0 = 0.5 * (out11 - out01 + out10 - out00) + out0 = grads[0] * tf.tile(out0 / delta_in0, dup0) + out1 = 0.5 * (out11 - out10 + out01 - out00) + out1 = grads[0] * tf.tile(out1 / delta_in1, dup0) + + # Avoid divide by zero nans + out0 = tf.where(tf.abs(tf.tile(delta_in0, dup0)) < 1e-7, tf.zeros_like(out0), out0) + out1 = tf.where(tf.abs(tf.tile(delta_in1, dup0)) < 1e-7, tf.zeros_like(out1), out1) + + # see if due to broadcasting our gradient shapes don't match our input shapes + if (np.any(np.array(out1.shape) != np.array(in1.shape))): + broadcast_index = np.where(np.array(out1.shape) != np.array(in1.shape))[0][0] + out1 = tf.reduce_sum(out1, axis=broadcast_index, keepdims=True) + elif (np.any(np.array(out0.shape) != np.array(in0.shape))): + broadcast_index = np.where(np.array(out0.shape) != np.array(in0.shape))[0][0] + out0 = tf.reduce_sum(out0, axis=broadcast_index, keepdims=True) + + return [out0, out1] + +def linearity_1d(input_ind): + def handler(explainer, op, *grads): + return linearity_1d_handler(input_ind, explainer, op, *grads) + return handler + +def linearity_1d_handler(input_ind, explainer, op, *grads): + # make sure only the given input varies (negative means only that input cannot vary, and is measured from the end of the list) + for i in range(len(op.inputs)): + if i != input_ind: + assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!" + if op.type.startswith("shap_"): + op.type = op.type[5:] + return explainer.orig_grads[op.type](op, *grads) + +def linearity_with_excluded(input_inds): + def handler(explainer, op, *grads): + return linearity_with_excluded_handler(input_inds, explainer, op, *grads) + return handler + +def linearity_with_excluded_handler(input_inds, explainer, op, *grads): + # make sure the given inputs don't vary (negative is measured from the end of the list) + for i in range(len(op.inputs)): + if i in input_inds or i - len(op.inputs) in input_inds: + assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!" + if op.type.startswith("shap_"): + op.type = op.type[5:] + return explainer.orig_grads[op.type](op, *grads) + +def passthrough(explainer, op, *grads): + if op.type.startswith("shap_"): + op.type = op.type[5:] + return explainer.orig_grads[op.type](op, *grads) + +def break_dependence(explainer, op, *grads): + """ This function name is used to break attribution dependence in the graph traversal. + + These operation types may be connected above input data values in the graph but their outputs + don't depend on the input values (for example they just depend on the shape). + """ + return [None for _ in op.inputs] + + +op_handlers = {} + +# ops that are always linear +op_handlers["Identity"] = passthrough +op_handlers["StridedSlice"] = passthrough +op_handlers["Squeeze"] = passthrough +op_handlers["ExpandDims"] = passthrough +op_handlers["Pack"] = passthrough +op_handlers["BiasAdd"] = passthrough +op_handlers["Unpack"] = passthrough +op_handlers["Add"] = passthrough +op_handlers["Sub"] = passthrough +op_handlers["Merge"] = passthrough +op_handlers["Sum"] = passthrough +op_handlers["Mean"] = passthrough +op_handlers["Cast"] = passthrough +op_handlers["Transpose"] = passthrough +op_handlers["Enter"] = passthrough +op_handlers["Exit"] = passthrough +op_handlers["NextIteration"] = passthrough +op_handlers["Tile"] = passthrough +op_handlers["TensorArrayScatterV3"] = passthrough +op_handlers["TensorArrayReadV3"] = passthrough +op_handlers["TensorArrayWriteV3"] = passthrough + + +# ops that don't pass any attributions to their inputs +op_handlers["Shape"] = break_dependence +op_handlers["RandomUniform"] = break_dependence +op_handlers["ZerosLike"] = break_dependence +#op_handlers["StopGradient"] = break_dependence # this allows us to stop attributions when we want to (like softmax re-centering) + +# ops that are linear and only allow a single input to vary +op_handlers["Reshape"] = linearity_1d(0) +op_handlers["Pad"] = linearity_1d(0) +op_handlers["ReverseV2"] = linearity_1d(0) +op_handlers["ConcatV2"] = linearity_with_excluded([-1]) +op_handlers["Conv2D"] = linearity_1d(0) +op_handlers["Switch"] = linearity_1d(0) +op_handlers["AvgPool"] = linearity_1d(0) +op_handlers["FusedBatchNorm"] = linearity_1d(0) + +# ops that are nonlinear and only allow a single input to vary +op_handlers["Relu"] = nonlinearity_1d(0) +op_handlers["Elu"] = nonlinearity_1d(0) +op_handlers["Sigmoid"] = nonlinearity_1d(0) +op_handlers["Tanh"] = nonlinearity_1d(0) +op_handlers["Softplus"] = nonlinearity_1d(0) +op_handlers["Exp"] = nonlinearity_1d(0) +op_handlers["ClipByValue"] = nonlinearity_1d(0) +op_handlers["Rsqrt"] = nonlinearity_1d(0) +op_handlers["Square"] = nonlinearity_1d(0) +op_handlers["Max"] = nonlinearity_1d(0) + +# ops that are nonlinear and allow two inputs to vary +op_handlers["SquaredDifference"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: (x - y) * (x - y)) +op_handlers["Minimum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.minimum(x, y)) +op_handlers["Maximum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.maximum(x, y)) + +# ops that allow up to two inputs to vary are are linear when only one input varies +op_handlers["Mul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x * y) +op_handlers["RealDiv"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x / y) +op_handlers["MatMul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.matmul(x, y)) + +# ops that need their own custom attribution functions +op_handlers["GatherV2"] = gather +op_handlers["ResourceGather"] = gather +op_handlers["MaxPool"] = maxpool +op_handlers["Softmax"] = softmax + + +# TODO items +# TensorArrayGatherV3 +# Max +# TensorArraySizeV3 +# Range diff --git a/lib/shap/explainers/_deep/deep_utils.py b/lib/shap/explainers/_deep/deep_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..701a7c45feb4ea79326312cc1bfe8d49769d810a --- /dev/null +++ b/lib/shap/explainers/_deep/deep_utils.py @@ -0,0 +1,23 @@ +import numpy as np + + +def _check_additivity(explainer, model_output_values, output_phis): + TOLERANCE = 1e-2 + + assert len(explainer.expected_value) == model_output_values.shape[1], "Length of expected values and model outputs does not match." + + for t in range(len(explainer.expected_value)): + if not explainer.multi_input: + diffs = model_output_values[:, t] - explainer.expected_value[t] - output_phis[t].sum(axis=tuple(range(1, output_phis[t].ndim))) + else: + diffs = model_output_values[:, t] - explainer.expected_value[t] + + for i in range(len(output_phis[t])): + diffs -= output_phis[t][i].sum(axis=tuple(range(1, output_phis[t][i].ndim))) + + maxdiff = np.abs(diffs).max() + + assert maxdiff < TOLERANCE, "The SHAP explanations do not sum up to the model's output! This is either because of a " \ + "rounding error or because an operator in your computation graph was not fully supported. If " \ + "the sum difference of %f is significant compared to the scale of your model outputs, please post " \ + f"as a github issue, with a reproducible example so we can debug it. Used framework: {explainer.framework} - Max. diff: {maxdiff} - Tolerance: {TOLERANCE}" diff --git a/lib/shap/explainers/_exact.py b/lib/shap/explainers/_exact.py new file mode 100644 index 0000000000000000000000000000000000000000..3ade9a82112b713bc908896d2c30b606287fd4dd --- /dev/null +++ b/lib/shap/explainers/_exact.py @@ -0,0 +1,366 @@ +import logging + +import numpy as np +from numba import njit + +from .. import links +from ..models import Model +from ..utils import ( + MaskedModel, + delta_minimization_order, + make_masks, + shapley_coefficients, +) +from ._explainer import Explainer + +log = logging.getLogger('shap') + + +class ExactExplainer(Explainer): + """ Computes SHAP values via an optimized exact enumeration. + + This works well for standard Shapley value maskers for models with less than ~15 features that vary + from the background per sample. It also works well for Owen values from hclustering structured + maskers when there are less than ~100 features that vary from the background per sample. This + explainer minimizes the number of function evaluations needed by ordering the masking sets to + minimize sequential differences. This is done using gray codes for standard Shapley values + and a greedy sorting method for hclustering structured maskers. + """ + + def __init__(self, model, masker, link=links.identity, linearize_link=True, feature_names=None): + """ Build an explainers.Exact object for the given model using the given masker object. + + Parameters + ---------- + model : function + A callable python object that executes the model given a set of input data samples. + + masker : function or numpy.array or pandas.DataFrame + A callable python object used to "mask" out hidden features of the form `masker(mask, *fargs)`. + It takes a single a binary mask and an input sample and returns a matrix of masked samples. These + masked samples are evaluated using the model function and the outputs are then averaged. + As a shortcut for the standard masking used by SHAP you can pass a background data matrix + instead of a function and that matrix will be used for masking. To use a clustering + game structure you can pass a shap.maskers.TabularPartitions(data) object. + + link : function + The link function used to map between the output units of the model and the SHAP value units. By + default it is shap.links.identity, but shap.links.logit can be useful so that expectations are + computed in probability units while explanations remain in the (more naturally additive) log-odds + units. For more details on how link functions work see any overview of link functions for generalized + linear models. + + linearize_link : bool + If we use a non-linear link function to take expectations then models that are additive with respect to that + link function for a single background sample will no longer be additive when using a background masker with + many samples. This for example means that a linear logistic regression model would have interaction effects + that arise from the non-linear changes in expectation averaging. To retain the additively of the model with + still respecting the link function we linearize the link function by default. + """ # TODO link to the link linearization paper when done + super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names) + + self.model = Model(model) + + if getattr(masker, "clustering", None) is not None: + self._partition_masks, self._partition_masks_inds = partition_masks(masker.clustering) + self._partition_delta_indexes = partition_delta_indexes(masker.clustering, self._partition_masks) + + self._gray_code_cache = {} # used to avoid regenerating the same gray code patterns + + def __call__(self, *args, max_evals=100000, main_effects=False, error_bounds=False, batch_size="auto", interactions=1, silent=False): + """ Explains the output of model(*args), where args represents one or more parallel iterators. + """ + + # we entirely rely on the general call implementation, we override just to remove **kwargs + # from the function signature + return super().__call__( + *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds, + batch_size=batch_size, interactions=interactions, silent=silent + ) + + def _cached_gray_codes(self, n): + if n not in self._gray_code_cache: + self._gray_code_cache[n] = gray_code_indexes(n) + return self._gray_code_cache[n] + + def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, interactions, silent): + """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes). + """ + + # build a masked version of the model for the current input sample + fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args) + + # do the standard Shapley values + inds = None + if getattr(self.masker, "clustering", None) is None: + + # see which elements we actually need to perturb + inds = fm.varying_inputs() + + # make sure we have enough evals + if max_evals is not None and max_evals != "auto" and max_evals < 2**len(inds): + raise ValueError( + f"It takes {2**len(inds)} masked evaluations to run the Exact explainer on this instance, but max_evals={max_evals}!" + ) + + # generate the masks in gray code order (so that we change the inputs as little + # as possible while we iterate to minimize the need to re-eval when the inputs + # don't vary from the background) + delta_indexes = self._cached_gray_codes(len(inds)) + + # map to a larger mask that includes the invariant entries + extended_delta_indexes = np.zeros(2**len(inds), dtype=int) + for i in range(2**len(inds)): + if delta_indexes[i] == MaskedModel.delta_mask_noop_value: + extended_delta_indexes[i] = delta_indexes[i] + else: + extended_delta_indexes[i] = inds[delta_indexes[i]] + + # run the model + outputs = fm(extended_delta_indexes, zero_index=0, batch_size=batch_size) + + # Shapley values + # Care: Need to distinguish between `True` and `1` + if interactions is False or (interactions == 1 and interactions is not True): + + # loop over all the outputs to update the rows + coeff = shapley_coefficients(len(inds)) + row_values = np.zeros((len(fm),) + outputs.shape[1:]) + mask = np.zeros(len(fm), dtype=bool) + _compute_grey_code_row_values(row_values, mask, inds, outputs, coeff, extended_delta_indexes, MaskedModel.delta_mask_noop_value) + + # Shapley-Taylor interaction values + elif interactions is True or interactions == 2: + + # loop over all the outputs to update the rows + coeff = shapley_coefficients(len(inds)) + row_values = np.zeros((len(fm), len(fm)) + outputs.shape[1:]) + mask = np.zeros(len(fm), dtype=bool) + _compute_grey_code_row_values_st(row_values, mask, inds, outputs, coeff, extended_delta_indexes, MaskedModel.delta_mask_noop_value) + + elif interactions > 2: + raise NotImplementedError("Currently the Exact explainer does not support interactions higher than order 2!") + + # do a partition tree constrained version of Shapley values + else: + + # make sure we have enough evals + if max_evals is not None and max_evals != "auto" and max_evals < len(fm)**2: + raise ValueError( + f"It takes {len(fm)**2} masked evaluations to run the Exact explainer on this instance, but max_evals={max_evals}!" + ) + + # generate the masks in a hclust order (so that we change the inputs as little + # as possible while we iterate to minimize the need to re-eval when the inputs + # don't vary from the background) + delta_indexes = self._partition_delta_indexes + + # run the model + outputs = fm(delta_indexes, batch_size=batch_size) + + # loop over each output feature + row_values = np.zeros((len(fm),) + outputs.shape[1:]) + for i in range(len(fm)): + on_outputs = outputs[self._partition_masks_inds[i][1]] + off_outputs = outputs[self._partition_masks_inds[i][0]] + row_values[i] = (on_outputs - off_outputs).mean(0) + + # compute the main effects if we need to + main_effect_values = None + if main_effects or interactions is True or interactions == 2: + if inds is None: + inds = np.arange(len(fm)) + main_effect_values = fm.main_effects(inds) + if interactions is True or interactions == 2: + for i in range(len(fm)): + row_values[i, i] = main_effect_values[i] + + return { + "values": row_values, + "expected_values": outputs[0], + "mask_shapes": fm.mask_shapes, + "main_effects": main_effect_values if main_effects else None, + "clustering": getattr(self.masker, "clustering", None) + } + +@njit +def _compute_grey_code_row_values(row_values, mask, inds, outputs, shapley_coeff, extended_delta_indexes, noop_code): + set_size = 0 + M = len(inds) + for i in range(2**M): + + # update the mask + delta_ind = extended_delta_indexes[i] + if delta_ind != noop_code: + mask[delta_ind] = ~mask[delta_ind] + if mask[delta_ind]: + set_size += 1 + else: + set_size -= 1 + + # update the output row values + on_coeff = shapley_coeff[set_size-1] + if set_size < M: + off_coeff = shapley_coeff[set_size] + out = outputs[i] + for j in inds: + if mask[j]: + row_values[j] += out * on_coeff + else: + row_values[j] -= out * off_coeff + +@njit +def _compute_grey_code_row_values_st(row_values, mask, inds, outputs, shapley_coeff, extended_delta_indexes, noop_code): + set_size = 0 + M = len(inds) + for i in range(2**M): + + # update the mask + delta_ind = extended_delta_indexes[i] + if delta_ind != noop_code: + mask[delta_ind] = ~mask[delta_ind] + if mask[delta_ind]: + set_size += 1 + else: + set_size -= 1 + + # distribute the effect of this mask set over all the terms it impacts + out = outputs[i] + for j in range(M): + for k in range(j+1, M): + if not mask[j] and not mask[k]: + delta = out * shapley_coeff[set_size] # * 2 + elif (not mask[j] and mask[k]) or (mask[j] and not mask[k]): + delta = -out * shapley_coeff[set_size - 1] # * 2 + else: # both true + delta = out * shapley_coeff[set_size - 2] # * 2 + row_values[j,k] += delta + row_values[k,j] += delta + +def partition_delta_indexes(partition_tree, all_masks): + """ Return an delta index encoded array of all the masks possible while following the given partition tree. + """ + + # convert the masks to delta index format + mask = np.zeros(all_masks.shape[1], dtype=bool) + delta_inds = [] + for i in range(len(all_masks)): + inds = np.where(mask ^ all_masks[i,:])[0] + + for j in inds[:-1]: + delta_inds.append(-j - 1) # negative + (-1) means we have more inds still to change... + if len(inds) == 0: + delta_inds.append(MaskedModel.delta_mask_noop_value) + else: + delta_inds.extend(inds[-1:]) + mask = all_masks[i,:] + + return np.array(delta_inds) + +def partition_masks(partition_tree): + """ Return an array of all the masks possible while following the given partition tree. + """ + + M = partition_tree.shape[0] + 1 + mask_matrix = make_masks(partition_tree) + all_masks = [] + m00 = np.zeros(M, dtype=bool) + all_masks.append(m00) + all_masks.append(~m00) + #inds_stack = [0,1] + inds_lists = [[[], []] for i in range(M)] + _partition_masks_recurse(len(partition_tree)-1, m00, 0, 1, inds_lists, mask_matrix, partition_tree, M, all_masks) + + all_masks = np.array(all_masks) + + # we resort the clustering matrix to minimize the sequential difference between the masks + # this minimizes the number of model evaluations we need to run when the background sometimes + # matches the foreground. We seem to average about 1.5 feature changes per mask with this + # approach. This is not as clean as the grey code ordering, but a perfect 1 feature change + # ordering is not possible with a clustering tree + order = delta_minimization_order(all_masks) + inverse_order = np.arange(len(order))[np.argsort(order)] + + for inds_list0,inds_list1 in inds_lists: + for i in range(len(inds_list0)): + inds_list0[i] = inverse_order[inds_list0[i]] + inds_list1[i] = inverse_order[inds_list1[i]] + + # Care: inds_lists have different lengths, so partition_masks_inds is a "ragged" array. See GH #3063 + partition_masks = all_masks[order] + partition_masks_inds = [[np.array(on), np.array(off)] for on, off in inds_lists] + return partition_masks, partition_masks_inds + +# TODO: this should be a jit function... which would require preallocating the inds_lists (sizes are 2**depth of that ind) +# TODO: we could also probable avoid making the masks at all and just record the deltas if we want... +def _partition_masks_recurse(index, m00, ind00, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks): + if index < 0: + inds_lists[index + M][0].append(ind00) + inds_lists[index + M][1].append(ind11) + return + + # get our children indexes + left_index = int(partition_tree[index,0] - M) + right_index = int(partition_tree[index,1] - M) + + # build more refined masks + m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix + m10[:] += mask_matrix[left_index+M, :] + m01 = m00.copy() + m01[:] += mask_matrix[right_index+M, :] + + # record the new masks we made + ind01 = len(all_masks) + all_masks.append(m01) + ind10 = len(all_masks) + all_masks.append(m10) + + # inds_stack.append(len(all_masks) - 2) + # inds_stack.append(len(all_masks) - 1) + + # recurse left and right with both 1 (True) and 0 (False) contexts + _partition_masks_recurse(left_index, m00, ind00, ind10, inds_lists, mask_matrix, partition_tree, M, all_masks) + _partition_masks_recurse(right_index, m10, ind10, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks) + _partition_masks_recurse(left_index, m01, ind01, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks) + _partition_masks_recurse(right_index, m00, ind00, ind01, inds_lists, mask_matrix, partition_tree, M, all_masks) + + +def gray_code_masks(nbits): + """ Produces an array of all binary patterns of size nbits in gray code order. + + This is based on code from: http://code.activestate.com/recipes/576592-gray-code-generatoriterator/ + """ + out = np.zeros((2**nbits, nbits), dtype=bool) + li = np.zeros(nbits, dtype=bool) + + for term in range(2, (1<