LLH commited on
Commit
c95b9af
·
1 Parent(s): b71863b

2024/02/14/10:51

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. analysis/shap_model.py +2 -1
  2. diagram/__init__.py +0 -0
  3. lib/__init__.py +0 -0
  4. lib/shap/__init__.py +144 -0
  5. lib/shap/_cext.cp310-win_amd64.pyd +0 -0
  6. lib/shap/_explanation.py +901 -0
  7. lib/shap/_serializable.py +204 -0
  8. lib/shap/_version.py +16 -0
  9. lib/shap/actions/__init__.py +3 -0
  10. lib/shap/actions/_action.py +8 -0
  11. lib/shap/actions/_optimizer.py +92 -0
  12. lib/shap/benchmark/__init__.py +9 -0
  13. lib/shap/benchmark/_compute.py +9 -0
  14. lib/shap/benchmark/_explanation_error.py +181 -0
  15. lib/shap/benchmark/_result.py +34 -0
  16. lib/shap/benchmark/_sequential.py +332 -0
  17. lib/shap/benchmark/experiments.py +414 -0
  18. lib/shap/benchmark/framework.py +113 -0
  19. lib/shap/benchmark/measures.py +424 -0
  20. lib/shap/benchmark/methods.py +148 -0
  21. lib/shap/benchmark/metrics.py +824 -0
  22. lib/shap/benchmark/models.py +230 -0
  23. lib/shap/benchmark/plots.py +566 -0
  24. lib/shap/cext/_cext.cc +560 -0
  25. lib/shap/cext/_cext_gpu.cc +187 -0
  26. lib/shap/cext/_cext_gpu.cu +353 -0
  27. lib/shap/cext/gpu_treeshap.h +1535 -0
  28. lib/shap/cext/tree_shap.h +1460 -0
  29. lib/shap/datasets.py +309 -0
  30. lib/shap/explainers/__init__.py +38 -0
  31. lib/shap/explainers/_additive.py +187 -0
  32. lib/shap/explainers/_deep/__init__.py +125 -0
  33. lib/shap/explainers/_deep/deep_pytorch.py +386 -0
  34. lib/shap/explainers/_deep/deep_tf.py +763 -0
  35. lib/shap/explainers/_deep/deep_utils.py +23 -0
  36. lib/shap/explainers/_exact.py +366 -0
  37. lib/shap/explainers/_explainer.py +457 -0
  38. lib/shap/explainers/_gpu_tree.py +179 -0
  39. lib/shap/explainers/_gradient.py +592 -0
  40. lib/shap/explainers/_kernel.py +696 -0
  41. lib/shap/explainers/_linear.py +406 -0
  42. lib/shap/explainers/_partition.py +681 -0
  43. lib/shap/explainers/_permutation.py +217 -0
  44. lib/shap/explainers/_sampling.py +199 -0
  45. lib/shap/explainers/_tree.py +0 -0
  46. lib/shap/explainers/other/__init__.py +26 -0
  47. lib/shap/explainers/other/_coefficient.py +17 -0
  48. lib/shap/explainers/other/_lime.py +73 -0
  49. lib/shap/explainers/other/_maple.py +306 -0
  50. lib/shap/explainers/other/_random.py +79 -0
analysis/shap_model.py CHANGED
@@ -1,6 +1,7 @@
1
- import shap
2
  import matplotlib.pyplot as plt
3
 
 
 
4
 
5
  def shap_calculate(model, x, feature_names):
6
  explainer = shap.Explainer(model.predict, x)
 
 
1
  import matplotlib.pyplot as plt
2
 
3
+ import lib.shap as shap
4
+
5
 
6
  def shap_calculate(model, x, feature_names):
7
  explainer = shap.Explainer(model.predict, x)
diagram/__init__.py ADDED
File without changes
lib/__init__.py ADDED
File without changes
lib/shap/__init__.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._explanation import Cohorts, Explanation
2
+
3
+ # explainers
4
+ from .explainers import other
5
+ from .explainers._additive import AdditiveExplainer
6
+ from .explainers._deep import DeepExplainer
7
+ from .explainers._exact import ExactExplainer
8
+ from .explainers._explainer import Explainer
9
+ from .explainers._gpu_tree import GPUTreeExplainer
10
+ from .explainers._gradient import GradientExplainer
11
+ from .explainers._kernel import KernelExplainer
12
+ from .explainers._linear import LinearExplainer
13
+ from .explainers._partition import PartitionExplainer
14
+ from .explainers._permutation import PermutationExplainer
15
+ from .explainers._sampling import SamplingExplainer
16
+ from .explainers._tree import TreeExplainer
17
+
18
+ try:
19
+ # Version from setuptools-scm
20
+ from ._version import version as __version__
21
+ except ImportError:
22
+ # Expected when running locally without build
23
+ __version__ = "0.0.0-not-built"
24
+
25
+ _no_matplotlib_warning = "matplotlib is not installed so plotting is not available! Run `pip install matplotlib` " \
26
+ "to fix this."
27
+
28
+
29
+ # plotting (only loaded if matplotlib is present)
30
+ def unsupported(*args, **kwargs):
31
+ raise ImportError(_no_matplotlib_warning)
32
+
33
+
34
+ class UnsupportedModule:
35
+ def __getattribute__(self, item):
36
+ raise ImportError(_no_matplotlib_warning)
37
+
38
+
39
+ try:
40
+ import matplotlib # noqa: F401
41
+ have_matplotlib = True
42
+ except ImportError:
43
+ have_matplotlib = False
44
+ if have_matplotlib:
45
+ from . import plots
46
+ from .plots._bar import bar_legacy as bar_plot
47
+ from .plots._beeswarm import summary_legacy as summary_plot
48
+ from .plots._decision import decision as decision_plot
49
+ from .plots._decision import multioutput_decision as multioutput_decision_plot
50
+ from .plots._embedding import embedding as embedding_plot
51
+ from .plots._force import force as force_plot
52
+ from .plots._force import getjs, initjs, save_html
53
+ from .plots._group_difference import group_difference as group_difference_plot
54
+ from .plots._heatmap import heatmap as heatmap_plot
55
+ from .plots._image import image as image_plot
56
+ from .plots._monitoring import monitoring as monitoring_plot
57
+ from .plots._partial_dependence import partial_dependence as partial_dependence_plot
58
+ from .plots._scatter import dependence_legacy as dependence_plot
59
+ from .plots._text import text as text_plot
60
+ from .plots._violin import violin as violin_plot
61
+ from .plots._waterfall import waterfall as waterfall_plot
62
+ else:
63
+ bar_plot = unsupported
64
+ summary_plot = unsupported
65
+ decision_plot = unsupported
66
+ multioutput_decision_plot = unsupported
67
+ embedding_plot = unsupported
68
+ force_plot = unsupported
69
+ getjs = unsupported
70
+ initjs = unsupported
71
+ save_html = unsupported
72
+ group_difference_plot = unsupported
73
+ heatmap_plot = unsupported
74
+ image_plot = unsupported
75
+ monitoring_plot = unsupported
76
+ partial_dependence_plot = unsupported
77
+ dependence_plot = unsupported
78
+ text_plot = unsupported
79
+ violin_plot = unsupported
80
+ waterfall_plot = unsupported
81
+ # If matplotlib is available, then the plots submodule will be directly available.
82
+ # If not, we need to define something that will issue a meaningful warning message
83
+ # (rather than ModuleNotFound).
84
+ plots = UnsupportedModule()
85
+
86
+
87
+ # other stuff :)
88
+ from . import datasets, links, utils # noqa: E402
89
+ from .actions._optimizer import ActionOptimizer # noqa: E402
90
+ from .utils import approximate_interactions, sample # noqa: E402
91
+
92
+ #from . import benchmark
93
+ from .utils._legacy import kmeans # noqa: E402
94
+
95
+ # Use __all__ to let type checkers know what is part of the public API.
96
+ __all__ = [
97
+ "Cohorts",
98
+ "Explanation",
99
+
100
+ # Explainers
101
+ "other",
102
+ "AdditiveExplainer",
103
+ "DeepExplainer",
104
+ "ExactExplainer",
105
+ "Explainer",
106
+ "GPUTreeExplainer",
107
+ "GradientExplainer",
108
+ "KernelExplainer",
109
+ "LinearExplainer",
110
+ "PartitionExplainer",
111
+ "PermutationExplainer",
112
+ "SamplingExplainer",
113
+ "TreeExplainer",
114
+
115
+ # Plots
116
+ "plots",
117
+ "bar_plot",
118
+ "summary_plot",
119
+ "decision_plot",
120
+ "multioutput_decision_plot",
121
+ "embedding_plot",
122
+ "force_plot",
123
+ "getjs",
124
+ "initjs",
125
+ "save_html",
126
+ "group_difference_plot",
127
+ "heatmap_plot",
128
+ "image_plot",
129
+ "monitoring_plot",
130
+ "partial_dependence_plot",
131
+ "dependence_plot",
132
+ "text_plot",
133
+ "violin_plot",
134
+ "waterfall_plot",
135
+
136
+ # Other stuff
137
+ "datasets",
138
+ "links",
139
+ "utils",
140
+ "ActionOptimizer",
141
+ "approximate_interactions",
142
+ "sample",
143
+ "kmeans",
144
+ ]
lib/shap/_cext.cp310-win_amd64.pyd ADDED
Binary file (44 kB). View file
 
lib/shap/_explanation.py ADDED
@@ -0,0 +1,901 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import copy
3
+ import operator
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scipy.cluster
8
+ import scipy.sparse
9
+ import scipy.spatial
10
+ import sklearn
11
+ from slicer import Alias, Obj, Slicer
12
+
13
+ from .utils._exceptions import DimensionError
14
+ from .utils._general import OpChain
15
+
16
+ op_chain_root = OpChain("shap.Explanation")
17
+ class MetaExplanation(type):
18
+ """ This metaclass exposes the Explanation object's methods for creating template op chains.
19
+ """
20
+
21
+ def __getitem__(cls, item):
22
+ return op_chain_root.__getitem__(item)
23
+
24
+ @property
25
+ def abs(cls):
26
+ """ Element-wise absolute value op.
27
+ """
28
+ return op_chain_root.abs
29
+
30
+ @property
31
+ def identity(cls):
32
+ """ A no-op.
33
+ """
34
+ return op_chain_root.identity
35
+
36
+ @property
37
+ def argsort(cls):
38
+ """ Numpy style argsort.
39
+ """
40
+ return op_chain_root.argsort
41
+
42
+ @property
43
+ def sum(cls):
44
+ """ Numpy style sum.
45
+ """
46
+ return op_chain_root.sum
47
+
48
+ @property
49
+ def max(cls):
50
+ """ Numpy style max.
51
+ """
52
+ return op_chain_root.max
53
+
54
+ @property
55
+ def min(cls):
56
+ """ Numpy style min.
57
+ """
58
+ return op_chain_root.min
59
+
60
+ @property
61
+ def mean(cls):
62
+ """ Numpy style mean.
63
+ """
64
+ return op_chain_root.mean
65
+
66
+ @property
67
+ def sample(cls):
68
+ """ Numpy style sample.
69
+ """
70
+ return op_chain_root.sample
71
+
72
+ @property
73
+ def hclust(cls):
74
+ """ Hierarchical clustering op.
75
+ """
76
+ return op_chain_root.hclust
77
+
78
+
79
+ class Explanation(metaclass=MetaExplanation):
80
+ """ A sliceable set of parallel arrays representing a SHAP explanation.
81
+ """
82
+ def __init__(
83
+ self,
84
+ values,
85
+ base_values=None,
86
+ data=None,
87
+ display_data=None,
88
+ instance_names=None,
89
+ feature_names=None,
90
+ output_names=None,
91
+ output_indexes=None,
92
+ lower_bounds=None,
93
+ upper_bounds=None,
94
+ error_std=None,
95
+ main_effects=None,
96
+ hierarchical_values=None,
97
+ clustering=None,
98
+ compute_time=None
99
+ ):
100
+ self.op_history = []
101
+
102
+ self.compute_time = compute_time
103
+
104
+ # cloning. TODOsomeday: better cloning :)
105
+ if issubclass(type(values), Explanation):
106
+ e = values
107
+ values = e.values
108
+ base_values = e.base_values
109
+ data = e.data
110
+
111
+ self.output_dims = compute_output_dims(values, base_values, data, output_names)
112
+ values_shape = _compute_shape(values)
113
+
114
+ if output_names is None and len(self.output_dims) == 1:
115
+ output_names = [f"Output {i}" for i in range(values_shape[self.output_dims[0]])]
116
+
117
+ if len(_compute_shape(feature_names)) == 1: # TODO: should always be an alias once slicer supports per-row aliases
118
+ if len(values_shape) >= 2 and len(feature_names) == values_shape[1]:
119
+ feature_names = Alias(list(feature_names), 1)
120
+ elif len(values_shape) >= 1 and len(feature_names) == values_shape[0]:
121
+ feature_names = Alias(list(feature_names), 0)
122
+
123
+ if len(_compute_shape(output_names)) == 1: # TODO: should always be an alias once slicer supports per-row aliases
124
+ output_names = Alias(list(output_names), self.output_dims[0])
125
+ # if len(values_shape) >= 1 and len(output_names) == values_shape[0]:
126
+ # output_names = Alias(list(output_names), 0)
127
+ # elif len(values_shape) >= 2 and len(output_names) == values_shape[1]:
128
+ # output_names = Alias(list(output_names), 1)
129
+
130
+ if output_names is not None and not isinstance(output_names, Alias):
131
+ output_names_order = len(_compute_shape(output_names))
132
+ if output_names_order == 0:
133
+ pass
134
+ elif output_names_order == 1:
135
+ output_names = Obj(output_names, self.output_dims)
136
+ elif output_names_order == 2:
137
+ output_names = Obj(output_names, [0] + list(self.output_dims))
138
+ else:
139
+ raise ValueError("shap.Explanation does not yet support output_names of order greater than 3!")
140
+
141
+ if not hasattr(base_values, "__len__") or len(base_values) == 0:
142
+ pass
143
+ elif len(_compute_shape(base_values)) == len(self.output_dims):
144
+ base_values = Obj(base_values, list(self.output_dims))
145
+ else:
146
+ base_values = Obj(base_values, [0] + list(self.output_dims))
147
+
148
+ self._s = Slicer(
149
+ values=values,
150
+ base_values=base_values,
151
+ data=list_wrap(data),
152
+ display_data=list_wrap(display_data),
153
+ instance_names=None if instance_names is None else Alias(instance_names, 0),
154
+ feature_names=feature_names,
155
+ output_names=output_names,
156
+ output_indexes=None if output_indexes is None else (self.output_dims, output_indexes),
157
+ lower_bounds=list_wrap(lower_bounds),
158
+ upper_bounds=list_wrap(upper_bounds),
159
+ error_std=list_wrap(error_std),
160
+ main_effects=list_wrap(main_effects),
161
+ hierarchical_values=list_wrap(hierarchical_values),
162
+ clustering=None if clustering is None else Obj(clustering, [0])
163
+ )
164
+
165
+ @property
166
+ def shape(self):
167
+ """ Compute the shape over potentially complex data nesting.
168
+ """
169
+ return _compute_shape(self._s.values)
170
+
171
+ @property
172
+ def values(self):
173
+ """ Pass-through from the underlying slicer object.
174
+ """
175
+ return self._s.values
176
+ @values.setter
177
+ def values(self, new_values):
178
+ self._s.values = new_values
179
+
180
+ @property
181
+ def base_values(self):
182
+ """ Pass-through from the underlying slicer object.
183
+ """
184
+ return self._s.base_values
185
+ @base_values.setter
186
+ def base_values(self, new_base_values):
187
+ self._s.base_values = new_base_values
188
+
189
+ @property
190
+ def data(self):
191
+ """ Pass-through from the underlying slicer object.
192
+ """
193
+ return self._s.data
194
+ @data.setter
195
+ def data(self, new_data):
196
+ self._s.data = new_data
197
+
198
+ @property
199
+ def display_data(self):
200
+ """ Pass-through from the underlying slicer object.
201
+ """
202
+ return self._s.display_data
203
+ @display_data.setter
204
+ def display_data(self, new_display_data):
205
+ if issubclass(type(new_display_data), pd.DataFrame):
206
+ new_display_data = new_display_data.values
207
+ self._s.display_data = new_display_data
208
+
209
+ @property
210
+ def instance_names(self):
211
+ """ Pass-through from the underlying slicer object.
212
+ """
213
+ return self._s.instance_names
214
+
215
+ @property
216
+ def output_names(self):
217
+ """ Pass-through from the underlying slicer object.
218
+ """
219
+ return self._s.output_names
220
+ @output_names.setter
221
+ def output_names(self, new_output_names):
222
+ self._s.output_names = new_output_names
223
+
224
+ @property
225
+ def output_indexes(self):
226
+ """ Pass-through from the underlying slicer object.
227
+ """
228
+ return self._s.output_indexes
229
+
230
+ @property
231
+ def feature_names(self):
232
+ """ Pass-through from the underlying slicer object.
233
+ """
234
+ return self._s.feature_names
235
+ @feature_names.setter
236
+ def feature_names(self, new_feature_names):
237
+ self._s.feature_names = new_feature_names
238
+
239
+ @property
240
+ def lower_bounds(self):
241
+ """ Pass-through from the underlying slicer object.
242
+ """
243
+ return self._s.lower_bounds
244
+
245
+ @property
246
+ def upper_bounds(self):
247
+ """ Pass-through from the underlying slicer object.
248
+ """
249
+ return self._s.upper_bounds
250
+
251
+ @property
252
+ def error_std(self):
253
+ """ Pass-through from the underlying slicer object.
254
+ """
255
+ return self._s.error_std
256
+
257
+ @property
258
+ def main_effects(self):
259
+ """ Pass-through from the underlying slicer object.
260
+ """
261
+ return self._s.main_effects
262
+ @main_effects.setter
263
+ def main_effects(self, new_main_effects):
264
+ self._s.main_effects = new_main_effects
265
+
266
+ @property
267
+ def hierarchical_values(self):
268
+ """ Pass-through from the underlying slicer object.
269
+ """
270
+ return self._s.hierarchical_values
271
+ @hierarchical_values.setter
272
+ def hierarchical_values(self, new_hierarchical_values):
273
+ self._s.hierarchical_values = new_hierarchical_values
274
+
275
+ @property
276
+ def clustering(self):
277
+ """ Pass-through from the underlying slicer object.
278
+ """
279
+ return self._s.clustering
280
+ @clustering.setter
281
+ def clustering(self, new_clustering):
282
+ self._s.clustering = new_clustering
283
+
284
+ def cohorts(self, cohorts):
285
+ """ Split this explanation into several cohorts.
286
+
287
+ Parameters
288
+ ----------
289
+ cohorts : int or array
290
+ If this is an integer then we auto build that many cohorts using a decision tree. If this is
291
+ an array then we treat that as an array of cohort names/ids for each instance.
292
+ """
293
+
294
+ if isinstance(cohorts, int):
295
+ return _auto_cohorts(self, max_cohorts=cohorts)
296
+ if isinstance(cohorts, (list, tuple, np.ndarray)):
297
+ cohorts = np.array(cohorts)
298
+ return Cohorts(**{name: self[cohorts == name] for name in np.unique(cohorts)})
299
+ raise TypeError("The given set of cohort indicators is not recognized! Please give an array or int.")
300
+
301
+ def __repr__(self):
302
+ """ Display some basic printable info, but not everything.
303
+ """
304
+ out = ".values =\n"+self.values.__repr__()
305
+ if self.base_values is not None:
306
+ out += "\n\n.base_values =\n"+self.base_values.__repr__()
307
+ if self.data is not None:
308
+ out += "\n\n.data =\n"+self.data.__repr__()
309
+ return out
310
+
311
+ def __getitem__(self, item):
312
+ """ This adds support for OpChain indexing.
313
+ """
314
+ new_self = None
315
+ if not isinstance(item, tuple):
316
+ item = (item,)
317
+
318
+ # convert any OpChains or magic strings
319
+ pos = -1
320
+ for t in item:
321
+ pos += 1
322
+
323
+ # skip over Ellipsis
324
+ if t is Ellipsis:
325
+ pos += len(self.shape) - len(item)
326
+ continue
327
+
328
+ orig_t = t
329
+ if issubclass(type(t), OpChain):
330
+ t = t.apply(self)
331
+ if issubclass(type(t), (np.int64, np.int32)): # because slicer does not like numpy indexes
332
+ t = int(t)
333
+ elif issubclass(type(t), np.ndarray):
334
+ t = [int(v) for v in t] # slicer wants lists not numpy arrays for indexing
335
+ elif issubclass(type(t), Explanation):
336
+ t = t.values
337
+ elif isinstance(t, str):
338
+
339
+ # work around for 2D output_names since they are not yet slicer supported
340
+ output_names_dims = []
341
+ if "output_names" in self._s._objects:
342
+ output_names_dims = self._s._objects["output_names"].dim
343
+ elif "output_names" in self._s._aliases:
344
+ output_names_dims = self._s._aliases["output_names"].dim
345
+ if pos != 0 and pos in output_names_dims:
346
+ if len(output_names_dims) == 1:
347
+ t = np.argwhere(np.array(self.output_names) == t)[0][0]
348
+ elif len(output_names_dims) == 2:
349
+ new_values = []
350
+ new_base_values = []
351
+ new_data = []
352
+ new_self = copy.deepcopy(self)
353
+ for i, v in enumerate(self.values):
354
+ for j, s in enumerate(self.output_names[i]):
355
+ if s == t:
356
+ new_values.append(np.array(v[:,j]))
357
+ new_data.append(np.array(self.data[i]))
358
+ new_base_values.append(self.base_values[i][j])
359
+
360
+ new_self = Explanation(
361
+ np.array(new_values),
362
+ np.array(new_base_values),
363
+ np.array(new_data),
364
+ self.display_data,
365
+ self.instance_names,
366
+ np.array(new_data),
367
+ t, # output_names
368
+ self.output_indexes,
369
+ self.lower_bounds,
370
+ self.upper_bounds,
371
+ self.error_std,
372
+ self.main_effects,
373
+ self.hierarchical_values,
374
+ self.clustering
375
+ )
376
+ new_self.op_history = copy.copy(self.op_history)
377
+ # new_self = copy.deepcopy(self)
378
+ # new_self.values = np.array(new_values)
379
+ # new_self.base_values = np.array(new_base_values)
380
+ # new_self.data = np.array(new_data)
381
+ # new_self.output_names = t
382
+ # new_self.feature_names = np.array(new_data)
383
+ # new_self.clustering = None
384
+
385
+ # work around for 2D feature_names since they are not yet slicer supported
386
+ feature_names_dims = []
387
+ if "feature_names" in self._s._objects:
388
+ feature_names_dims = self._s._objects["feature_names"].dim
389
+ if pos != 0 and pos in feature_names_dims and len(feature_names_dims) == 2:
390
+ new_values = []
391
+ new_data = []
392
+ for i, val_i in enumerate(self.values):
393
+ for s,v,d in zip(self.feature_names[i], val_i, self.data[i]):
394
+ if s == t:
395
+ new_values.append(v)
396
+ new_data.append(d)
397
+ new_self = copy.deepcopy(self)
398
+ new_self.values = new_values
399
+ new_self.data = new_data
400
+ new_self.feature_names = t
401
+ new_self.clustering = None
402
+ # return new_self
403
+
404
+ if issubclass(type(t), (np.int8, np.int16, np.int32, np.int64)):
405
+ t = int(t)
406
+
407
+ if t is not orig_t:
408
+ tmp = list(item)
409
+ tmp[pos] = t
410
+ item = tuple(tmp)
411
+
412
+ # call slicer for the real work
413
+ item = tuple(v for v in item) # SML I cut out: `if not isinstance(v, str)`
414
+ if len(item) == 0:
415
+ return new_self
416
+ if new_self is None:
417
+ new_self = copy.copy(self)
418
+ new_self._s = new_self._s.__getitem__(item)
419
+ new_self.op_history.append({
420
+ "name": "__getitem__",
421
+ "args": (item,),
422
+ "prev_shape": self.shape
423
+ })
424
+
425
+ return new_self
426
+
427
+ def __len__(self):
428
+ return self.shape[0]
429
+
430
+ def __copy__(self):
431
+ new_exp = Explanation(
432
+ self.values,
433
+ self.base_values,
434
+ self.data,
435
+ self.display_data,
436
+ self.instance_names,
437
+ self.feature_names,
438
+ self.output_names,
439
+ self.output_indexes,
440
+ self.lower_bounds,
441
+ self.upper_bounds,
442
+ self.error_std,
443
+ self.main_effects,
444
+ self.hierarchical_values,
445
+ self.clustering
446
+ )
447
+ new_exp.op_history = copy.copy(self.op_history)
448
+ return new_exp
449
+
450
+ def _apply_binary_operator(self, other, binary_op, op_name):
451
+ new_exp = self.__copy__()
452
+ new_exp.op_history = copy.copy(self.op_history)
453
+ new_exp.op_history.append({
454
+ "name": op_name,
455
+ "args": (other,),
456
+ "prev_shape": self.shape
457
+ })
458
+ if isinstance(other, Explanation):
459
+ new_exp.values = binary_op(new_exp.values, other.values)
460
+ if new_exp.data is not None:
461
+ new_exp.data = binary_op(new_exp.data, other.data)
462
+ if new_exp.base_values is not None:
463
+ new_exp.base_values = binary_op(new_exp.base_values, other.base_values)
464
+ else:
465
+ new_exp.values = binary_op(new_exp.values, other)
466
+ if new_exp.data is not None:
467
+ new_exp.data = binary_op(new_exp.data, other)
468
+ if new_exp.base_values is not None:
469
+ new_exp.base_values = binary_op(new_exp.base_values, other)
470
+ return new_exp
471
+
472
+ def __add__(self, other):
473
+ return self._apply_binary_operator(other, operator.add, "__add__")
474
+
475
+ def __radd__(self, other):
476
+ return self._apply_binary_operator(other, operator.add, "__add__")
477
+
478
+ def __sub__(self, other):
479
+ return self._apply_binary_operator(other, operator.sub, "__sub__")
480
+
481
+ def __rsub__(self, other):
482
+ return self._apply_binary_operator(other, operator.sub, "__sub__")
483
+
484
+ def __mul__(self, other):
485
+ return self._apply_binary_operator(other, operator.mul, "__mul__")
486
+
487
+ def __rmul__(self, other):
488
+ return self._apply_binary_operator(other, operator.mul, "__mul__")
489
+
490
+ def __truediv__(self, other):
491
+ return self._apply_binary_operator(other, operator.truediv, "__truediv__")
492
+
493
+ # @property
494
+ # def abs(self):
495
+ # """ Element-size absolute value operator.
496
+ # """
497
+ # new_self = copy.copy(self)
498
+ # new_self.values = np.abs(new_self.values)
499
+ # new_self.op_history.append({
500
+ # "name": "abs",
501
+ # "prev_shape": self.shape
502
+ # })
503
+ # return new_self
504
+
505
+ def _numpy_func(self, fname, **kwargs):
506
+ """ Apply a numpy-style function to this Explanation.
507
+ """
508
+ new_self = copy.copy(self)
509
+ axis = kwargs.get("axis", None)
510
+
511
+ # collapse the slicer to right shape
512
+ if axis == 0:
513
+ new_self = new_self[0]
514
+ elif axis == 1:
515
+ new_self = new_self[1]
516
+ elif axis == 2:
517
+ new_self = new_self[2]
518
+ if axis in [0,1,2]:
519
+ new_self.op_history = new_self.op_history[:-1] # pop off the slicing operation we just used
520
+
521
+ if self.feature_names is not None and not is_1d(self.feature_names) and axis == 0:
522
+ new_values = self._flatten_feature_names()
523
+ new_self.feature_names = np.array(list(new_values.keys()))
524
+ new_self.values = np.array([getattr(np, fname)(v,0) for v in new_values.values()])
525
+ new_self.clustering = None
526
+ else:
527
+ new_self.values = getattr(np, fname)(np.array(self.values), **kwargs)
528
+ if new_self.data is not None:
529
+ try:
530
+ new_self.data = getattr(np, fname)(np.array(self.data), **kwargs)
531
+ except Exception:
532
+ new_self.data = None
533
+ if new_self.base_values is not None and issubclass(type(axis), int) and len(self.base_values.shape) > axis:
534
+ new_self.base_values = getattr(np, fname)(self.base_values, **kwargs)
535
+ elif issubclass(type(axis), int):
536
+ new_self.base_values = None
537
+
538
+ if axis == 0 and self.clustering is not None and len(self.clustering.shape) == 3:
539
+ if self.clustering.std(0).sum() < 1e-8:
540
+ new_self.clustering = self.clustering[0]
541
+ else:
542
+ new_self.clustering = None
543
+
544
+ new_self.op_history.append({
545
+ "name": fname,
546
+ "kwargs": kwargs,
547
+ "prev_shape": self.shape,
548
+ "collapsed_instances": axis == 0
549
+ })
550
+
551
+ return new_self
552
+
553
+ def mean(self, axis):
554
+ """ Numpy-style mean function.
555
+ """
556
+ return self._numpy_func("mean", axis=axis)
557
+
558
+ def max(self, axis):
559
+ """ Numpy-style mean function.
560
+ """
561
+ return self._numpy_func("max", axis=axis)
562
+
563
+ def min(self, axis):
564
+ """ Numpy-style mean function.
565
+ """
566
+ return self._numpy_func("min", axis=axis)
567
+
568
+ def sum(self, axis=None, grouping=None):
569
+ """ Numpy-style mean function.
570
+ """
571
+ if grouping is None:
572
+ return self._numpy_func("sum", axis=axis)
573
+ elif axis == 1 or len(self.shape) == 1:
574
+ return group_features(self, grouping)
575
+ else:
576
+ raise DimensionError("Only axis = 1 is supported for grouping right now...")
577
+
578
+ def hstack(self, other):
579
+ """ Stack two explanations column-wise.
580
+ """
581
+ assert self.shape[0] == other.shape[0], "Can't hstack explanations with different numbers of rows!"
582
+ assert np.max(np.abs(self.base_values - other.base_values)) < 1e-6, "Can't hstack explanations with different base values!"
583
+
584
+ new_exp = Explanation(
585
+ values=np.hstack([self.values, other.values]),
586
+ base_values=self.base_values,
587
+ data=self.data,
588
+ display_data=self.display_data,
589
+ instance_names=self.instance_names,
590
+ feature_names=self.feature_names,
591
+ output_names=self.output_names,
592
+ output_indexes=self.output_indexes,
593
+ lower_bounds=self.lower_bounds,
594
+ upper_bounds=self.upper_bounds,
595
+ error_std=self.error_std,
596
+ main_effects=self.main_effects,
597
+ hierarchical_values=self.hierarchical_values,
598
+ clustering=self.clustering,
599
+ )
600
+ return new_exp
601
+
602
+ # def reshape(self, *args):
603
+ # return self._numpy_func("reshape", newshape=args)
604
+
605
+ @property
606
+ def abs(self):
607
+ return self._numpy_func("abs")
608
+
609
+ @property
610
+ def identity(self):
611
+ return self
612
+
613
+ @property
614
+ def argsort(self):
615
+ return self._numpy_func("argsort")
616
+
617
+ @property
618
+ def flip(self):
619
+ return self._numpy_func("flip")
620
+
621
+
622
+ def hclust(self, metric="sqeuclidean", axis=0):
623
+ """ Computes an optimal leaf ordering sort order using hclustering.
624
+
625
+ hclust(metric="sqeuclidean")
626
+
627
+ Parameters
628
+ ----------
629
+ metric : string
630
+ A metric supported by scipy clustering.
631
+
632
+ axis : int
633
+ The axis to cluster along.
634
+ """
635
+ values = self.values
636
+
637
+ if len(values.shape) != 2:
638
+ raise DimensionError("The hclust order only supports 2D arrays right now!")
639
+
640
+ if axis == 1:
641
+ values = values.T
642
+
643
+ # compute a hierarchical clustering and return the optimal leaf ordering
644
+ D = scipy.spatial.distance.pdist(values, metric)
645
+ cluster_matrix = scipy.cluster.hierarchy.complete(D)
646
+ inds = scipy.cluster.hierarchy.leaves_list(scipy.cluster.hierarchy.optimal_leaf_ordering(cluster_matrix, D))
647
+ return inds
648
+
649
+ def sample(self, max_samples, replace=False, random_state=0):
650
+ """ Randomly samples the instances (rows) of the Explanation object.
651
+
652
+ Parameters
653
+ ----------
654
+ max_samples : int
655
+ The number of rows to sample. Note that if replace=False then less than
656
+ fewer than max_samples will be drawn if explanation.shape[0] < max_samples.
657
+
658
+ replace : bool
659
+ Sample with or without replacement.
660
+ """
661
+ prev_seed = np.random.seed(random_state)
662
+ inds = np.random.choice(self.shape[0], min(max_samples, self.shape[0]), replace=replace)
663
+ np.random.seed(prev_seed)
664
+ return self[list(inds)]
665
+
666
+ def _flatten_feature_names(self):
667
+ new_values = {}
668
+ for i in range(len(self.values)):
669
+ for s,v in zip(self.feature_names[i], self.values[i]):
670
+ if s not in new_values:
671
+ new_values[s] = []
672
+ new_values[s].append(v)
673
+ return new_values
674
+
675
+ def _use_data_as_feature_names(self):
676
+ new_values = {}
677
+ for i in range(len(self.values)):
678
+ for s,v in zip(self.data[i], self.values[i]):
679
+ if s not in new_values:
680
+ new_values[s] = []
681
+ new_values[s].append(v)
682
+ return new_values
683
+
684
+ def percentile(self, q, axis=None):
685
+ new_self = copy.deepcopy(self)
686
+ if self.feature_names is not None and not is_1d(self.feature_names) and axis == 0:
687
+ new_values = self._flatten_feature_names()
688
+ new_self.feature_names = np.array(list(new_values.keys()))
689
+ new_self.values = np.array([np.percentile(v, q) for v in new_values.values()])
690
+ new_self.clustering = None
691
+ else:
692
+ new_self.values = np.percentile(new_self.values, q, axis)
693
+ new_self.data = np.percentile(new_self.data, q, axis)
694
+ #new_self.data = None
695
+ new_self.op_history.append({
696
+ "name": "percentile",
697
+ "args": (axis,),
698
+ "prev_shape": self.shape,
699
+ "collapsed_instances": axis == 0
700
+ })
701
+ return new_self
702
+
703
+ def group_features(shap_values, feature_map):
704
+ # TODOsomeday: support and deal with clusterings
705
+ reverse_map = {}
706
+ for name in feature_map:
707
+ reverse_map[feature_map[name]] = reverse_map.get(feature_map[name], []) + [name]
708
+
709
+ curr_names = shap_values.feature_names
710
+ sv_new = copy.deepcopy(shap_values)
711
+ found = {}
712
+ i = 0
713
+ rank1 = len(shap_values.shape) == 1
714
+ for name in curr_names:
715
+ new_name = feature_map.get(name, name)
716
+ if new_name in found:
717
+ continue
718
+ found[new_name] = True
719
+
720
+ new_name = feature_map.get(name, name)
721
+ cols_to_sum = reverse_map.get(new_name, [new_name])
722
+ old_inds = [curr_names.index(v) for v in cols_to_sum]
723
+
724
+ if rank1:
725
+ sv_new.values[i] = shap_values.values[old_inds].sum()
726
+ sv_new.data[i] = shap_values.data[old_inds].sum()
727
+ else:
728
+ sv_new.values[:,i] = shap_values.values[:,old_inds].sum(1)
729
+ sv_new.data[:,i] = shap_values.data[:,old_inds].sum(1)
730
+ sv_new.feature_names[i] = new_name
731
+ i += 1
732
+
733
+ return Explanation(
734
+ sv_new.values[:i] if rank1 else sv_new.values[:,:i],
735
+ base_values = sv_new.base_values,
736
+ data = sv_new.data[:i] if rank1 else sv_new.data[:,:i],
737
+ display_data = None if sv_new.display_data is None else (sv_new.display_data[:,:i] if rank1 else sv_new.display_data[:,:i]),
738
+ instance_names = None,
739
+ feature_names = None if sv_new.feature_names is None else sv_new.feature_names[:i],
740
+ output_names = None,
741
+ output_indexes = None,
742
+ lower_bounds = None,
743
+ upper_bounds = None,
744
+ error_std = None,
745
+ main_effects = None,
746
+ hierarchical_values = None,
747
+ clustering = None
748
+ )
749
+
750
+ def compute_output_dims(values, base_values, data, output_names):
751
+ """ Uses the passed data to infer which dimensions correspond to the model's output.
752
+ """
753
+ values_shape = _compute_shape(values)
754
+
755
+ # input shape matches the data shape
756
+ if data is not None:
757
+ data_shape = _compute_shape(data)
758
+
759
+ # if we are not given any data we assume it would be the same shape as the given values
760
+ else:
761
+ data_shape = values_shape
762
+
763
+ # output shape is known from the base values or output names
764
+ if output_names is not None:
765
+ output_shape = _compute_shape(output_names)
766
+
767
+ # if our output_names are per sample then we need to drop the sample dimension here
768
+ if values_shape[-len(output_shape):] != output_shape and \
769
+ values_shape[-len(output_shape)+1:] == output_shape[1:] and values_shape[0] == output_shape[0]:
770
+ output_shape = output_shape[1:]
771
+
772
+ elif base_values is not None:
773
+ output_shape = _compute_shape(base_values)[1:]
774
+ else:
775
+ output_shape = tuple()
776
+
777
+ interaction_order = len(values_shape) - len(data_shape) - len(output_shape)
778
+ output_dims = range(len(data_shape) + interaction_order, len(values_shape))
779
+ return tuple(output_dims)
780
+
781
+ def is_1d(val):
782
+ return not (isinstance(val[0], list) or isinstance(val[0], np.ndarray))
783
+
784
+ class Op:
785
+ pass
786
+
787
+ class Percentile(Op):
788
+ def __init__(self, percentile):
789
+ self.percentile = percentile
790
+
791
+ def add_repr(self, s, verbose=False):
792
+ return "percentile("+s+", "+str(self.percentile)+")"
793
+
794
+ def _first_item(x):
795
+ for item in x:
796
+ return item
797
+ return None
798
+
799
+ def _compute_shape(x):
800
+ if not hasattr(x, "__len__") or isinstance(x, str):
801
+ return tuple()
802
+ elif not scipy.sparse.issparse(x) and len(x) > 0 and isinstance(_first_item(x), str):
803
+ return (None,)
804
+ else:
805
+ if isinstance(x, dict):
806
+ return (len(x),) + _compute_shape(x[next(iter(x))])
807
+
808
+ # 2D arrays we just take their shape as-is
809
+ if len(getattr(x, "shape", tuple())) > 1:
810
+ return x.shape
811
+
812
+ # 1D arrays we need to look inside
813
+ if len(x) == 0:
814
+ return (0,)
815
+ elif len(x) == 1:
816
+ return (1,) + _compute_shape(_first_item(x))
817
+ else:
818
+ first_shape = _compute_shape(_first_item(x))
819
+ if first_shape == tuple():
820
+ return (len(x),)
821
+ else: # we have an array of arrays...
822
+ matches = np.ones(len(first_shape), dtype=bool)
823
+ for i in range(1, len(x)):
824
+ shape = _compute_shape(x[i])
825
+ assert len(shape) == len(first_shape), "Arrays in Explanation objects must have consistent inner dimensions!"
826
+ for j in range(0, len(shape)):
827
+ matches[j] &= shape[j] == first_shape[j]
828
+ return (len(x),) + tuple(first_shape[j] if match else None for j, match in enumerate(matches))
829
+
830
+ class Cohorts:
831
+ def __init__(self, **kwargs):
832
+ self.cohorts = kwargs
833
+ for k in self.cohorts:
834
+ assert isinstance(self.cohorts[k], Explanation), "All the arguments to a Cohorts set must be Explanation objects!"
835
+
836
+ def __getitem__(self, item):
837
+ new_cohorts = Cohorts()
838
+ for k in self.cohorts:
839
+ new_cohorts.cohorts[k] = self.cohorts[k].__getitem__(item)
840
+ return new_cohorts
841
+
842
+ def __getattr__(self, name):
843
+ new_cohorts = Cohorts()
844
+ for k in self.cohorts:
845
+ new_cohorts.cohorts[k] = getattr(self.cohorts[k], name)
846
+ return new_cohorts
847
+
848
+ def __call__(self, *args, **kwargs):
849
+ new_cohorts = Cohorts()
850
+ for k in self.cohorts:
851
+ new_cohorts.cohorts[k] = self.cohorts[k].__call__(*args, **kwargs)
852
+ return new_cohorts
853
+
854
+ def __repr__(self):
855
+ return f"<shap._explanation.Cohorts object with {len(self.cohorts)} cohorts of sizes: {[v.shape for v in self.cohorts.values()]}>"
856
+
857
+
858
+ def _auto_cohorts(shap_values, max_cohorts):
859
+ """ This uses a DecisionTreeRegressor to build a group of cohorts with similar SHAP values.
860
+ """
861
+
862
+ # fit a decision tree that well separates the SHAP values
863
+ m = sklearn.tree.DecisionTreeRegressor(max_leaf_nodes=max_cohorts)
864
+ m.fit(shap_values.data, shap_values.values)
865
+
866
+ # group instances by their decision paths
867
+ paths = m.decision_path(shap_values.data).toarray()
868
+ path_names = []
869
+
870
+ # mark each instance with a path name
871
+ for i in range(shap_values.shape[0]):
872
+ name = ""
873
+ for j in range(len(paths[i])):
874
+ if paths[i,j] > 0:
875
+ feature = m.tree_.feature[j]
876
+ threshold = m.tree_.threshold[j]
877
+ val = shap_values.data[i,feature]
878
+ if feature >= 0:
879
+ name += str(shap_values.feature_names[feature])
880
+ if val < threshold:
881
+ name += " < "
882
+ else:
883
+ name += " >= "
884
+ name += str(threshold) + " & "
885
+ path_names.append(name[:-3]) # the -3 strips off the last unneeded ' & '
886
+ path_names = np.array(path_names)
887
+
888
+ # split the instances into cohorts by their path names
889
+ cohorts = {}
890
+ for name in np.unique(path_names):
891
+ cohorts[name] = shap_values[path_names == name]
892
+
893
+ return Cohorts(**cohorts)
894
+
895
+ def list_wrap(x):
896
+ """ A helper to patch things since slicer doesn't handle arrays of arrays (it does handle lists of arrays)
897
+ """
898
+ if isinstance(x, np.ndarray) and len(x.shape) == 1 and isinstance(x[0], np.ndarray):
899
+ return [v for v in x]
900
+ else:
901
+ return x
lib/shap/_serializable.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import inspect
3
+ import logging
4
+ import pickle
5
+
6
+ import cloudpickle
7
+ import numpy as np
8
+
9
+ log = logging.getLogger('shap')
10
+
11
+ class Serializable:
12
+ """ This is the superclass of all serializable objects.
13
+ """
14
+
15
+ def save(self, out_file):
16
+ """ Save the model to the given file stream.
17
+ """
18
+ pickle.dump(type(self), out_file)
19
+
20
+ @classmethod
21
+ def load(cls, in_file, instantiate=True):
22
+ """ This is meant to be overridden by subclasses and called with super.
23
+
24
+ We return constructor argument values when not being instantiated. Since there are no
25
+ constructor arguments for the Serializable class we just return an empty dictionary.
26
+ """
27
+ if instantiate:
28
+ return cls._instantiated_load(in_file)
29
+ return {}
30
+
31
+ @classmethod
32
+ def _instantiated_load(cls, in_file, **kwargs):
33
+ """ This is meant to be overridden by subclasses and called with super.
34
+
35
+ We return constructor argument values (we have no values to load in this abstract class).
36
+ """
37
+ obj_type = pickle.load(in_file)
38
+ if obj_type is None:
39
+ return None
40
+
41
+ if not inspect.isclass(obj_type) or (not issubclass(obj_type, cls) and (obj_type is not cls)):
42
+ raise Exception(f"Invalid object type loaded from file. {obj_type} is not a subclass of {cls}.")
43
+
44
+ # here we call the constructor with all the arguments we have loaded
45
+ constructor_args = obj_type.load(in_file, instantiate=False, **kwargs)
46
+ used_args = inspect.getfullargspec(obj_type.__init__)[0]
47
+ return obj_type(**{k: constructor_args[k] for k in constructor_args if k in used_args})
48
+
49
+
50
+ class Serializer:
51
+ """ Save data items to an input stream.
52
+ """
53
+ def __init__(self, out_stream, block_name, version):
54
+ self.out_stream = out_stream
55
+ self.block_name = block_name
56
+ self.block_version = version
57
+ self.serializer_version = 0 # update this when the serializer changes
58
+
59
+ def __enter__(self):
60
+ log.debug("serializer_version = %d", self.serializer_version)
61
+ pickle.dump(self.serializer_version, self.out_stream)
62
+ log.debug("block_name = %s", self.block_name)
63
+ pickle.dump(self.block_name, self.out_stream)
64
+ log.debug("block_version = %d", self.block_version)
65
+ pickle.dump(self.block_version, self.out_stream)
66
+ return self
67
+
68
+ def __exit__(self, exception_type, exception_value, traceback):
69
+ log.debug("END_BLOCK___")
70
+ pickle.dump("END_BLOCK___", self.out_stream)
71
+
72
+ def save(self, name, value, encoder="auto"):
73
+ """ Dump a data item to the current input stream.
74
+ """
75
+ log.debug("name = %s", name)
76
+ pickle.dump(name, self.out_stream)
77
+ if encoder is None or encoder is False:
78
+ log.debug("encoder_name = %s", "no_encoder")
79
+ pickle.dump("no_encoder", self.out_stream)
80
+ elif callable(encoder):
81
+ log.debug("encoder_name = %s", "custom_encoder")
82
+ pickle.dump("custom_encoder", self.out_stream)
83
+ encoder(value, self.out_stream)
84
+ elif encoder == ".save" or (isinstance(value, Serializable) and encoder == "auto"):
85
+ log.debug("encoder_name = %s", "serializable.save")
86
+ pickle.dump("serializable.save", self.out_stream)
87
+ if len(inspect.getfullargspec(value.save)[0]) == 3: # backward compat for MLflow, can remove 4/1/2021
88
+ value.save(self.out_stream, value)
89
+ else:
90
+ value.save(self.out_stream)
91
+ elif encoder == "auto":
92
+ if isinstance(value, (int, float, str)):
93
+ log.debug("encoder_name = %s", "pickle.dump")
94
+ pickle.dump("pickle.dump", self.out_stream)
95
+ pickle.dump(value, self.out_stream)
96
+ else:
97
+ log.debug("encoder_name = %s", "cloudpickle.dump")
98
+ pickle.dump("cloudpickle.dump", self.out_stream)
99
+ cloudpickle.dump(value, self.out_stream)
100
+ else:
101
+ raise ValueError(f"Unknown encoder type '{encoder}' given for serialization!")
102
+ log.debug("value = %s", str(value))
103
+
104
+ class Deserializer:
105
+ """ Load data items from an input stream.
106
+ """
107
+
108
+ def __init__(self, in_stream, block_name, min_version, max_version):
109
+ self.in_stream = in_stream
110
+ self.block_name = block_name
111
+ self.block_min_version = min_version
112
+ self.block_max_version = max_version
113
+
114
+ # update these when the serializer changes
115
+ self.serializer_min_version = 0
116
+ self.serializer_max_version = 0
117
+
118
+ def __enter__(self):
119
+
120
+ # confirm the serializer version
121
+ serializer_version = pickle.load(self.in_stream)
122
+ log.debug("serializer_version = %d", serializer_version)
123
+ if serializer_version < self.serializer_min_version:
124
+ raise ValueError(
125
+ f"The file being loaded was saved with a serializer version of {serializer_version}, " + \
126
+ f"but the current deserializer in SHAP requires at least version {self.serializer_min_version}."
127
+ )
128
+ if serializer_version > self.serializer_max_version:
129
+ raise ValueError(
130
+ f"The file being loaded was saved with a serializer version of {serializer_version}, " + \
131
+ f"but the current deserializer in SHAP only support up to version {self.serializer_max_version}."
132
+ )
133
+
134
+ # confirm the block name
135
+ block_name = pickle.load(self.in_stream)
136
+ log.debug("block_name = %s", block_name)
137
+ if block_name != self.block_name:
138
+ raise ValueError(
139
+ f"The next data block in the file being loaded was supposed to be {self.block_name}, " + \
140
+ f"but the next block found was {block_name}."
141
+ )
142
+
143
+ # confirm the block version
144
+ block_version = pickle.load(self.in_stream)
145
+ log.debug("block_version = %d", block_version)
146
+ if block_version < self.block_min_version:
147
+ raise ValueError(
148
+ f"The file being loaded was saved with a block version of {block_version}, " + \
149
+ f"but the current deserializer in SHAP requires at least version {self.block_min_version}."
150
+ )
151
+ if block_version > self.block_max_version:
152
+ raise ValueError(
153
+ f"The file being loaded was saved with a block version of {block_version}, " + \
154
+ f"but the current deserializer in SHAP only support up to version {self.block_max_version}."
155
+ )
156
+ return self
157
+
158
+ def __exit__(self, exception_type, exception_value, traceback):
159
+ # confirm the block end token
160
+ for _ in range(100):
161
+ end_token = pickle.load(self.in_stream)
162
+ log.debug("end_token = %s", end_token)
163
+ if end_token == "END_BLOCK___":
164
+ return
165
+ self._load_data_value()
166
+ raise ValueError(
167
+ f"The data block end token wsa not found for the block {self.block_name}."
168
+ )
169
+
170
+ def load(self, name, decoder=None):
171
+ """ Load a data item from the current input stream.
172
+ """
173
+ # confirm the block name
174
+ loaded_name = pickle.load(self.in_stream)
175
+ log.debug("loaded_name = %s", loaded_name)
176
+ print("loaded_name", loaded_name)
177
+ if loaded_name != name:
178
+ raise ValueError(
179
+ f"The next data item in the file being loaded was supposed to be {name}, " + \
180
+ f"but the next block found was {loaded_name}."
181
+ ) # We should eventually add support for skipping over unused data items in old formats...
182
+
183
+ value = self._load_data_value(decoder)
184
+ log.debug("value = %s", str(value))
185
+ return value
186
+
187
+ def _load_data_value(self, decoder=None):
188
+ encoder_name = pickle.load(self.in_stream)
189
+ log.debug("encoder_name = %s", encoder_name)
190
+ if encoder_name == "custom_encoder" or callable(decoder):
191
+ assert callable(decoder), "You must provide a callable custom decoder for the data item {name}!"
192
+ return decoder(self.in_stream)
193
+ if encoder_name == "no_encoder":
194
+ return None
195
+ if encoder_name == "serializable.save":
196
+ return Serializable.load(self.in_stream)
197
+ if encoder_name == "numpy.save":
198
+ return np.load(self.in_stream)
199
+ if encoder_name == "pickle.dump":
200
+ return pickle.load(self.in_stream)
201
+ if encoder_name == "cloudpickle.dump":
202
+ return cloudpickle.load(self.in_stream)
203
+
204
+ raise ValueError(f"Unsupported encoder type found: {encoder_name}")
lib/shap/_version.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools_scm
2
+ # don't change, don't track in version control
3
+ TYPE_CHECKING = False
4
+ if TYPE_CHECKING:
5
+ from typing import Tuple, Union
6
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
7
+ else:
8
+ VERSION_TUPLE = object
9
+
10
+ version: str
11
+ __version__: str
12
+ __version_tuple__: VERSION_TUPLE
13
+ version_tuple: VERSION_TUPLE
14
+
15
+ __version__ = version = '0.44.1'
16
+ __version_tuple__ = version_tuple = (0, 44, 1)
lib/shap/actions/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from ._action import Action
2
+
3
+ __all__ = ["Action"]
lib/shap/actions/_action.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ class Action:
2
+ """ Abstract action class.
3
+ """
4
+ def __lt__(self, other_action):
5
+ return self.cost < other_action.cost
6
+
7
+ def __repr__(self):
8
+ return f"<Action '{self.__str__()}'>"
lib/shap/actions/_optimizer.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import queue
3
+ import warnings
4
+
5
+ from ..utils._exceptions import ConvergenceError, InvalidAction
6
+ from ._action import Action
7
+
8
+
9
+ class ActionOptimizer:
10
+ def __init__(self, model, actions):
11
+ self.model = model
12
+ warnings.warn(
13
+ "Note that ActionOptimizer is still in an alpha state and is subjust to API changes."
14
+ )
15
+ # actions go into mutually exclusive groups
16
+ self.action_groups = []
17
+ for group in actions:
18
+
19
+ if issubclass(type(group), Action):
20
+ group._group_index = len(self.action_groups)
21
+ group._grouped_index = 0
22
+ self.action_groups.append([copy.copy(group)])
23
+ elif issubclass(type(group), list):
24
+ group = sorted([copy.copy(v) for v in group], key=lambda a: a.cost)
25
+ for i, v in enumerate(group):
26
+ v._group_index = len(self.action_groups)
27
+ v._grouped_index = i
28
+ self.action_groups.append(group)
29
+ else:
30
+ raise InvalidAction(
31
+ "A passed action was not an Action or list of actions!"
32
+ )
33
+
34
+ def __call__(self, *args, max_evals=10000):
35
+
36
+ # init our queue with all the least costly actions
37
+ q = queue.PriorityQueue()
38
+ for i in range(len(self.action_groups)):
39
+ group = self.action_groups[i]
40
+ q.put((group[0].cost, [group[0]]))
41
+
42
+ nevals = 0
43
+ while not q.empty():
44
+
45
+ # see if we have exceeded our runtime budget
46
+ nevals += 1
47
+ if nevals > max_evals:
48
+ raise ConvergenceError(
49
+ f"Failed to find a solution with max_evals={max_evals}! Try reducing the number of actions or increasing max_evals."
50
+ )
51
+
52
+ # get the next cheapest set of actions we can do
53
+ cost, actions = q.get()
54
+
55
+ # apply those actions
56
+ args_tmp = copy.deepcopy(args)
57
+ for a in actions:
58
+ a(*args_tmp)
59
+
60
+ # if the model is now satisfied we are done!!
61
+ v = self.model(*args_tmp)
62
+ if v:
63
+ return actions
64
+
65
+ # if not then we add all possible follow-on actions to our queue
66
+ else:
67
+ for i in range(len(self.action_groups)):
68
+ group = self.action_groups[i]
69
+
70
+ # look to to see if we already have a action from this group, if so we need to
71
+ # move to a more expensive action in the same group
72
+ next_ind = 0
73
+ prev_in_group = -1
74
+ for j, a in enumerate(actions):
75
+ if a._group_index == i:
76
+ next_ind = max(next_ind, a._grouped_index + 1)
77
+ prev_in_group = j
78
+
79
+ # we are adding a new action type
80
+ if prev_in_group == -1:
81
+ new_actions = actions + [group[next_ind]]
82
+ # we are moving from one action to a more expensive one in the same group
83
+ elif next_ind < len(group):
84
+ new_actions = copy.copy(actions)
85
+ new_actions[prev_in_group] = group[next_ind]
86
+ # we don't have a more expensive action left in this group
87
+ else:
88
+ new_actions = None
89
+
90
+ # add the new option to our queue
91
+ if new_actions is not None:
92
+ q.put((sum([a.cost for a in new_actions]), new_actions))
lib/shap/benchmark/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from ._compute import ComputeTime
2
+ from ._explanation_error import ExplanationError
3
+ from ._result import BenchmarkResult
4
+ from ._sequential import SequentialMasker
5
+
6
+ # from . import framework
7
+ # from .. import datasets
8
+
9
+ __all__ = ["ComputeTime", "ExplanationError", "BenchmarkResult", "SequentialMasker"]
lib/shap/benchmark/_compute.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from ._result import BenchmarkResult
2
+
3
+
4
+ class ComputeTime:
5
+ """ Extracts a runtime benchmark result from the passed Explanation.
6
+ """
7
+
8
+ def __call__(self, explanation, name):
9
+ return BenchmarkResult("compute time", name, value=explanation.compute_time / explanation.shape[0])
lib/shap/benchmark/_explanation_error.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ from tqdm.auto import tqdm
5
+
6
+ from shap import Explanation, links
7
+ from shap.maskers import FixedComposite, Image, Text
8
+ from shap.utils import MaskedModel, partition_tree_shuffle
9
+ from shap.utils._exceptions import DimensionError
10
+
11
+ from ._result import BenchmarkResult
12
+
13
+
14
+ class ExplanationError:
15
+ """ A measure of the explanation error relative to a model's actual output.
16
+
17
+ This benchmark metric measures the discrepancy between the output of the model predicted by an
18
+ attribution explanation vs. the actual output of the model. This discrepancy is measured over
19
+ many masking patterns drawn from permutations of the input features.
20
+
21
+ For explanations (like Shapley values) that explain the difference between one alternative and another
22
+ (for example a current sample and typical background feature values) there is possible explanation error
23
+ for every pattern of mixing foreground and background, or other words every possible masking pattern.
24
+ In this class we compute the standard deviation over these explanation errors where masking patterns
25
+ are drawn from prefixes of random feature permutations. This seems natural, and aligns with Shapley value
26
+ computations, but of course you could choose to summarize explanation errors in others ways as well.
27
+ """
28
+
29
+ def __init__(self, masker, model, *model_args, batch_size=500, num_permutations=10, link=links.identity, linearize_link=True, seed=38923):
30
+ """ Build a new explanation error benchmarker with the given masker, model, and model args.
31
+
32
+ Parameters
33
+ ----------
34
+ masker : function or shap.Masker
35
+ The masker defines how we hide features during the perturbation process.
36
+
37
+ model : function or shap.Model
38
+ The model we want to evaluate explanations against.
39
+
40
+ model_args : ...
41
+ The list of arguments we will give to the model that we will have explained. When we later call this benchmark
42
+ object we should pass explanations that have been computed on this same data.
43
+
44
+ batch_size : int
45
+ The maximum batch size we should use when calling the model. For some large NLP models this needs to be set
46
+ lower (at say 1) to avoid running out of GPU memory.
47
+
48
+ num_permutations : int
49
+ How many permutations we will use to estimate the average explanation error for each sample. If you are running
50
+ this benchmark on a large dataset with many samples then you can reduce this value since the final result is
51
+ averaged over samples as well and the averages of both directly combine to reduce variance. So for 10k samples
52
+ num_permutations=1 is appropreiate.
53
+
54
+ link : function
55
+ Allows for a non-linear link function to be used to bringe between the model output space and the explanation
56
+ space.
57
+
58
+ linearize_link : bool
59
+ Non-linear links can destroy additive separation in generalized linear models, so by linearizing the link we can
60
+ retain additive separation. See upcoming paper/doc for details.
61
+ """
62
+
63
+ self.masker = masker
64
+ self.model = model
65
+ self.model_args = model_args
66
+ self.num_permutations = num_permutations
67
+ self.link = link
68
+ self.linearize_link = linearize_link
69
+ self.model_args = model_args
70
+ self.batch_size = batch_size
71
+ self.seed = seed
72
+
73
+ # user must give valid masker
74
+ underlying_masker = masker.masker if isinstance(masker, FixedComposite) else masker
75
+ if isinstance(underlying_masker, Text):
76
+ self.data_type = "text"
77
+ elif isinstance(underlying_masker, Image):
78
+ self.data_type = "image"
79
+ else:
80
+ self.data_type = "tabular"
81
+
82
+ def __call__(self, explanation, name, step_fraction=0.01, indices=[], silent=False):
83
+ """ Run this benchmark on the given explanation.
84
+ """
85
+
86
+ if isinstance(explanation, np.ndarray):
87
+ attributions = explanation
88
+ elif isinstance(explanation, Explanation):
89
+ attributions = explanation.values
90
+ else:
91
+ raise ValueError("The passed explanation must be either of type numpy.ndarray or shap.Explanation!")
92
+
93
+ if len(attributions) != len(self.model_args[0]):
94
+ emsg = (
95
+ "The explanation passed must have the same number of rows as "
96
+ "the self.model_args that were passed!"
97
+ )
98
+ raise DimensionError(emsg)
99
+
100
+ # it is important that we choose the same permutations for the different explanations we are comparing
101
+ # so as to avoid needless noise
102
+ old_seed = np.random.seed()
103
+ np.random.seed(self.seed)
104
+
105
+ pbar = None
106
+ start_time = time.time()
107
+ svals = []
108
+ mask_vals = []
109
+
110
+ for i, args in enumerate(zip(*self.model_args)):
111
+
112
+ if len(args[0].shape) != len(attributions[i].shape):
113
+ raise ValueError("The passed explanation must have the same dim as the model_args and must not have a vector output!")
114
+
115
+ feature_size = np.prod(attributions[i].shape)
116
+ sample_attributions = attributions[i].flatten()
117
+
118
+ # compute any custom clustering for this row
119
+ row_clustering = None
120
+ if getattr(self.masker, "clustering", None) is not None:
121
+ if isinstance(self.masker.clustering, np.ndarray):
122
+ row_clustering = self.masker.clustering
123
+ elif callable(self.masker.clustering):
124
+ row_clustering = self.masker.clustering(*args)
125
+ else:
126
+ raise NotImplementedError("The masker passed has a .clustering attribute that is not yet supported by the ExplanationError benchmark!")
127
+
128
+ masked_model = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *args)
129
+
130
+ total_values = None
131
+ for _ in range(self.num_permutations):
132
+ masks = []
133
+ mask = np.zeros(feature_size, dtype=bool)
134
+ masks.append(mask.copy())
135
+ ordered_inds = np.arange(feature_size)
136
+
137
+ # shuffle the indexes so we get a random permutation ordering
138
+ if row_clustering is not None:
139
+ inds_mask = np.ones(feature_size, dtype=bool)
140
+ partition_tree_shuffle(ordered_inds, inds_mask, row_clustering)
141
+ else:
142
+ np.random.shuffle(ordered_inds)
143
+
144
+ increment = max(1, int(feature_size * step_fraction))
145
+ for j in range(0, feature_size, increment):
146
+ mask[ordered_inds[np.arange(j, min(feature_size, j+increment))]] = True
147
+ masks.append(mask.copy())
148
+ mask_vals.append(masks)
149
+
150
+ values = []
151
+ masks_arr = np.array(masks)
152
+ for j in range(0, len(masks_arr), self.batch_size):
153
+ values.append(masked_model(masks_arr[j:j + self.batch_size]))
154
+ values = np.concatenate(values)
155
+ base_value = values[0]
156
+ for j, v in enumerate(values):
157
+ values[j] = (v - (base_value + np.sum(sample_attributions[masks_arr[j]])))**2
158
+
159
+ if total_values is None:
160
+ total_values = values
161
+ else:
162
+ total_values += values
163
+ total_values /= self.num_permutations
164
+
165
+ svals.append(total_values)
166
+
167
+ if pbar is None and time.time() - start_time > 5:
168
+ pbar = tqdm(total=len(self.model_args[0]), disable=silent, leave=False, desc=f"ExplanationError for {name}")
169
+ pbar.update(i+1)
170
+ if pbar is not None:
171
+ pbar.update(1)
172
+
173
+ if pbar is not None:
174
+ pbar.close()
175
+
176
+ svals = np.array(svals)
177
+
178
+ # reset the random seed so we don't mess up the caller
179
+ np.random.seed(old_seed)
180
+
181
+ return BenchmarkResult("explanation error", name, value=np.sqrt(np.sum(total_values)/len(total_values)))
lib/shap/benchmark/_result.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sklearn
3
+
4
+ sign_defaults = {
5
+ "keep positive": 1,
6
+ "keep negative": -1,
7
+ "remove positive": -1,
8
+ "remove negative": 1,
9
+ "compute time": -1,
10
+ "keep absolute": -1, # the absolute signs are defaults that make sense when scoring losses
11
+ "remove absolute": 1,
12
+ "explanation error": -1
13
+ }
14
+
15
+ class BenchmarkResult:
16
+ """ The result of a benchmark run.
17
+ """
18
+
19
+ def __init__(self, metric, method, value=None, curve_x=None, curve_y=None, curve_y_std=None, value_sign=None):
20
+ self.metric = metric
21
+ self.method = method
22
+ self.value = value
23
+ self.curve_x = curve_x
24
+ self.curve_y = curve_y
25
+ self.curve_y_std = curve_y_std
26
+ self.value_sign = value_sign
27
+ if self.value_sign is None and self.metric in sign_defaults:
28
+ self.value_sign = sign_defaults[self.metric]
29
+ if self.value is None:
30
+ self.value = sklearn.metrics.auc(curve_x, (np.array(curve_y) - curve_y[0]))
31
+
32
+ @property
33
+ def full_name(self):
34
+ return self.method + " " + self.metric
lib/shap/benchmark/_sequential.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import matplotlib.pyplot as pl
4
+ import numpy as np
5
+ import pandas as pd
6
+ import sklearn
7
+ from tqdm.auto import tqdm
8
+
9
+ from shap import Explanation, links
10
+ from shap.maskers import FixedComposite, Image, Text
11
+ from shap.utils import MaskedModel
12
+
13
+ from ._result import BenchmarkResult
14
+
15
+
16
+ class SequentialMasker:
17
+ def __init__(self, mask_type, sort_order, masker, model, *model_args, batch_size=500):
18
+
19
+ for arg in model_args:
20
+ if isinstance(arg, pd.DataFrame):
21
+ raise TypeError("DataFrame arguments dont iterate correctly, pass numpy arrays instead!")
22
+
23
+ # convert any DataFrames to numpy arrays
24
+ # self.model_arg_cols = []
25
+ # self.model_args = []
26
+ # self.has_df = False
27
+ # for arg in model_args:
28
+ # if isinstance(arg, pd.DataFrame):
29
+ # self.model_arg_cols.append(arg.columns)
30
+ # self.model_args.append(arg.values)
31
+ # self.has_df = True
32
+ # else:
33
+ # self.model_arg_cols.append(None)
34
+ # self.model_args.append(arg)
35
+
36
+ # if self.has_df:
37
+ # given_model = model
38
+ # def new_model(*args):
39
+ # df_args = []
40
+ # for i, arg in enumerate(args):
41
+ # if self.model_arg_cols[i] is not None:
42
+ # df_args.append(pd.DataFrame(arg, columns=self.model_arg_cols[i]))
43
+ # else:
44
+ # df_args.append(arg)
45
+ # return given_model(*df_args)
46
+ # model = new_model
47
+
48
+ self.inner = SequentialPerturbation(
49
+ model, masker, sort_order, mask_type
50
+ )
51
+ self.model_args = model_args
52
+ self.batch_size = batch_size
53
+
54
+ def __call__(self, explanation, name, **kwargs):
55
+ return self.inner(name, explanation, *self.model_args, batch_size=self.batch_size, **kwargs)
56
+
57
+ class SequentialPerturbation:
58
+ def __init__(self, model, masker, sort_order, perturbation, linearize_link=False):
59
+ # self.f = lambda masked, x, index: model.predict(masked)
60
+ self.model = model if callable(model) else model.predict
61
+ self.masker = masker
62
+ self.sort_order = sort_order
63
+ self.perturbation = perturbation
64
+ self.linearize_link = linearize_link
65
+
66
+ # define our sort order
67
+ if self.sort_order == "positive":
68
+ self.sort_order_map = lambda x: np.argsort(-x)
69
+ elif self.sort_order == "negative":
70
+ self.sort_order_map = lambda x: np.argsort(x)
71
+ elif self.sort_order == "absolute":
72
+ self.sort_order_map = lambda x: np.argsort(-abs(x))
73
+ else:
74
+ raise ValueError("sort_order must be either \"positive\", \"negative\", or \"absolute\"!")
75
+
76
+ # user must give valid masker
77
+ underlying_masker = masker.masker if isinstance(masker, FixedComposite) else masker
78
+ if isinstance(underlying_masker, Text):
79
+ self.data_type = "text"
80
+ elif isinstance(underlying_masker, Image):
81
+ self.data_type = "image"
82
+ else:
83
+ self.data_type = "tabular"
84
+ #raise ValueError("masker must be for \"tabular\", \"text\", or \"image\"!")
85
+
86
+ self.score_values = []
87
+ self.score_aucs = []
88
+ self.labels = []
89
+
90
+ def __call__(self, name, explanation, *model_args, percent=0.01, indices=[], y=None, label=None, silent=False, debug_mode=False, batch_size=10):
91
+ # if explainer is already the attributions
92
+ if isinstance(explanation, np.ndarray):
93
+ attributions = explanation
94
+ elif isinstance(explanation, Explanation):
95
+ attributions = explanation.values
96
+ else:
97
+ raise ValueError("The passed explanation must be either of type numpy.ndarray or shap.Explanation!")
98
+
99
+ assert len(attributions) == len(model_args[0]), "The explanation passed must have the same number of rows as the model_args that were passed!"
100
+
101
+ if label is None:
102
+ label = "Score %d" % len(self.score_values)
103
+
104
+ # convert dataframes
105
+ # if isinstance(X, (pd.Series, pd.DataFrame)):
106
+ # X = X.values
107
+
108
+ # convert all single-sample vectors to matrices
109
+ # if not hasattr(attributions[0], "__len__"):
110
+ # attributions = np.array([attributions])
111
+ # if not hasattr(X[0], "__len__") and self.data_type == "tabular":
112
+ # X = np.array([X])
113
+
114
+ pbar = None
115
+ start_time = time.time()
116
+ svals = []
117
+ mask_vals = []
118
+
119
+ for i, args in enumerate(zip(*model_args)):
120
+ # if self.data_type == "image":
121
+ # x_shape, y_shape = attributions[i].shape[0], attributions[i].shape[1]
122
+ # feature_size = np.prod([x_shape, y_shape])
123
+ # sample_attributions = attributions[i].mean(2).reshape(feature_size, -1)
124
+ # data = X[i].flatten()
125
+ # mask_shape = X[i].shape
126
+ # else:
127
+ feature_size = np.prod(attributions[i].shape)
128
+ sample_attributions = attributions[i].flatten()
129
+ # data = X[i]
130
+ # mask_shape = feature_size
131
+
132
+ self.masked_model = MaskedModel(self.model, self.masker, links.identity, self.linearize_link, *args)
133
+
134
+ masks = []
135
+
136
+ mask = np.ones(feature_size, dtype=bool) * (self.perturbation == "remove")
137
+ masks.append(mask.copy())
138
+
139
+ ordered_inds = self.sort_order_map(sample_attributions)
140
+ increment = max(1,int(feature_size*percent))
141
+ for j in range(0, feature_size, increment):
142
+ oind_list = [ordered_inds[t] for t in range(j, min(feature_size, j+increment))]
143
+
144
+ for oind in oind_list:
145
+ if not ((self.sort_order == "positive" and sample_attributions[oind] <= 0) or \
146
+ (self.sort_order == "negative" and sample_attributions[oind] >= 0)):
147
+ mask[oind] = self.perturbation == "keep"
148
+
149
+ masks.append(mask.copy())
150
+
151
+ mask_vals.append(masks)
152
+
153
+ # mask_size = len(range(0, feature_size, increment)) + 1
154
+ values = []
155
+ masks_arr = np.array(masks)
156
+ for j in range(0, len(masks_arr), batch_size):
157
+ values.append(self.masked_model(masks_arr[j:j + batch_size]))
158
+ values = np.concatenate(values)
159
+
160
+ svals.append(values)
161
+
162
+ if pbar is None and time.time() - start_time > 5:
163
+ pbar = tqdm(total=len(model_args[0]), disable=silent, leave=False, desc="SequentialMasker")
164
+ pbar.update(i+1)
165
+ if pbar is not None:
166
+ pbar.update(1)
167
+
168
+ if pbar is not None:
169
+ pbar.close()
170
+
171
+ self.score_values.append(np.array(svals))
172
+
173
+ # if self.sort_order == "negative":
174
+ # curve_sign = -1
175
+ # else:
176
+ curve_sign = 1
177
+
178
+ self.labels.append(label)
179
+
180
+ xs = np.linspace(0, 1, 100)
181
+ curves = np.zeros((len(self.score_values[-1]), len(xs)))
182
+ for j in range(len(self.score_values[-1])):
183
+ xp = np.linspace(0, 1, len(self.score_values[-1][j]))
184
+ yp = self.score_values[-1][j]
185
+ curves[j,:] = np.interp(xs, xp, yp)
186
+ ys = curves.mean(0)
187
+ std = curves.std(0) / np.sqrt(curves.shape[0])
188
+ auc = sklearn.metrics.auc(np.linspace(0, 1, len(ys)), curve_sign*(ys-ys[0]))
189
+
190
+ if not debug_mode:
191
+ return BenchmarkResult(self.perturbation + " " + self.sort_order, name, curve_x=xs, curve_y=ys, curve_y_std=std)
192
+ else:
193
+ aucs = []
194
+ for j in range(len(self.score_values[-1])):
195
+ curve = curves[j,:]
196
+ auc = sklearn.metrics.auc(np.linspace(0, 1, len(curve)), curve_sign*(curve-curve[0]))
197
+ aucs.append(auc)
198
+ return mask_vals, curves, aucs
199
+
200
+ def score(self, explanation, X, percent=0.01, y=None, label=None, silent=False, debug_mode=False):
201
+ '''
202
+ Will be deprecated once MaskedModel is in complete support
203
+ '''
204
+ # if explainer is already the attributions
205
+ if isinstance(explanation, np.ndarray):
206
+ attributions = explanation
207
+ elif isinstance(explanation, Explanation):
208
+ attributions = explanation.values
209
+
210
+ if label is None:
211
+ label = "Score %d" % len(self.score_values)
212
+
213
+ # convert dataframes
214
+ if isinstance(X, (pd.Series, pd.DataFrame)):
215
+ X = X.values
216
+
217
+ # convert all single-sample vectors to matrices
218
+ if not hasattr(attributions[0], "__len__"):
219
+ attributions = np.array([attributions])
220
+ if not hasattr(X[0], "__len__") and self.data_type == "tabular":
221
+ X = np.array([X])
222
+
223
+ pbar = None
224
+ start_time = time.time()
225
+ svals = []
226
+ mask_vals = []
227
+
228
+ for i in range(len(X)):
229
+ if self.data_type == "image":
230
+ x_shape, y_shape = attributions[i].shape[0], attributions[i].shape[1]
231
+ feature_size = np.prod([x_shape, y_shape])
232
+ sample_attributions = attributions[i].mean(2).reshape(feature_size, -1)
233
+ else:
234
+ feature_size = attributions[i].shape[0]
235
+ sample_attributions = attributions[i]
236
+
237
+ if len(attributions[i].shape) == 1 or self.data_type == "tabular":
238
+ output_size = 1
239
+ else:
240
+ output_size = attributions[i].shape[-1]
241
+
242
+ for k in range(output_size):
243
+ if self.data_type == "image":
244
+ mask_shape = X[i].shape
245
+ else:
246
+ mask_shape = feature_size
247
+
248
+ mask = np.ones(mask_shape, dtype=bool) * (self.perturbation == "remove")
249
+ masks = [mask.copy()]
250
+
251
+ values = np.zeros(feature_size+1)
252
+ # masked, data = self.masker(mask, X[i])
253
+ masked = self.masker(mask, X[i])
254
+ data = None
255
+ curr_val = self.f(masked, data, k).mean(0)
256
+
257
+ values[0] = curr_val
258
+
259
+ if output_size != 1:
260
+ test_attributions = sample_attributions[:,k]
261
+ else:
262
+ test_attributions = sample_attributions
263
+
264
+ ordered_inds = self.sort_order_map(test_attributions)
265
+ increment = max(1,int(feature_size*percent))
266
+ for j in range(0, feature_size, increment):
267
+ oind_list = [ordered_inds[t] for t in range(j, min(feature_size, j+increment))]
268
+
269
+ for oind in oind_list:
270
+ if not ((self.sort_order == "positive" and test_attributions[oind] <= 0) or \
271
+ (self.sort_order == "negative" and test_attributions[oind] >= 0)):
272
+ if self.data_type == "image":
273
+ xoind, yoind = oind // attributions[i].shape[1], oind % attributions[i].shape[1]
274
+ mask[xoind][yoind] = self.perturbation == "keep"
275
+ else:
276
+ mask[oind] = self.perturbation == "keep"
277
+
278
+ masks.append(mask.copy())
279
+ # masked, data = self.masker(mask, X[i])
280
+ masked = self.masker(mask, X[i])
281
+ curr_val = self.f(masked, data, k).mean(0)
282
+
283
+ for t in range(j, min(feature_size, j+increment)):
284
+ values[t+1] = curr_val
285
+
286
+ svals.append(values)
287
+ mask_vals.append(masks)
288
+
289
+ if pbar is None and time.time() - start_time > 5:
290
+ pbar = tqdm(total=len(X), disable=silent, leave=False)
291
+ pbar.update(i+1)
292
+ if pbar is not None:
293
+ pbar.update(1)
294
+
295
+ if pbar is not None:
296
+ pbar.close()
297
+
298
+ self.score_values.append(np.array(svals))
299
+
300
+ if self.sort_order == "negative":
301
+ curve_sign = -1
302
+ else:
303
+ curve_sign = 1
304
+
305
+ self.labels.append(label)
306
+
307
+ xs = np.linspace(0, 1, 100)
308
+ curves = np.zeros((len(self.score_values[-1]), len(xs)))
309
+ for j in range(len(self.score_values[-1])):
310
+ xp = np.linspace(0, 1, len(self.score_values[-1][j]))
311
+ yp = self.score_values[-1][j]
312
+ curves[j,:] = np.interp(xs, xp, yp)
313
+ ys = curves.mean(0)
314
+
315
+ if debug_mode:
316
+ aucs = []
317
+ for j in range(len(self.score_values[-1])):
318
+ curve = curves[j,:]
319
+ auc = sklearn.metrics.auc(np.linspace(0, 1, len(curve)), curve_sign*(curve-curve[0]))
320
+ aucs.append(auc)
321
+ return mask_vals, curves, aucs
322
+ else:
323
+ auc = sklearn.metrics.auc(np.linspace(0, 1, len(ys)), curve_sign*(ys-ys[0]))
324
+ return xs, ys, auc
325
+
326
+ def plot(self, xs, ys, auc):
327
+ pl.plot(xs, ys, label="AUC %0.4f" % auc)
328
+ pl.legend()
329
+ xlabel = "Percent Unmasked" if self.perturbation == "keep" else "Percent Masked"
330
+ pl.xlabel(xlabel)
331
+ pl.ylabel("Model Output")
332
+ pl.show()
lib/shap/benchmark/experiments.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import itertools
3
+ import os
4
+ import pickle
5
+ import random
6
+ import subprocess
7
+ import sys
8
+ import time
9
+ from multiprocessing import Pool
10
+
11
+ from .. import __version__, datasets
12
+ from . import metrics, models
13
+
14
+ try:
15
+ from queue import Queue
16
+ except ImportError:
17
+ from Queue import Queue
18
+ from threading import Lock, Thread
19
+
20
+ regression_metrics = [
21
+ "local_accuracy",
22
+ "consistency_guarantees",
23
+ "keep_positive_mask",
24
+ "keep_positive_resample",
25
+ #"keep_positive_impute",
26
+ "keep_negative_mask",
27
+ "keep_negative_resample",
28
+ #"keep_negative_impute",
29
+ "keep_absolute_mask__r2",
30
+ "keep_absolute_resample__r2",
31
+ #"keep_absolute_impute__r2",
32
+ "remove_positive_mask",
33
+ "remove_positive_resample",
34
+ #"remove_positive_impute",
35
+ "remove_negative_mask",
36
+ "remove_negative_resample",
37
+ #"remove_negative_impute",
38
+ "remove_absolute_mask__r2",
39
+ "remove_absolute_resample__r2",
40
+ #"remove_absolute_impute__r2"
41
+ "runtime",
42
+ ]
43
+
44
+ binary_classification_metrics = [
45
+ "local_accuracy",
46
+ "consistency_guarantees",
47
+ "keep_positive_mask",
48
+ "keep_positive_resample",
49
+ #"keep_positive_impute",
50
+ "keep_negative_mask",
51
+ "keep_negative_resample",
52
+ #"keep_negative_impute",
53
+ "keep_absolute_mask__roc_auc",
54
+ "keep_absolute_resample__roc_auc",
55
+ #"keep_absolute_impute__roc_auc",
56
+ "remove_positive_mask",
57
+ "remove_positive_resample",
58
+ #"remove_positive_impute",
59
+ "remove_negative_mask",
60
+ "remove_negative_resample",
61
+ #"remove_negative_impute",
62
+ "remove_absolute_mask__roc_auc",
63
+ "remove_absolute_resample__roc_auc",
64
+ #"remove_absolute_impute__roc_auc"
65
+ "runtime",
66
+ ]
67
+
68
+ human_metrics = [
69
+ "human_and_00",
70
+ "human_and_01",
71
+ "human_and_11",
72
+ "human_or_00",
73
+ "human_or_01",
74
+ "human_or_11",
75
+ "human_xor_00",
76
+ "human_xor_01",
77
+ "human_xor_11",
78
+ "human_sum_00",
79
+ "human_sum_01",
80
+ "human_sum_11"
81
+ ]
82
+
83
+ linear_regress_methods = [
84
+ "linear_shap_corr",
85
+ "linear_shap_ind",
86
+ "coef",
87
+ "random",
88
+ "kernel_shap_1000_meanref",
89
+ #"kernel_shap_100_meanref",
90
+ #"sampling_shap_10000",
91
+ "sampling_shap_1000",
92
+ "lime_tabular_regression_1000"
93
+ #"sampling_shap_100"
94
+ ]
95
+
96
+ linear_classify_methods = [
97
+ # NEED LIME
98
+ "linear_shap_corr",
99
+ "linear_shap_ind",
100
+ "coef",
101
+ "random",
102
+ "kernel_shap_1000_meanref",
103
+ #"kernel_shap_100_meanref",
104
+ #"sampling_shap_10000",
105
+ "sampling_shap_1000",
106
+ #"lime_tabular_regression_1000"
107
+ #"sampling_shap_100"
108
+ ]
109
+
110
+ tree_regress_methods = [
111
+ # NEED tree_shap_ind
112
+ # NEED split_count?
113
+ "tree_shap_tree_path_dependent",
114
+ "tree_shap_independent_200",
115
+ "saabas",
116
+ "random",
117
+ "tree_gain",
118
+ "kernel_shap_1000_meanref",
119
+ "mean_abs_tree_shap",
120
+ #"kernel_shap_100_meanref",
121
+ #"sampling_shap_10000",
122
+ "sampling_shap_1000",
123
+ "lime_tabular_regression_1000",
124
+ "maple"
125
+ #"sampling_shap_100"
126
+ ]
127
+
128
+ rf_regress_methods = [ # methods that only support random forest models
129
+ "tree_maple"
130
+ ]
131
+
132
+ tree_classify_methods = [
133
+ # NEED tree_shap_ind
134
+ # NEED split_count?
135
+ "tree_shap_tree_path_dependent",
136
+ "tree_shap_independent_200",
137
+ "saabas",
138
+ "random",
139
+ "tree_gain",
140
+ "kernel_shap_1000_meanref",
141
+ "mean_abs_tree_shap",
142
+ #"kernel_shap_100_meanref",
143
+ #"sampling_shap_10000",
144
+ "sampling_shap_1000",
145
+ "lime_tabular_classification_1000",
146
+ "maple"
147
+ #"sampling_shap_100"
148
+ ]
149
+
150
+ deep_regress_methods = [
151
+ "deep_shap",
152
+ "expected_gradients",
153
+ "random",
154
+ "kernel_shap_1000_meanref",
155
+ "sampling_shap_1000",
156
+ #"lime_tabular_regression_1000"
157
+ ]
158
+
159
+ deep_classify_methods = [
160
+ "deep_shap",
161
+ "expected_gradients",
162
+ "random",
163
+ "kernel_shap_1000_meanref",
164
+ "sampling_shap_1000",
165
+ #"lime_tabular_regression_1000"
166
+ ]
167
+
168
+ _experiments = []
169
+ _experiments += [["corrgroups60", "lasso", m, s] for s in regression_metrics for m in linear_regress_methods]
170
+ _experiments += [["corrgroups60", "ridge", m, s] for s in regression_metrics for m in linear_regress_methods]
171
+ _experiments += [["corrgroups60", "decision_tree", m, s] for s in regression_metrics for m in tree_regress_methods]
172
+ _experiments += [["corrgroups60", "random_forest", m, s] for s in regression_metrics for m in (tree_regress_methods + rf_regress_methods)]
173
+ _experiments += [["corrgroups60", "gbm", m, s] for s in regression_metrics for m in tree_regress_methods]
174
+ _experiments += [["corrgroups60", "ffnn", m, s] for s in regression_metrics for m in deep_regress_methods]
175
+
176
+ _experiments += [["independentlinear60", "lasso", m, s] for s in regression_metrics for m in linear_regress_methods]
177
+ _experiments += [["independentlinear60", "ridge", m, s] for s in regression_metrics for m in linear_regress_methods]
178
+ _experiments += [["independentlinear60", "decision_tree", m, s] for s in regression_metrics for m in tree_regress_methods]
179
+ _experiments += [["independentlinear60", "random_forest", m, s] for s in regression_metrics for m in (tree_regress_methods + rf_regress_methods)]
180
+ _experiments += [["independentlinear60", "gbm", m, s] for s in regression_metrics for m in tree_regress_methods]
181
+ _experiments += [["independentlinear60", "ffnn", m, s] for s in regression_metrics for m in deep_regress_methods]
182
+
183
+ _experiments += [["cric", "lasso", m, s] for s in binary_classification_metrics for m in linear_classify_methods]
184
+ _experiments += [["cric", "ridge", m, s] for s in binary_classification_metrics for m in linear_classify_methods]
185
+ _experiments += [["cric", "decision_tree", m, s] for s in binary_classification_metrics for m in tree_classify_methods]
186
+ _experiments += [["cric", "random_forest", m, s] for s in binary_classification_metrics for m in tree_classify_methods]
187
+ _experiments += [["cric", "gbm", m, s] for s in binary_classification_metrics for m in tree_classify_methods]
188
+ _experiments += [["cric", "ffnn", m, s] for s in binary_classification_metrics for m in deep_classify_methods]
189
+
190
+ _experiments += [["human", "decision_tree", m, s] for s in human_metrics for m in tree_regress_methods]
191
+
192
+
193
+ def experiments(dataset=None, model=None, method=None, metric=None):
194
+ for experiment in _experiments:
195
+ if dataset is not None and dataset != experiment[0]:
196
+ continue
197
+ if model is not None and model != experiment[1]:
198
+ continue
199
+ if method is not None and method != experiment[2]:
200
+ continue
201
+ if metric is not None and metric != experiment[3]:
202
+ continue
203
+ yield experiment
204
+
205
+ def run_experiment(experiment, use_cache=True, cache_dir="/tmp"):
206
+ dataset_name, model_name, method_name, metric_name = experiment
207
+
208
+ # see if we have a cached version
209
+ cache_id = __gen_cache_id(experiment)
210
+ cache_file = os.path.join(cache_dir, cache_id + ".pickle")
211
+ if use_cache and os.path.isfile(cache_file):
212
+ with open(cache_file, "rb") as f:
213
+ #print(cache_id.replace("__", " ") + " ...loaded from cache.")
214
+ return pickle.load(f)
215
+
216
+ # compute the scores
217
+ print(cache_id.replace("__", " ", 4) + " ...")
218
+ sys.stdout.flush()
219
+ start = time.time()
220
+ X,y = getattr(datasets, dataset_name)()
221
+ score = getattr(metrics, metric_name)(
222
+ X, y,
223
+ getattr(models, dataset_name+"__"+model_name),
224
+ method_name
225
+ )
226
+ print("...took %f seconds.\n" % (time.time() - start))
227
+
228
+ # cache the scores
229
+ with open(cache_file, "wb") as f:
230
+ pickle.dump(score, f)
231
+
232
+ return score
233
+
234
+
235
+ def run_experiments_helper(args):
236
+ experiment, cache_dir = args
237
+ return run_experiment(experiment, cache_dir=cache_dir)
238
+
239
+ def run_experiments(dataset=None, model=None, method=None, metric=None, cache_dir="/tmp", nworkers=1):
240
+ experiments_arr = list(experiments(dataset=dataset, model=model, method=method, metric=metric))
241
+ if nworkers == 1:
242
+ out = list(map(run_experiments_helper, zip(experiments_arr, itertools.repeat(cache_dir))))
243
+ else:
244
+ with Pool(nworkers) as pool:
245
+ out = pool.map(run_experiments_helper, zip(experiments_arr, itertools.repeat(cache_dir)))
246
+ return list(zip(experiments_arr, out))
247
+
248
+
249
+ nexperiments = 0
250
+ total_sent = 0
251
+ total_done = 0
252
+ total_failed = 0
253
+ host_records = {}
254
+ worker_lock = Lock()
255
+ ssh_conn_per_min_limit = 0 # set as an argument to run_remote_experiments
256
+ def __thread_worker(q, host):
257
+ global total_sent, total_done
258
+ hostname, python_binary = host.split(":")
259
+ while True:
260
+
261
+ # make sure we are not sending too many ssh connections to the host
262
+ # (if we send too many connections ssh thottling will lock us out)
263
+ while True:
264
+ all_clear = False
265
+
266
+ worker_lock.acquire()
267
+ try:
268
+ if hostname not in host_records:
269
+ host_records[hostname] = []
270
+
271
+ if len(host_records[hostname]) < ssh_conn_per_min_limit:
272
+ all_clear = True
273
+ elif time.time() - host_records[hostname][-ssh_conn_per_min_limit] > 61:
274
+ all_clear = True
275
+ finally:
276
+ worker_lock.release()
277
+
278
+ # if we are clear to send a new ssh connection then break
279
+ if all_clear:
280
+ break
281
+
282
+ # if we are not clear then we sleep and try again
283
+ time.sleep(5)
284
+
285
+ experiment = q.get()
286
+
287
+ # if we are not loading from the cache then we note that we have called the host
288
+ cache_dir = "/tmp"
289
+ cache_file = os.path.join(cache_dir, __gen_cache_id(experiment) + ".pickle")
290
+ if not os.path.isfile(cache_file):
291
+ worker_lock.acquire()
292
+ try:
293
+ host_records[hostname].append(time.time())
294
+ finally:
295
+ worker_lock.release()
296
+
297
+ # record how many we have sent off for execution
298
+ worker_lock.acquire()
299
+ try:
300
+ total_sent += 1
301
+ __print_status()
302
+ finally:
303
+ worker_lock.release()
304
+
305
+ __run_remote_experiment(experiment, hostname, cache_dir=cache_dir, python_binary=python_binary)
306
+
307
+ # record how many are finished
308
+ worker_lock.acquire()
309
+ try:
310
+ total_done += 1
311
+ __print_status()
312
+ finally:
313
+ worker_lock.release()
314
+
315
+ q.task_done()
316
+
317
+ def __print_status():
318
+ print("Benchmark task %d of %d done (%d failed, %d running)" % (total_done, nexperiments, total_failed, total_sent - total_done), end="\r")
319
+ sys.stdout.flush()
320
+
321
+
322
+ def run_remote_experiments(experiments, thread_hosts, rate_limit=10):
323
+ """ Use ssh to run the experiments on remote machines in parallel.
324
+
325
+ Parameters
326
+ ----------
327
+ experiments : iterable
328
+ Output of shap.benchmark.experiments(...).
329
+
330
+ thread_hosts : list of strings
331
+ Each host has the format "host_name:path_to_python_binary" and can appear multiple times
332
+ in the list (one for each parallel execution you want on that machine).
333
+
334
+ rate_limit : int
335
+ How many ssh connections we make per minute to each host (to avoid throttling issues).
336
+ """
337
+
338
+ global ssh_conn_per_min_limit
339
+ ssh_conn_per_min_limit = rate_limit
340
+
341
+ # first we kill any remaining workers from previous runs
342
+ # note we don't check_call because pkill kills our ssh call as well
343
+ thread_hosts = copy.copy(thread_hosts)
344
+ random.shuffle(thread_hosts)
345
+ for host in set(thread_hosts):
346
+ hostname,_ = host.split(":")
347
+ try:
348
+ subprocess.run(["ssh", hostname, "pkill -f shap.benchmark.run_experiment"], timeout=15)
349
+ except subprocess.TimeoutExpired:
350
+ print("Failed to connect to", hostname, "after 15 seconds! Exiting.")
351
+ return
352
+
353
+ experiments = copy.copy(list(experiments))
354
+ random.shuffle(experiments) # this way all the hard experiments don't get put on one machine
355
+ global nexperiments, total_sent, total_done, total_failed, host_records
356
+ nexperiments = len(experiments)
357
+ total_sent = 0
358
+ total_done = 0
359
+ total_failed = 0
360
+ host_records = {}
361
+
362
+ q = Queue()
363
+
364
+ for host in thread_hosts:
365
+ worker = Thread(target=__thread_worker, args=(q, host))
366
+ worker.setDaemon(True)
367
+ worker.start()
368
+
369
+ for experiment in experiments:
370
+ q.put(experiment)
371
+
372
+ q.join()
373
+
374
+ def __run_remote_experiment(experiment, remote, cache_dir="/tmp", python_binary="python"):
375
+ global total_failed
376
+ dataset_name, model_name, method_name, metric_name = experiment
377
+
378
+ # see if we have a cached version
379
+ cache_id = __gen_cache_id(experiment)
380
+ cache_file = os.path.join(cache_dir, cache_id + ".pickle")
381
+ if os.path.isfile(cache_file):
382
+ with open(cache_file, "rb") as f:
383
+ return pickle.load(f)
384
+
385
+ # this is just so we don't dump everything at once on a machine
386
+ time.sleep(random.uniform(0,5))
387
+
388
+ # run the benchmark on the remote machine
389
+ #start = time.time()
390
+ cmd = "CUDA_VISIBLE_DEVICES=\"\" "+python_binary+" -c \"import shap; shap.benchmark.run_experiment(['{}', '{}', '{}', '{}'], cache_dir='{}')\" &> {}/{}.output".format(
391
+ dataset_name, model_name, method_name, metric_name, cache_dir, cache_dir, cache_id
392
+ )
393
+ try:
394
+ subprocess.check_output(["ssh", remote, cmd])
395
+ except subprocess.CalledProcessError as e:
396
+ print("The following command failed on %s:" % remote, file=sys.stderr)
397
+ print(cmd, file=sys.stderr)
398
+ total_failed += 1
399
+ print(e)
400
+ return
401
+
402
+ # copy the results back
403
+ subprocess.check_output(["scp", remote+":"+cache_file, cache_file])
404
+
405
+ if os.path.isfile(cache_file):
406
+ with open(cache_file, "rb") as f:
407
+ #print(cache_id.replace("__", " ") + " ...loaded from remote after %f seconds" % (time.time() - start))
408
+ return pickle.load(f)
409
+ else:
410
+ raise FileNotFoundError("Remote benchmark call finished but no local file was found!")
411
+
412
+ def __gen_cache_id(experiment):
413
+ dataset_name, model_name, method_name, metric_name = experiment
414
+ return "v" + "__".join([__version__, dataset_name, model_name, method_name, metric_name])
lib/shap/benchmark/framework.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools as it
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ from . import perturbation
8
+
9
+
10
+ def update(model, attributions, X, y, masker, sort_order, perturbation_method, scores):
11
+ metric = perturbation_method + ' ' + sort_order
12
+ sp = perturbation.SequentialPerturbation(model, masker, sort_order, perturbation_method)
13
+ xs, ys, auc = sp.model_score(attributions, X, y=y)
14
+ scores['metrics'].append(metric)
15
+ scores['values'][metric] = [xs, ys, auc]
16
+
17
+ def get_benchmark(model, attributions, X, y, masker, metrics):
18
+ # convert dataframes
19
+ if isinstance(X, (pd.Series, pd.DataFrame)):
20
+ X = X.values
21
+ if isinstance(masker, (pd.Series, pd.DataFrame)):
22
+ masker = masker.values
23
+
24
+ # record scores per metric
25
+ scores = {'metrics': list(), 'values': dict()}
26
+ for sort_order, perturbation_method in list(it.product(metrics['sort_order'], metrics['perturbation'])):
27
+ update(model, attributions, X, y, masker, sort_order, perturbation_method, scores)
28
+
29
+ return scores
30
+
31
+ def get_metrics(benchmarks, selection):
32
+ # select metrics to plot using selection function
33
+ explainer_metrics = set()
34
+ for explainer in benchmarks:
35
+ scores = benchmarks[explainer]
36
+ if len(explainer_metrics) == 0:
37
+ explainer_metrics = set(scores['metrics'])
38
+ else:
39
+ explainer_metrics = selection(explainer_metrics, set(scores['metrics']))
40
+
41
+ return list(explainer_metrics)
42
+
43
+ def trend_plot(benchmarks):
44
+ explainer_metrics = get_metrics(benchmarks, lambda x, y: x.union(y))
45
+
46
+ # plot all curves if metric exists
47
+ for metric in explainer_metrics:
48
+ plt.clf()
49
+
50
+ for explainer in benchmarks:
51
+ scores = benchmarks[explainer]
52
+ if metric in scores['values']:
53
+ x, y, auc = scores['values'][metric]
54
+ plt.plot(x, y, label=f'{round(auc, 3)} - {explainer}')
55
+
56
+ if 'keep' in metric:
57
+ xlabel = 'Percent Unmasked'
58
+ if 'remove' in metric:
59
+ xlabel = 'Percent Masked'
60
+
61
+ plt.ylabel('Model Output')
62
+ plt.xlabel(xlabel)
63
+ plt.title(metric)
64
+ plt.legend()
65
+ plt.show()
66
+
67
+ def compare_plot(benchmarks):
68
+ explainer_metrics = get_metrics(benchmarks, lambda x, y: x.intersection(y))
69
+ explainers = list(benchmarks.keys())
70
+ num_explainers = len(explainers)
71
+ num_metrics = len(explainer_metrics)
72
+
73
+ # dummy start to evenly distribute explainers on the left
74
+ # can later be replaced by boolean metrics
75
+ aucs = dict()
76
+ for i in range(num_explainers):
77
+ explainer = explainers[i]
78
+ aucs[explainer] = [i/(num_explainers-1)]
79
+
80
+ # normalize per metric
81
+ for metric in explainer_metrics:
82
+ max_auc, min_auc = -float('inf'), float('inf')
83
+
84
+ for explainer in explainers:
85
+ scores = benchmarks[explainer]
86
+ _, _, auc = scores['values'][metric]
87
+ min_auc = min(auc, min_auc)
88
+ max_auc = max(auc, max_auc)
89
+
90
+ for explainer in explainers:
91
+ scores = benchmarks[explainer]
92
+ _, _, auc = scores['values'][metric]
93
+ aucs[explainer].append((auc-min_auc)/(max_auc-min_auc))
94
+
95
+ # plot common curves
96
+ ax = plt.gca()
97
+ for explainer in explainers:
98
+ plt.plot(np.linspace(0, 1, len(explainer_metrics)+1), aucs[explainer], '--o')
99
+
100
+ ax.tick_params(which='major', axis='both', labelsize=8)
101
+
102
+ ax.set_yticks([i/(num_explainers-1) for i in range(0, num_explainers)])
103
+ ax.set_yticklabels(explainers, rotation=0)
104
+
105
+ ax.set_xticks(np.linspace(0, 1, num_metrics+1))
106
+ ax.set_xticklabels([' '] + explainer_metrics, rotation=45, ha='right')
107
+
108
+ plt.grid(which='major', axis='x', linestyle='--')
109
+ plt.tight_layout()
110
+ plt.ylabel('Relative Performance of Each Explanation Method')
111
+ plt.xlabel('Evaluation Metrics')
112
+ plt.title('Explanation Method Performance Across Metrics')
113
+ plt.show()
lib/shap/benchmark/measures.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ import sklearn.utils
6
+ from tqdm.auto import tqdm
7
+
8
+ _remove_cache = {}
9
+ def remove_retrain(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
10
+ """ The model is retrained for each test sample with the important features set to a constant.
11
+
12
+ If you want to know how important a set of features is you can ask how the model would be
13
+ different if those features had never existed. To determine this we can mask those features
14
+ across the entire training and test datasets, then retrain the model. If we apply compare the
15
+ output of this retrained model to the original model we can see the effect produced by knowning
16
+ the features we masked. Since for individualized explanation methods each test sample has a
17
+ different set of most important features we need to retrain the model for every test sample
18
+ to get the change in model performance when a specified fraction of the most important features
19
+ are withheld.
20
+ """
21
+
22
+ warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
23
+
24
+ # see if we match the last cached call
25
+ global _remove_cache
26
+ args = (X_train, y_train, X_test, y_test, model_generator, metric)
27
+ cache_match = False
28
+ if "args" in _remove_cache:
29
+ if all(a is b for a,b in zip(_remove_cache["args"], args)) and np.all(_remove_cache["attr_test"] == attr_test):
30
+ cache_match = True
31
+
32
+ X_train, X_test = to_array(X_train, X_test)
33
+
34
+ # how many features to mask
35
+ assert X_train.shape[1] == X_test.shape[1]
36
+
37
+ # this is the model we will retrain many times
38
+ model_masked = model_generator()
39
+
40
+ # mask nmask top features and re-train the model for each test explanation
41
+ X_train_tmp = np.zeros(X_train.shape)
42
+ X_test_tmp = np.zeros(X_test.shape)
43
+ yp_masked_test = np.zeros(y_test.shape)
44
+ tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
45
+ last_nmask = _remove_cache.get("nmask", None)
46
+ last_yp_masked_test = _remove_cache.get("yp_masked_test", None)
47
+ for i in tqdm(range(len(y_test)), "Retraining for the 'remove' metric"):
48
+ if cache_match and last_nmask[i] == nmask[i]:
49
+ yp_masked_test[i] = last_yp_masked_test[i]
50
+ elif nmask[i] == 0:
51
+ yp_masked_test[i] = trained_model.predict(X_test[i:i+1])[0]
52
+ else:
53
+ # mask out the most important features for this test instance
54
+ X_train_tmp[:] = X_train
55
+ X_test_tmp[:] = X_test
56
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
57
+ X_train_tmp[:,ordering[:nmask[i]]] = X_train[:,ordering[:nmask[i]]].mean()
58
+ X_test_tmp[i,ordering[:nmask[i]]] = X_train[:,ordering[:nmask[i]]].mean()
59
+
60
+ # retrain the model and make a prediction
61
+ model_masked.fit(X_train_tmp, y_train)
62
+ yp_masked_test[i] = model_masked.predict(X_test_tmp[i:i+1])[0]
63
+
64
+ # save our results so the next call to us can be faster when there is redundancy
65
+ _remove_cache["nmask"] = nmask
66
+ _remove_cache["yp_masked_test"] = yp_masked_test
67
+ _remove_cache["attr_test"] = attr_test
68
+ _remove_cache["args"] = args
69
+
70
+ return metric(y_test, yp_masked_test)
71
+
72
+ def remove_mask(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
73
+ """ Each test sample is masked by setting the important features to a constant.
74
+ """
75
+
76
+ X_train, X_test = to_array(X_train, X_test)
77
+
78
+ # how many features to mask
79
+ assert X_train.shape[1] == X_test.shape[1]
80
+
81
+ # mask nmask top features for each test explanation
82
+ X_test_tmp = X_test.copy()
83
+ tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
84
+ mean_vals = X_train.mean(0)
85
+ for i in range(len(y_test)):
86
+ if nmask[i] > 0:
87
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
88
+ X_test_tmp[i,ordering[:nmask[i]]] = mean_vals[ordering[:nmask[i]]]
89
+
90
+ yp_masked_test = trained_model.predict(X_test_tmp)
91
+
92
+ return metric(y_test, yp_masked_test)
93
+
94
+ def remove_impute(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
95
+ """ The model is reevaluated for each test sample with the important features set to an imputed value.
96
+
97
+ Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on
98
+ being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should
99
+ be significantly bigger than X_train.shape[1].
100
+ """
101
+
102
+ X_train, X_test = to_array(X_train, X_test)
103
+
104
+ # how many features to mask
105
+ assert X_train.shape[1] == X_test.shape[1]
106
+
107
+ # keep nkeep top features for each test explanation
108
+ C = np.cov(X_train.T)
109
+ C += np.eye(C.shape[0]) * 1e-6
110
+ X_test_tmp = X_test.copy()
111
+ yp_masked_test = np.zeros(y_test.shape)
112
+ tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
113
+ mean_vals = X_train.mean(0)
114
+ for i in range(len(y_test)):
115
+ if nmask[i] > 0:
116
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
117
+ observe_inds = ordering[nmask[i]:]
118
+ impute_inds = ordering[:nmask[i]]
119
+
120
+ # impute missing data assuming it follows a multivariate normal distribution
121
+ Coo_inv = np.linalg.inv(C[observe_inds,:][:,observe_inds])
122
+ Cio = C[impute_inds,:][:,observe_inds]
123
+ impute = mean_vals[impute_inds] + Cio @ Coo_inv @ (X_test[i, observe_inds] - mean_vals[observe_inds])
124
+
125
+ X_test_tmp[i, impute_inds] = impute
126
+
127
+ yp_masked_test = trained_model.predict(X_test_tmp)
128
+
129
+ return metric(y_test, yp_masked_test)
130
+
131
+ def remove_resample(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
132
+ """ The model is reevaluated for each test sample with the important features set to resample background values.
133
+ """
134
+
135
+ X_train, X_test = to_array(X_train, X_test)
136
+
137
+ # how many features to mask
138
+ assert X_train.shape[1] == X_test.shape[1]
139
+
140
+ # how many samples to take
141
+ nsamples = 100
142
+
143
+ # keep nkeep top features for each test explanation
144
+ N,M = X_test.shape
145
+ X_test_tmp = np.tile(X_test, [1, nsamples]).reshape(nsamples * N, M)
146
+ tie_breaking_noise = const_rand(M) * 1e-6
147
+ inds = sklearn.utils.resample(np.arange(N), n_samples=nsamples, random_state=random_state)
148
+ for i in range(N):
149
+ if nmask[i] > 0:
150
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
151
+ X_test_tmp[i*nsamples:(i+1)*nsamples, ordering[:nmask[i]]] = X_train[inds, :][:, ordering[:nmask[i]]]
152
+
153
+ yp_masked_test = trained_model.predict(X_test_tmp)
154
+ yp_masked_test = np.reshape(yp_masked_test, (N, nsamples)).mean(1) # take the mean output over all samples
155
+
156
+ return metric(y_test, yp_masked_test)
157
+
158
+ def batch_remove_retrain(nmask_train, nmask_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric):
159
+ """ An approximation of holdout that only retraines the model once.
160
+
161
+ This is also called ROAR (RemOve And Retrain) in work by Google. It is much more computationally
162
+ efficient that the holdout method because it masks the most important features in every sample
163
+ and then retrains the model once, instead of retraining the model for every test sample like
164
+ the holdout metric.
165
+ """
166
+
167
+ warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
168
+
169
+ X_train, X_test = to_array(X_train, X_test)
170
+
171
+ # how many features to mask
172
+ assert X_train.shape[1] == X_test.shape[1]
173
+
174
+ # mask nmask top features for each explanation
175
+ X_train_tmp = X_train.copy()
176
+ X_train_mean = X_train.mean(0)
177
+ tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
178
+ for i in range(len(y_train)):
179
+ if nmask_train[i] > 0:
180
+ ordering = np.argsort(-attr_train[i, :] + tie_breaking_noise)
181
+ X_train_tmp[i, ordering[:nmask_train[i]]] = X_train_mean[ordering[:nmask_train[i]]]
182
+ X_test_tmp = X_test.copy()
183
+ for i in range(len(y_test)):
184
+ if nmask_test[i] > 0:
185
+ ordering = np.argsort(-attr_test[i, :] + tie_breaking_noise)
186
+ X_test_tmp[i, ordering[:nmask_test[i]]] = X_train_mean[ordering[:nmask_test[i]]]
187
+
188
+ # train the model with all the given features masked
189
+ model_masked = model_generator()
190
+ model_masked.fit(X_train_tmp, y_train)
191
+ yp_test_masked = model_masked.predict(X_test_tmp)
192
+
193
+ return metric(y_test, yp_test_masked)
194
+
195
+ _keep_cache = {}
196
+ def keep_retrain(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
197
+ """ The model is retrained for each test sample with the non-important features set to a constant.
198
+
199
+ If you want to know how important a set of features is you can ask how the model would be
200
+ different if only those features had existed. To determine this we can mask the other features
201
+ across the entire training and test datasets, then retrain the model. If we apply compare the
202
+ output of this retrained model to the original model we can see the effect produced by only
203
+ knowning the important features. Since for individualized explanation methods each test sample
204
+ has a different set of most important features we need to retrain the model for every test sample
205
+ to get the change in model performance when a specified fraction of the most important features
206
+ are retained.
207
+ """
208
+
209
+ warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
210
+
211
+ # see if we match the last cached call
212
+ global _keep_cache
213
+ args = (X_train, y_train, X_test, y_test, model_generator, metric)
214
+ cache_match = False
215
+ if "args" in _keep_cache:
216
+ if all(a is b for a,b in zip(_keep_cache["args"], args)) and np.all(_keep_cache["attr_test"] == attr_test):
217
+ cache_match = True
218
+
219
+ X_train, X_test = to_array(X_train, X_test)
220
+
221
+ # how many features to mask
222
+ assert X_train.shape[1] == X_test.shape[1]
223
+
224
+ # this is the model we will retrain many times
225
+ model_masked = model_generator()
226
+
227
+ # keep nkeep top features and re-train the model for each test explanation
228
+ X_train_tmp = np.zeros(X_train.shape)
229
+ X_test_tmp = np.zeros(X_test.shape)
230
+ yp_masked_test = np.zeros(y_test.shape)
231
+ tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
232
+ last_nkeep = _keep_cache.get("nkeep", None)
233
+ last_yp_masked_test = _keep_cache.get("yp_masked_test", None)
234
+ for i in tqdm(range(len(y_test)), "Retraining for the 'keep' metric"):
235
+ if cache_match and last_nkeep[i] == nkeep[i]:
236
+ yp_masked_test[i] = last_yp_masked_test[i]
237
+ elif nkeep[i] == attr_test.shape[1]:
238
+ yp_masked_test[i] = trained_model.predict(X_test[i:i+1])[0]
239
+ else:
240
+
241
+ # mask out the most important features for this test instance
242
+ X_train_tmp[:] = X_train
243
+ X_test_tmp[:] = X_test
244
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
245
+ X_train_tmp[:,ordering[nkeep[i]:]] = X_train[:,ordering[nkeep[i]:]].mean()
246
+ X_test_tmp[i,ordering[nkeep[i]:]] = X_train[:,ordering[nkeep[i]:]].mean()
247
+
248
+ # retrain the model and make a prediction
249
+ model_masked.fit(X_train_tmp, y_train)
250
+ yp_masked_test[i] = model_masked.predict(X_test_tmp[i:i+1])[0]
251
+
252
+ # save our results so the next call to us can be faster when there is redundancy
253
+ _keep_cache["nkeep"] = nkeep
254
+ _keep_cache["yp_masked_test"] = yp_masked_test
255
+ _keep_cache["attr_test"] = attr_test
256
+ _keep_cache["args"] = args
257
+
258
+ return metric(y_test, yp_masked_test)
259
+
260
+ def keep_mask(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
261
+ """ The model is reevaluated for each test sample with the non-important features set to their mean.
262
+ """
263
+
264
+ X_train, X_test = to_array(X_train, X_test)
265
+
266
+ # how many features to mask
267
+ assert X_train.shape[1] == X_test.shape[1]
268
+
269
+ # keep nkeep top features for each test explanation
270
+ X_test_tmp = X_test.copy()
271
+ yp_masked_test = np.zeros(y_test.shape)
272
+ tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
273
+ mean_vals = X_train.mean(0)
274
+ for i in range(len(y_test)):
275
+ if nkeep[i] < X_test.shape[1]:
276
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
277
+ X_test_tmp[i,ordering[nkeep[i]:]] = mean_vals[ordering[nkeep[i]:]]
278
+
279
+ yp_masked_test = trained_model.predict(X_test_tmp)
280
+
281
+ return metric(y_test, yp_masked_test)
282
+
283
+ def keep_impute(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
284
+ """ The model is reevaluated for each test sample with the non-important features set to an imputed value.
285
+
286
+ Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on
287
+ being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should
288
+ be significantly bigger than X_train.shape[1].
289
+ """
290
+
291
+ X_train, X_test = to_array(X_train, X_test)
292
+
293
+ # how many features to mask
294
+ assert X_train.shape[1] == X_test.shape[1]
295
+
296
+ # keep nkeep top features for each test explanation
297
+ C = np.cov(X_train.T)
298
+ C += np.eye(C.shape[0]) * 1e-6
299
+ X_test_tmp = X_test.copy()
300
+ yp_masked_test = np.zeros(y_test.shape)
301
+ tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
302
+ mean_vals = X_train.mean(0)
303
+ for i in range(len(y_test)):
304
+ if nkeep[i] < X_test.shape[1]:
305
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
306
+ observe_inds = ordering[:nkeep[i]]
307
+ impute_inds = ordering[nkeep[i]:]
308
+
309
+ # impute missing data assuming it follows a multivariate normal distribution
310
+ Coo_inv = np.linalg.inv(C[observe_inds,:][:,observe_inds])
311
+ Cio = C[impute_inds,:][:,observe_inds]
312
+ impute = mean_vals[impute_inds] + Cio @ Coo_inv @ (X_test[i, observe_inds] - mean_vals[observe_inds])
313
+
314
+ X_test_tmp[i, impute_inds] = impute
315
+
316
+ yp_masked_test = trained_model.predict(X_test_tmp)
317
+
318
+ return metric(y_test, yp_masked_test)
319
+
320
+ def keep_resample(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
321
+ """ The model is reevaluated for each test sample with the non-important features set to resample background values.
322
+ """ # why broken? overwriting?
323
+
324
+ X_train, X_test = to_array(X_train, X_test)
325
+
326
+ # how many features to mask
327
+ assert X_train.shape[1] == X_test.shape[1]
328
+
329
+ # how many samples to take
330
+ nsamples = 100
331
+
332
+ # keep nkeep top features for each test explanation
333
+ N,M = X_test.shape
334
+ X_test_tmp = np.tile(X_test, [1, nsamples]).reshape(nsamples * N, M)
335
+ tie_breaking_noise = const_rand(M) * 1e-6
336
+ inds = sklearn.utils.resample(np.arange(N), n_samples=nsamples, random_state=random_state)
337
+ for i in range(N):
338
+ if nkeep[i] < M:
339
+ ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
340
+ X_test_tmp[i*nsamples:(i+1)*nsamples, ordering[nkeep[i]:]] = X_train[inds, :][:, ordering[nkeep[i]:]]
341
+
342
+ yp_masked_test = trained_model.predict(X_test_tmp)
343
+ yp_masked_test = np.reshape(yp_masked_test, (N, nsamples)).mean(1) # take the mean output over all samples
344
+
345
+ return metric(y_test, yp_masked_test)
346
+
347
+ def batch_keep_retrain(nkeep_train, nkeep_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric):
348
+ """ An approximation of keep that only retraines the model once.
349
+
350
+ This is also called KAR (Keep And Retrain) in work by Google. It is much more computationally
351
+ efficient that the keep method because it masks the unimportant features in every sample
352
+ and then retrains the model once, instead of retraining the model for every test sample like
353
+ the keep metric.
354
+ """
355
+
356
+ warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
357
+
358
+ X_train, X_test = to_array(X_train, X_test)
359
+
360
+ # how many features to mask
361
+ assert X_train.shape[1] == X_test.shape[1]
362
+
363
+ # mask nkeep top features for each explanation
364
+ X_train_tmp = X_train.copy()
365
+ X_train_mean = X_train.mean(0)
366
+ tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
367
+ for i in range(len(y_train)):
368
+ if nkeep_train[i] < X_train.shape[1]:
369
+ ordering = np.argsort(-attr_train[i, :] + tie_breaking_noise)
370
+ X_train_tmp[i, ordering[nkeep_train[i]:]] = X_train_mean[ordering[nkeep_train[i]:]]
371
+ X_test_tmp = X_test.copy()
372
+ for i in range(len(y_test)):
373
+ if nkeep_test[i] < X_test.shape[1]:
374
+ ordering = np.argsort(-attr_test[i, :] + tie_breaking_noise)
375
+ X_test_tmp[i, ordering[nkeep_test[i]:]] = X_train_mean[ordering[nkeep_test[i]:]]
376
+
377
+ # train the model with all the features not given masked
378
+ model_masked = model_generator()
379
+ model_masked.fit(X_train_tmp, y_train)
380
+ yp_test_masked = model_masked.predict(X_test_tmp)
381
+
382
+ return metric(y_test, yp_test_masked)
383
+
384
+ def local_accuracy(X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model):
385
+ """ The how well do the features plus a constant base rate sum up to the model output.
386
+ """
387
+
388
+ X_train, X_test = to_array(X_train, X_test)
389
+
390
+ # how many features to mask
391
+ assert X_train.shape[1] == X_test.shape[1]
392
+
393
+ # keep nkeep top features and re-train the model for each test explanation
394
+ yp_test = trained_model.predict(X_test)
395
+
396
+ return metric(yp_test, strip_list(attr_test).sum(1))
397
+
398
+ def to_array(*args):
399
+ return [a.values if isinstance(a, pd.DataFrame) else a for a in args]
400
+
401
+ def const_rand(size, seed=23980):
402
+ """ Generate a random array with a fixed seed.
403
+ """
404
+ old_seed = np.random.seed()
405
+ np.random.seed(seed)
406
+ out = np.random.rand(size)
407
+ np.random.seed(old_seed)
408
+ return out
409
+
410
+ def const_shuffle(arr, seed=23980):
411
+ """ Shuffle an array in-place with a fixed seed.
412
+ """
413
+ old_seed = np.random.seed()
414
+ np.random.seed(seed)
415
+ np.random.shuffle(arr)
416
+ np.random.seed(old_seed)
417
+
418
+ def strip_list(attrs):
419
+ """ This assumes that if you have a list of outputs you just want the second one (the second class is the '1' class).
420
+ """
421
+ if isinstance(attrs, list):
422
+ return attrs[1]
423
+ else:
424
+ return attrs
lib/shap/benchmark/methods.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sklearn
3
+
4
+ from .. import (
5
+ DeepExplainer,
6
+ GradientExplainer,
7
+ KernelExplainer,
8
+ LinearExplainer,
9
+ SamplingExplainer,
10
+ TreeExplainer,
11
+ kmeans,
12
+ )
13
+ from ..explainers import other
14
+ from .models import KerasWrap
15
+
16
+
17
+ def linear_shap_corr(model, data):
18
+ """ Linear SHAP (corr 1000)
19
+ """
20
+ return LinearExplainer(model, data, feature_dependence="correlation", nsamples=1000).shap_values
21
+
22
+ def linear_shap_ind(model, data):
23
+ """ Linear SHAP (ind)
24
+ """
25
+ return LinearExplainer(model, data, feature_dependence="independent").shap_values
26
+
27
+ def coef(model, data):
28
+ """ Coefficients
29
+ """
30
+ return other.CoefficentExplainer(model).attributions
31
+
32
+ def random(model, data):
33
+ """ Random
34
+ color = #777777
35
+ linestyle = solid
36
+ """
37
+ return other.RandomExplainer().attributions
38
+
39
+ def kernel_shap_1000_meanref(model, data):
40
+ """ Kernel SHAP 1000 mean ref.
41
+ color = red_blue_circle(0.5)
42
+ linestyle = solid
43
+ """
44
+ return lambda X: KernelExplainer(model.predict, kmeans(data, 1)).shap_values(X, nsamples=1000, l1_reg=0)
45
+
46
+ def sampling_shap_1000(model, data):
47
+ """ IME 1000
48
+ color = red_blue_circle(0.5)
49
+ linestyle = dashed
50
+ """
51
+ return lambda X: SamplingExplainer(model.predict, data).shap_values(X, nsamples=1000)
52
+
53
+ def tree_shap_tree_path_dependent(model, data):
54
+ """ TreeExplainer
55
+ color = red_blue_circle(0)
56
+ linestyle = solid
57
+ """
58
+ return TreeExplainer(model, feature_dependence="tree_path_dependent").shap_values
59
+
60
+ def tree_shap_independent_200(model, data):
61
+ """ TreeExplainer (independent)
62
+ color = red_blue_circle(0)
63
+ linestyle = dashed
64
+ """
65
+ data_subsample = sklearn.utils.resample(data, replace=False, n_samples=min(200, data.shape[0]), random_state=0)
66
+ return TreeExplainer(model, data_subsample, feature_dependence="independent").shap_values
67
+
68
+ def mean_abs_tree_shap(model, data):
69
+ """ mean(|TreeExplainer|)
70
+ color = red_blue_circle(0.25)
71
+ linestyle = solid
72
+ """
73
+ def f(X):
74
+ v = TreeExplainer(model).shap_values(X)
75
+ if isinstance(v, list):
76
+ return [np.tile(np.abs(sv).mean(0), (X.shape[0], 1)) for sv in v]
77
+ else:
78
+ return np.tile(np.abs(v).mean(0), (X.shape[0], 1))
79
+ return f
80
+
81
+ def saabas(model, data):
82
+ """ Saabas
83
+ color = red_blue_circle(0)
84
+ linestyle = dotted
85
+ """
86
+ return lambda X: TreeExplainer(model).shap_values(X, approximate=True)
87
+
88
+ def tree_gain(model, data):
89
+ """ Gain/Gini Importance
90
+ color = red_blue_circle(0.25)
91
+ linestyle = dotted
92
+ """
93
+ return other.TreeGainExplainer(model).attributions
94
+
95
+ def lime_tabular_regression_1000(model, data):
96
+ """ LIME Tabular 1000
97
+ color = red_blue_circle(0.75)
98
+ """
99
+ return lambda X: other.LimeTabularExplainer(model.predict, data, mode="regression").attributions(X, nsamples=1000)
100
+
101
+ def lime_tabular_classification_1000(model, data):
102
+ """ LIME Tabular 1000
103
+ color = red_blue_circle(0.75)
104
+ """
105
+ return lambda X: other.LimeTabularExplainer(model.predict_proba, data, mode="classification").attributions(X, nsamples=1000)[1]
106
+
107
+ def maple(model, data):
108
+ """ MAPLE
109
+ color = red_blue_circle(0.6)
110
+ """
111
+ return lambda X: other.MapleExplainer(model.predict, data).attributions(X, multiply_by_input=False)
112
+
113
+ def tree_maple(model, data):
114
+ """ Tree MAPLE
115
+ color = red_blue_circle(0.6)
116
+ linestyle = dashed
117
+ """
118
+ return lambda X: other.TreeMapleExplainer(model, data).attributions(X, multiply_by_input=False)
119
+
120
+ def deep_shap(model, data):
121
+ """ Deep SHAP (DeepLIFT)
122
+ """
123
+ if isinstance(model, KerasWrap):
124
+ model = model.model
125
+ explainer = DeepExplainer(model, kmeans(data, 1).data)
126
+ def f(X):
127
+ phi = explainer.shap_values(X)
128
+ if isinstance(phi, list) and len(phi) == 1:
129
+ return phi[0]
130
+ else:
131
+ return phi
132
+
133
+ return f
134
+
135
+ def expected_gradients(model, data):
136
+ """ Expected Gradients
137
+ """
138
+ if isinstance(model, KerasWrap):
139
+ model = model.model
140
+ explainer = GradientExplainer(model, data)
141
+ def f(X):
142
+ phi = explainer.shap_values(X)
143
+ if isinstance(phi, list) and len(phi) == 1:
144
+ return phi[0]
145
+ else:
146
+ return phi
147
+
148
+ return f
lib/shap/benchmark/metrics.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import time
4
+
5
+ import numpy as np
6
+ import sklearn
7
+
8
+ from .. import __version__
9
+ from . import measures, methods
10
+
11
+ try:
12
+ import dill as pickle
13
+ except Exception:
14
+ pass
15
+
16
+ try:
17
+ from sklearn.model_selection import train_test_split
18
+ except Exception:
19
+ from sklearn.cross_validation import train_test_split
20
+
21
+
22
+ def runtime(X, y, model_generator, method_name):
23
+ """ Runtime (sec / 1k samples)
24
+ transform = "negate_log"
25
+ sort_order = 2
26
+ """
27
+
28
+ old_seed = np.random.seed()
29
+ np.random.seed(3293)
30
+
31
+ # average the method scores over several train/test splits
32
+ method_reps = []
33
+ for i in range(3):
34
+ X_train, X_test, y_train, _ = train_test_split(__toarray(X), y, test_size=100, random_state=i)
35
+
36
+ # define the model we are going to explain
37
+ model = model_generator()
38
+ model.fit(X_train, y_train)
39
+
40
+ # evaluate each method
41
+ start = time.time()
42
+ explainer = getattr(methods, method_name)(model, X_train)
43
+ build_time = time.time() - start
44
+
45
+ start = time.time()
46
+ explainer(X_test)
47
+ explain_time = time.time() - start
48
+
49
+ # we always normalize the explain time as though we were explaining 1000 samples
50
+ # even if to reduce the runtime of the benchmark we do less (like just 100)
51
+ method_reps.append(build_time + explain_time * 1000.0 / X_test.shape[0])
52
+ np.random.seed(old_seed)
53
+
54
+ return None, np.mean(method_reps)
55
+
56
+ def local_accuracy(X, y, model_generator, method_name):
57
+ """ Local Accuracy
58
+ transform = "identity"
59
+ sort_order = 0
60
+ """
61
+
62
+ def score_map(true, pred):
63
+ """ Computes local accuracy as the normalized standard deviation of numerical scores.
64
+ """
65
+ return np.std(pred - true) / (np.std(true) + 1e-6)
66
+
67
+ def score_function(X_train, X_test, y_train, y_test, attr_function, trained_model, random_state):
68
+ return measures.local_accuracy(
69
+ X_train, y_train, X_test, y_test, attr_function(X_test),
70
+ model_generator, score_map, trained_model
71
+ )
72
+ return None, __score_method(X, y, None, model_generator, score_function, method_name)
73
+
74
+ def consistency_guarantees(X, y, model_generator, method_name):
75
+ """ Consistency Guarantees
76
+ transform = "identity"
77
+ sort_order = 1
78
+ """
79
+
80
+ # 1.0 - perfect consistency
81
+ # 0.8 - guarantees depend on sampling
82
+ # 0.6 - guarantees depend on approximation
83
+ # 0.0 - no garuntees
84
+ guarantees = {
85
+ "linear_shap_corr": 1.0,
86
+ "linear_shap_ind": 1.0,
87
+ "coef": 0.0,
88
+ "kernel_shap_1000_meanref": 0.8,
89
+ "sampling_shap_1000": 0.8,
90
+ "random": 0.0,
91
+ "saabas": 0.0,
92
+ "tree_gain": 0.0,
93
+ "tree_shap_tree_path_dependent": 1.0,
94
+ "tree_shap_independent_200": 1.0,
95
+ "mean_abs_tree_shap": 1.0,
96
+ "lime_tabular_regression_1000": 0.8,
97
+ "lime_tabular_classification_1000": 0.8,
98
+ "maple": 0.8,
99
+ "tree_maple": 0.8,
100
+ "deep_shap": 0.6,
101
+ "expected_gradients": 0.6
102
+ }
103
+
104
+ return None, guarantees[method_name]
105
+
106
+ def __mean_pred(true, pred):
107
+ """ A trivial metric that is just is the output of the model.
108
+ """
109
+ return np.mean(pred)
110
+
111
+ def keep_positive_mask(X, y, model_generator, method_name, num_fcounts=11):
112
+ """ Keep Positive (mask)
113
+ xlabel = "Max fraction of features kept"
114
+ ylabel = "Mean model output"
115
+ transform = "identity"
116
+ sort_order = 4
117
+ """
118
+ return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
119
+
120
+ def keep_negative_mask(X, y, model_generator, method_name, num_fcounts=11):
121
+ """ Keep Negative (mask)
122
+ xlabel = "Max fraction of features kept"
123
+ ylabel = "Negative mean model output"
124
+ transform = "negate"
125
+ sort_order = 5
126
+ """
127
+ return __run_measure(measures.keep_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
128
+
129
+ def keep_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11):
130
+ """ Keep Absolute (mask)
131
+ xlabel = "Max fraction of features kept"
132
+ ylabel = "R^2"
133
+ transform = "identity"
134
+ sort_order = 6
135
+ """
136
+ return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
137
+
138
+ def keep_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
139
+ """ Keep Absolute (mask)
140
+ xlabel = "Max fraction of features kept"
141
+ ylabel = "ROC AUC"
142
+ transform = "identity"
143
+ sort_order = 6
144
+ """
145
+ return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
146
+
147
+ def remove_positive_mask(X, y, model_generator, method_name, num_fcounts=11):
148
+ """ Remove Positive (mask)
149
+ xlabel = "Max fraction of features removed"
150
+ ylabel = "Negative mean model output"
151
+ transform = "negate"
152
+ sort_order = 7
153
+ """
154
+ return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
155
+
156
+ def remove_negative_mask(X, y, model_generator, method_name, num_fcounts=11):
157
+ """ Remove Negative (mask)
158
+ xlabel = "Max fraction of features removed"
159
+ ylabel = "Mean model output"
160
+ transform = "identity"
161
+ sort_order = 8
162
+ """
163
+ return __run_measure(measures.remove_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
164
+
165
+ def remove_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11):
166
+ """ Remove Absolute (mask)
167
+ xlabel = "Max fraction of features removed"
168
+ ylabel = "1 - R^2"
169
+ transform = "one_minus"
170
+ sort_order = 9
171
+ """
172
+ return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
173
+
174
+ def remove_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
175
+ """ Remove Absolute (mask)
176
+ xlabel = "Max fraction of features removed"
177
+ ylabel = "1 - ROC AUC"
178
+ transform = "one_minus"
179
+ sort_order = 9
180
+ """
181
+ return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
182
+
183
+ def keep_positive_resample(X, y, model_generator, method_name, num_fcounts=11):
184
+ """ Keep Positive (resample)
185
+ xlabel = "Max fraction of features kept"
186
+ ylabel = "Mean model output"
187
+ transform = "identity"
188
+ sort_order = 10
189
+ """
190
+ return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
191
+
192
+ def keep_negative_resample(X, y, model_generator, method_name, num_fcounts=11):
193
+ """ Keep Negative (resample)
194
+ xlabel = "Max fraction of features kept"
195
+ ylabel = "Negative mean model output"
196
+ transform = "negate"
197
+ sort_order = 11
198
+ """
199
+ return __run_measure(measures.keep_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
200
+
201
+ def keep_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11):
202
+ """ Keep Absolute (resample)
203
+ xlabel = "Max fraction of features kept"
204
+ ylabel = "R^2"
205
+ transform = "identity"
206
+ sort_order = 12
207
+ """
208
+ return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
209
+
210
+ def keep_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
211
+ """ Keep Absolute (resample)
212
+ xlabel = "Max fraction of features kept"
213
+ ylabel = "ROC AUC"
214
+ transform = "identity"
215
+ sort_order = 12
216
+ """
217
+ return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
218
+
219
+ def remove_positive_resample(X, y, model_generator, method_name, num_fcounts=11):
220
+ """ Remove Positive (resample)
221
+ xlabel = "Max fraction of features removed"
222
+ ylabel = "Negative mean model output"
223
+ transform = "negate"
224
+ sort_order = 13
225
+ """
226
+ return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
227
+
228
+ def remove_negative_resample(X, y, model_generator, method_name, num_fcounts=11):
229
+ """ Remove Negative (resample)
230
+ xlabel = "Max fraction of features removed"
231
+ ylabel = "Mean model output"
232
+ transform = "identity"
233
+ sort_order = 14
234
+ """
235
+ return __run_measure(measures.remove_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
236
+
237
+ def remove_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11):
238
+ """ Remove Absolute (resample)
239
+ xlabel = "Max fraction of features removed"
240
+ ylabel = "1 - R^2"
241
+ transform = "one_minus"
242
+ sort_order = 15
243
+ """
244
+ return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
245
+
246
+ def remove_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
247
+ """ Remove Absolute (resample)
248
+ xlabel = "Max fraction of features removed"
249
+ ylabel = "1 - ROC AUC"
250
+ transform = "one_minus"
251
+ sort_order = 15
252
+ """
253
+ return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
254
+
255
+ def keep_positive_impute(X, y, model_generator, method_name, num_fcounts=11):
256
+ """ Keep Positive (impute)
257
+ xlabel = "Max fraction of features kept"
258
+ ylabel = "Mean model output"
259
+ transform = "identity"
260
+ sort_order = 16
261
+ """
262
+ return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
263
+
264
+ def keep_negative_impute(X, y, model_generator, method_name, num_fcounts=11):
265
+ """ Keep Negative (impute)
266
+ xlabel = "Max fraction of features kept"
267
+ ylabel = "Negative mean model output"
268
+ transform = "negate"
269
+ sort_order = 17
270
+ """
271
+ return __run_measure(measures.keep_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
272
+
273
+ def keep_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11):
274
+ """ Keep Absolute (impute)
275
+ xlabel = "Max fraction of features kept"
276
+ ylabel = "R^2"
277
+ transform = "identity"
278
+ sort_order = 18
279
+ """
280
+ return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
281
+
282
+ def keep_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
283
+ """ Keep Absolute (impute)
284
+ xlabel = "Max fraction of features kept"
285
+ ylabel = "ROC AUC"
286
+ transform = "identity"
287
+ sort_order = 19
288
+ """
289
+ return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
290
+
291
+ def remove_positive_impute(X, y, model_generator, method_name, num_fcounts=11):
292
+ """ Remove Positive (impute)
293
+ xlabel = "Max fraction of features removed"
294
+ ylabel = "Negative mean model output"
295
+ transform = "negate"
296
+ sort_order = 7
297
+ """
298
+ return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
299
+
300
+ def remove_negative_impute(X, y, model_generator, method_name, num_fcounts=11):
301
+ """ Remove Negative (impute)
302
+ xlabel = "Max fraction of features removed"
303
+ ylabel = "Mean model output"
304
+ transform = "identity"
305
+ sort_order = 8
306
+ """
307
+ return __run_measure(measures.remove_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
308
+
309
+ def remove_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11):
310
+ """ Remove Absolute (impute)
311
+ xlabel = "Max fraction of features removed"
312
+ ylabel = "1 - R^2"
313
+ transform = "one_minus"
314
+ sort_order = 9
315
+ """
316
+ return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
317
+
318
+ def remove_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
319
+ """ Remove Absolute (impute)
320
+ xlabel = "Max fraction of features removed"
321
+ ylabel = "1 - ROC AUC"
322
+ transform = "one_minus"
323
+ sort_order = 9
324
+ """
325
+ return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
326
+
327
+ def keep_positive_retrain(X, y, model_generator, method_name, num_fcounts=11):
328
+ """ Keep Positive (retrain)
329
+ xlabel = "Max fraction of features kept"
330
+ ylabel = "Mean model output"
331
+ transform = "identity"
332
+ sort_order = 6
333
+ """
334
+ return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
335
+
336
+ def keep_negative_retrain(X, y, model_generator, method_name, num_fcounts=11):
337
+ """ Keep Negative (retrain)
338
+ xlabel = "Max fraction of features kept"
339
+ ylabel = "Negative mean model output"
340
+ transform = "negate"
341
+ sort_order = 7
342
+ """
343
+ return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
344
+
345
+ def remove_positive_retrain(X, y, model_generator, method_name, num_fcounts=11):
346
+ """ Remove Positive (retrain)
347
+ xlabel = "Max fraction of features removed"
348
+ ylabel = "Negative mean model output"
349
+ transform = "negate"
350
+ sort_order = 11
351
+ """
352
+ return __run_measure(measures.remove_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
353
+
354
+ def remove_negative_retrain(X, y, model_generator, method_name, num_fcounts=11):
355
+ """ Remove Negative (retrain)
356
+ xlabel = "Max fraction of features removed"
357
+ ylabel = "Mean model output"
358
+ transform = "identity"
359
+ sort_order = 12
360
+ """
361
+ return __run_measure(measures.remove_retrain, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
362
+
363
+ def __run_measure(measure, X, y, model_generator, method_name, attribution_sign, num_fcounts, summary_function):
364
+
365
+ def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trained_model, random_state):
366
+ if attribution_sign == 0:
367
+ A = np.abs(__strip_list(attr_function(X_test)))
368
+ else:
369
+ A = attribution_sign * __strip_list(attr_function(X_test))
370
+ nmask = np.ones(len(y_test)) * fcount
371
+ nmask = np.minimum(nmask, np.array(A >= 0).sum(1)).astype(int)
372
+ return measure(
373
+ nmask, X_train, y_train, X_test, y_test, A,
374
+ model_generator, summary_function, trained_model, random_state
375
+ )
376
+ fcounts = __intlogspace(0, X.shape[1], num_fcounts)
377
+ return fcounts, __score_method(X, y, fcounts, model_generator, score_function, method_name)
378
+
379
+ def batch_remove_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11):
380
+ """ Batch Remove Absolute (retrain)
381
+ xlabel = "Fraction of features removed"
382
+ ylabel = "1 - R^2"
383
+ transform = "one_minus"
384
+ sort_order = 13
385
+ """
386
+ return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts)
387
+
388
+ def batch_keep_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11):
389
+ """ Batch Keep Absolute (retrain)
390
+ xlabel = "Fraction of features kept"
391
+ ylabel = "R^2"
392
+ transform = "identity"
393
+ sort_order = 13
394
+ """
395
+ return __run_batch_abs_metric(measures.batch_keep_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts)
396
+
397
+ def batch_remove_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
398
+ """ Batch Remove Absolute (retrain)
399
+ xlabel = "Fraction of features removed"
400
+ ylabel = "1 - ROC AUC"
401
+ transform = "one_minus"
402
+ sort_order = 13
403
+ """
404
+ return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.roc_auc_score, num_fcounts)
405
+
406
+ def batch_keep_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
407
+ """ Batch Keep Absolute (retrain)
408
+ xlabel = "Fraction of features kept"
409
+ ylabel = "ROC AUC"
410
+ transform = "identity"
411
+ sort_order = 13
412
+ """
413
+ return __run_batch_abs_metric(measures.batch_keep_retrain, X, y, model_generator, method_name, sklearn.metrics.roc_auc_score, num_fcounts)
414
+
415
+ def __run_batch_abs_metric(metric, X, y, model_generator, method_name, loss, num_fcounts):
416
+ def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trained_model):
417
+ A_train = np.abs(__strip_list(attr_function(X_train)))
418
+ nkeep_train = (np.ones(len(y_train)) * fcount).astype(int)
419
+ #nkeep_train = np.minimum(nkeep_train, np.array(A_train > 0).sum(1)).astype(int)
420
+ A_test = np.abs(__strip_list(attr_function(X_test)))
421
+ nkeep_test = (np.ones(len(y_test)) * fcount).astype(int)
422
+ #nkeep_test = np.minimum(nkeep_test, np.array(A_test >= 0).sum(1)).astype(int)
423
+ return metric(
424
+ nkeep_train, nkeep_test, X_train, y_train, X_test, y_test, A_train, A_test,
425
+ model_generator, loss
426
+ )
427
+ fcounts = __intlogspace(0, X.shape[1], num_fcounts)
428
+ return fcounts, __score_method(X, y, fcounts, model_generator, score_function, method_name)
429
+
430
+ _attribution_cache = {}
431
+ def __score_method(X, y, fcounts, model_generator, score_function, method_name, nreps=10, test_size=100, cache_dir="/tmp"):
432
+ """ Test an explanation method.
433
+ """
434
+
435
+ try:
436
+ pickle
437
+ except NameError:
438
+ raise ImportError("The 'dill' package could not be loaded and is needed for the benchmark!")
439
+
440
+ old_seed = np.random.seed()
441
+ np.random.seed(3293)
442
+
443
+ # average the method scores over several train/test splits
444
+ method_reps = []
445
+
446
+ data_hash = hashlib.sha256(__toarray(X).flatten()).hexdigest() + hashlib.sha256(__toarray(y)).hexdigest()
447
+ for i in range(nreps):
448
+ X_train, X_test, y_train, y_test = train_test_split(__toarray(X), y, test_size=test_size, random_state=i)
449
+
450
+ # define the model we are going to explain, caching so we onlu build it once
451
+ model_id = "model_cache__v" + "__".join([__version__, data_hash, model_generator.__name__])+".pickle"
452
+ cache_file = os.path.join(cache_dir, model_id + ".pickle")
453
+ if os.path.isfile(cache_file):
454
+ with open(cache_file, "rb") as f:
455
+ model = pickle.load(f)
456
+ else:
457
+ model = model_generator()
458
+ model.fit(X_train, y_train)
459
+ with open(cache_file, "wb") as f:
460
+ pickle.dump(model, f)
461
+
462
+ attr_key = "_".join([model_generator.__name__, method_name, str(test_size), str(nreps), str(i), data_hash])
463
+ def score(attr_function):
464
+ def cached_attr_function(X_inner):
465
+ if attr_key not in _attribution_cache:
466
+ _attribution_cache[attr_key] = attr_function(X_inner)
467
+ return _attribution_cache[attr_key]
468
+
469
+ #cached_attr_function = lambda X: __check_cache(attr_function, X)
470
+ if fcounts is None:
471
+ return score_function(X_train, X_test, y_train, y_test, cached_attr_function, model, i)
472
+ else:
473
+ scores = []
474
+ for f in fcounts:
475
+ scores.append(score_function(f, X_train, X_test, y_train, y_test, cached_attr_function, model, i))
476
+ return np.array(scores)
477
+
478
+ # evaluate the method (only building the attribution function if we need to)
479
+ if attr_key not in _attribution_cache:
480
+ method_reps.append(score(getattr(methods, method_name)(model, X_train)))
481
+ else:
482
+ method_reps.append(score(None))
483
+
484
+ np.random.seed(old_seed)
485
+ return np.array(method_reps).mean(0)
486
+
487
+
488
+ # used to memoize explainer functions so we don't waste time re-explaining the same object
489
+ __cache0 = None
490
+ __cache_X0 = None
491
+ __cache_f0 = None
492
+ __cache1 = None
493
+ __cache_X1 = None
494
+ __cache_f1 = None
495
+ def __check_cache(f, X):
496
+ global __cache0, __cache_X0, __cache_f0
497
+ global __cache1, __cache_X1, __cache_f1
498
+ if X is __cache_X0 and f is __cache_f0:
499
+ return __cache0
500
+ elif X is __cache_X1 and f is __cache_f1:
501
+ return __cache1
502
+ else:
503
+ __cache_f1 = __cache_f0
504
+ __cache_X1 = __cache_X0
505
+ __cache1 = __cache0
506
+ __cache_f0 = f
507
+ __cache_X0 = X
508
+ __cache0 = f(X)
509
+ return __cache0
510
+
511
+ def __intlogspace(start, end, count):
512
+ return np.unique(np.round(start + (end-start) * (np.logspace(0, 1, count, endpoint=True) - 1) / 9).astype(int))
513
+
514
+ def __toarray(X):
515
+ """ Converts DataFrames to numpy arrays.
516
+ """
517
+ if hasattr(X, "values"):
518
+ X = X.values
519
+ return X
520
+
521
+ def __strip_list(attrs):
522
+ """ This assumes that if you have a list of outputs you just want the second one (the second class).
523
+ """
524
+ if isinstance(attrs, list):
525
+ return attrs[1]
526
+ else:
527
+ return attrs
528
+
529
+ def _fit_human(model_generator, val00, val01, val11):
530
+ # force the model to fit a function with almost entirely zero background
531
+ N = 1000000
532
+ M = 3
533
+ X = np.zeros((N,M))
534
+ X.shape
535
+ y = np.ones(N) * val00
536
+ X[0:1000, 0] = 1
537
+ y[0:1000] = val01
538
+ for i in range(0,1000000,1000):
539
+ X[i, 1] = 1
540
+ y[i] = val01
541
+ y[0] = val11
542
+ model = model_generator()
543
+ model.fit(X, y)
544
+ return model
545
+
546
+ def _human_and(X, model_generator, method_name, fever, cough):
547
+ assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
548
+
549
+ # these are from the sickness_score mturk user study experiment
550
+ X_test = np.zeros((100,3))
551
+ if not fever and not cough:
552
+ human_consensus = np.array([0., 0., 0.])
553
+ X_test[0,:] = np.array([[0., 0., 1.]])
554
+ elif not fever and cough:
555
+ human_consensus = np.array([0., 2., 0.])
556
+ X_test[0,:] = np.array([[0., 1., 1.]])
557
+ elif fever and cough:
558
+ human_consensus = np.array([5., 5., 0.])
559
+ X_test[0,:] = np.array([[1., 1., 1.]])
560
+
561
+ # force the model to fit an XOR function with almost entirely zero background
562
+ model = _fit_human(model_generator, 0, 2, 10)
563
+
564
+ attr_function = getattr(methods, method_name)(model, X)
565
+ methods_attrs = attr_function(X_test)
566
+ return "human", (human_consensus, methods_attrs[0,:])
567
+
568
+ def human_and_00(X, y, model_generator, method_name):
569
+ """ AND (false/false)
570
+
571
+ This tests how well a feature attribution method agrees with human intuition
572
+ for an AND operation combined with linear effects. This metric deals
573
+ specifically with the question of credit allocation for the following function
574
+ when all three inputs are true:
575
+ if fever: +2 points
576
+ if cough: +2 points
577
+ if fever and cough: +6 points
578
+
579
+ transform = "identity"
580
+ sort_order = 0
581
+ """
582
+ return _human_and(X, model_generator, method_name, False, False)
583
+
584
+ def human_and_01(X, y, model_generator, method_name):
585
+ """ AND (false/true)
586
+
587
+ This tests how well a feature attribution method agrees with human intuition
588
+ for an AND operation combined with linear effects. This metric deals
589
+ specifically with the question of credit allocation for the following function
590
+ when all three inputs are true:
591
+ if fever: +2 points
592
+ if cough: +2 points
593
+ if fever and cough: +6 points
594
+
595
+ transform = "identity"
596
+ sort_order = 1
597
+ """
598
+ return _human_and(X, model_generator, method_name, False, True)
599
+
600
+ def human_and_11(X, y, model_generator, method_name):
601
+ """ AND (true/true)
602
+
603
+ This tests how well a feature attribution method agrees with human intuition
604
+ for an AND operation combined with linear effects. This metric deals
605
+ specifically with the question of credit allocation for the following function
606
+ when all three inputs are true:
607
+ if fever: +2 points
608
+ if cough: +2 points
609
+ if fever and cough: +6 points
610
+
611
+ transform = "identity"
612
+ sort_order = 2
613
+ """
614
+ return _human_and(X, model_generator, method_name, True, True)
615
+
616
+
617
+ def _human_or(X, model_generator, method_name, fever, cough):
618
+ assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
619
+
620
+ # these are from the sickness_score mturk user study experiment
621
+ X_test = np.zeros((100,3))
622
+ if not fever and not cough:
623
+ human_consensus = np.array([0., 0., 0.])
624
+ X_test[0,:] = np.array([[0., 0., 1.]])
625
+ elif not fever and cough:
626
+ human_consensus = np.array([0., 8., 0.])
627
+ X_test[0,:] = np.array([[0., 1., 1.]])
628
+ elif fever and cough:
629
+ human_consensus = np.array([5., 5., 0.])
630
+ X_test[0,:] = np.array([[1., 1., 1.]])
631
+
632
+ # force the model to fit an XOR function with almost entirely zero background
633
+ model = _fit_human(model_generator, 0, 8, 10)
634
+
635
+ attr_function = getattr(methods, method_name)(model, X)
636
+ methods_attrs = attr_function(X_test)
637
+ return "human", (human_consensus, methods_attrs[0,:])
638
+
639
+ def human_or_00(X, y, model_generator, method_name):
640
+ """ OR (false/false)
641
+
642
+ This tests how well a feature attribution method agrees with human intuition
643
+ for an OR operation combined with linear effects. This metric deals
644
+ specifically with the question of credit allocation for the following function
645
+ when all three inputs are true:
646
+ if fever: +2 points
647
+ if cough: +2 points
648
+ if fever or cough: +6 points
649
+
650
+ transform = "identity"
651
+ sort_order = 0
652
+ """
653
+ return _human_or(X, model_generator, method_name, False, False)
654
+
655
+ def human_or_01(X, y, model_generator, method_name):
656
+ """ OR (false/true)
657
+
658
+ This tests how well a feature attribution method agrees with human intuition
659
+ for an OR operation combined with linear effects. This metric deals
660
+ specifically with the question of credit allocation for the following function
661
+ when all three inputs are true:
662
+ if fever: +2 points
663
+ if cough: +2 points
664
+ if fever or cough: +6 points
665
+
666
+ transform = "identity"
667
+ sort_order = 1
668
+ """
669
+ return _human_or(X, model_generator, method_name, False, True)
670
+
671
+ def human_or_11(X, y, model_generator, method_name):
672
+ """ OR (true/true)
673
+
674
+ This tests how well a feature attribution method agrees with human intuition
675
+ for an OR operation combined with linear effects. This metric deals
676
+ specifically with the question of credit allocation for the following function
677
+ when all three inputs are true:
678
+ if fever: +2 points
679
+ if cough: +2 points
680
+ if fever or cough: +6 points
681
+
682
+ transform = "identity"
683
+ sort_order = 2
684
+ """
685
+ return _human_or(X, model_generator, method_name, True, True)
686
+
687
+
688
+ def _human_xor(X, model_generator, method_name, fever, cough):
689
+ assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
690
+
691
+ # these are from the sickness_score mturk user study experiment
692
+ X_test = np.zeros((100,3))
693
+ if not fever and not cough:
694
+ human_consensus = np.array([0., 0., 0.])
695
+ X_test[0,:] = np.array([[0., 0., 1.]])
696
+ elif not fever and cough:
697
+ human_consensus = np.array([0., 8., 0.])
698
+ X_test[0,:] = np.array([[0., 1., 1.]])
699
+ elif fever and cough:
700
+ human_consensus = np.array([2., 2., 0.])
701
+ X_test[0,:] = np.array([[1., 1., 1.]])
702
+
703
+ # force the model to fit an XOR function with almost entirely zero background
704
+ model = _fit_human(model_generator, 0, 8, 4)
705
+
706
+ attr_function = getattr(methods, method_name)(model, X)
707
+ methods_attrs = attr_function(X_test)
708
+ return "human", (human_consensus, methods_attrs[0,:])
709
+
710
+ def human_xor_00(X, y, model_generator, method_name):
711
+ """ XOR (false/false)
712
+
713
+ This tests how well a feature attribution method agrees with human intuition
714
+ for an eXclusive OR operation combined with linear effects. This metric deals
715
+ specifically with the question of credit allocation for the following function
716
+ when all three inputs are true:
717
+ if fever: +2 points
718
+ if cough: +2 points
719
+ if fever or cough but not both: +6 points
720
+
721
+ transform = "identity"
722
+ sort_order = 3
723
+ """
724
+ return _human_xor(X, model_generator, method_name, False, False)
725
+
726
+ def human_xor_01(X, y, model_generator, method_name):
727
+ """ XOR (false/true)
728
+
729
+ This tests how well a feature attribution method agrees with human intuition
730
+ for an eXclusive OR operation combined with linear effects. This metric deals
731
+ specifically with the question of credit allocation for the following function
732
+ when all three inputs are true:
733
+ if fever: +2 points
734
+ if cough: +2 points
735
+ if fever or cough but not both: +6 points
736
+
737
+ transform = "identity"
738
+ sort_order = 4
739
+ """
740
+ return _human_xor(X, model_generator, method_name, False, True)
741
+
742
+ def human_xor_11(X, y, model_generator, method_name):
743
+ """ XOR (true/true)
744
+
745
+ This tests how well a feature attribution method agrees with human intuition
746
+ for an eXclusive OR operation combined with linear effects. This metric deals
747
+ specifically with the question of credit allocation for the following function
748
+ when all three inputs are true:
749
+ if fever: +2 points
750
+ if cough: +2 points
751
+ if fever or cough but not both: +6 points
752
+
753
+ transform = "identity"
754
+ sort_order = 5
755
+ """
756
+ return _human_xor(X, model_generator, method_name, True, True)
757
+
758
+
759
+ def _human_sum(X, model_generator, method_name, fever, cough):
760
+ assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
761
+
762
+ # these are from the sickness_score mturk user study experiment
763
+ X_test = np.zeros((100,3))
764
+ if not fever and not cough:
765
+ human_consensus = np.array([0., 0., 0.])
766
+ X_test[0,:] = np.array([[0., 0., 1.]])
767
+ elif not fever and cough:
768
+ human_consensus = np.array([0., 2., 0.])
769
+ X_test[0,:] = np.array([[0., 1., 1.]])
770
+ elif fever and cough:
771
+ human_consensus = np.array([2., 2., 0.])
772
+ X_test[0,:] = np.array([[1., 1., 1.]])
773
+
774
+ # force the model to fit an XOR function with almost entirely zero background
775
+ model = _fit_human(model_generator, 0, 2, 4)
776
+
777
+ attr_function = getattr(methods, method_name)(model, X)
778
+ methods_attrs = attr_function(X_test)
779
+ return "human", (human_consensus, methods_attrs[0,:])
780
+
781
+ def human_sum_00(X, y, model_generator, method_name):
782
+ """ SUM (false/false)
783
+
784
+ This tests how well a feature attribution method agrees with human intuition
785
+ for a SUM operation. This metric deals
786
+ specifically with the question of credit allocation for the following function
787
+ when all three inputs are true:
788
+ if fever: +2 points
789
+ if cough: +2 points
790
+
791
+ transform = "identity"
792
+ sort_order = 0
793
+ """
794
+ return _human_sum(X, model_generator, method_name, False, False)
795
+
796
+ def human_sum_01(X, y, model_generator, method_name):
797
+ """ SUM (false/true)
798
+
799
+ This tests how well a feature attribution method agrees with human intuition
800
+ for a SUM operation. This metric deals
801
+ specifically with the question of credit allocation for the following function
802
+ when all three inputs are true:
803
+ if fever: +2 points
804
+ if cough: +2 points
805
+
806
+ transform = "identity"
807
+ sort_order = 1
808
+ """
809
+ return _human_sum(X, model_generator, method_name, False, True)
810
+
811
+ def human_sum_11(X, y, model_generator, method_name):
812
+ """ SUM (true/true)
813
+
814
+ This tests how well a feature attribution method agrees with human intuition
815
+ for a SUM operation. This metric deals
816
+ specifically with the question of credit allocation for the following function
817
+ when all three inputs are true:
818
+ if fever: +2 points
819
+ if cough: +2 points
820
+
821
+ transform = "identity"
822
+ sort_order = 2
823
+ """
824
+ return _human_sum(X, model_generator, method_name, True, True)
lib/shap/benchmark/models.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import sklearn
3
+ import sklearn.ensemble
4
+ from sklearn.preprocessing import StandardScaler
5
+
6
+
7
+ class KerasWrap:
8
+ """ A wrapper that allows us to set parameters in the constructor and do a reset before fitting.
9
+ """
10
+ def __init__(self, model, epochs, flatten_output=False):
11
+ self.model = model
12
+ self.epochs = epochs
13
+ self.flatten_output = flatten_output
14
+ self.init_weights = None
15
+ self.scaler = StandardScaler()
16
+
17
+ def fit(self, X, y, verbose=0):
18
+ if self.init_weights is None:
19
+ self.init_weights = self.model.get_weights()
20
+ else:
21
+ self.model.set_weights(self.init_weights)
22
+ self.scaler.fit(X)
23
+ return self.model.fit(X, y, epochs=self.epochs, verbose=verbose)
24
+
25
+ def predict(self, X):
26
+ X = self.scaler.transform(X)
27
+ if self.flatten_output:
28
+ return self.model.predict(X).flatten()
29
+ else:
30
+ return self.model.predict(X)
31
+
32
+
33
+ # This models are all tuned for the corrgroups60 dataset
34
+
35
+ def corrgroups60__lasso():
36
+ """ Lasso Regression
37
+ """
38
+ return sklearn.linear_model.Lasso(alpha=0.1)
39
+
40
+ def corrgroups60__ridge():
41
+ """ Ridge Regression
42
+ """
43
+ return sklearn.linear_model.Ridge(alpha=1.0)
44
+
45
+ def corrgroups60__decision_tree():
46
+ """ Decision Tree
47
+ """
48
+
49
+ # max_depth was chosen to minimise test error
50
+ return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=6)
51
+
52
+ def corrgroups60__random_forest():
53
+ """ Random Forest
54
+ """
55
+ return sklearn.ensemble.RandomForestRegressor(100, random_state=0)
56
+
57
+ def corrgroups60__gbm():
58
+ """ Gradient Boosted Trees
59
+ """
60
+ import xgboost
61
+
62
+ # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split
63
+ return xgboost.XGBRegressor(max_depth=6, n_estimators=50, learning_rate=0.1, n_jobs=8, random_state=0)
64
+
65
+ def corrgroups60__ffnn():
66
+ """ 4-Layer Neural Network
67
+ """
68
+ from tensorflow.keras.layers import Dense
69
+ from tensorflow.keras.models import Sequential
70
+
71
+ model = Sequential()
72
+ model.add(Dense(32, activation='relu', input_dim=60))
73
+ model.add(Dense(20, activation='relu'))
74
+ model.add(Dense(20, activation='relu'))
75
+ model.add(Dense(1))
76
+
77
+ model.compile(optimizer='adam',
78
+ loss='mean_squared_error',
79
+ metrics=['mean_squared_error'])
80
+
81
+ return KerasWrap(model, 30, flatten_output=True)
82
+
83
+
84
+ def independentlinear60__lasso():
85
+ """ Lasso Regression
86
+ """
87
+ return sklearn.linear_model.Lasso(alpha=0.1)
88
+
89
+ def independentlinear60__ridge():
90
+ """ Ridge Regression
91
+ """
92
+ return sklearn.linear_model.Ridge(alpha=1.0)
93
+
94
+ def independentlinear60__decision_tree():
95
+ """ Decision Tree
96
+ """
97
+
98
+ # max_depth was chosen to minimise test error
99
+ return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=4)
100
+
101
+ def independentlinear60__random_forest():
102
+ """ Random Forest
103
+ """
104
+ return sklearn.ensemble.RandomForestRegressor(100, random_state=0)
105
+
106
+ def independentlinear60__gbm():
107
+ """ Gradient Boosted Trees
108
+ """
109
+ import xgboost
110
+
111
+ # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split
112
+ return xgboost.XGBRegressor(max_depth=6, n_estimators=100, learning_rate=0.1, n_jobs=8, random_state=0)
113
+
114
+ def independentlinear60__ffnn():
115
+ """ 4-Layer Neural Network
116
+ """
117
+ from tensorflow.keras.layers import Dense
118
+ from tensorflow.keras.models import Sequential
119
+
120
+ model = Sequential()
121
+ model.add(Dense(32, activation='relu', input_dim=60))
122
+ model.add(Dense(20, activation='relu'))
123
+ model.add(Dense(20, activation='relu'))
124
+ model.add(Dense(1))
125
+
126
+ model.compile(optimizer='adam',
127
+ loss='mean_squared_error',
128
+ metrics=['mean_squared_error'])
129
+
130
+ return KerasWrap(model, 30, flatten_output=True)
131
+
132
+
133
+ def cric__lasso():
134
+ """ Lasso Regression
135
+ """
136
+ model = sklearn.linear_model.LogisticRegression(penalty="l1", C=0.002)
137
+
138
+ # we want to explain the raw probability outputs of the trees
139
+ model.predict = lambda X: model.predict_proba(X)[:,1]
140
+
141
+ return model
142
+
143
+ def cric__ridge():
144
+ """ Ridge Regression
145
+ """
146
+ model = sklearn.linear_model.LogisticRegression(penalty="l2")
147
+
148
+ # we want to explain the raw probability outputs of the trees
149
+ model.predict = lambda X: model.predict_proba(X)[:,1]
150
+
151
+ return model
152
+
153
+ def cric__decision_tree():
154
+ """ Decision Tree
155
+ """
156
+ model = sklearn.tree.DecisionTreeClassifier(random_state=0, max_depth=4)
157
+
158
+ # we want to explain the raw probability outputs of the trees
159
+ model.predict = lambda X: model.predict_proba(X)[:,1]
160
+
161
+ return model
162
+
163
+ def cric__random_forest():
164
+ """ Random Forest
165
+ """
166
+ model = sklearn.ensemble.RandomForestClassifier(100, random_state=0)
167
+
168
+ # we want to explain the raw probability outputs of the trees
169
+ model.predict = lambda X: model.predict_proba(X)[:,1]
170
+
171
+ return model
172
+
173
+ def cric__gbm():
174
+ """ Gradient Boosted Trees
175
+ """
176
+ import xgboost
177
+
178
+ # max_depth and subsample match the params used for the full cric data in the paper
179
+ # learning_rate was set a bit higher to allow for faster runtimes
180
+ # n_estimators was chosen based on a train/test split of the data
181
+ model = xgboost.XGBClassifier(max_depth=5, n_estimators=400, learning_rate=0.01, subsample=0.2, n_jobs=8, random_state=0)
182
+
183
+ # we want to explain the margin, not the transformed probability outputs
184
+ model.__orig_predict = model.predict
185
+ model.predict = lambda X: model.__orig_predict(X, output_margin=True)
186
+
187
+ return model
188
+
189
+ def cric__ffnn():
190
+ """ 4-Layer Neural Network
191
+ """
192
+ from tensorflow.keras.layers import Dense, Dropout
193
+ from tensorflow.keras.models import Sequential
194
+
195
+ model = Sequential()
196
+ model.add(Dense(10, activation='relu', input_dim=336))
197
+ model.add(Dropout(0.5))
198
+ model.add(Dense(10, activation='relu'))
199
+ model.add(Dropout(0.5))
200
+ model.add(Dense(1, activation='sigmoid'))
201
+
202
+ model.compile(optimizer='adam',
203
+ loss='binary_crossentropy',
204
+ metrics=['accuracy'])
205
+
206
+ return KerasWrap(model, 30, flatten_output=True)
207
+
208
+
209
+ def human__decision_tree():
210
+ """ Decision Tree
211
+ """
212
+
213
+ # build data
214
+ N = 1000000
215
+ M = 3
216
+ X = np.zeros((N,M))
217
+ X.shape
218
+ y = np.zeros(N)
219
+ X[0, 0] = 1
220
+ y[0] = 8
221
+ X[1, 1] = 1
222
+ y[1] = 8
223
+ X[2, 0:2] = 1
224
+ y[2] = 4
225
+
226
+ # fit model
227
+ xor_model = sklearn.tree.DecisionTreeRegressor(max_depth=2)
228
+ xor_model.fit(X, y)
229
+
230
+ return xor_model
lib/shap/benchmark/plots.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import os
4
+
5
+ import numpy as np
6
+ import sklearn
7
+ from matplotlib.colors import LinearSegmentedColormap
8
+
9
+ from .. import __version__
10
+ from ..plots import colors
11
+ from . import methods, metrics, models
12
+ from .experiments import run_experiments
13
+
14
+ try:
15
+ import matplotlib
16
+ import matplotlib.pyplot as pl
17
+ from IPython.display import HTML
18
+ except ImportError:
19
+ pass
20
+
21
+
22
+ metadata = {
23
+ # "runtime": {
24
+ # "title": "Runtime",
25
+ # "sort_order": 1
26
+ # },
27
+ # "local_accuracy": {
28
+ # "title": "Local Accuracy",
29
+ # "sort_order": 2
30
+ # },
31
+ # "consistency_guarantees": {
32
+ # "title": "Consistency Guarantees",
33
+ # "sort_order": 3
34
+ # },
35
+ # "keep_positive_mask": {
36
+ # "title": "Keep Positive (mask)",
37
+ # "xlabel": "Max fraction of features kept",
38
+ # "ylabel": "Mean model output",
39
+ # "sort_order": 4
40
+ # },
41
+ # "keep_negative_mask": {
42
+ # "title": "Keep Negative (mask)",
43
+ # "xlabel": "Max fraction of features kept",
44
+ # "ylabel": "Negative mean model output",
45
+ # "sort_order": 5
46
+ # },
47
+ # "keep_absolute_mask__r2": {
48
+ # "title": "Keep Absolute (mask)",
49
+ # "xlabel": "Max fraction of features kept",
50
+ # "ylabel": "R^2",
51
+ # "sort_order": 6
52
+ # },
53
+ # "keep_absolute_mask__roc_auc": {
54
+ # "title": "Keep Absolute (mask)",
55
+ # "xlabel": "Max fraction of features kept",
56
+ # "ylabel": "ROC AUC",
57
+ # "sort_order": 6
58
+ # },
59
+ # "remove_positive_mask": {
60
+ # "title": "Remove Positive (mask)",
61
+ # "xlabel": "Max fraction of features removed",
62
+ # "ylabel": "Negative mean model output",
63
+ # "sort_order": 7
64
+ # },
65
+ # "remove_negative_mask": {
66
+ # "title": "Remove Negative (mask)",
67
+ # "xlabel": "Max fraction of features removed",
68
+ # "ylabel": "Mean model output",
69
+ # "sort_order": 8
70
+ # },
71
+ # "remove_absolute_mask__r2": {
72
+ # "title": "Remove Absolute (mask)",
73
+ # "xlabel": "Max fraction of features removed",
74
+ # "ylabel": "1 - R^2",
75
+ # "sort_order": 9
76
+ # },
77
+ # "remove_absolute_mask__roc_auc": {
78
+ # "title": "Remove Absolute (mask)",
79
+ # "xlabel": "Max fraction of features removed",
80
+ # "ylabel": "1 - ROC AUC",
81
+ # "sort_order": 9
82
+ # },
83
+ # "keep_positive_resample": {
84
+ # "title": "Keep Positive (resample)",
85
+ # "xlabel": "Max fraction of features kept",
86
+ # "ylabel": "Mean model output",
87
+ # "sort_order": 10
88
+ # },
89
+ # "keep_negative_resample": {
90
+ # "title": "Keep Negative (resample)",
91
+ # "xlabel": "Max fraction of features kept",
92
+ # "ylabel": "Negative mean model output",
93
+ # "sort_order": 11
94
+ # },
95
+ # "keep_absolute_resample__r2": {
96
+ # "title": "Keep Absolute (resample)",
97
+ # "xlabel": "Max fraction of features kept",
98
+ # "ylabel": "R^2",
99
+ # "sort_order": 12
100
+ # },
101
+ # "keep_absolute_resample__roc_auc": {
102
+ # "title": "Keep Absolute (resample)",
103
+ # "xlabel": "Max fraction of features kept",
104
+ # "ylabel": "ROC AUC",
105
+ # "sort_order": 12
106
+ # },
107
+ # "remove_positive_resample": {
108
+ # "title": "Remove Positive (resample)",
109
+ # "xlabel": "Max fraction of features removed",
110
+ # "ylabel": "Negative mean model output",
111
+ # "sort_order": 13
112
+ # },
113
+ # "remove_negative_resample": {
114
+ # "title": "Remove Negative (resample)",
115
+ # "xlabel": "Max fraction of features removed",
116
+ # "ylabel": "Mean model output",
117
+ # "sort_order": 14
118
+ # },
119
+ # "remove_absolute_resample__r2": {
120
+ # "title": "Remove Absolute (resample)",
121
+ # "xlabel": "Max fraction of features removed",
122
+ # "ylabel": "1 - R^2",
123
+ # "sort_order": 15
124
+ # },
125
+ # "remove_absolute_resample__roc_auc": {
126
+ # "title": "Remove Absolute (resample)",
127
+ # "xlabel": "Max fraction of features removed",
128
+ # "ylabel": "1 - ROC AUC",
129
+ # "sort_order": 15
130
+ # },
131
+ # "remove_positive_retrain": {
132
+ # "title": "Remove Positive (retrain)",
133
+ # "xlabel": "Max fraction of features removed",
134
+ # "ylabel": "Negative mean model output",
135
+ # "sort_order": 11
136
+ # },
137
+ # "remove_negative_retrain": {
138
+ # "title": "Remove Negative (retrain)",
139
+ # "xlabel": "Max fraction of features removed",
140
+ # "ylabel": "Mean model output",
141
+ # "sort_order": 12
142
+ # },
143
+ # "keep_positive_retrain": {
144
+ # "title": "Keep Positive (retrain)",
145
+ # "xlabel": "Max fraction of features kept",
146
+ # "ylabel": "Mean model output",
147
+ # "sort_order": 6
148
+ # },
149
+ # "keep_negative_retrain": {
150
+ # "title": "Keep Negative (retrain)",
151
+ # "xlabel": "Max fraction of features kept",
152
+ # "ylabel": "Negative mean model output",
153
+ # "sort_order": 7
154
+ # },
155
+ # "batch_remove_absolute__r2": {
156
+ # "title": "Batch Remove Absolute",
157
+ # "xlabel": "Fraction of features removed",
158
+ # "ylabel": "1 - R^2",
159
+ # "sort_order": 13
160
+ # },
161
+ # "batch_keep_absolute__r2": {
162
+ # "title": "Batch Keep Absolute",
163
+ # "xlabel": "Fraction of features kept",
164
+ # "ylabel": "R^2",
165
+ # "sort_order": 8
166
+ # },
167
+ # "batch_remove_absolute__roc_auc": {
168
+ # "title": "Batch Remove Absolute",
169
+ # "xlabel": "Fraction of features removed",
170
+ # "ylabel": "1 - ROC AUC",
171
+ # "sort_order": 13
172
+ # },
173
+ # "batch_keep_absolute__roc_auc": {
174
+ # "title": "Batch Keep Absolute",
175
+ # "xlabel": "Fraction of features kept",
176
+ # "ylabel": "ROC AUC",
177
+ # "sort_order": 8
178
+ # },
179
+
180
+ # "linear_shap_corr": {
181
+ # "title": "Linear SHAP (corr)"
182
+ # },
183
+ # "linear_shap_ind": {
184
+ # "title": "Linear SHAP (ind)"
185
+ # },
186
+ # "coef": {
187
+ # "title": "Coefficients"
188
+ # },
189
+ # "random": {
190
+ # "title": "Random"
191
+ # },
192
+ # "kernel_shap_1000_meanref": {
193
+ # "title": "Kernel SHAP 1000 mean ref."
194
+ # },
195
+ # "sampling_shap_1000": {
196
+ # "title": "Sampling SHAP 1000"
197
+ # },
198
+ # "tree_shap_tree_path_dependent": {
199
+ # "title": "Tree SHAP"
200
+ # },
201
+ # "saabas": {
202
+ # "title": "Saabas"
203
+ # },
204
+ # "tree_gain": {
205
+ # "title": "Gain/Gini Importance"
206
+ # },
207
+ # "mean_abs_tree_shap": {
208
+ # "title": "mean(|Tree SHAP|)"
209
+ # },
210
+ # "lasso_regression": {
211
+ # "title": "Lasso Regression"
212
+ # },
213
+ # "ridge_regression": {
214
+ # "title": "Ridge Regression"
215
+ # },
216
+ # "gbm_regression": {
217
+ # "title": "Gradient Boosting Regression"
218
+ # }
219
+ }
220
+
221
+ benchmark_color_map = {
222
+ "tree_shap": "#1E88E5",
223
+ "deep_shap": "#1E88E5",
224
+ "linear_shap_corr": "#1E88E5",
225
+ "linear_shap_ind": "#ff0d57",
226
+ "coef": "#13B755",
227
+ "random": "#999999",
228
+ "const_random": "#666666",
229
+ "kernel_shap_1000_meanref": "#7C52FF"
230
+ }
231
+
232
+ # negated_metrics = [
233
+ # "runtime",
234
+ # "remove_positive_retrain",
235
+ # "remove_positive_mask",
236
+ # "remove_positive_resample",
237
+ # "keep_negative_retrain",
238
+ # "keep_negative_mask",
239
+ # "keep_negative_resample"
240
+ # ]
241
+
242
+ # one_minus_metrics = [
243
+ # "remove_absolute_mask__r2",
244
+ # "remove_absolute_mask__roc_auc",
245
+ # "remove_absolute_resample__r2",
246
+ # "remove_absolute_resample__roc_auc"
247
+ # ]
248
+
249
+ def get_method_color(method):
250
+ for line in getattr(methods, method).__doc__.split("\n"):
251
+ line = line.strip()
252
+ if line.startswith("color = "):
253
+ v = line.split("=")[1].strip()
254
+ if v.startswith("red_blue_circle("):
255
+ return colors.red_blue_circle(float(v[16:-1]))
256
+ else:
257
+ return v
258
+ return "#000000"
259
+
260
+ def get_method_linestyle(method):
261
+ for line in getattr(methods, method).__doc__.split("\n"):
262
+ line = line.strip()
263
+ if line.startswith("linestyle = "):
264
+ return line.split("=")[1].strip()
265
+ return "solid"
266
+
267
+ def get_metric_attr(metric, attr):
268
+ for line in getattr(metrics, metric).__doc__.split("\n"):
269
+ line = line.strip()
270
+
271
+ # string
272
+ prefix = attr+" = \""
273
+ suffix = "\""
274
+ if line.startswith(prefix) and line.endswith(suffix):
275
+ return line[len(prefix):-len(suffix)]
276
+
277
+ # number
278
+ prefix = attr+" = "
279
+ if line.startswith(prefix):
280
+ return float(line[len(prefix):])
281
+ return ""
282
+
283
+ def plot_curve(dataset, model, metric, cmap=benchmark_color_map):
284
+ experiments = run_experiments(dataset=dataset, model=model, metric=metric)
285
+ pl.figure()
286
+ method_arr = []
287
+ for (name,(fcounts,scores)) in experiments:
288
+ _,_,method,_ = name
289
+ transform = get_metric_attr(metric, "transform")
290
+ if transform == "negate":
291
+ scores = -scores
292
+ elif transform == "one_minus":
293
+ scores = 1 - scores
294
+ auc = sklearn.metrics.auc(fcounts, scores) / fcounts[-1]
295
+ method_arr.append((auc, method, scores))
296
+ for (auc,method,scores) in sorted(method_arr):
297
+ method_title = getattr(methods, method).__doc__.split("\n")[0].strip()
298
+ label = f"{auc:6.3f} - " + method_title
299
+ pl.plot(
300
+ fcounts / fcounts[-1], scores, label=label,
301
+ color=get_method_color(method), linewidth=2,
302
+ linestyle=get_method_linestyle(method)
303
+ )
304
+ metric_title = getattr(metrics, metric).__doc__.split("\n")[0].strip()
305
+ pl.xlabel(get_metric_attr(metric, "xlabel"))
306
+ pl.ylabel(get_metric_attr(metric, "ylabel"))
307
+ model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip()
308
+ pl.title(metric_title + " - " + model_title)
309
+ pl.gca().xaxis.set_ticks_position('bottom')
310
+ pl.gca().yaxis.set_ticks_position('left')
311
+ pl.gca().spines['right'].set_visible(False)
312
+ pl.gca().spines['top'].set_visible(False)
313
+ ahandles, alabels = pl.gca().get_legend_handles_labels()
314
+ pl.legend(reversed(ahandles), reversed(alabels))
315
+ return pl.gcf()
316
+
317
+ def plot_human(dataset, model, metric, cmap=benchmark_color_map):
318
+ experiments = run_experiments(dataset=dataset, model=model, metric=metric)
319
+ pl.figure()
320
+ method_arr = []
321
+ for (name,(fcounts,scores)) in experiments:
322
+ _,_,method,_ = name
323
+ diff_sum = np.sum(np.abs(scores[1] - scores[0]))
324
+ method_arr.append((diff_sum, method, scores[0], scores[1]))
325
+
326
+ inds = np.arange(3) # the x locations for the groups
327
+ inc_width = (1.0 / len(method_arr)) * 0.8
328
+ width = inc_width * 0.9
329
+ pl.bar(inds, method_arr[0][2], width, label="Human Consensus", color="black", edgecolor="white")
330
+ i = 1
331
+ line_style_to_hatch = {
332
+ "dashed": "///",
333
+ "dotted": "..."
334
+ }
335
+ for (diff_sum, method, _, methods_attrs) in sorted(method_arr):
336
+ method_title = getattr(methods, method).__doc__.split("\n")[0].strip()
337
+ label = f"{diff_sum:.2f} - " + method_title
338
+ pl.bar(
339
+ inds + inc_width * i, methods_attrs.flatten(), width, label=label, edgecolor="white",
340
+ color=get_method_color(method), hatch=line_style_to_hatch.get(get_method_linestyle(method), None)
341
+ )
342
+ i += 1
343
+ metric_title = getattr(metrics, metric).__doc__.split("\n")[0].strip()
344
+ pl.xlabel("Features in the model")
345
+ pl.ylabel("Feature attribution value")
346
+ model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip()
347
+ pl.title(metric_title + " - " + model_title)
348
+ pl.gca().xaxis.set_ticks_position('bottom')
349
+ pl.gca().yaxis.set_ticks_position('left')
350
+ pl.gca().spines['right'].set_visible(False)
351
+ pl.gca().spines['top'].set_visible(False)
352
+ ahandles, alabels = pl.gca().get_legend_handles_labels()
353
+ #pl.legend(ahandles, alabels)
354
+ pl.xticks(np.array([0, 1, 2, 3]) - (inc_width + width)/2, ["", "", "", ""])
355
+
356
+ pl.gca().xaxis.set_minor_locator(matplotlib.ticker.FixedLocator([0.4, 1.4, 2.4]))
357
+ pl.gca().xaxis.set_minor_formatter(matplotlib.ticker.FixedFormatter(["Fever", "Cough", "Headache"]))
358
+ pl.gca().tick_params(which='minor', length=0)
359
+
360
+ pl.axhline(0, color="#aaaaaa", linewidth=0.5)
361
+
362
+ box = pl.gca().get_position()
363
+ pl.gca().set_position([
364
+ box.x0, box.y0 + box.height * 0.3,
365
+ box.width, box.height * 0.7
366
+ ])
367
+
368
+ # Put a legend below current axis
369
+ pl.gca().legend(ahandles, alabels, loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=2)
370
+
371
+ return pl.gcf()
372
+
373
+ def _human_score_map(human_consensus, methods_attrs):
374
+ """ Converts human agreement differences to numerical scores for coloring.
375
+ """
376
+
377
+ v = 1 - min(np.sum(np.abs(methods_attrs - human_consensus)) / (np.abs(human_consensus).sum() + 1), 1.0)
378
+ return v
379
+
380
+ def make_grid(scores, dataset, model, normalize=True, transform=True):
381
+ color_vals = {}
382
+ metric_sort_order = {}
383
+ for (_,_,method,metric),(fcounts,score) in filter(lambda x: x[0][0] == dataset and x[0][1] == model, scores):
384
+ metric_sort_order[metric] = get_metric_attr(metric, "sort_order")
385
+ if metric not in color_vals:
386
+ color_vals[metric] = {}
387
+
388
+ if transform:
389
+ transform_type = get_metric_attr(metric, "transform")
390
+ if transform_type == "negate":
391
+ score = -score
392
+ elif transform_type == "one_minus":
393
+ score = 1 - score
394
+ elif transform_type == "negate_log":
395
+ score = -np.log10(score)
396
+
397
+ if fcounts is None:
398
+ color_vals[metric][method] = score
399
+ elif fcounts == "human":
400
+ color_vals[metric][method] = _human_score_map(*score)
401
+ else:
402
+ auc = sklearn.metrics.auc(fcounts, score) / fcounts[-1]
403
+ color_vals[metric][method] = auc
404
+ # print(metric_sort_order)
405
+ # col_keys = sorted(list(color_vals.keys()), key=lambda v: metric_sort_order[v])
406
+ # print(col_keys)
407
+ col_keys = list(color_vals.keys())
408
+ row_keys = list({v for k in col_keys for v in color_vals[k].keys()})
409
+
410
+ data = -28567 * np.ones((len(row_keys), len(col_keys)))
411
+
412
+ for i in range(len(row_keys)):
413
+ for j in range(len(col_keys)):
414
+ data[i,j] = color_vals[col_keys[j]][row_keys[i]]
415
+
416
+ assert np.sum(data == -28567) == 0, "There are missing data values!"
417
+
418
+ if normalize:
419
+ data = (data - data.min(0)) / (data.max(0) - data.min(0) + 1e-8)
420
+
421
+ # sort by performans
422
+ inds = np.argsort(-data.mean(1))
423
+ row_keys = [row_keys[i] for i in inds]
424
+ data = data[inds,:]
425
+
426
+ return row_keys, col_keys, data
427
+
428
+
429
+
430
+ red_blue_solid = LinearSegmentedColormap('red_blue_solid', {
431
+ 'red': ((0.0, 198./255, 198./255),
432
+ (1.0, 5./255, 5./255)),
433
+
434
+ 'green': ((0.0, 34./255, 34./255),
435
+ (1.0, 198./255, 198./255)),
436
+
437
+ 'blue': ((0.0, 5./255, 5./255),
438
+ (1.0, 24./255, 24./255)),
439
+
440
+ 'alpha': ((0.0, 1, 1),
441
+ (1.0, 1, 1))
442
+ })
443
+ def plot_grids(dataset, model_names, out_dir=None):
444
+
445
+ if out_dir is not None:
446
+ os.mkdir(out_dir)
447
+
448
+ scores = []
449
+ for model in model_names:
450
+ scores.extend(run_experiments(dataset=dataset, model=model))
451
+
452
+ prefix = "<style type='text/css'> .shap_benchmark__select:focus { outline-width: 0 }</style>"
453
+ out = "" # background: rgb(30, 136, 229)
454
+
455
+ # out += "<div style='font-weight: regular; font-size: 24px; text-align: center; background: #f8f8f8; color: #000; padding: 20px;'>SHAP Benchmark</div>\n"
456
+ # out += "<div style='height: 1px; background: #ddd;'></div>\n"
457
+ #out += "<div style='height: 7px; background-image: linear-gradient(to right, rgb(30, 136, 229), rgb(255, 13, 87));'></div>"
458
+
459
+ out += "<div style='position: fixed; left: 0px; top: 0px; right: 0px; height: 230px; background: #fff;'>\n" # box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2), 0 6px 20px 0 rgba(0, 0, 0, 0.19);
460
+ out += "<div style='position: absolute; bottom: 0px; left: 0px; right: 0px;' align='center'><table style='border-width: 1px; margin-right: 100px'>\n"
461
+ for ind,model in enumerate(model_names):
462
+ row_keys, col_keys, data = make_grid(scores, dataset, model)
463
+ # print(data)
464
+ # print(colors.red_blue_solid(0.))
465
+ # print(colors.red_blue_solid(1.))
466
+ # return
467
+ for metric in col_keys:
468
+ save_plot = False
469
+ if metric.startswith("human_"):
470
+ plot_human(dataset, model, metric)
471
+ save_plot = True
472
+ elif metric not in ["local_accuracy", "runtime", "consistency_guarantees"]:
473
+ plot_curve(dataset, model, metric)
474
+ save_plot = True
475
+
476
+ if save_plot:
477
+ buf = io.BytesIO()
478
+ pl.gcf().set_size_inches(1200.0/175,1000.0/175)
479
+ pl.savefig(buf, format='png', dpi=175)
480
+ if out_dir is not None:
481
+ pl.savefig(f"{out_dir}/plot_{dataset}_{model}_{metric}.pdf", format='pdf')
482
+ pl.close()
483
+ buf.seek(0)
484
+ data_uri = base64.b64encode(buf.read()).decode('utf-8').replace('\n', '')
485
+ plot_id = "plot__"+dataset+"__"+model+"__"+metric
486
+ prefix += f"<div onclick='document.getElementById(\"{plot_id}\").style.display = \"none\"' style='display: none; position: fixed; z-index: 10000; left: 0px; right: 0px; top: 0px; bottom: 0px; background: rgba(255,255,255,0.9);' id='{plot_id}'>"
487
+ prefix += "<img width='600' height='500' style='margin-left: auto; margin-right: auto; margin-top: 230px; box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2), 0 6px 20px 0 rgba(0, 0, 0, 0.19);' src='data:image/png;base64,%s'>" % data_uri
488
+ prefix += "</div>"
489
+
490
+ model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip()
491
+
492
+ if ind == 0:
493
+ out += "<tr><td style='background: #fff; width: 250px'></td></td>"
494
+ for j in range(data.shape[1]):
495
+ metric_title = getattr(metrics, col_keys[j]).__doc__.split("\n")[0].strip()
496
+ out += "<td style='width: 40px; min-width: 40px; background: #fff; text-align: right;'><div style='margin-left: 10px; margin-bottom: -5px; white-space: nowrap; transform: rotate(-45deg); transform-origin: left top 0; width: 1.5em; margin-top: 8em'>" + metric_title + "</div></td>"
497
+ out += "</tr>\n"
498
+ out += "</table></div></div>\n"
499
+ out += "<table style='border-width: 1px; margin-right: 100px; margin-top: 230px;'>\n"
500
+ out += "<tr><td style='background: #fff'></td><td colspan='%d' style='background: #fff; font-weight: bold; text-align: center; margin-top: 10px;'>%s</td></tr>\n" % (data.shape[1], model_title)
501
+ for i in range(data.shape[0]):
502
+ out += "<tr>"
503
+ # if i == 0:
504
+ # out += "<td rowspan='%d' style='background: #fff; text-align: center; white-space: nowrap; vertical-align: middle; '><div style='font-weight: bold; transform: rotate(-90deg); transform-origin: left top 0; width: 1.5em; margin-top: 8em'>%s</div></td>" % (data.shape[0], model_name)
505
+ method_title = getattr(methods, row_keys[i]).__doc__.split("\n")[0].strip()
506
+ out += "<td style='background: #ffffff; text-align: right; width: 250px' title='shap.LinearExplainer(model)'>" + method_title + "</td>\n"
507
+ for j in range(data.shape[1]):
508
+ plot_id = "plot__"+dataset+"__"+model+"__"+col_keys[j]
509
+ out += "<td onclick='document.getElementById(\"%s\").style.display = \"block\"' style='padding: 0px; padding-left: 0px; padding-right: 0px; border-left: 0px solid #999; width: 42px; min-width: 42px; height: 34px; background-color: #fff'>" % plot_id
510
+ #out += "<div style='opacity: "+str(2*(max(1-data[i,j], data[i,j])-0.5))+"; background-color: rgb" + str(tuple(v*255 for v in colors.red_blue_solid(0. if data[i,j] < 0.5 else 1.)[:-1])) + "; height: "+str((30*max(1-data[i,j], data[i,j])))+"px; margin-left: auto; margin-right: auto; width:"+str((30*max(1-data[i,j], data[i,j])))+"px'></div>"
511
+ out += "<div style='opacity: "+str(1)+"; background-color: rgb" + str(tuple(int(v*255) for v in colors.red_blue_no_bounds(5*(data[i,j]-0.8))[:-1])) + "; height: "+str(30*data[i,j])+"px; margin-left: auto; margin-right: auto; width:"+str(30*data[i,j])+"px'></div>"
512
+ #out += "<div style='float: left; background-color: #eee; height: 10px; width: "+str((40*(1-data[i,j])))+"px'></div>"
513
+ out += "</td>\n"
514
+ out += "</tr>\n" #
515
+
516
+ out += "<tr><td colspan='%d' style='background: #fff'></td></tr>" % (data.shape[1] + 1)
517
+ out += "</table>"
518
+
519
+ out += "<div style='position: fixed; left: 0px; top: 0px; right: 0px; text-align: left; padding: 20px; text-align: right'>\n"
520
+ out += "<div style='float: left; font-weight: regular; font-size: 24px; color: #000;'>SHAP Benchmark <span style='font-size: 14px; color: #777777;'>v"+__version__+"</span></div>\n"
521
+ # select {
522
+ # margin: 50px;
523
+ # width: 150px;
524
+ # padding: 5px 35px 5px 5px;
525
+ # font-size: 16px;
526
+ # border: 1px solid #ccc;
527
+ # height: 34px;
528
+ # -webkit-appearance: none;
529
+ # -moz-appearance: none;
530
+ # appearance: none;
531
+ # background: url(http://www.stackoverflow.com/favicon.ico) 96% / 15% no-repeat #eee;
532
+ # }
533
+ #out += "<div style='display: inline-block; margin-right: 20px; font-weight: normal; text-decoration: none; font-size: 18px; color: #000;'>Dataset:</div>\n"
534
+
535
+ out += "<select id='shap_benchmark__select' onchange=\"document.location = '../' + this.value + '/index.html'\"dir='rtl' class='shap_benchmark__select' style='font-weight: normal; font-size: 20px; color: #000; padding: 10px; background: #fff; border: 1px solid #fff; -webkit-appearance: none; appearance: none;'>\n"
536
+ out += "<option value='human' "+("selected" if dataset == "human" else "")+">Agreement with Human Intuition</option>\n"
537
+ out += "<option value='corrgroups60' "+("selected" if dataset == "corrgroups60" else "")+">Correlated Groups 60 Dataset</option>\n"
538
+ out += "<option value='independentlinear60' "+("selected" if dataset == "independentlinear60" else "")+">Independent Linear 60 Dataset</option>\n"
539
+ #out += "<option>CRIC</option>\n"
540
+ out += "</select>\n"
541
+ #out += "<script> document.onload = function() { document.getElementById('shap_benchmark__select').value = '"+dataset+"'; }</script>"
542
+ #out += "<div style='display: inline-block; margin-left: 20px; font-weight: normal; text-decoration: none; font-size: 18px; color: #000;'>CRIC</div>\n"
543
+ out += "</div>\n"
544
+
545
+ # output the legend
546
+ out += "<table style='border-width: 0px; width: 100px; position: fixed; right: 50px; top: 200px; background: rgba(255, 255, 255, 0.9)'>\n"
547
+ out += "<tr><td style='background: #fff; font-weight: normal; text-align: center'>Higher score</td></tr>\n"
548
+ legend_size = 21
549
+ for i in range(legend_size-9):
550
+ out += "<tr>"
551
+ out += "<td style='padding: 0px; padding-left: 0px; padding-right: 0px; border-left: 0px solid #999; height: 34px'>"
552
+ val = (legend_size-i-1) / (legend_size-1)
553
+ out += "<div style='opacity: 1; background-color: rgb" + str(tuple(int(v*255) for v in colors.red_blue_no_bounds(5*(val-0.8)))[:-1]) + "; height: "+str(30*val)+"px; margin-left: auto; margin-right: auto; width:"+str(30*val)+"px'></div>"
554
+ out += "</td>"
555
+ out += "</tr>\n" #
556
+ out += "<tr><td style='background: #fff; font-weight: normal; text-align: center'>Lower score</td></tr>\n"
557
+ out += "</table>\n"
558
+
559
+ if out_dir is not None:
560
+ with open(out_dir + "/index.html", "w") as f:
561
+ f.write("<html><body style='margin: 0px; font-size: 16px; font-family: \"Myriad Pro\", Arial, sans-serif;'><center>")
562
+ f.write(prefix)
563
+ f.write(out)
564
+ f.write("</center></body></html>")
565
+ else:
566
+ return HTML(prefix + out)
lib/shap/cext/_cext.cc ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
2
+
3
+ #include <Python.h>
4
+ #include <numpy/arrayobject.h>
5
+ #include "tree_shap.h"
6
+ #include <iostream>
7
+
8
+ static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args);
9
+ static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args);
10
+ static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args);
11
+ static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args);
12
+ static PyObject *_cext_compute_expectations(PyObject *self, PyObject *args);
13
+
14
+ static PyMethodDef module_methods[] = {
15
+ {"dense_tree_shap", _cext_dense_tree_shap, METH_VARARGS, "C implementation of Tree SHAP for dense."},
16
+ {"dense_tree_predict", _cext_dense_tree_predict, METH_VARARGS, "C implementation of tree predictions."},
17
+ {"dense_tree_update_weights", _cext_dense_tree_update_weights, METH_VARARGS, "C implementation of tree node weight compuatations."},
18
+ {"dense_tree_saabas", _cext_dense_tree_saabas, METH_VARARGS, "C implementation of Saabas (rough fast approximation to Tree SHAP)."},
19
+ {"compute_expectations", _cext_compute_expectations, METH_VARARGS, "Compute expectations of internal nodes."},
20
+ {NULL, NULL, 0, NULL}
21
+ };
22
+
23
+ #if PY_MAJOR_VERSION >= 3
24
+ static struct PyModuleDef moduledef = {
25
+ PyModuleDef_HEAD_INIT,
26
+ "_cext",
27
+ "This module provides an interface for a fast Tree SHAP implementation.",
28
+ -1,
29
+ module_methods,
30
+ NULL,
31
+ NULL,
32
+ NULL,
33
+ NULL
34
+ };
35
+ #endif
36
+
37
+ #if PY_MAJOR_VERSION >= 3
38
+ PyMODINIT_FUNC PyInit__cext(void)
39
+ #else
40
+ PyMODINIT_FUNC init_cext(void)
41
+ #endif
42
+ {
43
+ #if PY_MAJOR_VERSION >= 3
44
+ PyObject *module = PyModule_Create(&moduledef);
45
+ if (!module) return NULL;
46
+ #else
47
+ PyObject *module = Py_InitModule("_cext", module_methods);
48
+ if (!module) return;
49
+ #endif
50
+
51
+ /* Load `numpy` functionality. */
52
+ import_array();
53
+
54
+ #if PY_MAJOR_VERSION >= 3
55
+ return module;
56
+ #endif
57
+ }
58
+
59
+ static PyObject *_cext_compute_expectations(PyObject *self, PyObject *args)
60
+ {
61
+ PyObject *children_left_obj;
62
+ PyObject *children_right_obj;
63
+ PyObject *node_sample_weight_obj;
64
+ PyObject *values_obj;
65
+
66
+ /* Parse the input tuple */
67
+ if (!PyArg_ParseTuple(
68
+ args, "OOOO", &children_left_obj, &children_right_obj, &node_sample_weight_obj, &values_obj
69
+ )) return NULL;
70
+
71
+ /* Interpret the input objects as numpy arrays. */
72
+ PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
73
+ PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
74
+ PyArrayObject *node_sample_weight_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weight_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
75
+ PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
76
+
77
+ /* If that didn't work, throw an exception. */
78
+ if (children_left_array == NULL || children_right_array == NULL ||
79
+ values_array == NULL || node_sample_weight_array == NULL) {
80
+ Py_XDECREF(children_left_array);
81
+ Py_XDECREF(children_right_array);
82
+ //PyArray_ResolveWritebackIfCopy(values_array);
83
+ Py_XDECREF(values_array);
84
+ Py_XDECREF(node_sample_weight_array);
85
+ return NULL;
86
+ }
87
+
88
+ TreeEnsemble tree;
89
+
90
+ // number of outputs
91
+ tree.num_outputs = PyArray_DIM(values_array, 1);
92
+
93
+ /* Get pointers to the data as C-types. */
94
+ tree.children_left = (int*)PyArray_DATA(children_left_array);
95
+ tree.children_right = (int*)PyArray_DATA(children_right_array);
96
+ tree.values = (tfloat*)PyArray_DATA(values_array);
97
+ tree.node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weight_array);
98
+
99
+ const int max_depth = compute_expectations(tree);
100
+
101
+ // clean up the created python objects
102
+ Py_XDECREF(children_left_array);
103
+ Py_XDECREF(children_right_array);
104
+ //PyArray_ResolveWritebackIfCopy(values_array);
105
+ Py_XDECREF(values_array);
106
+ Py_XDECREF(node_sample_weight_array);
107
+
108
+ PyObject *ret = Py_BuildValue("i", max_depth);
109
+ return ret;
110
+ }
111
+
112
+
113
+ static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args)
114
+ {
115
+ PyObject *children_left_obj;
116
+ PyObject *children_right_obj;
117
+ PyObject *children_default_obj;
118
+ PyObject *features_obj;
119
+ PyObject *thresholds_obj;
120
+ PyObject *values_obj;
121
+ PyObject *node_sample_weights_obj;
122
+ int max_depth;
123
+ PyObject *X_obj;
124
+ PyObject *X_missing_obj;
125
+ PyObject *y_obj;
126
+ PyObject *R_obj;
127
+ PyObject *R_missing_obj;
128
+ int tree_limit;
129
+ PyObject *out_contribs_obj;
130
+ int feature_dependence;
131
+ int model_output;
132
+ PyObject *base_offset_obj;
133
+ bool interactions;
134
+
135
+ /* Parse the input tuple */
136
+ if (!PyArg_ParseTuple(
137
+ args, "OOOOOOOiOOOOOiOOiib", &children_left_obj, &children_right_obj, &children_default_obj,
138
+ &features_obj, &thresholds_obj, &values_obj, &node_sample_weights_obj,
139
+ &max_depth, &X_obj, &X_missing_obj, &y_obj, &R_obj, &R_missing_obj, &tree_limit, &base_offset_obj,
140
+ &out_contribs_obj, &feature_dependence, &model_output, &interactions
141
+ )) return NULL;
142
+
143
+ /* Interpret the input objects as numpy arrays. */
144
+ PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
145
+ PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
146
+ PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
147
+ PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
148
+ PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
149
+ PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
150
+ PyArrayObject *node_sample_weights_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weights_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
151
+ PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
152
+ PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
153
+ PyArrayObject *y_array = NULL;
154
+ if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
155
+ PyArrayObject *R_array = NULL;
156
+ if (R_obj != Py_None) R_array = (PyArrayObject*)PyArray_FROM_OTF(R_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
157
+ PyArrayObject *R_missing_array = NULL;
158
+ if (R_missing_obj != Py_None) R_missing_array = (PyArrayObject*)PyArray_FROM_OTF(R_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
159
+ PyArrayObject *out_contribs_array = (PyArrayObject*)PyArray_FROM_OTF(out_contribs_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
160
+ PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
161
+
162
+ /* If that didn't work, throw an exception. Note that R and y are optional. */
163
+ if (children_left_array == NULL || children_right_array == NULL ||
164
+ children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
165
+ values_array == NULL || node_sample_weights_array == NULL || X_array == NULL ||
166
+ X_missing_array == NULL || out_contribs_array == NULL) {
167
+ Py_XDECREF(children_left_array);
168
+ Py_XDECREF(children_right_array);
169
+ Py_XDECREF(children_default_array);
170
+ Py_XDECREF(features_array);
171
+ Py_XDECREF(thresholds_array);
172
+ Py_XDECREF(values_array);
173
+ Py_XDECREF(node_sample_weights_array);
174
+ Py_XDECREF(X_array);
175
+ Py_XDECREF(X_missing_array);
176
+ if (y_array != NULL) Py_XDECREF(y_array);
177
+ if (R_array != NULL) Py_XDECREF(R_array);
178
+ if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
179
+ //PyArray_ResolveWritebackIfCopy(out_contribs_array);
180
+ Py_XDECREF(out_contribs_array);
181
+ Py_XDECREF(base_offset_array);
182
+ return NULL;
183
+ }
184
+
185
+ const unsigned num_X = PyArray_DIM(X_array, 0);
186
+ const unsigned M = PyArray_DIM(X_array, 1);
187
+ const unsigned max_nodes = PyArray_DIM(values_array, 1);
188
+ const unsigned num_outputs = PyArray_DIM(values_array, 2);
189
+ unsigned num_R = 0;
190
+ if (R_array != NULL) num_R = PyArray_DIM(R_array, 0);
191
+
192
+ // Get pointers to the data as C-types
193
+ int *children_left = (int*)PyArray_DATA(children_left_array);
194
+ int *children_right = (int*)PyArray_DATA(children_right_array);
195
+ int *children_default = (int*)PyArray_DATA(children_default_array);
196
+ int *features = (int*)PyArray_DATA(features_array);
197
+ tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
198
+ tfloat *values = (tfloat*)PyArray_DATA(values_array);
199
+ tfloat *node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weights_array);
200
+ tfloat *X = (tfloat*)PyArray_DATA(X_array);
201
+ bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
202
+ tfloat *y = NULL;
203
+ if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
204
+ tfloat *R = NULL;
205
+ if (R_array != NULL) R = (tfloat*)PyArray_DATA(R_array);
206
+ bool *R_missing = NULL;
207
+ if (R_missing_array != NULL) R_missing = (bool*)PyArray_DATA(R_missing_array);
208
+ tfloat *out_contribs = (tfloat*)PyArray_DATA(out_contribs_array);
209
+ tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
210
+
211
+ // these are just a wrapper objects for all the pointers and numbers associated with
212
+ // the ensemble tree model and the dataset we are explaining
213
+ TreeEnsemble trees = TreeEnsemble(
214
+ children_left, children_right, children_default, features, thresholds, values,
215
+ node_sample_weights, max_depth, tree_limit, base_offset,
216
+ max_nodes, num_outputs
217
+ );
218
+ ExplanationDataset data = ExplanationDataset(X, X_missing, y, R, R_missing, num_X, M, num_R);
219
+
220
+ dense_tree_shap(trees, data, out_contribs, feature_dependence, model_output, interactions);
221
+
222
+ // retrieve return value before python cleanup of objects
223
+ tfloat ret_value = (double)values[0];
224
+
225
+ // clean up the created python objects
226
+ Py_XDECREF(children_left_array);
227
+ Py_XDECREF(children_right_array);
228
+ Py_XDECREF(children_default_array);
229
+ Py_XDECREF(features_array);
230
+ Py_XDECREF(thresholds_array);
231
+ Py_XDECREF(values_array);
232
+ Py_XDECREF(node_sample_weights_array);
233
+ Py_XDECREF(X_array);
234
+ Py_XDECREF(X_missing_array);
235
+ if (y_array != NULL) Py_XDECREF(y_array);
236
+ if (R_array != NULL) Py_XDECREF(R_array);
237
+ if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
238
+ //PyArray_ResolveWritebackIfCopy(out_contribs_array);
239
+ Py_XDECREF(out_contribs_array);
240
+ Py_XDECREF(base_offset_array);
241
+
242
+ /* Build the output tuple */
243
+ PyObject *ret = Py_BuildValue("d", ret_value);
244
+ return ret;
245
+ }
246
+
247
+
248
+ static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args)
249
+ {
250
+ PyObject *children_left_obj;
251
+ PyObject *children_right_obj;
252
+ PyObject *children_default_obj;
253
+ PyObject *features_obj;
254
+ PyObject *thresholds_obj;
255
+ PyObject *values_obj;
256
+ int max_depth;
257
+ int tree_limit;
258
+ PyObject *base_offset_obj;
259
+ int model_output;
260
+ PyObject *X_obj;
261
+ PyObject *X_missing_obj;
262
+ PyObject *y_obj;
263
+ PyObject *out_pred_obj;
264
+
265
+ /* Parse the input tuple */
266
+ if (!PyArg_ParseTuple(
267
+ args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj,
268
+ &features_obj, &thresholds_obj, &values_obj, &max_depth, &tree_limit, &base_offset_obj, &model_output,
269
+ &X_obj, &X_missing_obj, &y_obj, &out_pred_obj
270
+ )) return NULL;
271
+
272
+ /* Interpret the input objects as numpy arrays. */
273
+ PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
274
+ PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
275
+ PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
276
+ PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
277
+ PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
278
+ PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
279
+ PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
280
+ PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
281
+ PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
282
+ PyArrayObject *y_array = NULL;
283
+ if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
284
+ PyArrayObject *out_pred_array = (PyArrayObject*)PyArray_FROM_OTF(out_pred_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
285
+
286
+ /* If that didn't work, throw an exception. Note that R and y are optional. */
287
+ if (children_left_array == NULL || children_right_array == NULL ||
288
+ children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
289
+ values_array == NULL || X_array == NULL ||
290
+ X_missing_array == NULL || out_pred_array == NULL) {
291
+ Py_XDECREF(children_left_array);
292
+ Py_XDECREF(children_right_array);
293
+ Py_XDECREF(children_default_array);
294
+ Py_XDECREF(features_array);
295
+ Py_XDECREF(thresholds_array);
296
+ Py_XDECREF(values_array);
297
+ Py_XDECREF(base_offset_array);
298
+ Py_XDECREF(X_array);
299
+ Py_XDECREF(X_missing_array);
300
+ if (y_array != NULL) Py_XDECREF(y_array);
301
+ //PyArray_ResolveWritebackIfCopy(out_pred_array);
302
+ Py_XDECREF(out_pred_array);
303
+ return NULL;
304
+ }
305
+
306
+ const unsigned num_X = PyArray_DIM(X_array, 0);
307
+ const unsigned M = PyArray_DIM(X_array, 1);
308
+ const unsigned max_nodes = PyArray_DIM(values_array, 1);
309
+ const unsigned num_outputs = PyArray_DIM(values_array, 2);
310
+
311
+ const unsigned num_offsets = PyArray_DIM(base_offset_array, 0);
312
+ if (num_offsets != num_outputs) {
313
+ std::cerr << "The passed base_offset array does that have the same number of outputs as the values array: " << num_offsets << " vs. " << num_outputs << std::endl;
314
+ return NULL;
315
+ }
316
+
317
+ // Get pointers to the data as C-types
318
+ int *children_left = (int*)PyArray_DATA(children_left_array);
319
+ int *children_right = (int*)PyArray_DATA(children_right_array);
320
+ int *children_default = (int*)PyArray_DATA(children_default_array);
321
+ int *features = (int*)PyArray_DATA(features_array);
322
+ tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
323
+ tfloat *values = (tfloat*)PyArray_DATA(values_array);
324
+ tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
325
+ tfloat *X = (tfloat*)PyArray_DATA(X_array);
326
+ bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
327
+ tfloat *y = NULL;
328
+ if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
329
+ tfloat *out_pred = (tfloat*)PyArray_DATA(out_pred_array);
330
+
331
+ // these are just wrapper objects for all the pointers and numbers associated with
332
+ // the ensemble tree model and the dataset we are explaining
333
+ TreeEnsemble trees = TreeEnsemble(
334
+ children_left, children_right, children_default, features, thresholds, values,
335
+ NULL, max_depth, tree_limit, base_offset,
336
+ max_nodes, num_outputs
337
+ );
338
+ ExplanationDataset data = ExplanationDataset(X, X_missing, y, NULL, NULL, num_X, M, 0);
339
+
340
+ dense_tree_predict(out_pred, trees, data, model_output);
341
+
342
+ // clean up the created python objects
343
+ Py_XDECREF(children_left_array);
344
+ Py_XDECREF(children_right_array);
345
+ Py_XDECREF(children_default_array);
346
+ Py_XDECREF(features_array);
347
+ Py_XDECREF(thresholds_array);
348
+ Py_XDECREF(values_array);
349
+ Py_XDECREF(base_offset_array);
350
+ Py_XDECREF(X_array);
351
+ Py_XDECREF(X_missing_array);
352
+ if (y_array != NULL) Py_XDECREF(y_array);
353
+ //PyArray_ResolveWritebackIfCopy(out_pred_array);
354
+ Py_XDECREF(out_pred_array);
355
+
356
+ /* Build the output tuple */
357
+ PyObject *ret = Py_BuildValue("d", (double)values[0]);
358
+ return ret;
359
+ }
360
+
361
+
362
+ static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args)
363
+ {
364
+ PyObject *children_left_obj;
365
+ PyObject *children_right_obj;
366
+ PyObject *children_default_obj;
367
+ PyObject *features_obj;
368
+ PyObject *thresholds_obj;
369
+ PyObject *values_obj;
370
+ int tree_limit;
371
+ PyObject *node_sample_weight_obj;
372
+ PyObject *X_obj;
373
+ PyObject *X_missing_obj;
374
+
375
+ /* Parse the input tuple */
376
+ if (!PyArg_ParseTuple(
377
+ args, "OOOOOOiOOO", &children_left_obj, &children_right_obj, &children_default_obj,
378
+ &features_obj, &thresholds_obj, &values_obj, &tree_limit, &node_sample_weight_obj, &X_obj, &X_missing_obj
379
+ )) return NULL;
380
+
381
+ /* Interpret the input objects as numpy arrays. */
382
+ PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
383
+ PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
384
+ PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
385
+ PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
386
+ PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
387
+ PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
388
+ PyArrayObject *node_sample_weight_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weight_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
389
+ PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
390
+ PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
391
+
392
+ /* If that didn't work, throw an exception. */
393
+ if (children_left_array == NULL || children_right_array == NULL ||
394
+ children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
395
+ values_array == NULL || node_sample_weight_array == NULL || X_array == NULL ||
396
+ X_missing_array == NULL) {
397
+ Py_XDECREF(children_left_array);
398
+ Py_XDECREF(children_right_array);
399
+ Py_XDECREF(children_default_array);
400
+ Py_XDECREF(features_array);
401
+ Py_XDECREF(thresholds_array);
402
+ Py_XDECREF(values_array);
403
+ //PyArray_ResolveWritebackIfCopy(node_sample_weight_array);
404
+ Py_XDECREF(node_sample_weight_array);
405
+ Py_XDECREF(X_array);
406
+ Py_XDECREF(X_missing_array);
407
+ std::cerr << "Found a NULL input array in _cext_dense_tree_update_weights!\n";
408
+ return NULL;
409
+ }
410
+
411
+ const unsigned num_X = PyArray_DIM(X_array, 0);
412
+ const unsigned M = PyArray_DIM(X_array, 1);
413
+ const unsigned max_nodes = PyArray_DIM(values_array, 1);
414
+
415
+ // Get pointers to the data as C-types
416
+ int *children_left = (int*)PyArray_DATA(children_left_array);
417
+ int *children_right = (int*)PyArray_DATA(children_right_array);
418
+ int *children_default = (int*)PyArray_DATA(children_default_array);
419
+ int *features = (int*)PyArray_DATA(features_array);
420
+ tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
421
+ tfloat *values = (tfloat*)PyArray_DATA(values_array);
422
+ tfloat *node_sample_weight = (tfloat*)PyArray_DATA(node_sample_weight_array);
423
+ tfloat *X = (tfloat*)PyArray_DATA(X_array);
424
+ bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
425
+
426
+ // these are just wrapper objects for all the pointers and numbers associated with
427
+ // the ensemble tree model and the dataset we are explaining
428
+ TreeEnsemble trees = TreeEnsemble(
429
+ children_left, children_right, children_default, features, thresholds, values,
430
+ node_sample_weight, 0, tree_limit, 0, max_nodes, 0
431
+ );
432
+ ExplanationDataset data = ExplanationDataset(X, X_missing, NULL, NULL, NULL, num_X, M, 0);
433
+
434
+ dense_tree_update_weights(trees, data);
435
+
436
+ // clean up the created python objects
437
+ Py_XDECREF(children_left_array);
438
+ Py_XDECREF(children_right_array);
439
+ Py_XDECREF(children_default_array);
440
+ Py_XDECREF(features_array);
441
+ Py_XDECREF(thresholds_array);
442
+ Py_XDECREF(values_array);
443
+ // PyArray_ResolveWritebackIfCopy(node_sample_weight_array);
444
+ Py_XDECREF(node_sample_weight_array);
445
+ Py_XDECREF(X_array);
446
+ Py_XDECREF(X_missing_array);
447
+
448
+ /* Build the output tuple */
449
+ PyObject *ret = Py_BuildValue("d", 1);
450
+ return ret;
451
+ }
452
+
453
+
454
+ static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args)
455
+ {
456
+ PyObject *children_left_obj;
457
+ PyObject *children_right_obj;
458
+ PyObject *children_default_obj;
459
+ PyObject *features_obj;
460
+ PyObject *thresholds_obj;
461
+ PyObject *values_obj;
462
+ int max_depth;
463
+ int tree_limit;
464
+ PyObject *base_offset_obj;
465
+ int model_output;
466
+ PyObject *X_obj;
467
+ PyObject *X_missing_obj;
468
+ PyObject *y_obj;
469
+ PyObject *out_pred_obj;
470
+
471
+
472
+ /* Parse the input tuple */
473
+ if (!PyArg_ParseTuple(
474
+ args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj,
475
+ &features_obj, &thresholds_obj, &values_obj, &max_depth, &tree_limit, &base_offset_obj, &model_output,
476
+ &X_obj, &X_missing_obj, &y_obj, &out_pred_obj
477
+ )) return NULL;
478
+
479
+ /* Interpret the input objects as numpy arrays. */
480
+ PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
481
+ PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
482
+ PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
483
+ PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
484
+ PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
485
+ PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
486
+ PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
487
+ PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
488
+ PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
489
+ PyArrayObject *y_array = NULL;
490
+ if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
491
+ PyArrayObject *out_pred_array = (PyArrayObject*)PyArray_FROM_OTF(out_pred_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
492
+
493
+ /* If that didn't work, throw an exception. Note that R and y are optional. */
494
+ if (children_left_array == NULL || children_right_array == NULL ||
495
+ children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
496
+ values_array == NULL || X_array == NULL ||
497
+ X_missing_array == NULL || out_pred_array == NULL) {
498
+ Py_XDECREF(children_left_array);
499
+ Py_XDECREF(children_right_array);
500
+ Py_XDECREF(children_default_array);
501
+ Py_XDECREF(features_array);
502
+ Py_XDECREF(thresholds_array);
503
+ Py_XDECREF(values_array);
504
+ Py_XDECREF(base_offset_array);
505
+ Py_XDECREF(X_array);
506
+ Py_XDECREF(X_missing_array);
507
+ if (y_array != NULL) Py_XDECREF(y_array);
508
+ //PyArray_ResolveWritebackIfCopy(out_pred_array);
509
+ Py_XDECREF(out_pred_array);
510
+ return NULL;
511
+ }
512
+
513
+ const unsigned num_X = PyArray_DIM(X_array, 0);
514
+ const unsigned M = PyArray_DIM(X_array, 1);
515
+ const unsigned max_nodes = PyArray_DIM(values_array, 1);
516
+ const unsigned num_outputs = PyArray_DIM(values_array, 2);
517
+
518
+ // Get pointers to the data as C-types
519
+ int *children_left = (int*)PyArray_DATA(children_left_array);
520
+ int *children_right = (int*)PyArray_DATA(children_right_array);
521
+ int *children_default = (int*)PyArray_DATA(children_default_array);
522
+ int *features = (int*)PyArray_DATA(features_array);
523
+ tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
524
+ tfloat *values = (tfloat*)PyArray_DATA(values_array);
525
+ tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
526
+ tfloat *X = (tfloat*)PyArray_DATA(X_array);
527
+ bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
528
+ tfloat *y = NULL;
529
+ if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
530
+ tfloat *out_pred = (tfloat*)PyArray_DATA(out_pred_array);
531
+
532
+ // these are just wrapper objects for all the pointers and numbers associated with
533
+ // the ensemble tree model and the dataset we are explaining
534
+ TreeEnsemble trees = TreeEnsemble(
535
+ children_left, children_right, children_default, features, thresholds, values,
536
+ NULL, max_depth, tree_limit, base_offset,
537
+ max_nodes, num_outputs
538
+ );
539
+ ExplanationDataset data = ExplanationDataset(X, X_missing, y, NULL, NULL, num_X, M, 0);
540
+
541
+ dense_tree_saabas(out_pred, trees, data);
542
+
543
+ // clean up the created python objects
544
+ Py_XDECREF(children_left_array);
545
+ Py_XDECREF(children_right_array);
546
+ Py_XDECREF(children_default_array);
547
+ Py_XDECREF(features_array);
548
+ Py_XDECREF(thresholds_array);
549
+ Py_XDECREF(values_array);
550
+ Py_XDECREF(base_offset_array);
551
+ Py_XDECREF(X_array);
552
+ Py_XDECREF(X_missing_array);
553
+ if (y_array != NULL) Py_XDECREF(y_array);
554
+ //PyArray_ResolveWritebackIfCopy(out_pred_array);
555
+ Py_XDECREF(out_pred_array);
556
+
557
+ /* Build the output tuple */
558
+ PyObject *ret = Py_BuildValue("d", (double)values[0]);
559
+ return ret;
560
+ }
lib/shap/cext/_cext_gpu.cc ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
2
+
3
+ #include <Python.h>
4
+ #include <numpy/arrayobject.h>
5
+ #include "tree_shap.h"
6
+ #include <iostream>
7
+
8
+ static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args);
9
+
10
+ static PyMethodDef module_methods[] = {
11
+ {"dense_tree_shap", _cext_dense_tree_shap, METH_VARARGS, "C implementation of Tree SHAP for dense."},
12
+ {NULL, NULL, 0, NULL}
13
+ };
14
+
15
+ #if PY_MAJOR_VERSION >= 3
16
+ static struct PyModuleDef moduledef = {
17
+ PyModuleDef_HEAD_INIT,
18
+ "_cext_gpu",
19
+ "This module provides an interface for a fast Tree SHAP implementation.",
20
+ -1,
21
+ module_methods,
22
+ NULL,
23
+ NULL,
24
+ NULL,
25
+ NULL
26
+ };
27
+ #endif
28
+
29
+ #if PY_MAJOR_VERSION >= 3
30
+ PyMODINIT_FUNC PyInit__cext_gpu(void)
31
+ #else
32
+ PyMODINIT_FUNC init_cext(void)
33
+ #endif
34
+ {
35
+ #if PY_MAJOR_VERSION >= 3
36
+ PyObject *module = PyModule_Create(&moduledef);
37
+ if (!module) return NULL;
38
+ #else
39
+ PyObject *module = Py_InitModule("_cext", module_methods);
40
+ if (!module) return;
41
+ #endif
42
+
43
+ /* Load `numpy` functionality. */
44
+ import_array();
45
+
46
+ #if PY_MAJOR_VERSION >= 3
47
+ return module;
48
+ #endif
49
+ }
50
+
51
+ void dense_tree_shap_gpu(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs,
52
+ const int feature_dependence, unsigned model_transform, bool interactions);
53
+
54
+ static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args)
55
+ {
56
+ PyObject *children_left_obj;
57
+ PyObject *children_right_obj;
58
+ PyObject *children_default_obj;
59
+ PyObject *features_obj;
60
+ PyObject *thresholds_obj;
61
+ PyObject *values_obj;
62
+ PyObject *node_sample_weights_obj;
63
+ int max_depth;
64
+ PyObject *X_obj;
65
+ PyObject *X_missing_obj;
66
+ PyObject *y_obj;
67
+ PyObject *R_obj;
68
+ PyObject *R_missing_obj;
69
+ int tree_limit;
70
+ PyObject *out_contribs_obj;
71
+ int feature_dependence;
72
+ int model_output;
73
+ PyObject *base_offset_obj;
74
+ bool interactions;
75
+
76
+ /* Parse the input tuple */
77
+ if (!PyArg_ParseTuple(
78
+ args, "OOOOOOOiOOOOOiOOiib", &children_left_obj, &children_right_obj, &children_default_obj,
79
+ &features_obj, &thresholds_obj, &values_obj, &node_sample_weights_obj,
80
+ &max_depth, &X_obj, &X_missing_obj, &y_obj, &R_obj, &R_missing_obj, &tree_limit, &base_offset_obj,
81
+ &out_contribs_obj, &feature_dependence, &model_output, &interactions
82
+ )) return NULL;
83
+
84
+ /* Interpret the input objects as numpy arrays. */
85
+ PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
86
+ PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
87
+ PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
88
+ PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
89
+ PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
90
+ PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
91
+ PyArrayObject *node_sample_weights_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weights_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
92
+ PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
93
+ PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
94
+ PyArrayObject *y_array = NULL;
95
+ if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
96
+ PyArrayObject *R_array = NULL;
97
+ if (R_obj != Py_None) R_array = (PyArrayObject*)PyArray_FROM_OTF(R_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
98
+ PyArrayObject *R_missing_array = NULL;
99
+ if (R_missing_obj != Py_None) R_missing_array = (PyArrayObject*)PyArray_FROM_OTF(R_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
100
+ PyArrayObject *out_contribs_array = (PyArrayObject*)PyArray_FROM_OTF(out_contribs_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
101
+ PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
102
+
103
+ /* If that didn't work, throw an exception. Note that R and y are optional. */
104
+ if (children_left_array == NULL || children_right_array == NULL ||
105
+ children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
106
+ values_array == NULL || node_sample_weights_array == NULL || X_array == NULL ||
107
+ X_missing_array == NULL || out_contribs_array == NULL) {
108
+ Py_XDECREF(children_left_array);
109
+ Py_XDECREF(children_right_array);
110
+ Py_XDECREF(children_default_array);
111
+ Py_XDECREF(features_array);
112
+ Py_XDECREF(thresholds_array);
113
+ Py_XDECREF(values_array);
114
+ Py_XDECREF(node_sample_weights_array);
115
+ Py_XDECREF(X_array);
116
+ Py_XDECREF(X_missing_array);
117
+ if (y_array != NULL) Py_XDECREF(y_array);
118
+ if (R_array != NULL) Py_XDECREF(R_array);
119
+ if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
120
+ //PyArray_ResolveWritebackIfCopy(out_contribs_array);
121
+ Py_XDECREF(out_contribs_array);
122
+ Py_XDECREF(base_offset_array);
123
+ return NULL;
124
+ }
125
+
126
+ const unsigned num_X = PyArray_DIM(X_array, 0);
127
+ const unsigned M = PyArray_DIM(X_array, 1);
128
+ const unsigned max_nodes = PyArray_DIM(values_array, 1);
129
+ const unsigned num_outputs = PyArray_DIM(values_array, 2);
130
+ unsigned num_R = 0;
131
+ if (R_array != NULL) num_R = PyArray_DIM(R_array, 0);
132
+
133
+ // Get pointers to the data as C-types
134
+ int *children_left = (int*)PyArray_DATA(children_left_array);
135
+ int *children_right = (int*)PyArray_DATA(children_right_array);
136
+ int *children_default = (int*)PyArray_DATA(children_default_array);
137
+ int *features = (int*)PyArray_DATA(features_array);
138
+ tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
139
+ tfloat *values = (tfloat*)PyArray_DATA(values_array);
140
+ tfloat *node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weights_array);
141
+ tfloat *X = (tfloat*)PyArray_DATA(X_array);
142
+ bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
143
+ tfloat *y = NULL;
144
+ if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
145
+ tfloat *R = NULL;
146
+ if (R_array != NULL) R = (tfloat*)PyArray_DATA(R_array);
147
+ bool *R_missing = NULL;
148
+ if (R_missing_array != NULL) R_missing = (bool*)PyArray_DATA(R_missing_array);
149
+ tfloat *out_contribs = (tfloat*)PyArray_DATA(out_contribs_array);
150
+ tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
151
+
152
+ // these are just a wrapper objects for all the pointers and numbers associated with
153
+ // the ensemble tree model and the dataset we are explaining
154
+ TreeEnsemble trees = TreeEnsemble(
155
+ children_left, children_right, children_default, features, thresholds, values,
156
+ node_sample_weights, max_depth, tree_limit, base_offset,
157
+ max_nodes, num_outputs
158
+ );
159
+ ExplanationDataset data = ExplanationDataset(X, X_missing, y, R, R_missing, num_X, M, num_R);
160
+
161
+ dense_tree_shap_gpu(trees, data, out_contribs, feature_dependence, model_output, interactions);
162
+
163
+
164
+ // retrieve return value before python cleanup of objects
165
+ tfloat ret_value = (double)values[0];
166
+
167
+ // clean up the created python objects
168
+ Py_XDECREF(children_left_array);
169
+ Py_XDECREF(children_right_array);
170
+ Py_XDECREF(children_default_array);
171
+ Py_XDECREF(features_array);
172
+ Py_XDECREF(thresholds_array);
173
+ Py_XDECREF(values_array);
174
+ Py_XDECREF(node_sample_weights_array);
175
+ Py_XDECREF(X_array);
176
+ Py_XDECREF(X_missing_array);
177
+ if (y_array != NULL) Py_XDECREF(y_array);
178
+ if (R_array != NULL) Py_XDECREF(R_array);
179
+ if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
180
+ //PyArray_ResolveWritebackIfCopy(out_contribs_array);
181
+ Py_XDECREF(out_contribs_array);
182
+ Py_XDECREF(base_offset_array);
183
+
184
+ /* Build the output tuple */
185
+ PyObject *ret = Py_BuildValue("d", ret_value);
186
+ return ret;
187
+ }
lib/shap/cext/_cext_gpu.cu ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <Python.h>
2
+
3
+ #include "gpu_treeshap.h"
4
+ #include "tree_shap.h"
5
+
6
+ const float inf = std::numeric_limits<tfloat>::infinity();
7
+
8
+ struct ShapSplitCondition {
9
+ ShapSplitCondition() = default;
10
+ ShapSplitCondition(tfloat feature_lower_bound, tfloat feature_upper_bound,
11
+ bool is_missing_branch)
12
+ : feature_lower_bound(feature_lower_bound),
13
+ feature_upper_bound(feature_upper_bound),
14
+ is_missing_branch(is_missing_branch) {
15
+ assert(feature_lower_bound <= feature_upper_bound);
16
+ }
17
+
18
+ /*! Feature values >= lower and < upper flow down this path. */
19
+ tfloat feature_lower_bound;
20
+ tfloat feature_upper_bound;
21
+ /*! Do missing values flow down this path? */
22
+ bool is_missing_branch;
23
+
24
+ // Does this instance flow down this path?
25
+ __host__ __device__ bool EvaluateSplit(float x) const {
26
+ // is nan
27
+ if (isnan(x)) {
28
+ return is_missing_branch;
29
+ }
30
+ return x > feature_lower_bound && x <= feature_upper_bound;
31
+ }
32
+
33
+ // Combine two split conditions on the same feature
34
+ __host__ __device__ void
35
+ Merge(const ShapSplitCondition &other) { // Combine duplicate features
36
+ feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
37
+ feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
38
+ is_missing_branch = is_missing_branch && other.is_missing_branch;
39
+ }
40
+ };
41
+
42
+
43
+ // Inspired by: https://en.cppreference.com/w/cpp/iterator/size
44
+ // Limited implementation of std::size fo arrays
45
+ template <class T, size_t N>
46
+ constexpr size_t array_size(const T (&array)[N]) noexcept
47
+ {
48
+ return N;
49
+ }
50
+
51
+ void RecurseTree(
52
+ unsigned pos, const TreeEnsemble &tree,
53
+ std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> *tmp_path,
54
+ std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> *paths,
55
+ size_t *path_idx, int num_outputs) {
56
+ if (tree.is_leaf(pos)) {
57
+ for (auto j = 0ull; j < num_outputs; j++) {
58
+ auto v = tree.values[pos * num_outputs + j];
59
+ if (v == 0.0) {
60
+ // The tree has no output for this class, don't bother adding the path
61
+ continue;
62
+ }
63
+ // Go back over path, setting v, path_idx
64
+ for (auto &e : *tmp_path) {
65
+ e.v = v;
66
+ e.group = j;
67
+ e.path_idx = *path_idx;
68
+ }
69
+
70
+ paths->insert(paths->end(), tmp_path->begin(), tmp_path->end());
71
+ // Increment path index
72
+ (*path_idx)++;
73
+ }
74
+ return;
75
+ }
76
+
77
+ // Add left split to the path
78
+ unsigned left_child = tree.children_left[pos];
79
+ double left_zero_fraction =
80
+ tree.node_sample_weights[left_child] / tree.node_sample_weights[pos];
81
+ // Encode the range of feature values that flow down this path
82
+ tmp_path->emplace_back(0, tree.features[pos], 0,
83
+ ShapSplitCondition{-inf, tree.thresholds[pos], false},
84
+ left_zero_fraction, 0.0f);
85
+
86
+ RecurseTree(left_child, tree, tmp_path, paths, path_idx, num_outputs);
87
+
88
+ // Add left split to the path
89
+ tmp_path->back() = gpu_treeshap::PathElement<ShapSplitCondition>(
90
+ 0, tree.features[pos], 0,
91
+ ShapSplitCondition{tree.thresholds[pos], inf, false},
92
+ 1.0 - left_zero_fraction, 0.0f);
93
+
94
+ RecurseTree(tree.children_right[pos], tree, tmp_path, paths, path_idx,
95
+ num_outputs);
96
+
97
+ tmp_path->pop_back();
98
+ }
99
+
100
+ std::vector<gpu_treeshap::PathElement<ShapSplitCondition>>
101
+ ExtractPaths(const TreeEnsemble &trees) {
102
+ std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> paths;
103
+ size_t path_idx = 0;
104
+ for (auto i = 0; i < trees.tree_limit; i++) {
105
+ TreeEnsemble tree;
106
+ trees.get_tree(tree, i);
107
+ std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> tmp_path;
108
+ tmp_path.reserve(tree.max_depth);
109
+ tmp_path.emplace_back(0, -1, 0, ShapSplitCondition{-inf, inf, false}, 1.0,
110
+ 0.0f);
111
+ RecurseTree(0, tree, &tmp_path, &paths, &path_idx, tree.num_outputs);
112
+ }
113
+ return paths;
114
+ }
115
+
116
+ class DeviceExplanationDataset {
117
+ thrust::device_vector<tfloat> data;
118
+ thrust::device_vector<bool> missing;
119
+ size_t num_features;
120
+ size_t num_rows;
121
+
122
+ public:
123
+ DeviceExplanationDataset(const ExplanationDataset &host_data,
124
+ bool background_dataset = false) {
125
+ num_features = host_data.M;
126
+ if (background_dataset) {
127
+ num_rows = host_data.num_R;
128
+ data = thrust::device_vector<tfloat>(
129
+ host_data.R, host_data.R + host_data.num_R * host_data.M);
130
+ missing = thrust::device_vector<bool>(host_data.R_missing,
131
+ host_data.R_missing +
132
+ host_data.num_R * host_data.M);
133
+
134
+ } else {
135
+ num_rows = host_data.num_X;
136
+ data = thrust::device_vector<tfloat>(
137
+ host_data.X, host_data.X + host_data.num_X * host_data.M);
138
+ missing = thrust::device_vector<bool>(host_data.X_missing,
139
+ host_data.X_missing +
140
+ host_data.num_X * host_data.M);
141
+ }
142
+ }
143
+
144
+ class DenseDatasetWrapper {
145
+ const tfloat *data;
146
+ const bool *missing;
147
+ int num_rows;
148
+ int num_cols;
149
+
150
+ public:
151
+ DenseDatasetWrapper() = default;
152
+ DenseDatasetWrapper(const tfloat *data, const bool *missing, int num_rows,
153
+ int num_cols)
154
+ : data(data), missing(missing), num_rows(num_rows), num_cols(num_cols) {
155
+ }
156
+ __device__ tfloat GetElement(size_t row_idx, size_t col_idx) const {
157
+ auto idx = row_idx * num_cols + col_idx;
158
+ if (missing[idx]) {
159
+ return std::numeric_limits<tfloat>::quiet_NaN();
160
+ }
161
+ return data[idx];
162
+ }
163
+ __host__ __device__ size_t NumRows() const { return num_rows; }
164
+ __host__ __device__ size_t NumCols() const { return num_cols; }
165
+ };
166
+
167
+ DenseDatasetWrapper GetDeviceAccessor() {
168
+ return DenseDatasetWrapper(data.data().get(), missing.data().get(),
169
+ num_rows, num_features);
170
+ }
171
+ };
172
+
173
+ inline void dense_tree_path_dependent_gpu(
174
+ const TreeEnsemble &trees, const ExplanationDataset &data,
175
+ tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
176
+ auto paths = ExtractPaths(trees);
177
+ DeviceExplanationDataset device_data(data);
178
+ DeviceExplanationDataset::DenseDatasetWrapper X =
179
+ device_data.GetDeviceAccessor();
180
+
181
+ thrust::device_vector<float> phis((X.NumCols() + 1) * X.NumRows() *
182
+ trees.num_outputs);
183
+ gpu_treeshap::GPUTreeShap(X, paths.begin(), paths.end(), trees.num_outputs,
184
+ phis.begin(), phis.end());
185
+ // Add the base offset term to bias
186
+ thrust::device_vector<double> base_offset(
187
+ trees.base_offset, trees.base_offset + trees.num_outputs);
188
+ auto counting = thrust::make_counting_iterator(size_t(0));
189
+ auto d_phis = phis.data().get();
190
+ auto d_base_offset = base_offset.data().get();
191
+ size_t num_groups = trees.num_outputs;
192
+ thrust::for_each(counting, counting + X.NumRows() * trees.num_outputs,
193
+ [=] __device__(size_t idx) {
194
+ size_t row_idx = idx / num_groups;
195
+ size_t group = idx % num_groups;
196
+ auto phi_idx = gpu_treeshap::IndexPhi(
197
+ row_idx, num_groups, group, X.NumCols(), X.NumCols());
198
+ d_phis[phi_idx] += d_base_offset[group];
199
+ });
200
+
201
+ // Shap uses a slightly different layout for multiclass
202
+ thrust::device_vector<float> transposed_phis(phis.size());
203
+ auto d_transposed_phis = transposed_phis.data();
204
+ thrust::for_each(
205
+ counting, counting + phis.size(), [=] __device__(size_t idx) {
206
+ size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1)};
207
+ size_t old_idx[array_size(old_shape)];
208
+ gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx);
209
+ // Define new tensor format, switch num_groups axis to end
210
+ size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), num_groups};
211
+ size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[1]};
212
+ size_t transposed_idx =
213
+ gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx);
214
+ d_transposed_phis[transposed_idx] = d_phis[idx];
215
+ });
216
+ thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs);
217
+ }
218
+
219
+ inline void
220
+ dense_tree_independent_gpu(const TreeEnsemble &trees,
221
+ const ExplanationDataset &data, tfloat *out_contribs,
222
+ tfloat transform(const tfloat, const tfloat)) {
223
+ auto paths = ExtractPaths(trees);
224
+ DeviceExplanationDataset device_data(data);
225
+ DeviceExplanationDataset::DenseDatasetWrapper X =
226
+ device_data.GetDeviceAccessor();
227
+ DeviceExplanationDataset background_device_data(data, true);
228
+ DeviceExplanationDataset::DenseDatasetWrapper R =
229
+ background_device_data.GetDeviceAccessor();
230
+
231
+ thrust::device_vector<float> phis((X.NumCols() + 1) * X.NumRows() *
232
+ trees.num_outputs);
233
+ gpu_treeshap::GPUTreeShapInterventional(X, R, paths.begin(), paths.end(),
234
+ trees.num_outputs, phis.begin(),
235
+ phis.end());
236
+ // Add the base offset term to bias
237
+ thrust::device_vector<double> base_offset(
238
+ trees.base_offset, trees.base_offset + trees.num_outputs);
239
+ auto counting = thrust::make_counting_iterator(size_t(0));
240
+ auto d_phis = phis.data().get();
241
+ auto d_base_offset = base_offset.data().get();
242
+ size_t num_groups = trees.num_outputs;
243
+ thrust::for_each(counting, counting + X.NumRows() * trees.num_outputs,
244
+ [=] __device__(size_t idx) {
245
+ size_t row_idx = idx / num_groups;
246
+ size_t group = idx % num_groups;
247
+ auto phi_idx = gpu_treeshap::IndexPhi(
248
+ row_idx, num_groups, group, X.NumCols(), X.NumCols());
249
+ d_phis[phi_idx] += d_base_offset[group];
250
+ });
251
+
252
+ // Shap uses a slightly different layout for multiclass
253
+ thrust::device_vector<float> transposed_phis(phis.size());
254
+ auto d_transposed_phis = transposed_phis.data();
255
+ thrust::for_each(
256
+ counting, counting + phis.size(), [=] __device__(size_t idx) {
257
+ size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1)};
258
+ size_t old_idx[array_size(old_shape)];
259
+ gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx);
260
+ // Define new tensor format, switch num_groups axis to end
261
+ size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), num_groups};
262
+ size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[1]};
263
+ size_t transposed_idx =
264
+ gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx);
265
+ d_transposed_phis[transposed_idx] = d_phis[idx];
266
+ });
267
+ thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs);
268
+ }
269
+
270
+ inline void dense_tree_path_dependent_interactions_gpu(
271
+ const TreeEnsemble &trees, const ExplanationDataset &data,
272
+ tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
273
+ auto paths = ExtractPaths(trees);
274
+ DeviceExplanationDataset device_data(data);
275
+ DeviceExplanationDataset::DenseDatasetWrapper X =
276
+ device_data.GetDeviceAccessor();
277
+
278
+ thrust::device_vector<float> phis((X.NumCols() + 1) * (X.NumCols() + 1) *
279
+ X.NumRows() * trees.num_outputs);
280
+ gpu_treeshap::GPUTreeShapInteractions(X, paths.begin(), paths.end(),
281
+ trees.num_outputs, phis.begin(),
282
+ phis.end());
283
+ // Add the base offset term to bias
284
+ thrust::device_vector<double> base_offset(
285
+ trees.base_offset, trees.base_offset + trees.num_outputs);
286
+ auto counting = thrust::make_counting_iterator(size_t(0));
287
+ auto d_phis = phis.data().get();
288
+ auto d_base_offset = base_offset.data().get();
289
+ size_t num_groups = trees.num_outputs;
290
+ thrust::for_each(counting, counting + X.NumRows() * num_groups,
291
+ [=] __device__(size_t idx) {
292
+ size_t row_idx = idx / num_groups;
293
+ size_t group = idx % num_groups;
294
+ auto phi_idx = gpu_treeshap::IndexPhiInteractions(
295
+ row_idx, num_groups, group, X.NumCols(), X.NumCols(),
296
+ X.NumCols());
297
+ d_phis[phi_idx] += d_base_offset[group];
298
+ });
299
+ // Shap uses a slightly different layout for multiclass
300
+ thrust::device_vector<float> transposed_phis(phis.size());
301
+ auto d_transposed_phis = transposed_phis.data();
302
+ thrust::for_each(
303
+ counting, counting + phis.size(), [=] __device__(size_t idx) {
304
+ size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1),
305
+ (X.NumCols() + 1)};
306
+ size_t old_idx[array_size(old_shape)];
307
+ gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx);
308
+ // Define new tensor format, switch num_groups axis to end
309
+ size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), (X.NumCols() + 1),
310
+ num_groups};
311
+ size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[3], old_idx[1]};
312
+ size_t transposed_idx =
313
+ gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx);
314
+ d_transposed_phis[transposed_idx] = d_phis[idx];
315
+ });
316
+ thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs);
317
+ }
318
+
319
+ void dense_tree_shap_gpu(const TreeEnsemble &trees,
320
+ const ExplanationDataset &data, tfloat *out_contribs,
321
+ const int feature_dependence, unsigned model_transform,
322
+ bool interactions) {
323
+ // see what transform (if any) we have
324
+ transform_f transform = get_transform(model_transform);
325
+
326
+ // dispatch to the correct algorithm handler
327
+ switch (feature_dependence) {
328
+ case FEATURE_DEPENDENCE::independent:
329
+ if (interactions) {
330
+ std::cerr << "FEATURE_DEPENDENCE::independent with interactions not yet "
331
+ "supported\n";
332
+ } else {
333
+ dense_tree_independent_gpu(trees, data, out_contribs, transform);
334
+ }
335
+ return;
336
+
337
+ case FEATURE_DEPENDENCE::tree_path_dependent:
338
+ if (interactions) {
339
+ dense_tree_path_dependent_interactions_gpu(trees, data, out_contribs,
340
+ transform);
341
+ } else {
342
+ dense_tree_path_dependent_gpu(trees, data, out_contribs, transform);
343
+ }
344
+ return;
345
+
346
+ case FEATURE_DEPENDENCE::global_path_dependent:
347
+ std::cerr << "FEATURE_DEPENDENCE::global_path_dependent not supported\n";
348
+ return;
349
+ default:
350
+ std::cerr << "Unknown feature dependence option\n";
351
+ return;
352
+ }
353
+ }
lib/shap/cext/gpu_treeshap.h ADDED
@@ -0,0 +1,1535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2020, NVIDIA CORPORATION.
3
+ *
4
+ * Licensed under the Apache License, Version 2.0 (the "License");
5
+ * you may not use this file except in compliance with the License.
6
+ * You may obtain a copy of the License at
7
+ *
8
+ * http://www.apache.org/licenses/LICENSE-2.0
9
+ *
10
+ * Unless required by applicable law or agreed to in writing, software
11
+ * distributed under the License is distributed on an "AS IS" BASIS,
12
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ * See the License for the specific language governing permissions and
14
+ * limitations under the License.
15
+ */
16
+
17
+ #pragma once
18
+ #include <thrust/device_allocator.h>
19
+ #include <thrust/device_vector.h>
20
+ #include <thrust/iterator/discard_iterator.h>
21
+ #include <thrust/logical.h>
22
+ #include <thrust/reduce.h>
23
+ #include <thrust/host_vector.h>
24
+ #if (CUDART_VERSION >= 11000)
25
+ #include <cub/cub.cuh>
26
+ #else
27
+ // Hack to get cub device reduce on older toolkits
28
+ #include <thrust/system/cuda/detail/cub/device/device_reduce.cuh>
29
+ using namespace thrust::cuda_cub;
30
+ #endif
31
+ #include <algorithm>
32
+ #include <functional>
33
+ #include <set>
34
+ #include <stdexcept>
35
+ #include <utility>
36
+ #include <vector>
37
+
38
+ namespace gpu_treeshap {
39
+
40
+ struct XgboostSplitCondition {
41
+ XgboostSplitCondition() = default;
42
+ XgboostSplitCondition(float feature_lower_bound, float feature_upper_bound,
43
+ bool is_missing_branch)
44
+ : feature_lower_bound(feature_lower_bound),
45
+ feature_upper_bound(feature_upper_bound),
46
+ is_missing_branch(is_missing_branch) {
47
+ assert(feature_lower_bound <= feature_upper_bound);
48
+ }
49
+
50
+ /*! Feature values >= lower and < upper flow down this path. */
51
+ float feature_lower_bound;
52
+ float feature_upper_bound;
53
+ /*! Do missing values flow down this path? */
54
+ bool is_missing_branch;
55
+
56
+ // Does this instance flow down this path?
57
+ __host__ __device__ bool EvaluateSplit(float x) const {
58
+ // is nan
59
+ if (isnan(x)) {
60
+ return is_missing_branch;
61
+ }
62
+ return x >= feature_lower_bound && x < feature_upper_bound;
63
+ }
64
+
65
+ // Combine two split conditions on the same feature
66
+ __host__ __device__ void Merge(
67
+ const XgboostSplitCondition& other) { // Combine duplicate features
68
+ feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
69
+ feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
70
+ is_missing_branch = is_missing_branch && other.is_missing_branch;
71
+ }
72
+ };
73
+
74
+ /*!
75
+ * An element of a unique path through a decision tree. Can implement various
76
+ * types of splits via the templated SplitConditionT. Some decision tree
77
+ * implementations may wish to use double precision or single precision, some
78
+ * may use < or <= as the threshold, missing values can be handled differently,
79
+ * categoricals may be supported.
80
+ *
81
+ * \tparam SplitConditionT A split condition implementing the methods
82
+ * EvaluateSplit and Merge.
83
+ */
84
+ template <typename SplitConditionT>
85
+ struct PathElement {
86
+ using split_type = SplitConditionT;
87
+ __host__ __device__ PathElement(size_t path_idx, int64_t feature_idx,
88
+ int group, SplitConditionT split_condition,
89
+ double zero_fraction, float v)
90
+ : path_idx(path_idx),
91
+ feature_idx(feature_idx),
92
+ group(group),
93
+ split_condition(split_condition),
94
+ zero_fraction(zero_fraction),
95
+ v(v) {}
96
+
97
+ PathElement() = default;
98
+ __host__ __device__ bool IsRoot() const { return feature_idx == -1; }
99
+
100
+ template <typename DatasetT>
101
+ __host__ __device__ bool EvaluateSplit(DatasetT X, size_t row_idx) const {
102
+ if (this->IsRoot()) {
103
+ return 1.0;
104
+ }
105
+ return split_condition.EvaluateSplit(X.GetElement(row_idx, feature_idx));
106
+ }
107
+
108
+ /*! Unique path index. */
109
+ size_t path_idx;
110
+ /*! Feature of this split, -1 indicates bias term. */
111
+ int64_t feature_idx;
112
+ /*! Indicates class for multiclass problems. */
113
+ int group;
114
+ SplitConditionT split_condition;
115
+ /*! Probability of following this path when feature_idx is not in the active
116
+ * set. */
117
+ double zero_fraction;
118
+ float v; // Leaf weight at the end of the path
119
+ };
120
+
121
+ // Helper function that accepts an index into a flat contiguous array and the
122
+ // dimensions of a tensor and returns the indices with respect to the tensor
123
+ template <typename T, size_t N>
124
+ __device__ void FlatIdxToTensorIdx(T flat_idx, const T (&shape)[N],
125
+ T (&out_idx)[N]) {
126
+ T current_size = shape[0];
127
+ for (auto i = 1ull; i < N; i++) {
128
+ current_size *= shape[i];
129
+ }
130
+ for (auto i = 0ull; i < N; i++) {
131
+ current_size /= shape[i];
132
+ out_idx[i] = flat_idx / current_size;
133
+ flat_idx -= current_size * out_idx[i];
134
+ }
135
+ }
136
+
137
+ // Given a shape and coordinates into a tensor, return the index into the
138
+ // backing storage one-dimensional array
139
+ template <typename T, size_t N>
140
+ __device__ T TensorIdxToFlatIdx(const T (&shape)[N], const T (&tensor_idx)[N]) {
141
+ T current_size = shape[0];
142
+ for (auto i = 1ull; i < N; i++) {
143
+ current_size *= shape[i];
144
+ }
145
+ T idx = 0;
146
+ for (auto i = 0ull; i < N; i++) {
147
+ current_size /= shape[i];
148
+ idx += tensor_idx[i] * current_size;
149
+ }
150
+ return idx;
151
+ }
152
+
153
+ // Maps values to the phi array according to row, group and column
154
+ __host__ __device__ inline size_t IndexPhi(size_t row_idx, size_t num_groups,
155
+ size_t group, size_t num_columns,
156
+ size_t column_idx) {
157
+ return (row_idx * num_groups + group) * (num_columns + 1) + column_idx;
158
+ }
159
+
160
+ __host__ __device__ inline size_t IndexPhiInteractions(size_t row_idx,
161
+ size_t num_groups,
162
+ size_t group,
163
+ size_t num_columns,
164
+ size_t i, size_t j) {
165
+ size_t matrix_size = (num_columns + 1) * (num_columns + 1);
166
+ size_t matrix_offset = (row_idx * num_groups + group) * matrix_size;
167
+ return matrix_offset + i * (num_columns + 1) + j;
168
+ }
169
+
170
+ namespace detail {
171
+
172
+ // Shorthand for creating a device vector with an appropriate allocator type
173
+ template <class T, class DeviceAllocatorT>
174
+ using RebindVector =
175
+ thrust::device_vector<T,
176
+ typename DeviceAllocatorT::template rebind<T>::other>;
177
+
178
+ #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
179
+ __device__ __forceinline__ double atomicAddDouble(double* address, double val) {
180
+ return atomicAdd(address, val);
181
+ }
182
+ #else // In device code and CUDA < 600
183
+ __device__ __forceinline__ double atomicAddDouble(double* address,
184
+ double val) { // NOLINT
185
+ unsigned long long int* address_as_ull = // NOLINT
186
+ (unsigned long long int*)address; // NOLINT
187
+ unsigned long long int old = *address_as_ull, assumed; // NOLINT
188
+
189
+ do {
190
+ assumed = old;
191
+ old = atomicCAS(address_as_ull, assumed,
192
+ __double_as_longlong(val + __longlong_as_double(assumed)));
193
+
194
+ // Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
195
+ // NaN)
196
+ } while (assumed != old);
197
+
198
+ return __longlong_as_double(old);
199
+ }
200
+ #endif
201
+
202
+ __forceinline__ __device__ unsigned int lanemask32_lt() {
203
+ unsigned int lanemask32_lt;
204
+ asm volatile("mov.u32 %0, %%lanemask_lt;" : "=r"(lanemask32_lt));
205
+ return (lanemask32_lt);
206
+ }
207
+
208
+ // Like a coalesced group, except we can make the assumption that all threads in
209
+ // a group are next to each other. This makes shuffle operations much cheaper.
210
+ class ContiguousGroup {
211
+ public:
212
+ __device__ ContiguousGroup(uint32_t mask) : mask_(mask) {}
213
+
214
+ __device__ uint32_t size() const { return __popc(mask_); }
215
+ __device__ uint32_t thread_rank() const {
216
+ return __popc(mask_ & lanemask32_lt());
217
+ }
218
+ template <typename T>
219
+ __device__ T shfl(T val, uint32_t src) const {
220
+ return __shfl_sync(mask_, val, src + __ffs(mask_) - 1);
221
+ }
222
+ template <typename T>
223
+ __device__ T shfl_up(T val, uint32_t delta) const {
224
+ return __shfl_up_sync(mask_, val, delta);
225
+ }
226
+ __device__ uint32_t ballot(int predicate) const {
227
+ return __ballot_sync(mask_, predicate) >> (__ffs(mask_) - 1);
228
+ }
229
+
230
+ template <typename T, typename OpT>
231
+ __device__ T reduce(T val, OpT op) {
232
+ for (int i = 1; i < this->size(); i *= 2) {
233
+ T shfl = shfl_up(val, i);
234
+ if (static_cast<int>(thread_rank()) - i >= 0) {
235
+ val = op(val, shfl);
236
+ }
237
+ }
238
+ return shfl(val, size() - 1);
239
+ }
240
+ uint32_t mask_;
241
+ };
242
+
243
+ // Separate the active threads by labels
244
+ // This functionality is available in cuda 11.0 on cc >=7.0
245
+ // We reimplement for backwards compatibility
246
+ // Assumes partitions are contiguous
247
+ inline __device__ ContiguousGroup active_labeled_partition(uint32_t mask,
248
+ int label) {
249
+ #if __CUDA_ARCH__ >= 700
250
+ uint32_t subgroup_mask = __match_any_sync(mask, label);
251
+ #else
252
+ uint32_t subgroup_mask = 0;
253
+ for (int i = 0; i < 32;) {
254
+ int current_label = __shfl_sync(mask, label, i);
255
+ uint32_t ballot = __ballot_sync(mask, label == current_label);
256
+ if (label == current_label) {
257
+ subgroup_mask = ballot;
258
+ }
259
+ uint32_t completed_mask =
260
+ (1 << (32 - __clz(ballot))) - 1; // Threads that have finished
261
+ // Find the start of the next group, mask off completed threads from active
262
+ // threads Then use ffs - 1 to find the position of the next group
263
+ int next_i = __ffs(mask & ~completed_mask) - 1;
264
+ if (next_i == -1) break; // -1 indicates all finished
265
+ assert(next_i > i); // Prevent infinite loops when the constraints not met
266
+ i = next_i;
267
+ }
268
+ #endif
269
+ return ContiguousGroup(subgroup_mask);
270
+ }
271
+
272
+ // Group of threads where each thread holds a path element
273
+ class GroupPath {
274
+ protected:
275
+ const ContiguousGroup& g_;
276
+ // These are combined so we can communicate them in a single 64 bit shuffle
277
+ // instruction
278
+ float zero_one_fraction_[2];
279
+ float pweight_;
280
+ int unique_depth_;
281
+
282
+ public:
283
+ __device__ GroupPath(const ContiguousGroup& g, float zero_fraction,
284
+ float one_fraction)
285
+ : g_(g),
286
+ zero_one_fraction_{zero_fraction, one_fraction},
287
+ pweight_(g.thread_rank() == 0 ? 1.0f : 0.0f),
288
+ unique_depth_(0) {}
289
+
290
+ // Cooperatively extend the path with a group of threads
291
+ // Each thread maintains pweight for its path element in register
292
+ __device__ void Extend() {
293
+ unique_depth_++;
294
+
295
+ // Broadcast the zero and one fraction from the newly added path element
296
+ // Combine 2 shuffle operations into 64 bit word
297
+ const size_t rank = g_.thread_rank();
298
+ const float inv_unique_depth =
299
+ __fdividef(1.0f, static_cast<float>(unique_depth_ + 1));
300
+ uint64_t res = g_.shfl(*reinterpret_cast<uint64_t*>(&zero_one_fraction_),
301
+ unique_depth_);
302
+ const float new_zero_fraction = reinterpret_cast<float*>(&res)[0];
303
+ const float new_one_fraction = reinterpret_cast<float*>(&res)[1];
304
+ float left_pweight = g_.shfl_up(pweight_, 1);
305
+
306
+ // pweight of threads with rank < unique_depth_ is 0
307
+ // We use max(x,0) to avoid using a branch
308
+ // pweight_ *=
309
+ // new_zero_fraction * max(unique_depth_ - rank, 0llu) * inv_unique_depth;
310
+ pweight_ = __fmul_rn(
311
+ __fmul_rn(pweight_, new_zero_fraction),
312
+ __fmul_rn(max(unique_depth_ - rank, size_t(0)), inv_unique_depth));
313
+
314
+ // pweight_ += new_one_fraction * left_pweight * rank * inv_unique_depth;
315
+ pweight_ = __fmaf_rn(__fmul_rn(new_one_fraction, left_pweight),
316
+ __fmul_rn(rank, inv_unique_depth), pweight_);
317
+ }
318
+
319
+ // Each thread unwinds the path for its feature and returns the sum
320
+ __device__ float UnwoundPathSum() {
321
+ float next_one_portion = g_.shfl(pweight_, unique_depth_);
322
+ float total = 0.0f;
323
+ const float zero_frac_div_unique_depth = __fdividef(
324
+ zero_one_fraction_[0], static_cast<float>(unique_depth_ + 1));
325
+ for (int i = unique_depth_ - 1; i >= 0; i--) {
326
+ float ith_pweight = g_.shfl(pweight_, i);
327
+ float precomputed =
328
+ __fmul_rn((unique_depth_ - i), zero_frac_div_unique_depth);
329
+ const float tmp =
330
+ __fdividef(__fmul_rn(next_one_portion, unique_depth_ + 1), i + 1);
331
+ total = __fmaf_rn(tmp, zero_one_fraction_[1], total);
332
+ next_one_portion = __fmaf_rn(-tmp, precomputed, ith_pweight);
333
+ float numerator =
334
+ __fmul_rn(__fsub_rn(1.0f, zero_one_fraction_[1]), ith_pweight);
335
+ if (precomputed > 0.0f) {
336
+ total += __fdividef(numerator, precomputed);
337
+ }
338
+ }
339
+
340
+ return total;
341
+ }
342
+ };
343
+
344
+ // Has different permutation weightings to the above
345
+ // Used in Taylor Shapley interaction index
346
+ class TaylorGroupPath : GroupPath {
347
+ public:
348
+ __device__ TaylorGroupPath(const ContiguousGroup& g, float zero_fraction,
349
+ float one_fraction)
350
+ : GroupPath(g, zero_fraction, one_fraction) {}
351
+
352
+ // Extend the path is normal, all reweighting can happen in UnwoundPathSum
353
+ __device__ void Extend() { GroupPath::Extend(); }
354
+
355
+ // Each thread unwinds the path for its feature and returns the sum
356
+ // We use a different permutation weighting for Taylor interactions
357
+ // As if the total number of features was one larger
358
+ __device__ float UnwoundPathSum() {
359
+ float one_fraction = zero_one_fraction_[1];
360
+ float zero_fraction = zero_one_fraction_[0];
361
+ float next_one_portion = g_.shfl(pweight_, unique_depth_) /
362
+ static_cast<float>(unique_depth_ + 2);
363
+
364
+ float total = 0.0f;
365
+ for (int i = unique_depth_ - 1; i >= 0; i--) {
366
+ float ith_pweight =
367
+ g_.shfl(pweight_, i) * (static_cast<float>(unique_depth_ - i + 1) /
368
+ static_cast<float>(unique_depth_ + 2));
369
+ if (one_fraction > 0.0f) {
370
+ const float tmp =
371
+ next_one_portion * (unique_depth_ + 2) / ((i + 1) * one_fraction);
372
+
373
+ total += tmp;
374
+ next_one_portion =
375
+ ith_pweight - tmp * zero_fraction *
376
+ ((unique_depth_ - i + 1) /
377
+ static_cast<float>(unique_depth_ + 2));
378
+ } else if (zero_fraction > 0.0f) {
379
+ total +=
380
+ (ith_pweight / zero_fraction) /
381
+ ((unique_depth_ - i + 1) / static_cast<float>(unique_depth_ + 2));
382
+ }
383
+ }
384
+
385
+ return 2 * total;
386
+ }
387
+ };
388
+
389
+ template <typename DatasetT, typename SplitConditionT>
390
+ __device__ float ComputePhi(const PathElement<SplitConditionT>& e,
391
+ size_t row_idx, const DatasetT& X,
392
+ const ContiguousGroup& group, float zero_fraction) {
393
+ float one_fraction =
394
+ e.EvaluateSplit(X, row_idx);
395
+ GroupPath path(group, zero_fraction, one_fraction);
396
+ size_t unique_path_length = group.size();
397
+
398
+ // Extend the path
399
+ for (auto unique_depth = 1ull; unique_depth < unique_path_length;
400
+ unique_depth++) {
401
+ path.Extend();
402
+ }
403
+
404
+ float sum = path.UnwoundPathSum();
405
+ return sum * (one_fraction - zero_fraction) * e.v;
406
+ }
407
+
408
+ inline __host__ __device__ size_t DivRoundUp(size_t a, size_t b) {
409
+ return (a + b - 1) / b;
410
+ }
411
+
412
+ template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
413
+ typename SplitConditionT>
414
+ void __device__
415
+ ConfigureThread(const DatasetT& X, const size_t bins_per_row,
416
+ const PathElement<SplitConditionT>* path_elements,
417
+ const size_t* bin_segments, size_t* start_row, size_t* end_row,
418
+ PathElement<SplitConditionT>* e, bool* thread_active) {
419
+ // Partition work
420
+ // Each warp processes a set of training instances applied to a path
421
+ size_t tid = kBlockSize * blockIdx.x + threadIdx.x;
422
+ const size_t warp_size = 32;
423
+ size_t warp_rank = tid / warp_size;
424
+ if (warp_rank >= bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp)) {
425
+ *thread_active = false;
426
+ return;
427
+ }
428
+ size_t bin_idx = warp_rank % bins_per_row;
429
+ size_t bank = warp_rank / bins_per_row;
430
+ size_t path_start = bin_segments[bin_idx];
431
+ size_t path_end = bin_segments[bin_idx + 1];
432
+ uint32_t thread_rank = threadIdx.x % warp_size;
433
+ if (thread_rank >= path_end - path_start) {
434
+ *thread_active = false;
435
+ } else {
436
+ *e = path_elements[path_start + thread_rank];
437
+ *start_row = bank * kRowsPerWarp;
438
+ *end_row = min((bank + 1) * kRowsPerWarp, X.NumRows());
439
+ *thread_active = true;
440
+ }
441
+ }
442
+
443
+ #define GPUTREESHAP_MAX_THREADS_PER_BLOCK 256
444
+ #define FULL_MASK 0xffffffff
445
+
446
+ template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
447
+ typename SplitConditionT>
448
+ __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
449
+ ShapKernel(DatasetT X, size_t bins_per_row,
450
+ const PathElement<SplitConditionT>* path_elements,
451
+ const size_t* bin_segments, size_t num_groups, double* phis) {
452
+ // Use shared memory for structs, otherwise nvcc puts in local memory
453
+ __shared__ DatasetT s_X;
454
+ s_X = X;
455
+ __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
456
+ PathElement<SplitConditionT>& e = s_elements[threadIdx.x];
457
+
458
+ size_t start_row, end_row;
459
+ bool thread_active;
460
+ ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
461
+ s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e,
462
+ &thread_active);
463
+ uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
464
+ if (!thread_active) return;
465
+
466
+ float zero_fraction = e.zero_fraction;
467
+ auto labelled_group = active_labeled_partition(mask, e.path_idx);
468
+
469
+ for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
470
+ float phi = ComputePhi(e, row_idx, X, labelled_group, zero_fraction);
471
+
472
+ if (!e.IsRoot()) {
473
+ atomicAddDouble(&phis[IndexPhi(row_idx, num_groups, e.group, X.NumCols(),
474
+ e.feature_idx)],
475
+ phi);
476
+ }
477
+ }
478
+ }
479
+
480
+ template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
481
+ typename SplitConditionT>
482
+ void ComputeShap(
483
+ DatasetT X,
484
+ const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
485
+ const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
486
+ path_elements,
487
+ size_t num_groups, double* phis) {
488
+ size_t bins_per_row = bin_segments.size() - 1;
489
+ const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
490
+ const int warps_per_block = kBlockThreads / 32;
491
+ const int kRowsPerWarp = 1024;
492
+ size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
493
+
494
+ const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
495
+
496
+ ShapKernel<DatasetT, kBlockThreads, kRowsPerWarp>
497
+ <<<grid_size, kBlockThreads>>>(
498
+ X, bins_per_row, path_elements.data().get(),
499
+ bin_segments.data().get(), num_groups, phis);
500
+ }
501
+
502
+ template <typename PathT, typename DatasetT, typename SplitConditionT>
503
+ __device__ float ComputePhiCondition(const PathElement<SplitConditionT>& e,
504
+ size_t row_idx, const DatasetT& X,
505
+ const ContiguousGroup& group,
506
+ int64_t condition_feature) {
507
+ float one_fraction = e.EvaluateSplit(X, row_idx);
508
+ PathT path(group, e.zero_fraction, one_fraction);
509
+ size_t unique_path_length = group.size();
510
+ float condition_on_fraction = 1.0f;
511
+ float condition_off_fraction = 1.0f;
512
+
513
+ // Extend the path
514
+ for (auto i = 1ull; i < unique_path_length; i++) {
515
+ bool is_condition_feature =
516
+ group.shfl(e.feature_idx, i) == condition_feature;
517
+ float o_i = group.shfl(one_fraction, i);
518
+ float z_i = group.shfl(e.zero_fraction, i);
519
+
520
+ if (is_condition_feature) {
521
+ condition_on_fraction = o_i;
522
+ condition_off_fraction = z_i;
523
+ } else {
524
+ path.Extend();
525
+ }
526
+ }
527
+ float sum = path.UnwoundPathSum();
528
+ if (e.feature_idx == condition_feature) {
529
+ return 0.0f;
530
+ }
531
+ float phi = sum * (one_fraction - e.zero_fraction) * e.v;
532
+ return phi * (condition_on_fraction - condition_off_fraction) * 0.5f;
533
+ }
534
+
535
+ // If there is a feature in the path we are conditioning on, swap it to the end
536
+ // of the path
537
+ template <typename SplitConditionT>
538
+ inline __device__ void SwapConditionedElement(
539
+ PathElement<SplitConditionT>** e, PathElement<SplitConditionT>* s_elements,
540
+ uint32_t condition_rank, const ContiguousGroup& group) {
541
+ auto last_rank = group.size() - 1;
542
+ auto this_rank = group.thread_rank();
543
+ if (this_rank == last_rank) {
544
+ *e = &s_elements[(threadIdx.x - this_rank) + condition_rank];
545
+ } else if (this_rank == condition_rank) {
546
+ *e = &s_elements[(threadIdx.x - this_rank) + last_rank];
547
+ }
548
+ }
549
+
550
+ template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
551
+ typename SplitConditionT>
552
+ __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
553
+ ShapInteractionsKernel(DatasetT X, size_t bins_per_row,
554
+ const PathElement<SplitConditionT>* path_elements,
555
+ const size_t* bin_segments, size_t num_groups,
556
+ double* phis_interactions) {
557
+ // Use shared memory for structs, otherwise nvcc puts in local memory
558
+ __shared__ DatasetT s_X;
559
+ s_X = X;
560
+ __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
561
+ PathElement<SplitConditionT>* e = &s_elements[threadIdx.x];
562
+
563
+ size_t start_row, end_row;
564
+ bool thread_active;
565
+ ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
566
+ s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e,
567
+ &thread_active);
568
+ uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
569
+ if (!thread_active) return;
570
+
571
+ auto labelled_group = active_labeled_partition(mask, e->path_idx);
572
+
573
+ for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
574
+ float phi = ComputePhi(*e, row_idx, X, labelled_group, e->zero_fraction);
575
+ if (!e->IsRoot()) {
576
+ auto phi_offset =
577
+ IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
578
+ e->feature_idx, e->feature_idx);
579
+ atomicAddDouble(phis_interactions + phi_offset, phi);
580
+ }
581
+
582
+ for (auto condition_rank = 1ull; condition_rank < labelled_group.size();
583
+ condition_rank++) {
584
+ e = &s_elements[threadIdx.x];
585
+ int64_t condition_feature =
586
+ labelled_group.shfl(e->feature_idx, condition_rank);
587
+ SwapConditionedElement(&e, s_elements, condition_rank, labelled_group);
588
+ float x = ComputePhiCondition<GroupPath>(*e, row_idx, X, labelled_group,
589
+ condition_feature);
590
+ if (!e->IsRoot()) {
591
+ auto phi_offset =
592
+ IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
593
+ e->feature_idx, condition_feature);
594
+ atomicAddDouble(phis_interactions + phi_offset, x);
595
+ // Subtract effect from diagonal
596
+ auto phi_diag =
597
+ IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
598
+ e->feature_idx, e->feature_idx);
599
+ atomicAddDouble(phis_interactions + phi_diag, -x);
600
+ }
601
+ }
602
+ }
603
+ }
604
+
605
+ template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
606
+ typename SplitConditionT>
607
+ void ComputeShapInteractions(
608
+ DatasetT X,
609
+ const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
610
+ const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
611
+ path_elements,
612
+ size_t num_groups, double* phis) {
613
+ size_t bins_per_row = bin_segments.size() - 1;
614
+ const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
615
+ const int warps_per_block = kBlockThreads / 32;
616
+ const int kRowsPerWarp = 100;
617
+ size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
618
+
619
+ const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
620
+
621
+ ShapInteractionsKernel<DatasetT, kBlockThreads, kRowsPerWarp>
622
+ <<<grid_size, kBlockThreads>>>(
623
+ X, bins_per_row, path_elements.data().get(),
624
+ bin_segments.data().get(), num_groups, phis);
625
+ }
626
+
627
+ template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
628
+ typename SplitConditionT>
629
+ __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
630
+ ShapTaylorInteractionsKernel(
631
+ DatasetT X, size_t bins_per_row,
632
+ const PathElement<SplitConditionT>* path_elements,
633
+ const size_t* bin_segments, size_t num_groups,
634
+ double* phis_interactions) {
635
+ // Use shared memory for structs, otherwise nvcc puts in local memory
636
+ __shared__ DatasetT s_X;
637
+ if (threadIdx.x == 0) {
638
+ s_X = X;
639
+ }
640
+ __syncthreads();
641
+ __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
642
+ PathElement<SplitConditionT>* e = &s_elements[threadIdx.x];
643
+
644
+ size_t start_row, end_row;
645
+ bool thread_active;
646
+ ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
647
+ s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e,
648
+ &thread_active);
649
+ uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
650
+ if (!thread_active) return;
651
+
652
+ auto labelled_group = active_labeled_partition(mask, e->path_idx);
653
+
654
+ for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
655
+ for (auto condition_rank = 1ull; condition_rank < labelled_group.size();
656
+ condition_rank++) {
657
+ e = &s_elements[threadIdx.x];
658
+ // Compute the diagonal terms
659
+ // TODO(Rory): this can be more efficient
660
+ float reduce_input =
661
+ e->IsRoot() || labelled_group.thread_rank() == condition_rank
662
+ ? 1.0f
663
+ : e->zero_fraction;
664
+ float reduce =
665
+ labelled_group.reduce(reduce_input, thrust::multiplies<float>());
666
+ if (labelled_group.thread_rank() == condition_rank) {
667
+ float one_fraction = e->split_condition.EvaluateSplit(
668
+ X.GetElement(row_idx, e->feature_idx));
669
+ auto phi_offset =
670
+ IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
671
+ e->feature_idx, e->feature_idx);
672
+ atomicAddDouble(phis_interactions + phi_offset,
673
+ reduce * (one_fraction - e->zero_fraction) * e->v);
674
+ }
675
+
676
+ int64_t condition_feature =
677
+ labelled_group.shfl(e->feature_idx, condition_rank);
678
+
679
+ SwapConditionedElement(&e, s_elements, condition_rank, labelled_group);
680
+
681
+ float x = ComputePhiCondition<TaylorGroupPath>(
682
+ *e, row_idx, X, labelled_group, condition_feature);
683
+ if (!e->IsRoot()) {
684
+ auto phi_offset =
685
+ IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
686
+ e->feature_idx, condition_feature);
687
+ atomicAddDouble(phis_interactions + phi_offset, x);
688
+ }
689
+ }
690
+ }
691
+ }
692
+
693
+ template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
694
+ typename SplitConditionT>
695
+ void ComputeShapTaylorInteractions(
696
+ DatasetT X,
697
+ const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
698
+ const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
699
+ path_elements,
700
+ size_t num_groups, double* phis) {
701
+ size_t bins_per_row = bin_segments.size() - 1;
702
+ const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
703
+ const int warps_per_block = kBlockThreads / 32;
704
+ const int kRowsPerWarp = 100;
705
+ size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
706
+
707
+ const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
708
+
709
+ ShapTaylorInteractionsKernel<DatasetT, kBlockThreads, kRowsPerWarp>
710
+ <<<grid_size, kBlockThreads>>>(
711
+ X, bins_per_row, path_elements.data().get(),
712
+ bin_segments.data().get(), num_groups, phis);
713
+ }
714
+
715
+
716
+ inline __host__ __device__ int64_t Factorial(int64_t x) {
717
+ int64_t y = 1;
718
+ for (auto i = 2; i <= x; i++) {
719
+ y *= i;
720
+ }
721
+ return y;
722
+ }
723
+
724
+ // Compute factorials in log space using lgamma to avoid overflow
725
+ inline __host__ __device__ double W(double s, double n) {
726
+ assert(n - s - 1 >= 0);
727
+ return exp(lgamma(s + 1) - lgamma(n + 1) + lgamma(n - s));
728
+ }
729
+
730
+ template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
731
+ typename SplitConditionT>
732
+ __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
733
+ ShapInterventionalKernel(DatasetT X, DatasetT R, size_t bins_per_row,
734
+ const PathElement<SplitConditionT>* path_elements,
735
+ const size_t* bin_segments, size_t num_groups,
736
+ double* phis) {
737
+ // Cache W coefficients
738
+ __shared__ float s_W[33][33];
739
+ for (int i = threadIdx.x; i < 33 * 33; i += kBlockSize) {
740
+ auto s = i % 33;
741
+ auto n = i / 33;
742
+ if (n - s - 1 >= 0) {
743
+ s_W[s][n] = W(s, n);
744
+ } else {
745
+ s_W[s][n] = 0.0;
746
+ }
747
+ }
748
+
749
+ __syncthreads();
750
+
751
+ __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
752
+ PathElement<SplitConditionT>& e = s_elements[threadIdx.x];
753
+
754
+ size_t start_row, end_row;
755
+ bool thread_active;
756
+ ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
757
+ X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e,
758
+ &thread_active);
759
+
760
+ uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
761
+ if (!thread_active) return;
762
+
763
+ auto labelled_group = active_labeled_partition(mask, e.path_idx);
764
+
765
+ for (int64_t x_idx = start_row; x_idx < end_row; x_idx++) {
766
+ float result = 0.0f;
767
+ bool x_cond = e.EvaluateSplit(X, x_idx);
768
+ uint32_t x_ballot = labelled_group.ballot(x_cond);
769
+ for (int64_t r_idx = 0; r_idx < R.NumRows(); r_idx++) {
770
+ bool r_cond = e.EvaluateSplit(R, r_idx);
771
+ uint32_t r_ballot = labelled_group.ballot(r_cond);
772
+ assert(!e.IsRoot() ||
773
+ (x_cond == r_cond)); // These should be the same for the root
774
+ uint32_t s = __popc(x_ballot & ~r_ballot);
775
+ uint32_t n = __popc(x_ballot ^ r_ballot);
776
+ float tmp = 0.0f;
777
+ // Theorem 1
778
+ if (x_cond && !r_cond) {
779
+ tmp += s_W[s - 1][n];
780
+ }
781
+ tmp -= s_W[s][n] * (r_cond && !x_cond);
782
+
783
+ // No foreground samples make it to this leaf, increment bias
784
+ if (e.IsRoot() && s == 0) {
785
+ tmp += 1.0f;
786
+ }
787
+ // If neither foreground or background go down this path, ignore this path
788
+ bool reached_leaf = !labelled_group.ballot(!x_cond && !r_cond);
789
+ tmp *= reached_leaf;
790
+ result += tmp;
791
+ }
792
+
793
+ if (result != 0.0) {
794
+ result /= R.NumRows();
795
+
796
+ // Root writes bias
797
+ auto feature = e.IsRoot() ? X.NumCols() : e.feature_idx;
798
+ atomicAddDouble(
799
+ &phis[IndexPhi(x_idx, num_groups, e.group, X.NumCols(), feature)],
800
+ result * e.v);
801
+ }
802
+ }
803
+ }
804
+
805
+ template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
806
+ typename SplitConditionT>
807
+ void ComputeShapInterventional(
808
+ DatasetT X, DatasetT R,
809
+ const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
810
+ const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
811
+ path_elements,
812
+ size_t num_groups, double* phis) {
813
+ size_t bins_per_row = bin_segments.size() - 1;
814
+ const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
815
+ const int warps_per_block = kBlockThreads / 32;
816
+ const int kRowsPerWarp = 100;
817
+ size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
818
+
819
+ const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
820
+
821
+ ShapInterventionalKernel<DatasetT, kBlockThreads, kRowsPerWarp>
822
+ <<<grid_size, kBlockThreads>>>(
823
+ X, R, bins_per_row, path_elements.data().get(),
824
+ bin_segments.data().get(), num_groups, phis);
825
+ }
826
+
827
+ template <typename PathVectorT, typename SizeVectorT, typename DeviceAllocatorT>
828
+ void GetBinSegments(const PathVectorT& paths, const SizeVectorT& bin_map,
829
+ SizeVectorT* bin_segments) {
830
+ DeviceAllocatorT alloc;
831
+ size_t num_bins =
832
+ thrust::reduce(thrust::cuda::par(alloc), bin_map.begin(), bin_map.end(),
833
+ size_t(0), thrust::maximum<size_t>()) +
834
+ 1;
835
+ bin_segments->resize(num_bins + 1, 0);
836
+ auto counting = thrust::make_counting_iterator(0llu);
837
+ auto d_paths = paths.data().get();
838
+ auto d_bin_segments = bin_segments->data().get();
839
+ auto d_bin_map = bin_map.data();
840
+ thrust::for_each_n(counting, paths.size(), [=] __device__(size_t idx) {
841
+ auto path_idx = d_paths[idx].path_idx;
842
+ atomicAdd(reinterpret_cast<unsigned long long*>(d_bin_segments) + // NOLINT
843
+ d_bin_map[path_idx],
844
+ 1);
845
+ });
846
+ thrust::exclusive_scan(thrust::cuda::par(alloc), bin_segments->begin(),
847
+ bin_segments->end(), bin_segments->begin());
848
+ }
849
+
850
+ struct DeduplicateKeyTransformOp {
851
+ template <typename SplitConditionT>
852
+ __device__ thrust::pair<size_t, int64_t> operator()(
853
+ const PathElement<SplitConditionT>& e) {
854
+ return {e.path_idx, e.feature_idx};
855
+ }
856
+ };
857
+
858
+ inline void CheckCuda(cudaError_t err) {
859
+ if (err != cudaSuccess) {
860
+ throw thrust::system_error(err, thrust::cuda_category());
861
+ }
862
+ }
863
+
864
+ template <typename Return>
865
+ class DiscardOverload : public thrust::discard_iterator<Return> {
866
+ public:
867
+ using value_type = Return; // NOLINT
868
+ };
869
+
870
+ template <typename PathVectorT, typename DeviceAllocatorT,
871
+ typename SplitConditionT>
872
+ void DeduplicatePaths(PathVectorT* device_paths,
873
+ PathVectorT* deduplicated_paths) {
874
+ DeviceAllocatorT alloc;
875
+ // Sort by feature
876
+ thrust::sort(thrust::cuda::par(alloc), device_paths->begin(),
877
+ device_paths->end(),
878
+ [=] __device__(const PathElement<SplitConditionT>& a,
879
+ const PathElement<SplitConditionT>& b) {
880
+ if (a.path_idx < b.path_idx) return true;
881
+ if (b.path_idx < a.path_idx) return false;
882
+
883
+ if (a.feature_idx < b.feature_idx) return true;
884
+ if (b.feature_idx < a.feature_idx) return false;
885
+ return false;
886
+ });
887
+
888
+ deduplicated_paths->resize(device_paths->size());
889
+
890
+ using Pair = thrust::pair<size_t, int64_t>;
891
+ auto key_transform = thrust::make_transform_iterator(
892
+ device_paths->begin(), DeduplicateKeyTransformOp());
893
+
894
+ thrust::device_vector<size_t> d_num_runs_out(1);
895
+ size_t* h_num_runs_out;
896
+ CheckCuda(cudaMallocHost(&h_num_runs_out, sizeof(size_t)));
897
+
898
+ auto combine = [] __device__(PathElement<SplitConditionT> a,
899
+ PathElement<SplitConditionT> b) {
900
+ // Combine duplicate features
901
+ a.split_condition.Merge(b.split_condition);
902
+ a.zero_fraction *= b.zero_fraction;
903
+ return a;
904
+ }; // NOLINT
905
+ size_t temp_size = 0;
906
+ CheckCuda(cub::DeviceReduce::ReduceByKey(
907
+ nullptr, temp_size, key_transform, DiscardOverload<Pair>(),
908
+ device_paths->begin(), deduplicated_paths->begin(),
909
+ d_num_runs_out.begin(), combine, device_paths->size()));
910
+ using TempAlloc = RebindVector<char, DeviceAllocatorT>;
911
+ TempAlloc tmp(temp_size);
912
+ CheckCuda(cub::DeviceReduce::ReduceByKey(
913
+ tmp.data().get(), temp_size, key_transform, DiscardOverload<Pair>(),
914
+ device_paths->begin(), deduplicated_paths->begin(),
915
+ d_num_runs_out.begin(), combine, device_paths->size()));
916
+
917
+ CheckCuda(cudaMemcpy(h_num_runs_out, d_num_runs_out.data().get(),
918
+ sizeof(size_t), cudaMemcpyDeviceToHost));
919
+ deduplicated_paths->resize(*h_num_runs_out);
920
+ CheckCuda(cudaFreeHost(h_num_runs_out));
921
+ }
922
+
923
+ template <typename PathVectorT, typename SplitConditionT, typename SizeVectorT,
924
+ typename DeviceAllocatorT>
925
+ void SortPaths(PathVectorT* paths, const SizeVectorT& bin_map) {
926
+ auto d_bin_map = bin_map.data();
927
+ DeviceAllocatorT alloc;
928
+ thrust::sort(thrust::cuda::par(alloc), paths->begin(), paths->end(),
929
+ [=] __device__(const PathElement<SplitConditionT>& a,
930
+ const PathElement<SplitConditionT>& b) {
931
+ size_t a_bin = d_bin_map[a.path_idx];
932
+ size_t b_bin = d_bin_map[b.path_idx];
933
+ if (a_bin < b_bin) return true;
934
+ if (b_bin < a_bin) return false;
935
+
936
+ if (a.path_idx < b.path_idx) return true;
937
+ if (b.path_idx < a.path_idx) return false;
938
+
939
+ if (a.feature_idx < b.feature_idx) return true;
940
+ if (b.feature_idx < a.feature_idx) return false;
941
+ return false;
942
+ });
943
+ }
944
+
945
+ using kv = std::pair<size_t, int>;
946
+
947
+ struct BFDCompare {
948
+ bool operator()(const kv& lhs, const kv& rhs) const {
949
+ if (lhs.second == rhs.second) {
950
+ return lhs.first < rhs.first;
951
+ }
952
+ return lhs.second < rhs.second;
953
+ }
954
+ };
955
+
956
+ // Best Fit Decreasing bin packing
957
+ // Efficient O(nlogn) implementation with balanced tree using std::set
958
+ template <typename IntVectorT>
959
+ std::vector<size_t> BFDBinPacking(const IntVectorT& counts,
960
+ int bin_limit = 32) {
961
+ thrust::host_vector<int> counts_host(counts);
962
+ std::vector<kv> path_lengths(counts_host.size());
963
+ for (auto i = 0ull; i < counts_host.size(); i++) {
964
+ path_lengths[i] = {i, counts_host[i]};
965
+ }
966
+
967
+ std::sort(path_lengths.begin(), path_lengths.end(),
968
+ [&](const kv& a, const kv& b) {
969
+ std::greater<> op;
970
+ return op(a.second, b.second);
971
+ });
972
+
973
+ // map unique_id -> bin
974
+ std::vector<size_t> bin_map(counts_host.size());
975
+ std::set<kv, BFDCompare> bin_capacities;
976
+ bin_capacities.insert({bin_capacities.size(), bin_limit});
977
+ for (auto pair : path_lengths) {
978
+ int new_size = pair.second;
979
+ auto itr = bin_capacities.lower_bound({0, new_size});
980
+ // Does not fit in any bin
981
+ if (itr == bin_capacities.end()) {
982
+ size_t new_bin_idx = bin_capacities.size();
983
+ bin_capacities.insert({new_bin_idx, bin_limit - new_size});
984
+ bin_map[pair.first] = new_bin_idx;
985
+ } else {
986
+ kv entry = *itr;
987
+ entry.second -= new_size;
988
+ bin_map[pair.first] = entry.first;
989
+ bin_capacities.erase(itr);
990
+ bin_capacities.insert(entry);
991
+ }
992
+ }
993
+
994
+ return bin_map;
995
+ }
996
+
997
+ // First Fit Decreasing bin packing
998
+ // Inefficient O(n^2) implementation
999
+ template <typename IntVectorT>
1000
+ std::vector<size_t> FFDBinPacking(const IntVectorT& counts,
1001
+ int bin_limit = 32) {
1002
+ thrust::host_vector<int> counts_host(counts);
1003
+ std::vector<kv> path_lengths(counts_host.size());
1004
+ for (auto i = 0ull; i < counts_host.size(); i++) {
1005
+ path_lengths[i] = {i, counts_host[i]};
1006
+ }
1007
+ std::sort(path_lengths.begin(), path_lengths.end(),
1008
+ [&](const kv& a, const kv& b) {
1009
+ std::greater<> op;
1010
+ return op(a.second, b.second);
1011
+ });
1012
+
1013
+ // map unique_id -> bin
1014
+ std::vector<size_t> bin_map(counts_host.size());
1015
+ std::vector<int> bin_capacities(path_lengths.size(), bin_limit);
1016
+ for (auto pair : path_lengths) {
1017
+ int new_size = pair.second;
1018
+ for (auto j = 0ull; j < bin_capacities.size(); j++) {
1019
+ int& capacity = bin_capacities[j];
1020
+
1021
+ if (capacity >= new_size) {
1022
+ capacity -= new_size;
1023
+ bin_map[pair.first] = j;
1024
+ break;
1025
+ }
1026
+ }
1027
+ }
1028
+
1029
+ return bin_map;
1030
+ }
1031
+
1032
+ // Next Fit bin packing
1033
+ // O(n) implementation
1034
+ template <typename IntVectorT>
1035
+ std::vector<size_t> NFBinPacking(const IntVectorT& counts, int bin_limit = 32) {
1036
+ thrust::host_vector<int> counts_host(counts);
1037
+ std::vector<size_t> bin_map(counts_host.size());
1038
+ size_t current_bin = 0;
1039
+ int current_capacity = bin_limit;
1040
+ for (auto i = 0ull; i < counts_host.size(); i++) {
1041
+ int new_size = counts_host[i];
1042
+ size_t path_idx = i;
1043
+ if (new_size <= current_capacity) {
1044
+ current_capacity -= new_size;
1045
+ bin_map[path_idx] = current_bin;
1046
+ } else {
1047
+ current_capacity = bin_limit - new_size;
1048
+ bin_map[path_idx] = ++current_bin;
1049
+ }
1050
+ }
1051
+ return bin_map;
1052
+ }
1053
+
1054
+ template <typename DeviceAllocatorT, typename SplitConditionT,
1055
+ typename PathVectorT, typename LengthVectorT>
1056
+ void GetPathLengths(const PathVectorT& device_paths,
1057
+ LengthVectorT* path_lengths) {
1058
+ path_lengths->resize(
1059
+ static_cast<PathElement<SplitConditionT>>(device_paths.back()).path_idx +
1060
+ 1,
1061
+ 0);
1062
+ auto counting = thrust::make_counting_iterator(0llu);
1063
+ auto d_paths = device_paths.data().get();
1064
+ auto d_lengths = path_lengths->data().get();
1065
+ thrust::for_each_n(counting, device_paths.size(), [=] __device__(size_t idx) {
1066
+ auto path_idx = d_paths[idx].path_idx;
1067
+ atomicAdd(d_lengths + path_idx, 1ull);
1068
+ });
1069
+ }
1070
+
1071
+ struct PathTooLongOp {
1072
+ __device__ size_t operator()(size_t length) { return length > 32; }
1073
+ };
1074
+
1075
+ template <typename SplitConditionT>
1076
+ struct IncorrectVOp {
1077
+ const PathElement<SplitConditionT>* paths;
1078
+ __device__ size_t operator()(size_t idx) {
1079
+ auto a = paths[idx - 1];
1080
+ auto b = paths[idx];
1081
+ return a.path_idx == b.path_idx && a.v != b.v;
1082
+ }
1083
+ };
1084
+
1085
+ template <typename DeviceAllocatorT, typename SplitConditionT,
1086
+ typename PathVectorT, typename LengthVectorT>
1087
+ void ValidatePaths(const PathVectorT& device_paths,
1088
+ const LengthVectorT& path_lengths) {
1089
+ DeviceAllocatorT alloc;
1090
+ PathTooLongOp too_long_op;
1091
+ auto invalid_length =
1092
+ thrust::any_of(thrust::cuda::par(alloc), path_lengths.begin(),
1093
+ path_lengths.end(), too_long_op);
1094
+
1095
+ if (invalid_length) {
1096
+ throw std::invalid_argument("Tree depth must be < 32");
1097
+ }
1098
+
1099
+ IncorrectVOp<SplitConditionT> incorrect_v_op{device_paths.data().get()};
1100
+ auto counting = thrust::counting_iterator<size_t>(0);
1101
+ auto incorrect_v =
1102
+ thrust::any_of(thrust::cuda::par(alloc), counting + 1,
1103
+ counting + device_paths.size(), incorrect_v_op);
1104
+
1105
+ if (incorrect_v) {
1106
+ throw std::invalid_argument(
1107
+ "Leaf value v should be the same across a single path");
1108
+ }
1109
+ }
1110
+
1111
+ template <typename DeviceAllocatorT, typename SplitConditionT,
1112
+ typename PathVectorT, typename SizeVectorT>
1113
+ void PreprocessPaths(PathVectorT* device_paths, PathVectorT* deduplicated_paths,
1114
+ SizeVectorT* bin_segments) {
1115
+ // Sort paths by length and feature
1116
+ detail::DeduplicatePaths<PathVectorT, DeviceAllocatorT, SplitConditionT>(
1117
+ device_paths, deduplicated_paths);
1118
+ using int_vector = RebindVector<int, DeviceAllocatorT>;
1119
+ int_vector path_lengths;
1120
+ detail::GetPathLengths<DeviceAllocatorT, SplitConditionT>(*deduplicated_paths,
1121
+ &path_lengths);
1122
+ SizeVectorT device_bin_map = detail::BFDBinPacking(path_lengths);
1123
+ ValidatePaths<DeviceAllocatorT, SplitConditionT>(*deduplicated_paths,
1124
+ path_lengths);
1125
+ detail::SortPaths<PathVectorT, SplitConditionT, SizeVectorT,
1126
+ DeviceAllocatorT>(deduplicated_paths, device_bin_map);
1127
+ detail::GetBinSegments<PathVectorT, SizeVectorT, DeviceAllocatorT>(
1128
+ *deduplicated_paths, device_bin_map, bin_segments);
1129
+ }
1130
+
1131
+ struct PathIdxTransformOp {
1132
+ template <typename SplitConditionT>
1133
+ __device__ size_t operator()(const PathElement<SplitConditionT>& e) {
1134
+ return e.path_idx;
1135
+ }
1136
+ };
1137
+
1138
+ struct GroupIdxTransformOp {
1139
+ template <typename SplitConditionT>
1140
+ __device__ size_t operator()(const PathElement<SplitConditionT>& e) {
1141
+ return e.group;
1142
+ }
1143
+ };
1144
+
1145
+ struct BiasTransformOp {
1146
+ template <typename SplitConditionT>
1147
+ __device__ double operator()(const PathElement<SplitConditionT>& e) {
1148
+ return e.zero_fraction * e.v;
1149
+ }
1150
+ };
1151
+
1152
+ // While it is possible to compute bias in the primary kernel, we do it here
1153
+ // using double precision to avoid numerical stability issues
1154
+ template <typename PathVectorT, typename DoubleVectorT,
1155
+ typename DeviceAllocatorT, typename SplitConditionT>
1156
+ void ComputeBias(const PathVectorT& device_paths, DoubleVectorT* bias) {
1157
+ using double_vector = thrust::device_vector<
1158
+ double, typename DeviceAllocatorT::template rebind<double>::other>;
1159
+ PathVectorT sorted_paths(device_paths);
1160
+ DeviceAllocatorT alloc;
1161
+ // Make sure groups are contiguous
1162
+ thrust::sort(thrust::cuda::par(alloc), sorted_paths.begin(),
1163
+ sorted_paths.end(),
1164
+ [=] __device__(const PathElement<SplitConditionT>& a,
1165
+ const PathElement<SplitConditionT>& b) {
1166
+ if (a.group < b.group) return true;
1167
+ if (b.group < a.group) return false;
1168
+
1169
+ if (a.path_idx < b.path_idx) return true;
1170
+ if (b.path_idx < a.path_idx) return false;
1171
+
1172
+ return false;
1173
+ });
1174
+ // Combine zero fraction for all paths
1175
+ auto path_key = thrust::make_transform_iterator(sorted_paths.begin(),
1176
+ PathIdxTransformOp());
1177
+ PathVectorT combined(sorted_paths.size());
1178
+ auto combined_out = thrust::reduce_by_key(
1179
+ thrust::cuda ::par(alloc), path_key, path_key + sorted_paths.size(),
1180
+ sorted_paths.begin(), thrust::make_discard_iterator(), combined.begin(),
1181
+ thrust::equal_to<size_t>(),
1182
+ [=] __device__(PathElement<SplitConditionT> a,
1183
+ const PathElement<SplitConditionT>& b) {
1184
+ a.zero_fraction *= b.zero_fraction;
1185
+ return a;
1186
+ });
1187
+ size_t num_paths = combined_out.second - combined.begin();
1188
+ // Combine bias for each path, over each group
1189
+ using size_vector = thrust::device_vector<
1190
+ size_t, typename DeviceAllocatorT::template rebind<size_t>::other>;
1191
+ size_vector keys_out(num_paths);
1192
+ double_vector values_out(num_paths);
1193
+ auto group_key =
1194
+ thrust::make_transform_iterator(combined.begin(), GroupIdxTransformOp());
1195
+ auto values =
1196
+ thrust::make_transform_iterator(combined.begin(), BiasTransformOp());
1197
+
1198
+ auto out_itr = thrust::reduce_by_key(thrust::cuda::par(alloc), group_key,
1199
+ group_key + num_paths, values,
1200
+ keys_out.begin(), values_out.begin());
1201
+
1202
+ // Write result
1203
+ size_t n = out_itr.first - keys_out.begin();
1204
+ auto counting = thrust::make_counting_iterator(0llu);
1205
+ auto d_keys_out = keys_out.data().get();
1206
+ auto d_values_out = values_out.data().get();
1207
+ auto d_bias = bias->data().get();
1208
+ thrust::for_each_n(counting, n, [=] __device__(size_t idx) {
1209
+ d_bias[d_keys_out[idx]] = d_values_out[idx];
1210
+ });
1211
+ }
1212
+
1213
+ }; // namespace detail
1214
+
1215
+ /*!
1216
+ * Compute feature contributions on the GPU given a set of unique paths through
1217
+ * a tree ensemble and a dataset. Uses device memory proportional to the tree
1218
+ * ensemble size.
1219
+ *
1220
+ * \exception std::invalid_argument Thrown when an invalid argument error
1221
+ * condition occurs. \tparam PathIteratorT Thrust type iterator, may be
1222
+ * thrust::device_ptr for device memory, or stl iterator/raw pointer for host
1223
+ * memory. \tparam PhiIteratorT Thrust type iterator, may be
1224
+ * thrust::device_ptr for device memory, or stl iterator/raw pointer for host
1225
+ * memory. Value type must be floating point. \tparam DatasetT User-specified
1226
+ * dataset container. \tparam DeviceAllocatorT Optional thrust style
1227
+ * allocator.
1228
+ *
1229
+ * \param X Thin wrapper over a dataset allocated in device memory. X
1230
+ * should be trivially copyable as a kernel parameter (i.e. contain only
1231
+ * pointers to actual data) and must implement the methods
1232
+ * NumRows()/NumCols()/GetElement(size_t row_idx, size_t col_idx) as __device__
1233
+ * functions. GetElement may return NaN where the feature value is missing.
1234
+ * \param begin Iterator to paths, where separate paths are delineated by
1235
+ * PathElement.path_idx. Each unique path should contain 1
1236
+ * root with feature_idx = -1 and zero_fraction = 1.0. The ordering of path
1237
+ * elements inside a unique path does not matter - the result will be the same.
1238
+ * Paths may contain duplicate features. See the PathElement class for more
1239
+ * information. \param end Path end iterator. \param num_groups Number
1240
+ * of output groups. In multiclass classification the algorithm outputs feature
1241
+ * contributions per output class. \param phis_begin Begin iterator for output
1242
+ * phis. \param phis_end End iterator for output phis.
1243
+ */
1244
+ template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1245
+ typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1246
+ void GPUTreeShap(DatasetT X, PathIteratorT begin, PathIteratorT end,
1247
+ size_t num_groups, PhiIteratorT phis_begin,
1248
+ PhiIteratorT phis_end) {
1249
+ if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1250
+
1251
+ if (size_t(phis_end - phis_begin) <
1252
+ X.NumRows() * (X.NumCols() + 1) * num_groups) {
1253
+ throw std::invalid_argument(
1254
+ "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1255
+ "num_groups");
1256
+ }
1257
+
1258
+ using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1259
+ using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1260
+ using path_vector = detail::RebindVector<
1261
+ typename std::iterator_traits<PathIteratorT>::value_type,
1262
+ DeviceAllocatorT>;
1263
+ using split_condition =
1264
+ typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1265
+
1266
+ // Compute the global bias
1267
+ double_vector temp_phi(phis_end - phis_begin, 0.0);
1268
+ path_vector device_paths(begin, end);
1269
+ double_vector bias(num_groups, 0.0);
1270
+ detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT,
1271
+ split_condition>(device_paths, &bias);
1272
+ auto d_bias = bias.data().get();
1273
+ auto d_temp_phi = temp_phi.data().get();
1274
+ thrust::for_each_n(thrust::make_counting_iterator(0llu),
1275
+ X.NumRows() * num_groups, [=] __device__(size_t idx) {
1276
+ size_t group = idx % num_groups;
1277
+ size_t row_idx = idx / num_groups;
1278
+ d_temp_phi[IndexPhi(row_idx, num_groups, group,
1279
+ X.NumCols(), X.NumCols())] +=
1280
+ d_bias[group];
1281
+ });
1282
+
1283
+ path_vector deduplicated_paths;
1284
+ size_vector device_bin_segments;
1285
+ detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1286
+ &device_paths, &deduplicated_paths, &device_bin_segments);
1287
+
1288
+ detail::ComputeShap(X, device_bin_segments, deduplicated_paths, num_groups,
1289
+ temp_phi.data().get());
1290
+ thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1291
+ }
1292
+
1293
+ /*!
1294
+ * Compute feature interaction contributions on the GPU given a set of unique
1295
+ * paths through a tree ensemble and a dataset. Uses device memory
1296
+ * proportional to the tree ensemble size.
1297
+ *
1298
+ * \exception std::invalid_argument Thrown when an invalid argument error
1299
+ * condition occurs.
1300
+ * \tparam DeviceAllocatorT Optional thrust style allocator.
1301
+ * \tparam DatasetT User-specified dataset container.
1302
+ * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr
1303
+ * for device memory, or stl iterator/raw pointer for
1304
+ * host memory.
1305
+ * \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr
1306
+ * for device memory, or stl iterator/raw pointer for
1307
+ * host memory. Value type must be floating point.
1308
+ *
1309
+ * \param X Thin wrapper over a dataset allocated in device memory. X
1310
+ * should be trivially copyable as a kernel parameter (i.e.
1311
+ * contain only pointers to actual data) and must implement
1312
+ * the methods NumRows()/NumCols()/GetElement(size_t row_idx,
1313
+ * size_t col_idx) as __device__ functions. GetElement may
1314
+ * return NaN where the feature value is missing.
1315
+ * \param begin Iterator to paths, where separate paths are delineated by
1316
+ * PathElement.path_idx. Each unique path should contain 1
1317
+ * root with feature_idx = -1 and zero_fraction = 1.0. The
1318
+ * ordering of path elements inside a unique path does not
1319
+ * matter - the result will be the same. Paths may contain
1320
+ * duplicate features. See the PathElement class for more
1321
+ * information.
1322
+ * \param end Path end iterator.
1323
+ * \param num_groups Number of output groups. In multiclass classification the
1324
+ * algorithm outputs feature contributions per output class.
1325
+ * \param phis_begin Begin iterator for output phis.
1326
+ * \param phis_end End iterator for output phis.
1327
+ */
1328
+ template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1329
+ typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1330
+ void GPUTreeShapInteractions(DatasetT X, PathIteratorT begin, PathIteratorT end,
1331
+ size_t num_groups, PhiIteratorT phis_begin,
1332
+ PhiIteratorT phis_end) {
1333
+ if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1334
+ if (size_t(phis_end - phis_begin) <
1335
+ X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) {
1336
+ throw std::invalid_argument(
1337
+ "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1338
+ "(X.NumCols() + 1) * "
1339
+ "num_groups");
1340
+ }
1341
+
1342
+ using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1343
+ using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1344
+ using path_vector = detail::RebindVector<
1345
+ typename std::iterator_traits<PathIteratorT>::value_type,
1346
+ DeviceAllocatorT>;
1347
+ using split_condition =
1348
+ typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1349
+
1350
+ // Compute the global bias
1351
+ double_vector temp_phi(phis_end - phis_begin, 0.0);
1352
+ path_vector device_paths(begin, end);
1353
+ double_vector bias(num_groups, 0.0);
1354
+ detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT,
1355
+ split_condition>(device_paths, &bias);
1356
+ auto d_bias = bias.data().get();
1357
+ auto d_temp_phi = temp_phi.data().get();
1358
+ thrust::for_each_n(
1359
+ thrust::make_counting_iterator(0llu), X.NumRows() * num_groups,
1360
+ [=] __device__(size_t idx) {
1361
+ size_t group = idx % num_groups;
1362
+ size_t row_idx = idx / num_groups;
1363
+ d_temp_phi[IndexPhiInteractions(row_idx, num_groups, group, X.NumCols(),
1364
+ X.NumCols(), X.NumCols())] +=
1365
+ d_bias[group];
1366
+ });
1367
+
1368
+ path_vector deduplicated_paths;
1369
+ size_vector device_bin_segments;
1370
+ detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1371
+ &device_paths, &deduplicated_paths, &device_bin_segments);
1372
+
1373
+ detail::ComputeShapInteractions(X, device_bin_segments, deduplicated_paths,
1374
+ num_groups, temp_phi.data().get());
1375
+ thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1376
+ }
1377
+
1378
+ /*!
1379
+ * Compute feature interaction contributions using the Shapley Taylor index on
1380
+ * the GPU, given a set of unique paths through a tree ensemble and a dataset.
1381
+ * Uses device memory proportional to the tree ensemble size.
1382
+ *
1383
+ * \exception std::invalid_argument Thrown when an invalid argument error
1384
+ * condition occurs.
1385
+ * \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr
1386
+ * for device memory, or stl iterator/raw pointer for
1387
+ * host memory. Value type must be floating point.
1388
+ * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr
1389
+ * for device memory, or stl iterator/raw pointer for
1390
+ * host memory.
1391
+ * \tparam DatasetT User-specified dataset container.
1392
+ * \tparam DeviceAllocatorT Optional thrust style allocator.
1393
+ *
1394
+ * \param X Thin wrapper over a dataset allocated in device memory. X
1395
+ * should be trivially copyable as a kernel parameter (i.e.
1396
+ * contain only pointers to actual data) and must implement
1397
+ * the methods NumRows()/NumCols()/GetElement(size_t row_idx,
1398
+ * size_t col_idx) as __device__ functions. GetElement may
1399
+ * return NaN where the feature value is missing.
1400
+ * \param begin Iterator to paths, where separate paths are delineated by
1401
+ * PathElement.path_idx. Each unique path should contain 1
1402
+ * root with feature_idx = -1 and zero_fraction = 1.0. The
1403
+ * ordering of path elements inside a unique path does not
1404
+ * matter - the result will be the same. Paths may contain
1405
+ * duplicate features. See the PathElement class for more
1406
+ * information.
1407
+ * \param end Path end iterator.
1408
+ * \param num_groups Number of output groups. In multiclass classification the
1409
+ * algorithm outputs feature contributions per output class.
1410
+ * \param phis_begin Begin iterator for output phis.
1411
+ * \param phis_end End iterator for output phis.
1412
+ */
1413
+ template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1414
+ typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1415
+ void GPUTreeShapTaylorInteractions(DatasetT X, PathIteratorT begin,
1416
+ PathIteratorT end, size_t num_groups,
1417
+ PhiIteratorT phis_begin,
1418
+ PhiIteratorT phis_end) {
1419
+ using phis_type = typename std::iterator_traits<PhiIteratorT>::value_type;
1420
+ static_assert(std::is_floating_point<phis_type>::value,
1421
+ "Phis type must be floating point");
1422
+
1423
+ if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1424
+
1425
+ if (size_t(phis_end - phis_begin) <
1426
+ X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) {
1427
+ throw std::invalid_argument(
1428
+ "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1429
+ "(X.NumCols() + 1) * "
1430
+ "num_groups");
1431
+ }
1432
+
1433
+ using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1434
+ using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1435
+ using path_vector = detail::RebindVector<
1436
+ typename std::iterator_traits<PathIteratorT>::value_type,
1437
+ DeviceAllocatorT>;
1438
+ using split_condition =
1439
+ typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1440
+
1441
+ // Compute the global bias
1442
+ double_vector temp_phi(phis_end - phis_begin, 0.0);
1443
+ path_vector device_paths(begin, end);
1444
+ double_vector bias(num_groups, 0.0);
1445
+ detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT,
1446
+ split_condition>(device_paths, &bias);
1447
+ auto d_bias = bias.data().get();
1448
+ auto d_temp_phi = temp_phi.data().get();
1449
+ thrust::for_each_n(
1450
+ thrust::make_counting_iterator(0llu), X.NumRows() * num_groups,
1451
+ [=] __device__(size_t idx) {
1452
+ size_t group = idx % num_groups;
1453
+ size_t row_idx = idx / num_groups;
1454
+ d_temp_phi[IndexPhiInteractions(row_idx, num_groups, group, X.NumCols(),
1455
+ X.NumCols(), X.NumCols())] +=
1456
+ d_bias[group];
1457
+ });
1458
+
1459
+ path_vector deduplicated_paths;
1460
+ size_vector device_bin_segments;
1461
+ detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1462
+ &device_paths, &deduplicated_paths, &device_bin_segments);
1463
+
1464
+ detail::ComputeShapTaylorInteractions(X, device_bin_segments,
1465
+ deduplicated_paths, num_groups,
1466
+ temp_phi.data().get());
1467
+ thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1468
+ }
1469
+
1470
+ /*!
1471
+ * Compute feature contributions on the GPU given a set of unique paths through a tree ensemble
1472
+ * and a dataset. Uses device memory proportional to the tree ensemble size. This variant
1473
+ * implements the interventional tree shap algorithm described here:
1474
+ * https://drafts.distill.pub/HughChen/its_blog/
1475
+ *
1476
+ * It requires a background dataset R.
1477
+ *
1478
+ * \exception std::invalid_argument Thrown when an invalid argument error condition occurs.
1479
+ * \tparam DeviceAllocatorT Optional thrust style allocator.
1480
+ * \tparam DatasetT User-specified dataset container.
1481
+ * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or
1482
+ * stl iterator/raw pointer for host memory.
1483
+ *
1484
+ * \param X Thin wrapper over a dataset allocated in device memory. X should be trivially
1485
+ * copyable as a kernel parameter (i.e. contain only pointers to actual data) and
1486
+ * must implement the methods NumRows()/NumCols()/GetElement(size_t row_idx,
1487
+ * size_t col_idx) as __device__ functions. GetElement may return NaN where the
1488
+ * feature value is missing.
1489
+ * \param R Background dataset.
1490
+ * \param begin Iterator to paths, where separate paths are delineated by
1491
+ * PathElement.path_idx. Each unique path should contain 1 root with feature_idx =
1492
+ * -1 and zero_fraction = 1.0. The ordering of path elements inside a unique path
1493
+ * does not matter - the result will be the same. Paths may contain duplicate
1494
+ * features. See the PathElement class for more information.
1495
+ * \param end Path end iterator.
1496
+ * \param num_groups Number of output groups. In multiclass classification the algorithm outputs
1497
+ * feature contributions per output class.
1498
+ * \param phis_begin Begin iterator for output phis.
1499
+ * \param phis_end End iterator for output phis.
1500
+ */
1501
+ template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1502
+ typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1503
+ void GPUTreeShapInterventional(DatasetT X, DatasetT R, PathIteratorT begin,
1504
+ PathIteratorT end, size_t num_groups,
1505
+ PhiIteratorT phis_begin, PhiIteratorT phis_end) {
1506
+ if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1507
+
1508
+ if (size_t(phis_end - phis_begin) <
1509
+ X.NumRows() * (X.NumCols() + 1) * num_groups) {
1510
+ throw std::invalid_argument(
1511
+ "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1512
+ "num_groups");
1513
+ }
1514
+
1515
+ using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1516
+ using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1517
+ using path_vector = detail::RebindVector<
1518
+ typename std::iterator_traits<PathIteratorT>::value_type,
1519
+ DeviceAllocatorT>;
1520
+ using split_condition =
1521
+ typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1522
+
1523
+ double_vector temp_phi(phis_end - phis_begin, 0.0);
1524
+ path_vector device_paths(begin, end);
1525
+
1526
+ path_vector deduplicated_paths;
1527
+ size_vector device_bin_segments;
1528
+ detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1529
+ &device_paths, &deduplicated_paths, &device_bin_segments);
1530
+ detail::ComputeShapInterventional(X, R, device_bin_segments,
1531
+ deduplicated_paths, num_groups,
1532
+ temp_phi.data().get());
1533
+ thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1534
+ }
1535
+ } // namespace gpu_treeshap
lib/shap/cext/tree_shap.h ADDED
@@ -0,0 +1,1460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * Fast recursive computation of SHAP values in trees.
3
+ * See https://arxiv.org/abs/1802.03888 for details.
4
+ *
5
+ * Scott Lundberg, 2018 (independent algorithm courtesy of Hugh Chen 2018)
6
+ */
7
+
8
+ #include <algorithm>
9
+ #include <iostream>
10
+ #include <fstream>
11
+ #include <stdio.h>
12
+ #include <cmath>
13
+ #include <ctime>
14
+ #if defined(_WIN32) || defined(WIN32)
15
+ #include <malloc.h>
16
+ #elif defined(__MVS__)
17
+ #include <stdlib.h>
18
+ #else
19
+ #include <alloca.h>
20
+ #endif
21
+ using namespace std;
22
+
23
+ typedef double tfloat;
24
+ typedef tfloat (* transform_f)(const tfloat margin, const tfloat y);
25
+
26
+ namespace FEATURE_DEPENDENCE {
27
+ const unsigned independent = 0;
28
+ const unsigned tree_path_dependent = 1;
29
+ const unsigned global_path_dependent = 2;
30
+ }
31
+
32
+ struct TreeEnsemble {
33
+ int *children_left;
34
+ int *children_right;
35
+ int *children_default;
36
+ int *features;
37
+ tfloat *thresholds;
38
+ tfloat *values;
39
+ tfloat *node_sample_weights;
40
+ unsigned max_depth;
41
+ unsigned tree_limit;
42
+ tfloat *base_offset;
43
+ unsigned max_nodes;
44
+ unsigned num_outputs;
45
+
46
+ TreeEnsemble() {}
47
+ TreeEnsemble(int *children_left, int *children_right, int *children_default, int *features,
48
+ tfloat *thresholds, tfloat *values, tfloat *node_sample_weights,
49
+ unsigned max_depth, unsigned tree_limit, tfloat *base_offset,
50
+ unsigned max_nodes, unsigned num_outputs) :
51
+ children_left(children_left), children_right(children_right),
52
+ children_default(children_default), features(features), thresholds(thresholds),
53
+ values(values), node_sample_weights(node_sample_weights),
54
+ max_depth(max_depth), tree_limit(tree_limit),
55
+ base_offset(base_offset), max_nodes(max_nodes), num_outputs(num_outputs) {}
56
+
57
+ void get_tree(TreeEnsemble &tree, const unsigned i) const {
58
+ const unsigned d = i * max_nodes;
59
+
60
+ tree.children_left = children_left + d;
61
+ tree.children_right = children_right + d;
62
+ tree.children_default = children_default + d;
63
+ tree.features = features + d;
64
+ tree.thresholds = thresholds + d;
65
+ tree.values = values + d * num_outputs;
66
+ tree.node_sample_weights = node_sample_weights + d;
67
+ tree.max_depth = max_depth;
68
+ tree.tree_limit = 1;
69
+ tree.base_offset = base_offset;
70
+ tree.max_nodes = max_nodes;
71
+ tree.num_outputs = num_outputs;
72
+ }
73
+
74
+ bool is_leaf(unsigned pos)const {
75
+ return children_left[pos] < 0;
76
+ }
77
+
78
+ void allocate(unsigned tree_limit_in, unsigned max_nodes_in, unsigned num_outputs_in) {
79
+ tree_limit = tree_limit_in;
80
+ max_nodes = max_nodes_in;
81
+ num_outputs = num_outputs_in;
82
+ children_left = new int[tree_limit * max_nodes];
83
+ children_right = new int[tree_limit * max_nodes];
84
+ children_default = new int[tree_limit * max_nodes];
85
+ features = new int[tree_limit * max_nodes];
86
+ thresholds = new tfloat[tree_limit * max_nodes];
87
+ values = new tfloat[tree_limit * max_nodes * num_outputs];
88
+ node_sample_weights = new tfloat[tree_limit * max_nodes];
89
+ }
90
+
91
+ void free() {
92
+ delete[] children_left;
93
+ delete[] children_right;
94
+ delete[] children_default;
95
+ delete[] features;
96
+ delete[] thresholds;
97
+ delete[] values;
98
+ delete[] node_sample_weights;
99
+ }
100
+ };
101
+
102
+ struct ExplanationDataset {
103
+ tfloat *X;
104
+ bool *X_missing;
105
+ tfloat *y;
106
+ tfloat *R;
107
+ bool *R_missing;
108
+ unsigned num_X;
109
+ unsigned M;
110
+ unsigned num_R;
111
+
112
+ ExplanationDataset() {}
113
+ ExplanationDataset(tfloat *X, bool *X_missing, tfloat *y, tfloat *R, bool *R_missing, unsigned num_X,
114
+ unsigned M, unsigned num_R) :
115
+ X(X), X_missing(X_missing), y(y), R(R), R_missing(R_missing), num_X(num_X), M(M), num_R(num_R) {}
116
+
117
+ void get_x_instance(ExplanationDataset &instance, const unsigned i) const {
118
+ instance.M = M;
119
+ instance.X = X + i * M;
120
+ instance.X_missing = X_missing + i * M;
121
+ instance.num_X = 1;
122
+ }
123
+ };
124
+
125
+
126
+ // data we keep about our decision path
127
+ // note that pweight is included for convenience and is not tied with the other attributes
128
+ // the pweight of the i'th path element is the permutation weight of paths with i-1 ones in them
129
+ struct PathElement {
130
+ int feature_index;
131
+ tfloat zero_fraction;
132
+ tfloat one_fraction;
133
+ tfloat pweight;
134
+ PathElement() {}
135
+ PathElement(int i, tfloat z, tfloat o, tfloat w) :
136
+ feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
137
+ };
138
+
139
+ inline tfloat logistic_transform(const tfloat margin, const tfloat y) {
140
+ return 1 / (1 + exp(-margin));
141
+ }
142
+
143
+ inline tfloat logistic_nlogloss_transform(const tfloat margin, const tfloat y) {
144
+ return log(1 + exp(margin)) - y * margin; // y is in {0, 1}
145
+ }
146
+
147
+ inline tfloat squared_loss_transform(const tfloat margin, const tfloat y) {
148
+ return (margin - y) * (margin - y);
149
+ }
150
+
151
+ namespace MODEL_TRANSFORM {
152
+ const unsigned identity = 0;
153
+ const unsigned logistic = 1;
154
+ const unsigned logistic_nlogloss = 2;
155
+ const unsigned squared_loss = 3;
156
+ }
157
+
158
+ inline transform_f get_transform(unsigned model_transform) {
159
+ transform_f transform = NULL;
160
+ switch (model_transform) {
161
+ case MODEL_TRANSFORM::logistic:
162
+ transform = logistic_transform;
163
+ break;
164
+
165
+ case MODEL_TRANSFORM::logistic_nlogloss:
166
+ transform = logistic_nlogloss_transform;
167
+ break;
168
+
169
+ case MODEL_TRANSFORM::squared_loss:
170
+ transform = squared_loss_transform;
171
+ break;
172
+ }
173
+
174
+ return transform;
175
+ }
176
+
177
+ inline tfloat *tree_predict(unsigned i, const TreeEnsemble &trees, const tfloat *x, const bool *x_missing) {
178
+ const unsigned offset = i * trees.max_nodes;
179
+ unsigned node = 0;
180
+ while (true) {
181
+ const unsigned pos = offset + node;
182
+ const unsigned feature = trees.features[pos];
183
+
184
+ // we hit a leaf so return a pointer to the values
185
+ if (trees.is_leaf(pos)) {
186
+ return trees.values + pos * trees.num_outputs;
187
+ }
188
+
189
+ // otherwise we are at an internal node and need to recurse
190
+ if (x_missing[feature]) {
191
+ node = trees.children_default[pos];
192
+ } else if (x[feature] <= trees.thresholds[pos]) {
193
+ node = trees.children_left[pos];
194
+ } else {
195
+ node = trees.children_right[pos];
196
+ }
197
+ }
198
+ }
199
+
200
+ inline void dense_tree_predict(tfloat *out, const TreeEnsemble &trees, const ExplanationDataset &data, unsigned model_transform) {
201
+ tfloat *row_out = out;
202
+ const tfloat *x = data.X;
203
+ const bool *x_missing = data.X_missing;
204
+
205
+ // see what transform (if any) we have
206
+ transform_f transform = get_transform(model_transform);
207
+
208
+ for (unsigned i = 0; i < data.num_X; ++i) {
209
+
210
+ // add the base offset
211
+ for (unsigned k = 0; k < trees.num_outputs; ++k) {
212
+ row_out[k] += trees.base_offset[k];
213
+ }
214
+
215
+ // add the leaf values from each tree
216
+ for (unsigned j = 0; j < trees.tree_limit; ++j) {
217
+ const tfloat *leaf_value = tree_predict(j, trees, x, x_missing);
218
+
219
+ for (unsigned k = 0; k < trees.num_outputs; ++k) {
220
+ row_out[k] += leaf_value[k];
221
+ }
222
+ }
223
+
224
+ // apply any needed transform
225
+ if (transform != NULL) {
226
+ const tfloat y_i = data.y == NULL ? 0 : data.y[i];
227
+ for (unsigned k = 0; k < trees.num_outputs; ++k) {
228
+ row_out[k] = transform(row_out[k], y_i);
229
+ }
230
+ }
231
+
232
+ x += data.M;
233
+ x_missing += data.M;
234
+ row_out += trees.num_outputs;
235
+ }
236
+ }
237
+
238
+ inline void tree_update_weights(unsigned i, TreeEnsemble &trees, const tfloat *x, const bool *x_missing) {
239
+ const unsigned offset = i * trees.max_nodes;
240
+ unsigned node = 0;
241
+ while (true) {
242
+ const unsigned pos = offset + node;
243
+ const unsigned feature = trees.features[pos];
244
+
245
+ // Record that a sample passed through this node
246
+ trees.node_sample_weights[pos] += 1.0;
247
+
248
+ // we hit a leaf so return a pointer to the values
249
+ if (trees.children_left[pos] < 0) break;
250
+
251
+ // otherwise we are at an internal node and need to recurse
252
+ if (x_missing[feature]) {
253
+ node = trees.children_default[pos];
254
+ } else if (x[feature] <= trees.thresholds[pos]) {
255
+ node = trees.children_left[pos];
256
+ } else {
257
+ node = trees.children_right[pos];
258
+ }
259
+ }
260
+ }
261
+
262
+ inline void dense_tree_update_weights(TreeEnsemble &trees, const ExplanationDataset &data) {
263
+ const tfloat *x = data.X;
264
+ const bool *x_missing = data.X_missing;
265
+
266
+ for (unsigned i = 0; i < data.num_X; ++i) {
267
+
268
+ // add the leaf values from each tree
269
+ for (unsigned j = 0; j < trees.tree_limit; ++j) {
270
+ tree_update_weights(j, trees, x, x_missing);
271
+ }
272
+
273
+ x += data.M;
274
+ x_missing += data.M;
275
+ }
276
+ }
277
+
278
+ inline void tree_saabas(tfloat *out, const TreeEnsemble &tree, const ExplanationDataset &data) {
279
+ unsigned curr_node = 0;
280
+ unsigned next_node = 0;
281
+ while (true) {
282
+
283
+ // we hit a leaf and are done
284
+ if (tree.children_left[curr_node] < 0) return;
285
+
286
+ // otherwise we are at an internal node and need to recurse
287
+ const unsigned feature = tree.features[curr_node];
288
+ if (data.X_missing[feature]) {
289
+ next_node = tree.children_default[curr_node];
290
+ } else if (data.X[feature] <= tree.thresholds[curr_node]) {
291
+ next_node = tree.children_left[curr_node];
292
+ } else {
293
+ next_node = tree.children_right[curr_node];
294
+ }
295
+
296
+ // assign credit to this feature as the difference in values at the current node vs. the next node
297
+ for (unsigned i = 0; i < tree.num_outputs; ++i) {
298
+ out[feature * tree.num_outputs + i] += tree.values[next_node * tree.num_outputs + i] - tree.values[curr_node * tree.num_outputs + i];
299
+ }
300
+
301
+ curr_node = next_node;
302
+ }
303
+ }
304
+
305
+ /**
306
+ * This runs Tree SHAP with a per tree path conditional dependence assumption.
307
+ */
308
+ inline void dense_tree_saabas(tfloat *out_contribs, const TreeEnsemble& trees, const ExplanationDataset &data) {
309
+ tfloat *instance_out_contribs;
310
+ TreeEnsemble tree;
311
+ ExplanationDataset instance;
312
+
313
+ // build explanation for each sample
314
+ for (unsigned i = 0; i < data.num_X; ++i) {
315
+ instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
316
+ data.get_x_instance(instance, i);
317
+
318
+ // aggregate the effect of explaining each tree
319
+ // (this works because of the linearity property of Shapley values)
320
+ for (unsigned j = 0; j < trees.tree_limit; ++j) {
321
+ trees.get_tree(tree, j);
322
+ tree_saabas(instance_out_contribs, tree, instance);
323
+ }
324
+
325
+ // apply the base offset to the bias term
326
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
327
+ instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j];
328
+ }
329
+ }
330
+ }
331
+
332
+
333
+ // extend our decision path with a fraction of one and zero extensions
334
+ inline void extend_path(PathElement *unique_path, unsigned unique_depth,
335
+ tfloat zero_fraction, tfloat one_fraction, int feature_index) {
336
+ unique_path[unique_depth].feature_index = feature_index;
337
+ unique_path[unique_depth].zero_fraction = zero_fraction;
338
+ unique_path[unique_depth].one_fraction = one_fraction;
339
+ unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f);
340
+ for (int i = unique_depth - 1; i >= 0; i--) {
341
+ unique_path[i + 1].pweight += one_fraction * unique_path[i].pweight * (i + 1)
342
+ / static_cast<tfloat>(unique_depth + 1);
343
+ unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i)
344
+ / static_cast<tfloat>(unique_depth + 1);
345
+ }
346
+ }
347
+
348
+ // undo a previous extension of the decision path
349
+ inline void unwind_path(PathElement *unique_path, unsigned unique_depth, unsigned path_index) {
350
+ const tfloat one_fraction = unique_path[path_index].one_fraction;
351
+ const tfloat zero_fraction = unique_path[path_index].zero_fraction;
352
+ tfloat next_one_portion = unique_path[unique_depth].pweight;
353
+
354
+ for (int i = unique_depth - 1; i >= 0; --i) {
355
+ if (one_fraction != 0) {
356
+ const tfloat tmp = unique_path[i].pweight;
357
+ unique_path[i].pweight = next_one_portion * (unique_depth + 1)
358
+ / static_cast<tfloat>((i + 1) * one_fraction);
359
+ next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i)
360
+ / static_cast<tfloat>(unique_depth + 1);
361
+ } else {
362
+ unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1))
363
+ / static_cast<tfloat>(zero_fraction * (unique_depth - i));
364
+ }
365
+ }
366
+
367
+ for (unsigned i = path_index; i < unique_depth; ++i) {
368
+ unique_path[i].feature_index = unique_path[i+1].feature_index;
369
+ unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
370
+ unique_path[i].one_fraction = unique_path[i+1].one_fraction;
371
+ }
372
+ }
373
+
374
+ // determine what the total permutation weight would be if
375
+ // we unwound a previous extension in the decision path
376
+ inline tfloat unwound_path_sum(const PathElement *unique_path, unsigned unique_depth,
377
+ unsigned path_index) {
378
+ const tfloat one_fraction = unique_path[path_index].one_fraction;
379
+ const tfloat zero_fraction = unique_path[path_index].zero_fraction;
380
+ tfloat next_one_portion = unique_path[unique_depth].pweight;
381
+ tfloat total = 0;
382
+
383
+ if (one_fraction != 0) {
384
+ for (int i = unique_depth - 1; i >= 0; --i) {
385
+ const tfloat tmp = next_one_portion / static_cast<tfloat>((i + 1) * one_fraction);
386
+ total += tmp;
387
+ next_one_portion = unique_path[i].pweight - tmp * zero_fraction * (unique_depth - i);
388
+ }
389
+ } else {
390
+ for (int i = unique_depth - 1; i >= 0; --i) {
391
+ total += unique_path[i].pweight / (zero_fraction * (unique_depth - i));
392
+ }
393
+ }
394
+ return total * (unique_depth + 1);
395
+ }
396
+
397
+ // recursive computation of SHAP values for a decision tree
398
+ inline void tree_shap_recursive(const unsigned num_outputs, const int *children_left,
399
+ const int *children_right,
400
+ const int *children_default, const int *features,
401
+ const tfloat *thresholds, const tfloat *values,
402
+ const tfloat *node_sample_weight,
403
+ const tfloat *x, const bool *x_missing, tfloat *phi,
404
+ unsigned node_index, unsigned unique_depth,
405
+ PathElement *parent_unique_path, tfloat parent_zero_fraction,
406
+ tfloat parent_one_fraction, int parent_feature_index,
407
+ int condition, unsigned condition_feature,
408
+ tfloat condition_fraction) {
409
+
410
+ // stop if we have no weight coming down to us
411
+ if (condition_fraction == 0) return;
412
+
413
+ // extend the unique path
414
+ PathElement *unique_path = parent_unique_path + unique_depth + 1;
415
+ std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path);
416
+
417
+ if (condition == 0 || condition_feature != static_cast<unsigned>(parent_feature_index)) {
418
+ extend_path(unique_path, unique_depth, parent_zero_fraction,
419
+ parent_one_fraction, parent_feature_index);
420
+ }
421
+ const unsigned split_index = features[node_index];
422
+
423
+ // leaf node
424
+ if (children_right[node_index] < 0) {
425
+ for (unsigned i = 1; i <= unique_depth; ++i) {
426
+ const tfloat w = unwound_path_sum(unique_path, unique_depth, i);
427
+ const PathElement &el = unique_path[i];
428
+ const unsigned phi_offset = el.feature_index * num_outputs;
429
+ const unsigned values_offset = node_index * num_outputs;
430
+ const tfloat scale = w * (el.one_fraction - el.zero_fraction) * condition_fraction;
431
+ for (unsigned j = 0; j < num_outputs; ++j) {
432
+ phi[phi_offset + j] += scale * values[values_offset + j];
433
+ }
434
+ }
435
+
436
+ // internal node
437
+ } else {
438
+ // find which branch is "hot" (meaning x would follow it)
439
+ unsigned hot_index = 0;
440
+ if (x_missing[split_index]) {
441
+ hot_index = children_default[node_index];
442
+ } else if (x[split_index] <= thresholds[node_index]) {
443
+ hot_index = children_left[node_index];
444
+ } else {
445
+ hot_index = children_right[node_index];
446
+ }
447
+ const unsigned cold_index = (static_cast<int>(hot_index) == children_left[node_index] ?
448
+ children_right[node_index] : children_left[node_index]);
449
+ const tfloat w = node_sample_weight[node_index];
450
+ const tfloat hot_zero_fraction = node_sample_weight[hot_index] / w;
451
+ const tfloat cold_zero_fraction = node_sample_weight[cold_index] / w;
452
+ tfloat incoming_zero_fraction = 1;
453
+ tfloat incoming_one_fraction = 1;
454
+
455
+ // see if we have already split on this feature,
456
+ // if so we undo that split so we can redo it for this node
457
+ unsigned path_index = 0;
458
+ for (; path_index <= unique_depth; ++path_index) {
459
+ if (static_cast<unsigned>(unique_path[path_index].feature_index) == split_index) break;
460
+ }
461
+ if (path_index != unique_depth + 1) {
462
+ incoming_zero_fraction = unique_path[path_index].zero_fraction;
463
+ incoming_one_fraction = unique_path[path_index].one_fraction;
464
+ unwind_path(unique_path, unique_depth, path_index);
465
+ unique_depth -= 1;
466
+ }
467
+
468
+ // divide up the condition_fraction among the recursive calls
469
+ tfloat hot_condition_fraction = condition_fraction;
470
+ tfloat cold_condition_fraction = condition_fraction;
471
+ if (condition > 0 && split_index == condition_feature) {
472
+ cold_condition_fraction = 0;
473
+ unique_depth -= 1;
474
+ } else if (condition < 0 && split_index == condition_feature) {
475
+ hot_condition_fraction *= hot_zero_fraction;
476
+ cold_condition_fraction *= cold_zero_fraction;
477
+ unique_depth -= 1;
478
+ }
479
+
480
+ tree_shap_recursive(
481
+ num_outputs, children_left, children_right, children_default, features, thresholds, values,
482
+ node_sample_weight, x, x_missing, phi, hot_index, unique_depth + 1, unique_path,
483
+ hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,
484
+ split_index, condition, condition_feature, hot_condition_fraction
485
+ );
486
+
487
+ tree_shap_recursive(
488
+ num_outputs, children_left, children_right, children_default, features, thresholds, values,
489
+ node_sample_weight, x, x_missing, phi, cold_index, unique_depth + 1, unique_path,
490
+ cold_zero_fraction * incoming_zero_fraction, 0,
491
+ split_index, condition, condition_feature, cold_condition_fraction
492
+ );
493
+ }
494
+ }
495
+
496
+ inline int compute_expectations(TreeEnsemble &tree, int i = 0, int depth = 0) {
497
+ unsigned max_depth = 0;
498
+
499
+ if (tree.children_right[i] >= 0) {
500
+ const unsigned li = tree.children_left[i];
501
+ const unsigned ri = tree.children_right[i];
502
+ const unsigned depth_left = compute_expectations(tree, li, depth + 1);
503
+ const unsigned depth_right = compute_expectations(tree, ri, depth + 1);
504
+ const tfloat left_weight = tree.node_sample_weights[li];
505
+ const tfloat right_weight = tree.node_sample_weights[ri];
506
+ const unsigned li_offset = li * tree.num_outputs;
507
+ const unsigned ri_offset = ri * tree.num_outputs;
508
+ const unsigned i_offset = i * tree.num_outputs;
509
+ for (unsigned j = 0; j < tree.num_outputs; ++j) {
510
+ if ((left_weight == 0) && (right_weight == 0)) {
511
+ tree.values[i_offset + j] = 0.0;
512
+ } else {
513
+ const tfloat v = (left_weight * tree.values[li_offset + j] + right_weight * tree.values[ri_offset + j]) / (left_weight + right_weight);
514
+ tree.values[i_offset + j] = v;
515
+ }
516
+ }
517
+ max_depth = std::max(depth_left, depth_right) + 1;
518
+ }
519
+
520
+ if (depth == 0) tree.max_depth = max_depth;
521
+
522
+ return max_depth;
523
+ }
524
+
525
+ inline void tree_shap(const TreeEnsemble& tree, const ExplanationDataset &data,
526
+ tfloat *out_contribs, int condition, unsigned condition_feature) {
527
+
528
+ // update the reference value with the expected value of the tree's predictions
529
+ if (condition == 0) {
530
+ for (unsigned j = 0; j < tree.num_outputs; ++j) {
531
+ out_contribs[data.M * tree.num_outputs + j] += tree.values[j];
532
+ }
533
+ }
534
+
535
+ // Pre-allocate space for the unique path data
536
+ const unsigned maxd = tree.max_depth + 2; // need a bit more space than the max depth
537
+ PathElement *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
538
+
539
+ tree_shap_recursive(
540
+ tree.num_outputs, tree.children_left, tree.children_right, tree.children_default,
541
+ tree.features, tree.thresholds, tree.values, tree.node_sample_weights, data.X,
542
+ data.X_missing, out_contribs, 0, 0, unique_path_data, 1, 1, -1, condition,
543
+ condition_feature, 1
544
+ );
545
+
546
+ delete[] unique_path_data;
547
+ }
548
+
549
+
550
+ inline unsigned build_merged_tree_recursive(TreeEnsemble &out_tree, const TreeEnsemble &trees,
551
+ const tfloat *data, const bool *data_missing, int *data_inds,
552
+ const unsigned num_background_data_inds, unsigned num_data_inds,
553
+ unsigned M, unsigned row = 0, unsigned i = 0, unsigned pos = 0,
554
+ tfloat *leaf_value = NULL) {
555
+ //tfloat new_leaf_value[trees.num_outputs];
556
+ tfloat *new_leaf_value = (tfloat *) alloca(sizeof(tfloat) * trees.num_outputs); // allocate on the stack
557
+ unsigned row_offset = row * trees.max_nodes;
558
+
559
+ // we have hit a terminal leaf!!!
560
+ if (trees.children_left[row_offset + i] < 0 && row + 1 == trees.tree_limit) {
561
+
562
+ // create the leaf node
563
+ const tfloat *vals = trees.values + (row * trees.max_nodes + i) * trees.num_outputs;
564
+ if (leaf_value == NULL) {
565
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
566
+ out_tree.values[pos * trees.num_outputs + j] = vals[j];
567
+ }
568
+ } else {
569
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
570
+ out_tree.values[pos * trees.num_outputs + j] = leaf_value[j] + vals[j];
571
+ }
572
+ }
573
+ out_tree.children_left[pos] = -1;
574
+ out_tree.children_right[pos] = -1;
575
+ out_tree.children_default[pos] = -1;
576
+ out_tree.features[pos] = -1;
577
+ out_tree.thresholds[pos] = 0;
578
+ out_tree.node_sample_weights[pos] = num_background_data_inds;
579
+
580
+ return pos;
581
+ }
582
+
583
+ // we hit an intermediate leaf (so just add the value to our accumulator and move to the next tree)
584
+ if (trees.children_left[row_offset + i] < 0) {
585
+
586
+ // accumulate the value of this original leaf so it will land on all eventual terminal leaves
587
+ const tfloat *vals = trees.values + (row * trees.max_nodes + i) * trees.num_outputs;
588
+ if (leaf_value == NULL) {
589
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
590
+ new_leaf_value[j] = vals[j];
591
+ }
592
+ } else {
593
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
594
+ new_leaf_value[j] = leaf_value[j] + vals[j];
595
+ }
596
+ }
597
+ leaf_value = new_leaf_value;
598
+
599
+ // move forward to the next tree
600
+ row += 1;
601
+ row_offset += trees.max_nodes;
602
+ i = 0;
603
+ }
604
+
605
+ // split the data inds by this node's threshold
606
+ const tfloat t = trees.thresholds[row_offset + i];
607
+ const int f = trees.features[row_offset + i];
608
+ const bool right_default = trees.children_default[row_offset + i] == trees.children_right[row_offset + i];
609
+ int low_ptr = 0;
610
+ int high_ptr = num_data_inds - 1;
611
+ unsigned num_left_background_data_inds = 0;
612
+ int low_data_ind;
613
+ while (low_ptr <= high_ptr) {
614
+ low_data_ind = data_inds[low_ptr];
615
+ const int data_ind = std::abs(low_data_ind) * M + f;
616
+ const bool is_missing = data_missing[data_ind];
617
+ if ((!is_missing && data[data_ind] > t) || (right_default && is_missing)) {
618
+ data_inds[low_ptr] = data_inds[high_ptr];
619
+ data_inds[high_ptr] = low_data_ind;
620
+ high_ptr -= 1;
621
+ } else {
622
+ if (low_data_ind >= 0) ++num_left_background_data_inds; // negative data_inds are not background samples
623
+ low_ptr += 1;
624
+ }
625
+ }
626
+ int *left_data_inds = data_inds;
627
+ const unsigned num_left_data_inds = low_ptr;
628
+ int *right_data_inds = data_inds + low_ptr;
629
+ const unsigned num_right_data_inds = num_data_inds - num_left_data_inds;
630
+ const unsigned num_right_background_data_inds = num_background_data_inds - num_left_background_data_inds;
631
+
632
+ // all the data went right, so we skip creating this node and just recurse right
633
+ if (num_left_data_inds == 0) {
634
+ return build_merged_tree_recursive(
635
+ out_tree, trees, data, data_missing, data_inds,
636
+ num_background_data_inds, num_data_inds, M, row,
637
+ trees.children_right[row_offset + i], pos, leaf_value
638
+ );
639
+
640
+ // all the data went left, so we skip creating this node and just recurse left
641
+ } else if (num_right_data_inds == 0) {
642
+ return build_merged_tree_recursive(
643
+ out_tree, trees, data, data_missing, data_inds,
644
+ num_background_data_inds, num_data_inds, M, row,
645
+ trees.children_left[row_offset + i], pos, leaf_value
646
+ );
647
+
648
+ // data went both ways so we create this node and recurse down both paths
649
+ } else {
650
+
651
+ // build the left subtree
652
+ const unsigned new_pos = build_merged_tree_recursive(
653
+ out_tree, trees, data, data_missing, left_data_inds,
654
+ num_left_background_data_inds, num_left_data_inds, M, row,
655
+ trees.children_left[row_offset + i], pos + 1, leaf_value
656
+ );
657
+
658
+ // fill in the data for this node
659
+ out_tree.children_left[pos] = pos + 1;
660
+ out_tree.children_right[pos] = new_pos + 1;
661
+ if (trees.children_left[row_offset + i] == trees.children_default[row_offset + i]) {
662
+ out_tree.children_default[pos] = pos + 1;
663
+ } else {
664
+ out_tree.children_default[pos] = new_pos + 1;
665
+ }
666
+
667
+ out_tree.features[pos] = trees.features[row_offset + i];
668
+ out_tree.thresholds[pos] = trees.thresholds[row_offset + i];
669
+ out_tree.node_sample_weights[pos] = num_background_data_inds;
670
+
671
+ // build the right subtree
672
+ return build_merged_tree_recursive(
673
+ out_tree, trees, data, data_missing, right_data_inds,
674
+ num_right_background_data_inds, num_right_data_inds, M, row,
675
+ trees.children_right[row_offset + i], new_pos + 1, leaf_value
676
+ );
677
+ }
678
+ }
679
+
680
+
681
+ inline void build_merged_tree(TreeEnsemble &out_tree, const ExplanationDataset &data, const TreeEnsemble &trees) {
682
+
683
+ // create a joint data matrix from both X and R matrices
684
+ tfloat *joined_data = new tfloat[(data.num_X + data.num_R) * data.M];
685
+ std::copy(data.X, data.X + data.num_X * data.M, joined_data);
686
+ std::copy(data.R, data.R + data.num_R * data.M, joined_data + data.num_X * data.M);
687
+ bool *joined_data_missing = new bool[(data.num_X + data.num_R) * data.M];
688
+ std::copy(data.X_missing, data.X_missing + data.num_X * data.M, joined_data_missing);
689
+ std::copy(data.R_missing, data.R_missing + data.num_R * data.M, joined_data_missing + data.num_X * data.M);
690
+
691
+ // create an starting array of data indexes we will recursively sort
692
+ int *data_inds = new int[data.num_X + data.num_R];
693
+ for (unsigned i = 0; i < data.num_X; ++i) data_inds[i] = i;
694
+ for (unsigned i = data.num_X; i < data.num_X + data.num_R; ++i) {
695
+ data_inds[i] = -i; // a negative index means it won't be recorded as a background sample
696
+ }
697
+
698
+ build_merged_tree_recursive(
699
+ out_tree, trees, joined_data, joined_data_missing, data_inds, data.num_R,
700
+ data.num_X + data.num_R, data.M
701
+ );
702
+
703
+ delete[] joined_data;
704
+ delete[] joined_data_missing;
705
+ delete[] data_inds;
706
+ }
707
+
708
+
709
+ // Independent Tree SHAP functions below here
710
+ // ------------------------------------------
711
+ struct Node {
712
+ short cl, cr, cd, pnode, feat, pfeat; // uint_16
713
+ float thres, value;
714
+ char from_flag;
715
+ };
716
+
717
+ #define FROM_NEITHER 0
718
+ #define FROM_X_NOT_R 1
719
+ #define FROM_R_NOT_X 2
720
+
721
+ // https://www.geeksforgeeks.org/space-and-time-efficient-binomial-coefficient/
722
+ inline int bin_coeff(int n, int k) {
723
+ int res = 1;
724
+ if (k > n - k)
725
+ k = n - k;
726
+ for (int i = 0; i < k; ++i) {
727
+ res *= (n - i);
728
+ res /= (i + 1);
729
+ }
730
+ return res;
731
+ }
732
+
733
+ // note this only handles single output models, so multi-output models get explained using multiple passes
734
+ inline void tree_shap_indep(const unsigned max_depth, const unsigned num_feats,
735
+ const unsigned num_nodes, const tfloat *x,
736
+ const bool *x_missing, const tfloat *r,
737
+ const bool *r_missing, tfloat *out_contribs,
738
+ float *pos_lst, float *neg_lst, signed short *feat_hist,
739
+ float *memoized_weights, int *node_stack, Node *mytree) {
740
+
741
+ // const bool DEBUG = true;
742
+ // ofstream myfile;
743
+ // if (DEBUG) {
744
+ // myfile.open ("/homes/gws/hughchen/shap/out.txt",fstream::app);
745
+ // myfile << "Entering tree_shap_indep\n";
746
+ // }
747
+ int ns_ctr = 0;
748
+ std::fill_n(feat_hist, num_feats, 0);
749
+ short node = 0, feat, cl, cr, cd, pnode, pfeat = -1;
750
+ short next_xnode = -1, next_rnode = -1;
751
+ short next_node = -1, from_child = -1;
752
+ float thres, pos_x = 0, neg_x = 0, pos_r = 0, neg_r = 0;
753
+ char from_flag;
754
+ unsigned M = 0, N = 0;
755
+
756
+ Node curr_node = mytree[node];
757
+ feat = curr_node.feat;
758
+ thres = curr_node.thres;
759
+ cl = curr_node.cl;
760
+ cr = curr_node.cr;
761
+ cd = curr_node.cd;
762
+
763
+ // short circuit when this is a stump tree (with no splits)
764
+ if (cl < 0) {
765
+ out_contribs[num_feats] += curr_node.value;
766
+ return;
767
+ }
768
+
769
+ // if (DEBUG) {
770
+ // myfile << "\nNode: " << node << "\n";
771
+ // myfile << "x[feat]: " << x[feat] << ", r[feat]: " << r[feat] << "\n";
772
+ // myfile << "thres: " << thres << "\n";
773
+ // }
774
+
775
+ if (x_missing[feat]) {
776
+ next_xnode = cd;
777
+ } else if (x[feat] > thres) {
778
+ next_xnode = cr;
779
+ } else if (x[feat] <= thres) {
780
+ next_xnode = cl;
781
+ }
782
+
783
+ if (r_missing[feat]) {
784
+ next_rnode = cd;
785
+ } else if (r[feat] > thres) {
786
+ next_rnode = cr;
787
+ } else if (r[feat] <= thres) {
788
+ next_rnode = cl;
789
+ }
790
+
791
+ if (next_xnode != next_rnode) {
792
+ mytree[next_xnode].from_flag = FROM_X_NOT_R;
793
+ mytree[next_rnode].from_flag = FROM_R_NOT_X;
794
+ } else {
795
+ mytree[next_xnode].from_flag = FROM_NEITHER;
796
+ }
797
+
798
+ // Check if x and r go the same way
799
+ if (next_xnode == next_rnode) {
800
+ next_node = next_xnode;
801
+ }
802
+
803
+ // If not, go left
804
+ if (next_node < 0) {
805
+ next_node = cl;
806
+ if (next_rnode == next_node) { // rpath
807
+ N = N+1;
808
+ feat_hist[feat] -= 1;
809
+ } else if (next_xnode == next_node) { // xpath
810
+ M = M+1;
811
+ N = N+1;
812
+ feat_hist[feat] += 1;
813
+ }
814
+ }
815
+ node_stack[ns_ctr] = node;
816
+ ns_ctr += 1;
817
+ while (true) {
818
+ node = next_node;
819
+ curr_node = mytree[node];
820
+ feat = curr_node.feat;
821
+ thres = curr_node.thres;
822
+ cl = curr_node.cl;
823
+ cr = curr_node.cr;
824
+ cd = curr_node.cd;
825
+ pnode = curr_node.pnode;
826
+ pfeat = curr_node.pfeat;
827
+ from_flag = curr_node.from_flag;
828
+
829
+
830
+
831
+ // if (DEBUG) {
832
+ // myfile << "\nNode: " << node << "\n";
833
+ // myfile << "N: " << N << ", M: " << M << "\n";
834
+ // myfile << "from_flag==FROM_X_NOT_R: " << (from_flag==FROM_X_NOT_R) << "\n";
835
+ // myfile << "from_flag==FROM_R_NOT_X: " << (from_flag==FROM_R_NOT_X) << "\n";
836
+ // myfile << "from_flag==FROM_NEITHER: " << (from_flag==FROM_NEITHER) << "\n";
837
+ // myfile << "feat_hist[feat]: " << feat_hist[feat] << "\n";
838
+ // }
839
+
840
+ // At a leaf
841
+ if (cl < 0) {
842
+ // if (DEBUG) {
843
+ // myfile << "At a leaf\n";
844
+ // }
845
+
846
+ if (M == 0) {
847
+ out_contribs[num_feats] += mytree[node].value;
848
+ }
849
+
850
+ // Currently assuming a single output
851
+ if (N != 0) {
852
+ if (M != 0) {
853
+ pos_lst[node] = mytree[node].value * memoized_weights[N + max_depth * (M-1)];
854
+ }
855
+ if (M != N) {
856
+ neg_lst[node] = -mytree[node].value * memoized_weights[N + max_depth * M];
857
+ }
858
+ }
859
+ // if (DEBUG) {
860
+ // myfile << "pos_lst[node]: " << pos_lst[node] << "\n";
861
+ // myfile << "neg_lst[node]: " << neg_lst[node] << "\n";
862
+ // }
863
+ // Pop from node_stack
864
+ ns_ctr -= 1;
865
+ next_node = node_stack[ns_ctr];
866
+ from_child = node;
867
+ // Unwind
868
+ if (feat_hist[pfeat] > 0) {
869
+ feat_hist[pfeat] -= 1;
870
+ } else if (feat_hist[pfeat] < 0) {
871
+ feat_hist[pfeat] += 1;
872
+ }
873
+ if (feat_hist[pfeat] == 0) {
874
+ if (from_flag == FROM_X_NOT_R) {
875
+ N = N-1;
876
+ M = M-1;
877
+ } else if (from_flag == FROM_R_NOT_X) {
878
+ N = N-1;
879
+ }
880
+ }
881
+ continue;
882
+ }
883
+
884
+ const bool x_right = x[feat] > thres;
885
+ const bool r_right = r[feat] > thres;
886
+
887
+ if (x_missing[feat]) {
888
+ next_xnode = cd;
889
+ } else if (x_right) {
890
+ next_xnode = cr;
891
+ } else if (!x_right) {
892
+ next_xnode = cl;
893
+ }
894
+
895
+ if (r_missing[feat]) {
896
+ next_rnode = cd;
897
+ } else if (r_right) {
898
+ next_rnode = cr;
899
+ } else if (!r_right) {
900
+ next_rnode = cl;
901
+ }
902
+
903
+ if (next_xnode >= 0) {
904
+ if (next_xnode != next_rnode) {
905
+ mytree[next_xnode].from_flag = FROM_X_NOT_R;
906
+ mytree[next_rnode].from_flag = FROM_R_NOT_X;
907
+ } else {
908
+ mytree[next_xnode].from_flag = FROM_NEITHER;
909
+ }
910
+ }
911
+
912
+ // Arriving at node from parent
913
+ if (from_child == -1) {
914
+ // if (DEBUG) {
915
+ // myfile << "Arriving at node from parent\n";
916
+ // }
917
+ node_stack[ns_ctr] = node;
918
+ ns_ctr += 1;
919
+ next_node = -1;
920
+
921
+ // if (DEBUG) {
922
+ // myfile << "feat_hist[feat]" << feat_hist[feat] << "\n";
923
+ // }
924
+ // Feature is set upstream
925
+ if (feat_hist[feat] > 0) {
926
+ next_node = next_xnode;
927
+ feat_hist[feat] += 1;
928
+ } else if (feat_hist[feat] < 0) {
929
+ next_node = next_rnode;
930
+ feat_hist[feat] -= 1;
931
+ }
932
+
933
+ // x and r go the same way
934
+ if (next_node < 0) {
935
+ if (next_xnode == next_rnode) {
936
+ next_node = next_xnode;
937
+ }
938
+ }
939
+
940
+ // Go down one path
941
+ if (next_node >= 0) {
942
+ continue;
943
+ }
944
+
945
+ // Go down both paths, but go left first
946
+ next_node = cl;
947
+ if (next_rnode == next_node) {
948
+ N = N+1;
949
+ feat_hist[feat] -= 1;
950
+ } else if (next_xnode == next_node) {
951
+ M = M+1;
952
+ N = N+1;
953
+ feat_hist[feat] += 1;
954
+ }
955
+ from_child = -1;
956
+ continue;
957
+ }
958
+
959
+ // Arriving at node from child
960
+ if (from_child != -1) {
961
+ // if (DEBUG) {
962
+ // myfile << "Arriving at node from child\n";
963
+ // }
964
+ next_node = -1;
965
+ // Check if we should unroll immediately
966
+ if ((next_rnode == next_xnode) || (feat_hist[feat] != 0)) {
967
+ next_node = pnode;
968
+ }
969
+
970
+ // Came from a single path, so unroll
971
+ if (next_node >= 0) {
972
+ // if (DEBUG) {
973
+ // myfile << "Came from a single path, so unroll\n";
974
+ // }
975
+ // At the root node
976
+ if (node == 0) {
977
+ break;
978
+ }
979
+ // Update and unroll
980
+ pos_lst[node] = pos_lst[from_child];
981
+ neg_lst[node] = neg_lst[from_child];
982
+
983
+ // if (DEBUG) {
984
+ // myfile << "pos_lst[node]: " << pos_lst[node] << "\n";
985
+ // myfile << "neg_lst[node]: " << neg_lst[node] << "\n";
986
+ // }
987
+ from_child = node;
988
+ ns_ctr -= 1;
989
+
990
+ // Unwind
991
+ if (feat_hist[pfeat] > 0) {
992
+ feat_hist[pfeat] -= 1;
993
+ } else if (feat_hist[pfeat] < 0) {
994
+ feat_hist[pfeat] += 1;
995
+ }
996
+ if (feat_hist[pfeat] == 0) {
997
+ if (from_flag == FROM_X_NOT_R) {
998
+ N = N-1;
999
+ M = M-1;
1000
+ } else if (from_flag == FROM_R_NOT_X) {
1001
+ N = N-1;
1002
+ }
1003
+ }
1004
+ continue;
1005
+ // Go right - Arriving from the left child
1006
+ } else if (from_child == cl) {
1007
+ // if (DEBUG) {
1008
+ // myfile << "Go right - Arriving from the left child\n";
1009
+ // }
1010
+ node_stack[ns_ctr] = node;
1011
+ ns_ctr += 1;
1012
+ next_node = cr;
1013
+ if (next_xnode == next_node) {
1014
+ M = M+1;
1015
+ N = N+1;
1016
+ feat_hist[feat] += 1;
1017
+ } else if (next_rnode == next_node) {
1018
+ N = N+1;
1019
+ feat_hist[feat] -= 1;
1020
+ }
1021
+ from_child = -1;
1022
+ continue;
1023
+ // Compute stuff and unroll - Arriving from the right child
1024
+ } else if (from_child == cr) {
1025
+ // if (DEBUG) {
1026
+ // myfile << "Compute stuff and unroll - Arriving from the right child\n";
1027
+ // }
1028
+ pos_x = 0;
1029
+ neg_x = 0;
1030
+ pos_r = 0;
1031
+ neg_r = 0;
1032
+ if ((next_xnode == cr) && (next_rnode == cl)) {
1033
+ pos_x = pos_lst[cr];
1034
+ neg_x = neg_lst[cr];
1035
+ pos_r = pos_lst[cl];
1036
+ neg_r = neg_lst[cl];
1037
+ } else if ((next_xnode == cl) && (next_rnode == cr)) {
1038
+ pos_x = pos_lst[cl];
1039
+ neg_x = neg_lst[cl];
1040
+ pos_r = pos_lst[cr];
1041
+ neg_r = neg_lst[cr];
1042
+ }
1043
+ // out_contribs needs to have been initialized as all zeros
1044
+ // if (pos_x + neg_r != 0) {
1045
+ // std::cout << "val " << pos_x + neg_r << "\n";
1046
+ // }
1047
+ out_contribs[feat] += pos_x + neg_r;
1048
+ pos_lst[node] = pos_x + pos_r;
1049
+ neg_lst[node] = neg_x + neg_r;
1050
+
1051
+ // if (DEBUG) {
1052
+ // myfile << "out_contribs[feat]: " << out_contribs[feat] << "\n";
1053
+ // myfile << "pos_lst[node]: " << pos_lst[node] << "\n";
1054
+ // myfile << "neg_lst[node]: " << neg_lst[node] << "\n";
1055
+ // }
1056
+
1057
+ // Check if at root
1058
+ if (node == 0) {
1059
+ break;
1060
+ }
1061
+
1062
+ // Pop
1063
+ ns_ctr -= 1;
1064
+ next_node = node_stack[ns_ctr];
1065
+ from_child = node;
1066
+
1067
+ // Unwind
1068
+ if (feat_hist[pfeat] > 0) {
1069
+ feat_hist[pfeat] -= 1;
1070
+ } else if (feat_hist[pfeat] < 0) {
1071
+ feat_hist[pfeat] += 1;
1072
+ }
1073
+ if (feat_hist[pfeat] == 0) {
1074
+ if (from_flag == FROM_X_NOT_R) {
1075
+ N = N-1;
1076
+ M = M-1;
1077
+ } else if (from_flag == FROM_R_NOT_X) {
1078
+ N = N-1;
1079
+ }
1080
+ }
1081
+ continue;
1082
+ }
1083
+ }
1084
+ }
1085
+ // if (DEBUG) {
1086
+ // myfile.close();
1087
+ // }
1088
+ }
1089
+
1090
+
1091
+ inline void print_progress_bar(tfloat &last_print, tfloat start_time, unsigned i, unsigned total_count) {
1092
+ const tfloat elapsed_seconds = difftime(time(NULL), start_time);
1093
+
1094
+ if (elapsed_seconds > 10 && elapsed_seconds - last_print > 0.5) {
1095
+ const tfloat fraction = static_cast<tfloat>(i) / total_count;
1096
+ const double total_seconds = elapsed_seconds / fraction;
1097
+ last_print = elapsed_seconds;
1098
+
1099
+ PySys_WriteStderr(
1100
+ "\r%3.0f%%|%.*s%.*s| %d/%d [%02d:%02d<%02d:%02d] ",
1101
+ fraction * 100, int(0.5 + fraction*20), "===================",
1102
+ 20-int(0.5 + fraction*20), " ",
1103
+ i, total_count,
1104
+ int(elapsed_seconds/60), int(elapsed_seconds) % 60,
1105
+ int((total_seconds - elapsed_seconds)/60), int(total_seconds - elapsed_seconds) % 60
1106
+ );
1107
+
1108
+ // Get handle to python stderr file and flush it (https://mail.python.org/pipermail/python-list/2004-November/294912.html)
1109
+ PyObject *pyStderr = PySys_GetObject("stderr");
1110
+ if (pyStderr) {
1111
+ PyObject *result = PyObject_CallMethod(pyStderr, "flush", NULL);
1112
+ Py_XDECREF(result);
1113
+ }
1114
+ }
1115
+ }
1116
+
1117
+ /**
1118
+ * Runs Tree SHAP with feature independence assumptions on dense data.
1119
+ */
1120
+ inline void dense_independent(const TreeEnsemble& trees, const ExplanationDataset &data,
1121
+ tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
1122
+
1123
+ // reformat the trees for faster access
1124
+ Node *node_trees = new Node[trees.tree_limit * trees.max_nodes];
1125
+ for (unsigned i = 0; i < trees.tree_limit; ++i) {
1126
+ Node *node_tree = node_trees + i * trees.max_nodes;
1127
+ for (unsigned j = 0; j < trees.max_nodes; ++j) {
1128
+ const unsigned en_ind = i * trees.max_nodes + j;
1129
+ node_tree[j].cl = trees.children_left[en_ind];
1130
+ node_tree[j].cr = trees.children_right[en_ind];
1131
+ node_tree[j].cd = trees.children_default[en_ind];
1132
+ if (j == 0) {
1133
+ node_tree[j].pnode = 0;
1134
+ }
1135
+ if (trees.children_left[en_ind] >= 0) { // relies on all unused entries having negative values in them
1136
+ node_tree[trees.children_left[en_ind]].pnode = j;
1137
+ node_tree[trees.children_left[en_ind]].pfeat = trees.features[en_ind];
1138
+ }
1139
+ if (trees.children_right[en_ind] >= 0) { // relies on all unused entries having negative values in them
1140
+ node_tree[trees.children_right[en_ind]].pnode = j;
1141
+ node_tree[trees.children_right[en_ind]].pfeat = trees.features[en_ind];
1142
+ }
1143
+
1144
+ node_tree[j].thres = trees.thresholds[en_ind];
1145
+ node_tree[j].feat = trees.features[en_ind];
1146
+ }
1147
+ }
1148
+
1149
+ // preallocate arrays needed by the algorithm
1150
+ float *pos_lst = new float[trees.max_nodes];
1151
+ float *neg_lst = new float[trees.max_nodes];
1152
+ int *node_stack = new int[(unsigned) trees.max_depth];
1153
+ signed short *feat_hist = new signed short[data.M];
1154
+ tfloat *tmp_out_contribs = new tfloat[(data.M + 1)];
1155
+
1156
+ // precompute all the weight coefficients
1157
+ float *memoized_weights = new float[(trees.max_depth+1) * (trees.max_depth+1)];
1158
+ for (unsigned n = 0; n <= trees.max_depth; ++n) {
1159
+ for (unsigned m = 0; m <= trees.max_depth; ++m) {
1160
+ memoized_weights[n + trees.max_depth * m] = 1.0 / (n * bin_coeff(n-1, m));
1161
+ }
1162
+ }
1163
+
1164
+ // compute the explanations for each sample
1165
+ tfloat *instance_out_contribs;
1166
+ tfloat rescale_factor = 1.0;
1167
+ tfloat margin_x = 0;
1168
+ tfloat margin_r = 0;
1169
+ time_t start_time = time(NULL);
1170
+ tfloat last_print = 0;
1171
+ for (unsigned oind = 0; oind < trees.num_outputs; ++oind) {
1172
+ // set the values in the reformatted tree to the current output index
1173
+ for (unsigned i = 0; i < trees.tree_limit; ++i) {
1174
+ Node *node_tree = node_trees + i * trees.max_nodes;
1175
+ for (unsigned j = 0; j < trees.max_nodes; ++j) {
1176
+ const unsigned en_ind = i * trees.max_nodes + j;
1177
+ node_tree[j].value = trees.values[en_ind * trees.num_outputs + oind];
1178
+ }
1179
+ }
1180
+
1181
+ // loop over all the samples
1182
+ for (unsigned i = 0; i < data.num_X; ++i) {
1183
+ const tfloat *x = data.X + i * data.M;
1184
+ const bool *x_missing = data.X_missing + i * data.M;
1185
+ instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
1186
+ const tfloat y_i = data.y == NULL ? 0 : data.y[i];
1187
+
1188
+ print_progress_bar(last_print, start_time, oind * data.num_X + i, data.num_X * trees.num_outputs);
1189
+
1190
+ // compute the model's margin output for x
1191
+ if (transform != NULL) {
1192
+ margin_x = trees.base_offset[oind];
1193
+ for (unsigned k = 0; k < trees.tree_limit; ++k) {
1194
+ margin_x += tree_predict(k, trees, x, x_missing)[oind];
1195
+ }
1196
+ }
1197
+
1198
+ for (unsigned j = 0; j < data.num_R; ++j) {
1199
+ const tfloat *r = data.R + j * data.M;
1200
+ const bool *r_missing = data.R_missing + j * data.M;
1201
+ std::fill_n(tmp_out_contribs, (data.M + 1), 0);
1202
+
1203
+ // compute the model's margin output for r
1204
+ if (transform != NULL) {
1205
+ margin_r = trees.base_offset[oind];
1206
+ for (unsigned k = 0; k < trees.tree_limit; ++k) {
1207
+ margin_r += tree_predict(k, trees, r, r_missing)[oind];
1208
+ }
1209
+ }
1210
+
1211
+ for (unsigned k = 0; k < trees.tree_limit; ++k) {
1212
+ tree_shap_indep(
1213
+ trees.max_depth, data.M, trees.max_nodes, x, x_missing, r, r_missing,
1214
+ tmp_out_contribs, pos_lst, neg_lst, feat_hist, memoized_weights,
1215
+ node_stack, node_trees + k * trees.max_nodes
1216
+ );
1217
+ }
1218
+
1219
+ // compute the rescale factor
1220
+ if (transform != NULL) {
1221
+ if (margin_x == margin_r) {
1222
+ rescale_factor = 1.0;
1223
+ } else {
1224
+ rescale_factor = (*transform)(margin_x, y_i) - (*transform)(margin_r, y_i);
1225
+ rescale_factor /= margin_x - margin_r;
1226
+ }
1227
+ }
1228
+
1229
+ // add the effect of the current reference to our running total
1230
+ // this is where we can do per reference scaling for non-linear transformations
1231
+ for (unsigned k = 0; k < data.M; ++k) {
1232
+ instance_out_contribs[k * trees.num_outputs + oind] += tmp_out_contribs[k] * rescale_factor;
1233
+ }
1234
+
1235
+ // Add the base offset
1236
+ if (transform != NULL) {
1237
+ instance_out_contribs[data.M * trees.num_outputs + oind] += (*transform)(trees.base_offset[oind] + tmp_out_contribs[data.M], 0);
1238
+ } else {
1239
+ instance_out_contribs[data.M * trees.num_outputs + oind] += trees.base_offset[oind] + tmp_out_contribs[data.M];
1240
+ }
1241
+ }
1242
+
1243
+ // average the results over all the references.
1244
+ for (unsigned j = 0; j < (data.M + 1); ++j) {
1245
+ instance_out_contribs[j * trees.num_outputs + oind] /= data.num_R;
1246
+ }
1247
+
1248
+ // apply the base offset to the bias term
1249
+ // for (unsigned j = 0; j < trees.num_outputs; ++j) {
1250
+ // instance_out_contribs[data.M * trees.num_outputs + j] += (*transform)(trees.base_offset[j], 0);
1251
+ // }
1252
+ }
1253
+ }
1254
+
1255
+ delete[] tmp_out_contribs;
1256
+ delete[] node_trees;
1257
+ delete[] pos_lst;
1258
+ delete[] neg_lst;
1259
+ delete[] node_stack;
1260
+ delete[] feat_hist;
1261
+ delete[] memoized_weights;
1262
+ }
1263
+
1264
+
1265
+ /**
1266
+ * This runs Tree SHAP with a per tree path conditional dependence assumption.
1267
+ */
1268
+ inline void dense_tree_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
1269
+ tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
1270
+ tfloat *instance_out_contribs;
1271
+ TreeEnsemble tree;
1272
+ ExplanationDataset instance;
1273
+
1274
+ // build explanation for each sample
1275
+ for (unsigned i = 0; i < data.num_X; ++i) {
1276
+ instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
1277
+ data.get_x_instance(instance, i);
1278
+
1279
+ // aggregate the effect of explaining each tree
1280
+ // (this works because of the linearity property of Shapley values)
1281
+ for (unsigned j = 0; j < trees.tree_limit; ++j) {
1282
+ trees.get_tree(tree, j);
1283
+ tree_shap(tree, instance, instance_out_contribs, 0, 0);
1284
+ }
1285
+
1286
+ // apply the base offset to the bias term
1287
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
1288
+ instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j];
1289
+ }
1290
+ }
1291
+ }
1292
+
1293
+ // phi = np.zeros((self._current_X.shape[1] + 1, self._current_X.shape[1] + 1, self.n_outputs))
1294
+ // phi_diag = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
1295
+ // for t in range(self.tree_limit):
1296
+ // self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_diag)
1297
+ // for j in self.trees[t].unique_features:
1298
+ // phi_on = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
1299
+ // phi_off = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
1300
+ // self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_on, 1, j)
1301
+ // self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_off, -1, j)
1302
+ // phi[j] += np.true_divide(np.subtract(phi_on,phi_off),2.0)
1303
+ // phi_diag[j] -= np.sum(np.true_divide(np.subtract(phi_on,phi_off),2.0))
1304
+ // for j in range(self._current_X.shape[1]+1):
1305
+ // phi[j][j] = phi_diag[j]
1306
+ // phi /= self.tree_limit
1307
+ // return phi
1308
+
1309
+ inline void dense_tree_interactions_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
1310
+ tfloat *out_contribs,
1311
+ tfloat transform(const tfloat, const tfloat)) {
1312
+
1313
+ // build a list of all the unique features in each tree
1314
+ int amount_of_unique_features = min(data.M, trees.max_nodes);
1315
+ int *unique_features = new int[trees.tree_limit * amount_of_unique_features];
1316
+ std::fill(unique_features, unique_features + trees.tree_limit * amount_of_unique_features, -1);
1317
+ for (unsigned j = 0; j < trees.tree_limit; ++j) {
1318
+ const int *features_row = trees.features + j * trees.max_nodes;
1319
+ int *unique_features_row = unique_features + j * amount_of_unique_features;
1320
+ for (unsigned k = 0; k < trees.max_nodes; ++k) {
1321
+ for (unsigned l = 0; l < amount_of_unique_features; ++l) {
1322
+ if (features_row[k] == unique_features_row[l]) break;
1323
+ if (unique_features_row[l] < 0) {
1324
+ unique_features_row[l] = features_row[k];
1325
+ break;
1326
+ }
1327
+ }
1328
+ }
1329
+ }
1330
+
1331
+ // build an interaction explanation for each sample
1332
+ tfloat *instance_out_contribs;
1333
+ TreeEnsemble tree;
1334
+ ExplanationDataset instance;
1335
+ const unsigned contrib_row_size = (data.M + 1) * trees.num_outputs;
1336
+ tfloat *diag_contribs = new tfloat[contrib_row_size];
1337
+ tfloat *on_contribs = new tfloat[contrib_row_size];
1338
+ tfloat *off_contribs = new tfloat[contrib_row_size];
1339
+ for (unsigned i = 0; i < data.num_X; ++i) {
1340
+ instance_out_contribs = out_contribs + i * (data.M + 1) * contrib_row_size;
1341
+ data.get_x_instance(instance, i);
1342
+
1343
+ // aggregate the effect of explaining each tree
1344
+ // (this works because of the linearity property of Shapley values)
1345
+ std::fill(diag_contribs, diag_contribs + contrib_row_size, 0);
1346
+ for (unsigned j = 0; j < trees.tree_limit; ++j) {
1347
+ trees.get_tree(tree, j);
1348
+ tree_shap(tree, instance, diag_contribs, 0, 0);
1349
+
1350
+ const int *unique_features_row = unique_features + j * amount_of_unique_features;
1351
+ for (unsigned k = 0; k < amount_of_unique_features; ++k) {
1352
+ const int ind = unique_features_row[k];
1353
+ if (ind < 0) break; // < 0 means we have seen all the features for this tree
1354
+
1355
+ // compute the shap value with this feature held on and off
1356
+ std::fill(on_contribs, on_contribs + contrib_row_size, 0);
1357
+ std::fill(off_contribs, off_contribs + contrib_row_size, 0);
1358
+ tree_shap(tree, instance, on_contribs, 1, ind);
1359
+ tree_shap(tree, instance, off_contribs, -1, ind);
1360
+
1361
+ // save the difference between on and off as the interaction value
1362
+ for (unsigned l = 0; l < contrib_row_size; ++l) {
1363
+ const tfloat val = (on_contribs[l] - off_contribs[l]) / 2;
1364
+ instance_out_contribs[ind * contrib_row_size + l] += val;
1365
+ diag_contribs[l] -= val;
1366
+ }
1367
+ }
1368
+ }
1369
+
1370
+ // set the diagonal
1371
+ for (unsigned j = 0; j < data.M + 1; ++j) {
1372
+ const unsigned offset = j * contrib_row_size + j * trees.num_outputs;
1373
+ for (unsigned k = 0; k < trees.num_outputs; ++k) {
1374
+ instance_out_contribs[offset + k] = diag_contribs[j * trees.num_outputs + k];
1375
+ }
1376
+ }
1377
+
1378
+ // apply the base offset to the bias term
1379
+ const unsigned last_ind = (data.M * (data.M + 1) + data.M) * trees.num_outputs;
1380
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
1381
+ instance_out_contribs[last_ind + j] += trees.base_offset[j];
1382
+ }
1383
+ }
1384
+
1385
+ delete[] diag_contribs;
1386
+ delete[] on_contribs;
1387
+ delete[] off_contribs;
1388
+ delete[] unique_features;
1389
+ }
1390
+
1391
+ /**
1392
+ * This runs Tree SHAP with a global path conditional dependence assumption.
1393
+ *
1394
+ * By first merging all the trees in a tree ensemble into an equivalent single tree
1395
+ * this method allows arbitrary marginal transformations and also ensures that all the
1396
+ * evaluations of the model are consistent with some training data point.
1397
+ */
1398
+ inline void dense_global_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
1399
+ tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
1400
+
1401
+ // allocate space for our new merged tree (we save enough room to totally split all samples if need be)
1402
+ TreeEnsemble merged_tree;
1403
+ merged_tree.allocate(1, (data.num_X + data.num_R) * 2, trees.num_outputs);
1404
+
1405
+ // collapse the ensemble of trees into a single tree that has the same behavior
1406
+ // for all the X and R samples in the dataset
1407
+ build_merged_tree(merged_tree, data, trees);
1408
+
1409
+ // compute the expected value and depth of the new merged tree
1410
+ compute_expectations(merged_tree);
1411
+
1412
+ // explain each sample using our new merged tree
1413
+ ExplanationDataset instance;
1414
+ tfloat *instance_out_contribs;
1415
+ for (unsigned i = 0; i < data.num_X; ++i) {
1416
+ instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
1417
+ data.get_x_instance(instance, i);
1418
+
1419
+ // since we now just have a single merged tree we can just use the tree_path_dependent algorithm
1420
+ tree_shap(merged_tree, instance, instance_out_contribs, 0, 0);
1421
+
1422
+ // apply the base offset to the bias term
1423
+ for (unsigned j = 0; j < trees.num_outputs; ++j) {
1424
+ instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j];
1425
+ }
1426
+ }
1427
+
1428
+ merged_tree.free();
1429
+ }
1430
+
1431
+
1432
+ /**
1433
+ * The main method for computing Tree SHAP on models using dense data.
1434
+ */
1435
+ inline void dense_tree_shap(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs,
1436
+ const int feature_dependence, unsigned model_transform, bool interactions) {
1437
+
1438
+ // see what transform (if any) we have
1439
+ transform_f transform = get_transform(model_transform);
1440
+
1441
+ // dispatch to the correct algorithm handler
1442
+ switch (feature_dependence) {
1443
+ case FEATURE_DEPENDENCE::independent:
1444
+ if (interactions) {
1445
+ std::cerr << "FEATURE_DEPENDENCE::independent does not support interactions!\n";
1446
+ } else dense_independent(trees, data, out_contribs, transform);
1447
+ return;
1448
+
1449
+ case FEATURE_DEPENDENCE::tree_path_dependent:
1450
+ if (interactions) dense_tree_interactions_path_dependent(trees, data, out_contribs, transform);
1451
+ else dense_tree_path_dependent(trees, data, out_contribs, transform);
1452
+ return;
1453
+
1454
+ case FEATURE_DEPENDENCE::global_path_dependent:
1455
+ if (interactions) {
1456
+ std::cerr << "FEATURE_DEPENDENCE::global_path_dependent does not support interactions!\n";
1457
+ } else dense_global_path_dependent(trees, data, out_contribs, transform);
1458
+ return;
1459
+ }
1460
+ }
lib/shap/datasets.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from urllib.request import urlretrieve
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import sklearn.datasets
7
+
8
+ import shap
9
+
10
+ github_data_url = "https://github.com/shap/shap/raw/master/data/"
11
+
12
+
13
+ def imagenet50(display=False, resolution=224, n_points=None):
14
+ """ This is a set of 50 images representative of ImageNet images.
15
+
16
+ This dataset was collected by randomly finding a working ImageNet link and then pasting the
17
+ original ImageNet image into Google image search restricted to images licensed for reuse. A
18
+ similar image (now with rights to reuse) was downloaded as a rough replacement for the original
19
+ ImageNet image. The point is to have a random sample of ImageNet for use as a background
20
+ distribution for explaining models trained on ImageNet data.
21
+
22
+ Note that because the images are only rough replacements the labels might no longer be correct.
23
+ """
24
+
25
+ prefix = github_data_url + "imagenet50_"
26
+ X = np.load(cache(f"{prefix}{resolution}x{resolution}.npy")).astype(np.float32)
27
+ y = np.loadtxt(cache(f"{prefix}labels.csv"))
28
+
29
+ if n_points is not None:
30
+ X = shap.utils.sample(X, n_points, random_state=0)
31
+ y = shap.utils.sample(y, n_points, random_state=0)
32
+
33
+ return X, y
34
+
35
+
36
+ def california(display=False, n_points=None):
37
+ """ Return the california housing data in a nice package. """
38
+
39
+ d = sklearn.datasets.fetch_california_housing()
40
+ df = pd.DataFrame(data=d.data, columns=d.feature_names)
41
+ target = d.target
42
+
43
+ if n_points is not None:
44
+ df = shap.utils.sample(df, n_points, random_state=0)
45
+ target = shap.utils.sample(target, n_points, random_state=0)
46
+
47
+ return df, target
48
+
49
+
50
+ def linnerud(display=False, n_points=None):
51
+ """ Return the linnerud data in a nice package (multi-target regression). """
52
+
53
+ d = sklearn.datasets.load_linnerud()
54
+ X = pd.DataFrame(d.data, columns=d.feature_names)
55
+ y = pd.DataFrame(d.target, columns=d.target_names)
56
+
57
+ if n_points is not None:
58
+ X = shap.utils.sample(X, n_points, random_state=0)
59
+ y = shap.utils.sample(y, n_points, random_state=0)
60
+
61
+ return X, y
62
+
63
+
64
+ def imdb(display=False, n_points=None):
65
+ """ Return the classic IMDB sentiment analysis training data in a nice package.
66
+
67
+ Full data is at: http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
68
+ Paper to cite when using the data is: http://www.aclweb.org/anthology/P11-1015
69
+ """
70
+
71
+ with open(cache(github_data_url + "imdb_train.txt"), encoding="utf-8") as f:
72
+ data = f.readlines()
73
+ y = np.ones(25000, dtype=bool)
74
+ y[:12500] = 0
75
+
76
+ if n_points is not None:
77
+ data = shap.utils.sample(data, n_points, random_state=0)
78
+ y = shap.utils.sample(y, n_points, random_state=0)
79
+
80
+ return data, y
81
+
82
+
83
+ def communitiesandcrime(display=False, n_points=None):
84
+ """ Predict total number of non-violent crimes per 100K popuation.
85
+
86
+ This dataset is from the classic UCI Machine Learning repository:
87
+ https://archive.ics.uci.edu/ml/datasets/Communities+and+Crime+Unnormalized
88
+ """
89
+
90
+ raw_data = pd.read_csv(
91
+ cache(github_data_url + "CommViolPredUnnormalizedData.txt"),
92
+ na_values="?"
93
+ )
94
+
95
+ # find the indices where the total violent crimes are known
96
+ valid_inds = np.where(np.invert(np.isnan(raw_data.iloc[:,-2])))[0]
97
+
98
+ if n_points is not None:
99
+ valid_inds = shap.utils.sample(valid_inds, n_points, random_state=0)
100
+
101
+ y = np.array(raw_data.iloc[valid_inds,-2], dtype=float)
102
+
103
+ # extract the predictive features and remove columns with missing values
104
+ X = raw_data.iloc[valid_inds,5:-18]
105
+ valid_cols = np.where(np.isnan(X.values).sum(0) == 0)[0]
106
+ X = X.iloc[:,valid_cols]
107
+
108
+ return X, y
109
+
110
+
111
+ def diabetes(display=False, n_points=None):
112
+ """ Return the diabetes data in a nice package. """
113
+
114
+ d = sklearn.datasets.load_diabetes()
115
+ df = pd.DataFrame(data=d.data, columns=d.feature_names)
116
+ target = d.target
117
+
118
+ if n_points is not None:
119
+ df = shap.utils.sample(df, n_points, random_state=0)
120
+ target = shap.utils.sample(target, n_points, random_state=0)
121
+
122
+ return df, target
123
+
124
+
125
+ def iris(display=False, n_points=None):
126
+ """ Return the classic iris data in a nice package. """
127
+
128
+ d = sklearn.datasets.load_iris()
129
+ df = pd.DataFrame(data=d.data, columns=d.feature_names)
130
+ target = d.target
131
+
132
+ if n_points is not None:
133
+ df = shap.utils.sample(df, n_points, random_state=0)
134
+ target = shap.utils.sample(target, n_points, random_state=0)
135
+
136
+ if display:
137
+ return df, [d.target_names[v] for v in target]
138
+ return df, target
139
+
140
+
141
+ def adult(display=False, n_points=None):
142
+ """ Return the Adult census data in a nice package. """
143
+ dtypes = [
144
+ ("Age", "float32"), ("Workclass", "category"), ("fnlwgt", "float32"),
145
+ ("Education", "category"), ("Education-Num", "float32"), ("Marital Status", "category"),
146
+ ("Occupation", "category"), ("Relationship", "category"), ("Race", "category"),
147
+ ("Sex", "category"), ("Capital Gain", "float32"), ("Capital Loss", "float32"),
148
+ ("Hours per week", "float32"), ("Country", "category"), ("Target", "category")
149
+ ]
150
+ raw_data = pd.read_csv(
151
+ cache(github_data_url + "adult.data"),
152
+ names=[d[0] for d in dtypes],
153
+ na_values="?",
154
+ dtype=dict(dtypes)
155
+ )
156
+
157
+ if n_points is not None:
158
+ raw_data = shap.utils.sample(raw_data, n_points, random_state=0)
159
+
160
+ data = raw_data.drop(["Education"], axis=1) # redundant with Education-Num
161
+ filt_dtypes = list(filter(lambda x: x[0] not in ["Target", "Education"], dtypes))
162
+ data["Target"] = data["Target"] == " >50K"
163
+ rcode = {
164
+ "Not-in-family": 0,
165
+ "Unmarried": 1,
166
+ "Other-relative": 2,
167
+ "Own-child": 3,
168
+ "Husband": 4,
169
+ "Wife": 5
170
+ }
171
+ for k, dtype in filt_dtypes:
172
+ if dtype == "category":
173
+ if k == "Relationship":
174
+ data[k] = np.array([rcode[v.strip()] for v in data[k]])
175
+ else:
176
+ data[k] = data[k].cat.codes
177
+
178
+ if display:
179
+ return raw_data.drop(["Education", "Target", "fnlwgt"], axis=1), data["Target"].values
180
+ return data.drop(["Target", "fnlwgt"], axis=1), data["Target"].values
181
+
182
+
183
+ def nhanesi(display=False, n_points=None):
184
+ """ A nicely packaged version of NHANES I data with surivival times as labels.
185
+ """
186
+ X = pd.read_csv(cache(github_data_url + "NHANESI_X.csv"), index_col=0)
187
+ y = pd.read_csv(cache(github_data_url + "NHANESI_y.csv"), index_col=0)["y"]
188
+
189
+ if n_points is not None:
190
+ X = shap.utils.sample(X, n_points, random_state=0)
191
+ y = shap.utils.sample(y, n_points, random_state=0)
192
+
193
+ if display:
194
+ X_display = X.copy()
195
+ # X_display["sex_isFemale"] = ["Female" if v else "Male" for v in X["sex_isFemale"]]
196
+ return X_display, np.array(y)
197
+ return X, np.array(y)
198
+
199
+
200
+ def corrgroups60(display=False, n_points=1_000):
201
+ """ Correlated Groups 60
202
+
203
+ A simulated dataset with tight correlations among distinct groups of features.
204
+ """
205
+
206
+ # set a constant seed
207
+ old_seed = np.random.seed()
208
+ np.random.seed(0)
209
+
210
+ # generate dataset with known correlation
211
+ N, M = n_points, 60
212
+
213
+ # set one coefficient from each group of 3 to 1
214
+ beta = np.zeros(M)
215
+ beta[0:30:3] = 1
216
+
217
+ # build a correlation matrix with groups of 3 tightly correlated features
218
+ C = np.eye(M)
219
+ for i in range(0,30,3):
220
+ C[i,i+1] = C[i+1,i] = 0.99
221
+ C[i,i+2] = C[i+2,i] = 0.99
222
+ C[i+1,i+2] = C[i+2,i+1] = 0.99
223
+ def f(X):
224
+ return np.matmul(X, beta)
225
+
226
+ # Make sure the sample correlation is a perfect match
227
+ X_start = np.random.randn(N, M)
228
+ X_centered = X_start - X_start.mean(0)
229
+ Sigma = np.matmul(X_centered.T, X_centered) / X_centered.shape[0]
230
+ W = np.linalg.cholesky(np.linalg.inv(Sigma)).T
231
+ X_white = np.matmul(X_centered, W.T)
232
+ assert np.linalg.norm(np.corrcoef(np.matmul(X_centered, W.T).T) - np.eye(M)) < 1e-6 # ensure this decorrelates the data
233
+
234
+ # create the final data
235
+ X_final = np.matmul(X_white, np.linalg.cholesky(C).T)
236
+ X = X_final
237
+ y = f(X) + np.random.randn(N) * 1e-2
238
+
239
+ # restore the previous numpy random seed
240
+ np.random.seed(old_seed)
241
+
242
+ return pd.DataFrame(X), y
243
+
244
+
245
+ def independentlinear60(display=False, n_points=1_000):
246
+ """ A simulated dataset with tight correlations among distinct groups of features.
247
+ """
248
+
249
+ # set a constant seed
250
+ old_seed = np.random.seed()
251
+ np.random.seed(0)
252
+
253
+ # generate dataset with known correlation
254
+ N, M = n_points, 60
255
+
256
+ # set one coefficient from each group of 3 to 1
257
+ beta = np.zeros(M)
258
+ beta[0:30:3] = 1
259
+ def f(X):
260
+ return np.matmul(X, beta)
261
+
262
+ # Make sure the sample correlation is a perfect match
263
+ X_start = np.random.randn(N, M)
264
+ X = X_start - X_start.mean(0)
265
+ y = f(X) + np.random.randn(N) * 1e-2
266
+
267
+ # restore the previous numpy random seed
268
+ np.random.seed(old_seed)
269
+
270
+ return pd.DataFrame(X), y
271
+
272
+
273
+ def a1a(n_points=None):
274
+ """ A sparse dataset in scipy csr matrix format.
275
+ """
276
+ data, target = sklearn.datasets.load_svmlight_file(cache(github_data_url + 'a1a.svmlight'))
277
+
278
+ if n_points is not None:
279
+ data = shap.utils.sample(data, n_points, random_state=0)
280
+ target = shap.utils.sample(target, n_points, random_state=0)
281
+
282
+ return data, target
283
+
284
+
285
+ def rank():
286
+ """ Ranking datasets from lightgbm repository.
287
+ """
288
+ rank_data_url = 'https://raw.githubusercontent.com/Microsoft/LightGBM/master/examples/lambdarank/'
289
+ x_train, y_train = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.train'))
290
+ x_test, y_test = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.test'))
291
+ q_train = np.loadtxt(cache(rank_data_url + 'rank.train.query'))
292
+ q_test = np.loadtxt(cache(rank_data_url + 'rank.test.query'))
293
+
294
+ return x_train, y_train, x_test, y_test, q_train, q_test
295
+
296
+
297
+ def cache(url, file_name=None):
298
+ """ Loads a file from the URL and caches it locally.
299
+ """
300
+ if file_name is None:
301
+ file_name = os.path.basename(url)
302
+ data_dir = os.path.join(os.path.dirname(__file__), "cached_data")
303
+ os.makedirs(data_dir, exist_ok=True)
304
+
305
+ file_path = os.path.join(data_dir, file_name)
306
+ if not os.path.isfile(file_path):
307
+ urlretrieve(url, file_path)
308
+
309
+ return file_path
lib/shap/explainers/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._additive import AdditiveExplainer
2
+ from ._deep import DeepExplainer
3
+ from ._exact import ExactExplainer
4
+ from ._gpu_tree import GPUTreeExplainer
5
+ from ._gradient import GradientExplainer
6
+ from ._kernel import KernelExplainer
7
+ from ._linear import LinearExplainer
8
+ from ._partition import PartitionExplainer
9
+ from ._permutation import PermutationExplainer
10
+ from ._sampling import SamplingExplainer
11
+ from ._tree import TreeExplainer
12
+
13
+ # Alternative legacy "short-form" aliases, which are kept here for backwards-compatibility
14
+ Additive = AdditiveExplainer
15
+ Deep = DeepExplainer
16
+ Exact = ExactExplainer
17
+ GPUTree = GPUTreeExplainer
18
+ Gradient = GradientExplainer
19
+ Kernel = KernelExplainer
20
+ Linear = LinearExplainer
21
+ Partition = PartitionExplainer
22
+ Permutation = PermutationExplainer
23
+ Sampling = SamplingExplainer
24
+ Tree = TreeExplainer
25
+
26
+ __all__ = [
27
+ "AdditiveExplainer",
28
+ "DeepExplainer",
29
+ "ExactExplainer",
30
+ "GPUTreeExplainer",
31
+ "GradientExplainer",
32
+ "KernelExplainer",
33
+ "LinearExplainer",
34
+ "PartitionExplainer",
35
+ "PermutationExplainer",
36
+ "SamplingExplainer",
37
+ "TreeExplainer",
38
+ ]
lib/shap/explainers/_additive.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from ..utils import MaskedModel, safe_isinstance
4
+ from ._explainer import Explainer
5
+
6
+
7
+ class AdditiveExplainer(Explainer):
8
+ """ Computes SHAP values for generalized additive models.
9
+
10
+ This assumes that the model only has first-order effects. Extending this to
11
+ second- and third-order effects is future work (if you apply this to those models right now
12
+ you will get incorrect answers that fail additivity).
13
+ """
14
+
15
+ def __init__(self, model, masker, link=None, feature_names=None, linearize_link=True):
16
+ """ Build an Additive explainer for the given model using the given masker object.
17
+
18
+ Parameters
19
+ ----------
20
+ model : function
21
+ A callable python object that executes the model given a set of input data samples.
22
+
23
+ masker : function or numpy.array or pandas.DataFrame
24
+ A callable python object used to "mask" out hidden features of the form `masker(mask, *fargs)`.
25
+ It takes a single a binary mask and an input sample and returns a matrix of masked samples. These
26
+ masked samples are evaluated using the model function and the outputs are then averaged.
27
+ As a shortcut for the standard masking used by SHAP you can pass a background data matrix
28
+ instead of a function and that matrix will be used for masking. To use a clustering
29
+ game structure you can pass a shap.maskers.Tabular(data, hclustering=\"correlation\") object, but
30
+ note that this structure information has no effect on the explanations of additive models.
31
+ """
32
+ super().__init__(model, masker, feature_names=feature_names, linearize_link=linearize_link)
33
+
34
+ if safe_isinstance(model, "interpret.glassbox.ExplainableBoostingClassifier"):
35
+ self.model = model.decision_function
36
+
37
+ if self.masker is None:
38
+ self._expected_value = model.intercept_
39
+ # num_features = len(model.additive_terms_)
40
+
41
+ # fm = MaskedModel(self.model, self.masker, self.link, np.zeros(num_features))
42
+ # masks = np.ones((1, num_features), dtype=bool)
43
+ # outputs = fm(masks)
44
+ # self.model(np.zeros(num_features))
45
+ # self._zero_offset = self.model(np.zeros(num_features))#model.intercept_#outputs[0]
46
+ # self._input_offsets = np.zeros(num_features) #* self._zero_offset
47
+ raise NotImplementedError("Masker not given and we don't yet support pulling the distribution centering directly from the EBM model!")
48
+ return
49
+
50
+ # here we need to compute the offsets ourselves because we can't pull them directly from a model we know about
51
+ assert safe_isinstance(self.masker, "shap.maskers.Independent"), "The Additive explainer only supports the Tabular masker at the moment!"
52
+
53
+ # pre-compute per-feature offsets
54
+ fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, np.zeros(self.masker.shape[1]))
55
+ masks = np.ones((self.masker.shape[1]+1, self.masker.shape[1]), dtype=bool)
56
+ for i in range(1, self.masker.shape[1]+1):
57
+ masks[i,i-1] = False
58
+ outputs = fm(masks)
59
+ self._zero_offset = outputs[0]
60
+ self._input_offsets = np.zeros(masker.shape[1])
61
+ for i in range(1, self.masker.shape[1]+1):
62
+ self._input_offsets[i-1] = outputs[i] - self._zero_offset
63
+
64
+ self._expected_value = self._input_offsets.sum() + self._zero_offset
65
+
66
+ def __call__(self, *args, max_evals=None, silent=False):
67
+ """ Explains the output of model(*args), where args represents one or more parallel iterable args.
68
+ """
69
+
70
+ # we entirely rely on the general call implementation, we override just to remove **kwargs
71
+ # from the function signature
72
+ return super().__call__(*args, max_evals=max_evals, silent=silent)
73
+
74
+ @staticmethod
75
+ def supports_model_with_masker(model, masker):
76
+ """ Determines if this explainer can handle the given model.
77
+
78
+ This is an abstract static method meant to be implemented by each subclass.
79
+ """
80
+ if safe_isinstance(model, "interpret.glassbox.ExplainableBoostingClassifier"):
81
+ if model.interactions != 0:
82
+ raise NotImplementedError("Need to add support for interaction effects!")
83
+ return True
84
+
85
+ return False
86
+
87
+ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent):
88
+ """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
89
+ """
90
+
91
+ x = row_args[0]
92
+ inputs = np.zeros((len(x), len(x)))
93
+ for i in range(len(x)):
94
+ inputs[i,i] = x[i]
95
+
96
+ phi = self.model(inputs) - self._zero_offset - self._input_offsets
97
+
98
+ return {
99
+ "values": phi,
100
+ "expected_values": self._expected_value,
101
+ "mask_shapes": [a.shape for a in row_args],
102
+ "main_effects": phi,
103
+ "clustering": getattr(self.masker, "clustering", None)
104
+ }
105
+
106
+ # class AdditiveExplainer(Explainer):
107
+ # """ Computes SHAP values for generalized additive models.
108
+
109
+ # This assumes that the model only has first order effects. Extending this to
110
+ # 2nd and third order effects is future work (if you apply this to those models right now
111
+ # you will get incorrect answers that fail additivity).
112
+
113
+ # Parameters
114
+ # ----------
115
+ # model : function or ExplainableBoostingRegressor
116
+ # User supplied additive model either as either a function or a model object.
117
+
118
+ # data : numpy.array, pandas.DataFrame
119
+ # The background dataset to use for computing conditional expectations.
120
+ # feature_perturbation : "interventional"
121
+ # Only the standard interventional SHAP values are supported by AdditiveExplainer right now.
122
+ # """
123
+
124
+ # def __init__(self, model, data, feature_perturbation="interventional"):
125
+ # if feature_perturbation != "interventional":
126
+ # raise Exception("Unsupported type of feature_perturbation provided: " + feature_perturbation)
127
+
128
+ # if safe_isinstance(model, "interpret.glassbox.ebm.ebm.ExplainableBoostingRegressor"):
129
+ # self.f = model.predict
130
+ # elif callable(model):
131
+ # self.f = model
132
+ # else:
133
+ # raise ValueError("The passed model must be a recognized object or a function!")
134
+
135
+ # # convert dataframes
136
+ # if isinstance(data, (pd.Series, pd.DataFrame)):
137
+ # data = data.values
138
+ # self.data = data
139
+
140
+ # # compute the expected value of the model output
141
+ # self.expected_value = self.f(data).mean()
142
+
143
+ # # pre-compute per-feature offsets
144
+ # tmp = np.zeros(data.shape)
145
+ # self._zero_offset = self.f(tmp).mean()
146
+ # self._feature_offset = np.zeros(data.shape[1])
147
+ # for i in range(data.shape[1]):
148
+ # tmp[:,i] = data[:,i]
149
+ # self._feature_offset[i] = self.f(tmp).mean() - self._zero_offset
150
+ # tmp[:,i] = 0
151
+
152
+
153
+ # def shap_values(self, X):
154
+ # """ Estimate the SHAP values for a set of samples.
155
+
156
+ # Parameters
157
+ # ----------
158
+ # X : numpy.array, pandas.DataFrame or scipy.csr_matrix
159
+ # A matrix of samples (# samples x # features) on which to explain the model's output.
160
+
161
+ # Returns
162
+ # -------
163
+ # For models with a single output this returns a matrix of SHAP values
164
+ # (# samples x # features). Each row sums to the difference between the model output for that
165
+ # sample and the expected value of the model output (which is stored as expected_value
166
+ # attribute of the explainer).
167
+ # """
168
+
169
+ # # convert dataframes
170
+ # if isinstance(X, (pd.Series, pd.DataFrame)):
171
+ # X = X.values
172
+
173
+ # # assert isinstance(X, np.ndarray), "Unknown instance type: " + str(type(X))
174
+ # assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!"
175
+
176
+ # # convert dataframes
177
+ # if isinstance(X, (pd.Series, pd.DataFrame)):
178
+ # X = X.values
179
+
180
+ # phi = np.zeros(X.shape)
181
+ # tmp = np.zeros(X.shape)
182
+ # for i in range(X.shape[1]):
183
+ # tmp[:,i] = X[:,i]
184
+ # phi[:,i] = self.f(tmp) - self._zero_offset - self._feature_offset[i]
185
+ # tmp[:,i] = 0
186
+
187
+ # return phi
lib/shap/explainers/_deep/__init__.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .._explainer import Explainer
2
+ from .deep_pytorch import PyTorchDeep
3
+ from .deep_tf import TFDeep
4
+
5
+
6
+ class DeepExplainer(Explainer):
7
+ """ Meant to approximate SHAP values for deep learning models.
8
+
9
+ This is an enhanced version of the DeepLIFT algorithm (Deep SHAP) where, similar to Kernel SHAP, we
10
+ approximate the conditional expectations of SHAP values using a selection of background samples.
11
+ Lundberg and Lee, NIPS 2017 showed that the per node attribution rules in DeepLIFT (Shrikumar,
12
+ Greenside, and Kundaje, arXiv 2017) can be chosen to approximate Shapley values. By integrating
13
+ over many background samples Deep estimates approximate SHAP values such that they sum
14
+ up to the difference between the expected model output on the passed background samples and the
15
+ current model output (f(x) - E[f(x)]).
16
+
17
+ Examples
18
+ --------
19
+ See :ref:`Deep Explainer Examples <deep_explainer_examples>`
20
+ """
21
+
22
+ def __init__(self, model, data, session=None, learning_phase_flags=None):
23
+ """ An explainer object for a differentiable model using a given background dataset.
24
+
25
+ Note that the complexity of the method scales linearly with the number of background data
26
+ samples. Passing the entire training dataset as `data` will give very accurate expected
27
+ values, but be unreasonably expensive. The variance of the expectation estimates scale by
28
+ roughly 1/sqrt(N) for N background data samples. So 100 samples will give a good estimate,
29
+ and 1000 samples a very good estimate of the expected values.
30
+
31
+ Parameters
32
+ ----------
33
+ model : if framework == 'tensorflow', (input : [tf.Tensor], output : tf.Tensor)
34
+ A pair of TensorFlow tensors (or a list and a tensor) that specifies the input and
35
+ output of the model to be explained. Note that SHAP values are specific to a single
36
+ output value, so the output tf.Tensor should be a single dimensional output (,1).
37
+
38
+ if framework == 'pytorch', an nn.Module object (model), or a tuple (model, layer),
39
+ where both are nn.Module objects
40
+ The model is an nn.Module object which takes as input a tensor (or list of tensors) of
41
+ shape data, and returns a single dimensional output.
42
+ If the input is a tuple, the returned shap values will be for the input of the
43
+ layer argument. layer must be a layer in the model, i.e. model.conv2
44
+
45
+ data :
46
+ if framework == 'tensorflow': [numpy.array] or [pandas.DataFrame]
47
+ if framework == 'pytorch': [torch.tensor]
48
+ The background dataset to use for integrating out features. Deep integrates
49
+ over these samples. The data passed here must match the input tensors given in the
50
+ first argument. Note that since these samples are integrated over for each sample you
51
+ should only something like 100 or 1000 random background samples, not the whole training
52
+ dataset.
53
+
54
+ if framework == 'tensorflow':
55
+
56
+ session : None or tensorflow.Session
57
+ The TensorFlow session that has the model we are explaining. If None is passed then
58
+ we do our best to find the right session, first looking for a keras session, then
59
+ falling back to the default TensorFlow session.
60
+
61
+ learning_phase_flags : None or list of tensors
62
+ If you have your own custom learning phase flags pass them here. When explaining a prediction
63
+ we need to ensure we are not in training mode, since this changes the behavior of ops like
64
+ batch norm or dropout. If None is passed then we look for tensors in the graph that look like
65
+ learning phase flags (this works for Keras models). Note that we assume all the flags should
66
+ have a value of False during predictions (and hence explanations).
67
+ """
68
+ # first, we need to find the framework
69
+ if type(model) is tuple:
70
+ a, b = model
71
+ try:
72
+ a.named_parameters()
73
+ framework = 'pytorch'
74
+ except Exception:
75
+ framework = 'tensorflow'
76
+ else:
77
+ try:
78
+ model.named_parameters()
79
+ framework = 'pytorch'
80
+ except Exception:
81
+ framework = 'tensorflow'
82
+
83
+ if framework == 'tensorflow':
84
+ self.explainer = TFDeep(model, data, session, learning_phase_flags)
85
+ elif framework == 'pytorch':
86
+ self.explainer = PyTorchDeep(model, data)
87
+
88
+ self.expected_value = self.explainer.expected_value
89
+ self.explainer.framework = framework
90
+
91
+ def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True):
92
+ """ Return approximate SHAP values for the model applied to the data given by X.
93
+
94
+ Parameters
95
+ ----------
96
+ X : list,
97
+ if framework == 'tensorflow': numpy.array, or pandas.DataFrame
98
+ if framework == 'pytorch': torch.tensor
99
+ A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to
100
+ explain the model's output.
101
+
102
+ ranked_outputs : None or int
103
+ If ranked_outputs is None then we explain all the outputs in a multi-output model. If
104
+ ranked_outputs is a positive integer then we only explain that many of the top model
105
+ outputs (where "top" is determined by output_rank_order). Note that this causes a pair
106
+ of values to be returned (shap_values, indexes), where shap_values is a list of numpy
107
+ arrays for each of the output ranks, and indexes is a matrix that indicates for each sample
108
+ which output indexes were choses as "top".
109
+
110
+ output_rank_order : "max", "min", or "max_abs"
111
+ How to order the model outputs when using ranked_outputs, either by maximum, minimum, or
112
+ maximum absolute value.
113
+
114
+ Returns
115
+ -------
116
+ array or list
117
+ For a models with a single output this returns a tensor of SHAP values with the same shape
118
+ as X. For a model with multiple outputs this returns a list of SHAP value tensors, each of
119
+ which are the same shape as X. If ranked_outputs is None then this list of tensors matches
120
+ the number of model outputs. If ranked_outputs is a positive integer a pair is returned
121
+ (shap_values, indexes), where shap_values is a list of tensors with a length of
122
+ ranked_outputs, and indexes is a matrix that indicates for each sample which output indexes
123
+ were chosen as "top".
124
+ """
125
+ return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
lib/shap/explainers/_deep/deep_pytorch.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from packaging import version
5
+
6
+ from .._explainer import Explainer
7
+ from .deep_utils import _check_additivity
8
+
9
+ torch = None
10
+
11
+
12
+ class PyTorchDeep(Explainer):
13
+
14
+ def __init__(self, model, data):
15
+ # try and import pytorch
16
+ global torch
17
+ if torch is None:
18
+ import torch
19
+ if version.parse(torch.__version__) < version.parse("0.4"):
20
+ warnings.warn("Your PyTorch version is older than 0.4 and not supported.")
21
+
22
+ # check if we have multiple inputs
23
+ self.multi_input = False
24
+ if isinstance(data, list):
25
+ self.multi_input = True
26
+ if not isinstance(data, list):
27
+ data = [data]
28
+ self.data = data
29
+ self.layer = None
30
+ self.input_handle = None
31
+ self.interim = False
32
+ self.interim_inputs_shape = None
33
+ self.expected_value = None # to keep the DeepExplainer base happy
34
+ if type(model) == tuple:
35
+ self.interim = True
36
+ model, layer = model
37
+ model = model.eval()
38
+ self.layer = layer
39
+ self.add_target_handle(self.layer)
40
+
41
+ # if we are taking an interim layer, the 'data' is going to be the input
42
+ # of the interim layer; we will capture this using a forward hook
43
+ with torch.no_grad():
44
+ _ = model(*data)
45
+ interim_inputs = self.layer.target_input
46
+ if type(interim_inputs) is tuple:
47
+ # this should always be true, but just to be safe
48
+ self.interim_inputs_shape = [i.shape for i in interim_inputs]
49
+ else:
50
+ self.interim_inputs_shape = [interim_inputs.shape]
51
+ self.target_handle.remove()
52
+ del self.layer.target_input
53
+ self.model = model.eval()
54
+
55
+ self.multi_output = False
56
+ self.num_outputs = 1
57
+ with torch.no_grad():
58
+ outputs = model(*data)
59
+
60
+ # also get the device everything is running on
61
+ self.device = outputs.device
62
+ if outputs.shape[1] > 1:
63
+ self.multi_output = True
64
+ self.num_outputs = outputs.shape[1]
65
+ self.expected_value = outputs.mean(0).cpu().numpy()
66
+
67
+ def add_target_handle(self, layer):
68
+ input_handle = layer.register_forward_hook(get_target_input)
69
+ self.target_handle = input_handle
70
+
71
+ def add_handles(self, model, forward_handle, backward_handle):
72
+ """
73
+ Add handles to all non-container layers in the model.
74
+ Recursively for non-container layers
75
+ """
76
+ handles_list = []
77
+ model_children = list(model.children())
78
+ if model_children:
79
+ for child in model_children:
80
+ handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
81
+ else: # leaves
82
+ handles_list.append(model.register_forward_hook(forward_handle))
83
+ handles_list.append(model.register_full_backward_hook(backward_handle))
84
+ return handles_list
85
+
86
+ def remove_attributes(self, model):
87
+ """
88
+ Removes the x and y attributes which were added by the forward handles
89
+ Recursively searches for non-container layers
90
+ """
91
+ for child in model.children():
92
+ if 'nn.modules.container' in str(type(child)):
93
+ self.remove_attributes(child)
94
+ else:
95
+ try:
96
+ del child.x
97
+ except AttributeError:
98
+ pass
99
+ try:
100
+ del child.y
101
+ except AttributeError:
102
+ pass
103
+
104
+ def gradient(self, idx, inputs):
105
+ self.model.zero_grad()
106
+ X = [x.requires_grad_() for x in inputs]
107
+ outputs = self.model(*X)
108
+ selected = [val for val in outputs[:, idx]]
109
+ grads = []
110
+ if self.interim:
111
+ interim_inputs = self.layer.target_input
112
+ for idx, input in enumerate(interim_inputs):
113
+ grad = torch.autograd.grad(selected, input,
114
+ retain_graph=True if idx + 1 < len(interim_inputs) else None,
115
+ allow_unused=True)[0]
116
+ if grad is not None:
117
+ grad = grad.cpu().numpy()
118
+ else:
119
+ grad = torch.zeros_like(X[idx]).cpu().numpy()
120
+ grads.append(grad)
121
+ del self.layer.target_input
122
+ return grads, [i.detach().cpu().numpy() for i in interim_inputs]
123
+ else:
124
+ for idx, x in enumerate(X):
125
+ grad = torch.autograd.grad(selected, x,
126
+ retain_graph=True if idx + 1 < len(X) else None,
127
+ allow_unused=True)[0]
128
+ if grad is not None:
129
+ grad = grad.cpu().numpy()
130
+ else:
131
+ grad = torch.zeros_like(X[idx]).cpu().numpy()
132
+ grads.append(grad)
133
+ return grads
134
+
135
+ def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_additivity=True):
136
+ # X ~ self.model_input
137
+ # X_data ~ self.data
138
+
139
+ # check if we have multiple inputs
140
+ if not self.multi_input:
141
+ assert not isinstance(X, list), "Expected a single tensor model input!"
142
+ X = [X]
143
+ else:
144
+ assert isinstance(X, list), "Expected a list of model inputs!"
145
+
146
+ X = [x.detach().to(self.device) for x in X]
147
+
148
+ model_output_values = None
149
+
150
+ if ranked_outputs is not None and self.multi_output:
151
+ with torch.no_grad():
152
+ model_output_values = self.model(*X)
153
+ # rank and determine the model outputs that we will explain
154
+ if output_rank_order == "max":
155
+ _, model_output_ranks = torch.sort(model_output_values, descending=True)
156
+ elif output_rank_order == "min":
157
+ _, model_output_ranks = torch.sort(model_output_values, descending=False)
158
+ elif output_rank_order == "max_abs":
159
+ _, model_output_ranks = torch.sort(torch.abs(model_output_values), descending=True)
160
+ else:
161
+ emsg = "output_rank_order must be max, min, or max_abs!"
162
+ raise ValueError(emsg)
163
+ model_output_ranks = model_output_ranks[:, :ranked_outputs]
164
+ else:
165
+ model_output_ranks = (torch.ones((X[0].shape[0], self.num_outputs)).int() *
166
+ torch.arange(0, self.num_outputs).int())
167
+
168
+ # add the gradient handles
169
+ handles = self.add_handles(self.model, add_interim_values, deeplift_grad)
170
+ if self.interim:
171
+ self.add_target_handle(self.layer)
172
+
173
+ # compute the attributions
174
+ output_phis = []
175
+ for i in range(model_output_ranks.shape[1]):
176
+ phis = []
177
+ if self.interim:
178
+ for k in range(len(self.interim_inputs_shape)):
179
+ phis.append(np.zeros((X[0].shape[0], ) + self.interim_inputs_shape[k][1: ]))
180
+ else:
181
+ for k in range(len(X)):
182
+ phis.append(np.zeros(X[k].shape))
183
+ for j in range(X[0].shape[0]):
184
+ # tile the inputs to line up with the background data samples
185
+ tiled_X = [X[t][j:j + 1].repeat(
186
+ (self.data[t].shape[0],) + tuple([1 for k in range(len(X[t].shape) - 1)])) for t
187
+ in range(len(X))]
188
+ joint_x = [torch.cat((tiled_X[t], self.data[t]), dim=0) for t in range(len(X))]
189
+ # run attribution computation graph
190
+ feature_ind = model_output_ranks[j, i]
191
+ sample_phis = self.gradient(feature_ind, joint_x)
192
+ # assign the attributions to the right part of the output arrays
193
+ if self.interim:
194
+ sample_phis, output = sample_phis
195
+ x, data = [], []
196
+ for k in range(len(output)):
197
+ x_temp, data_temp = np.split(output[k], 2)
198
+ x.append(x_temp)
199
+ data.append(data_temp)
200
+ for t in range(len(self.interim_inputs_shape)):
201
+ phis[t][j] = (sample_phis[t][self.data[t].shape[0]:] * (x[t] - data[t])).mean(0)
202
+ else:
203
+ for t in range(len(X)):
204
+ phis[t][j] = (torch.from_numpy(sample_phis[t][self.data[t].shape[0]:]).to(self.device) * (X[t][j: j + 1] - self.data[t])).cpu().detach().numpy().mean(0)
205
+ output_phis.append(phis[0] if not self.multi_input else phis)
206
+ # cleanup; remove all gradient handles
207
+ for handle in handles:
208
+ handle.remove()
209
+ self.remove_attributes(self.model)
210
+ if self.interim:
211
+ self.target_handle.remove()
212
+
213
+ # check that the SHAP values sum up to the model output
214
+ if check_additivity:
215
+ if model_output_values is None:
216
+ with torch.no_grad():
217
+ model_output_values = self.model(*X)
218
+
219
+ _check_additivity(self, model_output_values.cpu(), output_phis)
220
+
221
+ if not self.multi_output:
222
+ return output_phis[0]
223
+ elif ranked_outputs is not None:
224
+ return output_phis, model_output_ranks
225
+ else:
226
+ return output_phis
227
+
228
+ # Module hooks
229
+
230
+
231
+ def deeplift_grad(module, grad_input, grad_output):
232
+ """The backward hook which computes the deeplift
233
+ gradient for an nn.Module
234
+ """
235
+ # first, get the module type
236
+ module_type = module.__class__.__name__
237
+ # first, check the module is supported
238
+ if module_type in op_handler:
239
+ if op_handler[module_type].__name__ not in ['passthrough', 'linear_1d']:
240
+ return op_handler[module_type](module, grad_input, grad_output)
241
+ else:
242
+ warnings.warn(f'unrecognized nn.Module: {module_type}')
243
+ return grad_input
244
+
245
+
246
+ def add_interim_values(module, input, output):
247
+ """The forward hook used to save interim tensors, detached
248
+ from the graph. Used to calculate the multipliers
249
+ """
250
+ try:
251
+ del module.x
252
+ except AttributeError:
253
+ pass
254
+ try:
255
+ del module.y
256
+ except AttributeError:
257
+ pass
258
+ module_type = module.__class__.__name__
259
+ if module_type in op_handler:
260
+ func_name = op_handler[module_type].__name__
261
+ # First, check for cases where we don't need to save the x and y tensors
262
+ if func_name == 'passthrough':
263
+ pass
264
+ else:
265
+ # check only the 0th input varies
266
+ for i in range(len(input)):
267
+ if i != 0 and type(output) is tuple:
268
+ assert input[i] == output[i], "Only the 0th input may vary!"
269
+ # if a new method is added, it must be added here too. This ensures tensors
270
+ # are only saved if necessary
271
+ if func_name in ['maxpool', 'nonlinear_1d']:
272
+ # only save tensors if necessary
273
+ if type(input) is tuple:
274
+ setattr(module, 'x', torch.nn.Parameter(input[0].detach()))
275
+ else:
276
+ setattr(module, 'x', torch.nn.Parameter(input.detach()))
277
+ if type(output) is tuple:
278
+ setattr(module, 'y', torch.nn.Parameter(output[0].detach()))
279
+ else:
280
+ setattr(module, 'y', torch.nn.Parameter(output.detach()))
281
+
282
+
283
+ def get_target_input(module, input, output):
284
+ """A forward hook which saves the tensor - attached to its graph.
285
+ Used if we want to explain the interim outputs of a model
286
+ """
287
+ try:
288
+ del module.target_input
289
+ except AttributeError:
290
+ pass
291
+ setattr(module, 'target_input', input)
292
+
293
+
294
+ def passthrough(module, grad_input, grad_output):
295
+ """No change made to gradients"""
296
+ return None
297
+
298
+
299
+ def maxpool(module, grad_input, grad_output):
300
+ pool_to_unpool = {
301
+ 'MaxPool1d': torch.nn.functional.max_unpool1d,
302
+ 'MaxPool2d': torch.nn.functional.max_unpool2d,
303
+ 'MaxPool3d': torch.nn.functional.max_unpool3d
304
+ }
305
+ pool_to_function = {
306
+ 'MaxPool1d': torch.nn.functional.max_pool1d,
307
+ 'MaxPool2d': torch.nn.functional.max_pool2d,
308
+ 'MaxPool3d': torch.nn.functional.max_pool3d
309
+ }
310
+ delta_in = module.x[: int(module.x.shape[0] / 2)] - module.x[int(module.x.shape[0] / 2):]
311
+ dup0 = [2] + [1 for i in delta_in.shape[1:]]
312
+ # we also need to check if the output is a tuple
313
+ y, ref_output = torch.chunk(module.y, 2)
314
+ cross_max = torch.max(y, ref_output)
315
+ diffs = torch.cat([cross_max - ref_output, y - cross_max], 0)
316
+
317
+ # all of this just to unpool the outputs
318
+ with torch.no_grad():
319
+ _, indices = pool_to_function[module.__class__.__name__](
320
+ module.x, module.kernel_size, module.stride, module.padding,
321
+ module.dilation, module.ceil_mode, True)
322
+ xmax_pos, rmax_pos = torch.chunk(pool_to_unpool[module.__class__.__name__](
323
+ grad_output[0] * diffs, indices, module.kernel_size, module.stride,
324
+ module.padding, list(module.x.shape)), 2)
325
+
326
+ grad_input = [None for _ in grad_input]
327
+ grad_input[0] = torch.where(torch.abs(delta_in) < 1e-7, torch.zeros_like(delta_in),
328
+ (xmax_pos + rmax_pos) / delta_in).repeat(dup0)
329
+
330
+ return tuple(grad_input)
331
+
332
+
333
+ def linear_1d(module, grad_input, grad_output):
334
+ """No change made to gradients."""
335
+ return None
336
+
337
+
338
+ def nonlinear_1d(module, grad_input, grad_output):
339
+ delta_out = module.y[: int(module.y.shape[0] / 2)] - module.y[int(module.y.shape[0] / 2):]
340
+
341
+ delta_in = module.x[: int(module.x.shape[0] / 2)] - module.x[int(module.x.shape[0] / 2):]
342
+ dup0 = [2] + [1 for i in delta_in.shape[1:]]
343
+ # handles numerical instabilities where delta_in is very small by
344
+ # just taking the gradient in those cases
345
+ grads = [None for _ in grad_input]
346
+ grads[0] = torch.where(torch.abs(delta_in.repeat(dup0)) < 1e-6, grad_input[0],
347
+ grad_output[0] * (delta_out / delta_in).repeat(dup0))
348
+ return tuple(grads)
349
+
350
+
351
+ op_handler = {}
352
+
353
+ # passthrough ops, where we make no change to the gradient
354
+ op_handler['Dropout3d'] = passthrough
355
+ op_handler['Dropout2d'] = passthrough
356
+ op_handler['Dropout'] = passthrough
357
+ op_handler['AlphaDropout'] = passthrough
358
+
359
+ op_handler['Conv1d'] = linear_1d
360
+ op_handler['Conv2d'] = linear_1d
361
+ op_handler['Conv3d'] = linear_1d
362
+ op_handler['ConvTranspose1d'] = linear_1d
363
+ op_handler['ConvTranspose2d'] = linear_1d
364
+ op_handler['ConvTranspose3d'] = linear_1d
365
+ op_handler['Linear'] = linear_1d
366
+ op_handler['AvgPool1d'] = linear_1d
367
+ op_handler['AvgPool2d'] = linear_1d
368
+ op_handler['AvgPool3d'] = linear_1d
369
+ op_handler['AdaptiveAvgPool1d'] = linear_1d
370
+ op_handler['AdaptiveAvgPool2d'] = linear_1d
371
+ op_handler['AdaptiveAvgPool3d'] = linear_1d
372
+ op_handler['BatchNorm1d'] = linear_1d
373
+ op_handler['BatchNorm2d'] = linear_1d
374
+ op_handler['BatchNorm3d'] = linear_1d
375
+
376
+ op_handler['LeakyReLU'] = nonlinear_1d
377
+ op_handler['ReLU'] = nonlinear_1d
378
+ op_handler['ELU'] = nonlinear_1d
379
+ op_handler['Sigmoid'] = nonlinear_1d
380
+ op_handler["Tanh"] = nonlinear_1d
381
+ op_handler["Softplus"] = nonlinear_1d
382
+ op_handler['Softmax'] = nonlinear_1d
383
+
384
+ op_handler['MaxPool1d'] = maxpool
385
+ op_handler['MaxPool2d'] = maxpool
386
+ op_handler['MaxPool3d'] = maxpool
lib/shap/explainers/_deep/deep_tf.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import numpy as np
4
+ from packaging import version
5
+
6
+ from ...utils._exceptions import DimensionError
7
+ from .._explainer import Explainer
8
+ from ..tf_utils import _get_graph, _get_model_inputs, _get_model_output, _get_session
9
+ from .deep_utils import _check_additivity
10
+
11
+ tf = None
12
+ tf_ops = None
13
+ tf_backprop = None
14
+ tf_execute = None
15
+ tf_gradients_impl = None
16
+
17
+ def custom_record_gradient(op_name, inputs, attrs, results):
18
+ """ This overrides tensorflow.python.eager.backprop._record_gradient.
19
+
20
+ We need to override _record_gradient in order to get gradient backprop to
21
+ get called for ResourceGather operations. In order to make this work we
22
+ temporarily "lie" about the input type to prevent the node from getting
23
+ pruned from the gradient backprop process. We then reset the type directly
24
+ afterwards back to what it was (an integer type).
25
+ """
26
+ reset_input = False
27
+ if op_name == "ResourceGather" and inputs[1].dtype == tf.int32:
28
+ inputs[1].__dict__["_dtype"] = tf.float32
29
+ reset_input = True
30
+ try:
31
+ out = tf_backprop._record_gradient("shap_"+op_name, inputs, attrs, results)
32
+ except AttributeError:
33
+ out = tf_backprop.record_gradient("shap_"+op_name, inputs, attrs, results)
34
+
35
+ if reset_input:
36
+ inputs[1].__dict__["_dtype"] = tf.int32
37
+
38
+ return out
39
+
40
+ class TFDeep(Explainer):
41
+ """
42
+ Using tf.gradients to implement the backpropagation was
43
+ inspired by the gradient-based implementation approach proposed by Ancona et al, ICLR 2018. Note
44
+ that this package does not currently use the reveal-cancel rule for ReLu units proposed in DeepLIFT.
45
+ """
46
+
47
+ def __init__(self, model, data, session=None, learning_phase_flags=None):
48
+ """ An explainer object for a deep model using a given background dataset.
49
+
50
+ Note that the complexity of the method scales linearly with the number of background data
51
+ samples. Passing the entire training dataset as `data` will give very accurate expected
52
+ values, but will be computationally expensive. The variance of the expectation estimates scales by
53
+ roughly 1/sqrt(N) for N background data samples. So 100 samples will give a good estimate,
54
+ and 1000 samples a very good estimate of the expected values.
55
+
56
+ Parameters
57
+ ----------
58
+ model : tf.keras.Model or (input : [tf.Operation], output : tf.Operation)
59
+ A keras model object or a pair of TensorFlow operations (or a list and an op) that
60
+ specifies the input and output of the model to be explained. Note that SHAP values
61
+ are specific to a single output value, so you get an explanation for each element of
62
+ the output tensor (which must be a flat rank one vector).
63
+
64
+ data : [numpy.array] or [pandas.DataFrame] or function
65
+ The background dataset to use for integrating out features. DeepExplainer integrates
66
+ over all these samples for each explanation. The data passed here must match the input
67
+ operations given to the model. If a function is supplied, it must be a function that
68
+ takes a particular input example and generates the background dataset for that example
69
+ session : None or tensorflow.Session
70
+ The TensorFlow session that has the model we are explaining. If None is passed then
71
+ we do our best to find the right session, first looking for a keras session, then
72
+ falling back to the default TensorFlow session.
73
+
74
+ learning_phase_flags : None or list of tensors
75
+ If you have your own custom learning phase flags pass them here. When explaining a prediction
76
+ we need to ensure we are not in training mode, since this changes the behavior of ops like
77
+ batch norm or dropout. If None is passed then we look for tensors in the graph that look like
78
+ learning phase flags (this works for Keras models). Note that we assume all the flags should
79
+ have a value of False during predictions (and hence explanations).
80
+
81
+ """
82
+ # try to import tensorflow
83
+ global tf, tf_ops, tf_backprop, tf_execute, tf_gradients_impl
84
+ if tf is None:
85
+ from tensorflow.python.eager import backprop as tf_backprop
86
+ from tensorflow.python.eager import execute as tf_execute
87
+ from tensorflow.python.framework import (
88
+ ops as tf_ops,
89
+ )
90
+ from tensorflow.python.ops import (
91
+ gradients_impl as tf_gradients_impl,
92
+ )
93
+ if not hasattr(tf_gradients_impl, "_IsBackpropagatable"):
94
+ from tensorflow.python.ops import gradients_util as tf_gradients_impl
95
+ import tensorflow as tf
96
+ if version.parse(tf.__version__) < version.parse("1.4.0"):
97
+ warnings.warn("Your TensorFlow version is older than 1.4.0 and not supported.")
98
+
99
+ if version.parse(tf.__version__) >= version.parse("2.4.0"):
100
+ warnings.warn("Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.")
101
+
102
+ # determine the model inputs and outputs
103
+ self.model_inputs = _get_model_inputs(model)
104
+ self.model_output = _get_model_output(model)
105
+ assert not isinstance(self.model_output, list), "The model output to be explained must be a single tensor!"
106
+ assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!"
107
+ self.multi_output = True
108
+ if len(self.model_output.shape) == 1:
109
+ self.multi_output = False
110
+
111
+ if tf.executing_eagerly():
112
+ if isinstance(model, tuple) or isinstance(model, list):
113
+ assert len(model) == 2, "When a tuple is passed it must be of the form (inputs, outputs)"
114
+ from tensorflow.keras import Model
115
+ self.model = Model(model[0], model[1])
116
+ else:
117
+ self.model = model
118
+
119
+ # check if we have multiple inputs
120
+ self.multi_input = True
121
+ if not isinstance(self.model_inputs, list) or len(self.model_inputs) == 1:
122
+ self.multi_input = False
123
+ if not isinstance(self.model_inputs, list):
124
+ self.model_inputs = [self.model_inputs]
125
+ if not isinstance(data, list) and (hasattr(data, "__call__") is False):
126
+ data = [data]
127
+ self.data = data
128
+
129
+ self._vinputs = {} # used to track what op inputs depends on the model inputs
130
+ self.orig_grads = {}
131
+
132
+ if not tf.executing_eagerly():
133
+ self.session = _get_session(session)
134
+
135
+ self.graph = _get_graph(self)
136
+
137
+ # if no learning phase flags were given we go looking for them
138
+ # ...this will catch the one that keras uses
139
+ # we need to find them since we want to make sure learning phase flags are set to False
140
+ if learning_phase_flags is None:
141
+ self.learning_phase_ops = []
142
+ for op in self.graph.get_operations():
143
+ if 'learning_phase' in op.name and op.type == "Const" and len(op.outputs[0].shape) == 0:
144
+ if op.outputs[0].dtype == tf.bool:
145
+ self.learning_phase_ops.append(op)
146
+ self.learning_phase_flags = [op.outputs[0] for op in self.learning_phase_ops]
147
+ else:
148
+ self.learning_phase_ops = [t.op for t in learning_phase_flags]
149
+
150
+ # save the expected output of the model
151
+ # if self.data is a function, set self.expected_value to None
152
+ if (hasattr(self.data, '__call__')):
153
+ self.expected_value = None
154
+ else:
155
+ if self.data[0].shape[0] > 5000:
156
+ warnings.warn("You have provided over 5k background samples! For better performance consider using smaller random sample.")
157
+ if not tf.executing_eagerly():
158
+ self.expected_value = self.run(self.model_output, self.model_inputs, self.data).mean(0)
159
+ else:
160
+ #if type(self.model)is tuple:
161
+ # self.fModel(cnn.inputs, cnn.get_layer(theNameYouWant).outputs)
162
+ self.expected_value = tf.reduce_mean(self.model(self.data), 0)
163
+
164
+ if not tf.executing_eagerly():
165
+ self._init_between_tensors(self.model_output.op, self.model_inputs)
166
+
167
+ # make a blank array that will get lazily filled in with the SHAP value computation
168
+ # graphs for each output. Lazy is important since if there are 1000 outputs and we
169
+ # only explain the top 5 it would be a waste to build graphs for the other 995
170
+ if not self.multi_output:
171
+ self.phi_symbolics = [None]
172
+ else:
173
+ noutputs = self.model_output.shape.as_list()[1]
174
+ if noutputs is not None:
175
+ self.phi_symbolics = [None for i in range(noutputs)]
176
+ else:
177
+ raise DimensionError("The model output tensor to be explained cannot have a static shape in dim 1 of None!")
178
+
179
+ def _get_model_output(self, model):
180
+ if len(model.layers[-1]._inbound_nodes) == 0:
181
+ if len(model.outputs) > 1:
182
+ warnings.warn("Only one model output supported.")
183
+ return model.outputs[0]
184
+ else:
185
+ return model.layers[-1].output
186
+
187
+ def _init_between_tensors(self, out_op, model_inputs):
188
+ # find all the operations in the graph between our inputs and outputs
189
+ tensor_blacklist = tensors_blocked_by_false(self.learning_phase_ops) # don't follow learning phase branches
190
+ dependence_breakers = [k for k in op_handlers if op_handlers[k] == break_dependence]
191
+ back_ops = backward_walk_ops(
192
+ [out_op], tensor_blacklist,
193
+ dependence_breakers
194
+ )
195
+ start_ops = []
196
+ for minput in model_inputs:
197
+ for op in minput.consumers():
198
+ start_ops.append(op)
199
+ self.between_ops = forward_walk_ops(
200
+ start_ops,
201
+ tensor_blacklist, dependence_breakers,
202
+ within_ops=back_ops
203
+ )
204
+
205
+ # note all the tensors that are on the path between the inputs and the output
206
+ self.between_tensors = {}
207
+ for op in self.between_ops:
208
+ for t in op.outputs:
209
+ self.between_tensors[t.name] = True
210
+ for t in model_inputs:
211
+ self.between_tensors[t.name] = True
212
+
213
+ # save what types are being used
214
+ self.used_types = {}
215
+ for op in self.between_ops:
216
+ self.used_types[op.type] = True
217
+
218
+ def _variable_inputs(self, op):
219
+ """ Return which inputs of this operation are variable (i.e. depend on the model inputs).
220
+ """
221
+ if op not in self._vinputs:
222
+ out = np.zeros(len(op.inputs), dtype=bool)
223
+ for i,t in enumerate(op.inputs):
224
+ out[i] = t.name in self.between_tensors
225
+ self._vinputs[op] = out
226
+ return self._vinputs[op]
227
+
228
+ def phi_symbolic(self, i):
229
+ """ Get the SHAP value computation graph for a given model output.
230
+ """
231
+ if self.phi_symbolics[i] is None:
232
+
233
+ if not tf.executing_eagerly():
234
+ def anon():
235
+ out = self.model_output[:,i] if self.multi_output else self.model_output
236
+ return tf.gradients(out, self.model_inputs)
237
+
238
+ self.phi_symbolics[i] = self.execute_with_overridden_gradients(anon)
239
+ else:
240
+ @tf.function
241
+ def grad_graph(shap_rAnD):
242
+ phase = tf.keras.backend.learning_phase()
243
+ tf.keras.backend.set_learning_phase(0)
244
+
245
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
246
+ tape.watch(shap_rAnD)
247
+ out = self.model(shap_rAnD)
248
+ if self.multi_output:
249
+ out = out[:,i]
250
+
251
+ self._init_between_tensors(out.op, shap_rAnD)
252
+ x_grad = tape.gradient(out, shap_rAnD)
253
+ tf.keras.backend.set_learning_phase(phase)
254
+ return x_grad
255
+
256
+ self.phi_symbolics[i] = grad_graph
257
+
258
+ return self.phi_symbolics[i]
259
+
260
+ def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_additivity=True):
261
+ # check if we have multiple inputs
262
+ if not self.multi_input:
263
+ if isinstance(X, list) and len(X) != 1:
264
+ raise ValueError("Expected a single tensor as model input!")
265
+ elif not isinstance(X, list):
266
+ X = [X]
267
+ else:
268
+ assert isinstance(X, list), "Expected a list of model inputs!"
269
+ assert len(self.model_inputs) == len(X), "Number of model inputs (%d) does not match the number given (%d)!" % (len(self.model_inputs), len(X))
270
+
271
+ # rank and determine the model outputs that we will explain
272
+ if ranked_outputs is not None and self.multi_output:
273
+ if not tf.executing_eagerly():
274
+ model_output_values = self.run(self.model_output, self.model_inputs, X)
275
+ else:
276
+ model_output_values = self.model(X)
277
+
278
+ if output_rank_order == "max":
279
+ model_output_ranks = np.argsort(-model_output_values)
280
+ elif output_rank_order == "min":
281
+ model_output_ranks = np.argsort(model_output_values)
282
+ elif output_rank_order == "max_abs":
283
+ model_output_ranks = np.argsort(np.abs(model_output_values))
284
+ else:
285
+ emsg = "output_rank_order must be max, min, or max_abs!"
286
+ raise ValueError(emsg)
287
+ model_output_ranks = model_output_ranks[:,:ranked_outputs]
288
+ else:
289
+ model_output_ranks = np.tile(np.arange(len(self.phi_symbolics)), (X[0].shape[0], 1))
290
+
291
+ # compute the attributions
292
+ output_phis = []
293
+ for i in range(model_output_ranks.shape[1]):
294
+ phis = []
295
+ for k in range(len(X)):
296
+ phis.append(np.zeros(X[k].shape))
297
+ for j in range(X[0].shape[0]):
298
+ if (hasattr(self.data, '__call__')):
299
+ bg_data = self.data([X[t][j] for t in range(len(X))])
300
+ if not isinstance(bg_data, list):
301
+ bg_data = [bg_data]
302
+ else:
303
+ bg_data = self.data
304
+
305
+ # tile the inputs to line up with the background data samples
306
+ tiled_X = [np.tile(X[t][j:j+1], (bg_data[t].shape[0],) + tuple([1 for k in range(len(X[t].shape)-1)])) for t in range(len(X))]
307
+
308
+ # we use the first sample for the current sample and the rest for the references
309
+ joint_input = [np.concatenate([tiled_X[t], bg_data[t]], 0) for t in range(len(X))]
310
+
311
+ # run attribution computation graph
312
+ feature_ind = model_output_ranks[j,i]
313
+ sample_phis = self.run(self.phi_symbolic(feature_ind), self.model_inputs, joint_input)
314
+
315
+ # assign the attributions to the right part of the output arrays
316
+ for t in range(len(X)):
317
+ phis[t][j] = (sample_phis[t][bg_data[t].shape[0]:] * (X[t][j] - bg_data[t])).mean(0)
318
+
319
+ output_phis.append(phis[0] if not self.multi_input else phis)
320
+
321
+ # check that the SHAP values sum up to the model output
322
+ if check_additivity:
323
+ if not tf.executing_eagerly():
324
+ model_output = self.run(self.model_output, self.model_inputs, X)
325
+ else:
326
+ model_output = self.model(X)
327
+
328
+ _check_additivity(self, model_output, output_phis)
329
+
330
+ if not self.multi_output:
331
+ return output_phis[0]
332
+ elif ranked_outputs is not None:
333
+ return output_phis, model_output_ranks
334
+ else:
335
+ return output_phis
336
+
337
+ def run(self, out, model_inputs, X):
338
+ """ Runs the model while also setting the learning phase flags to False.
339
+ """
340
+ if not tf.executing_eagerly():
341
+ feed_dict = dict(zip(model_inputs, X))
342
+ for t in self.learning_phase_flags:
343
+ feed_dict[t] = False
344
+ return self.session.run(out, feed_dict)
345
+ else:
346
+ def anon():
347
+ tf_execute.record_gradient = custom_record_gradient
348
+
349
+ # build inputs that are correctly shaped, typed, and tf-wrapped
350
+ inputs = []
351
+ for i in range(len(X)):
352
+ shape = list(self.model_inputs[i].shape)
353
+ shape[0] = -1
354
+ data = X[i].reshape(shape)
355
+ v = tf.constant(data, dtype=self.model_inputs[i].dtype)
356
+ inputs.append(v)
357
+ final_out = out(inputs)
358
+ try:
359
+ tf_execute.record_gradient = tf_backprop._record_gradient
360
+ except AttributeError:
361
+ tf_execute.record_gradient = tf_backprop.record_gradient
362
+
363
+ return final_out
364
+ return self.execute_with_overridden_gradients(anon)
365
+
366
+ def custom_grad(self, op, *grads):
367
+ """ Passes a gradient op creation request to the correct handler.
368
+ """
369
+ type_name = op.type[5:] if op.type.startswith("shap_") else op.type
370
+ out = op_handlers[type_name](self, op, *grads) # we cut off the shap_ prefix before the lookup
371
+ return out
372
+
373
+ def execute_with_overridden_gradients(self, f):
374
+ # replace the gradients for all the non-linear activations
375
+ # we do this by hacking our way into the registry (TODO: find a public API for this if it exists)
376
+ reg = tf_ops._gradient_registry._registry
377
+ ops_not_in_registry = ['TensorListReserve']
378
+ # NOTE: location_tag taken from tensorflow source for None type ops
379
+ location_tag = ("UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN")
380
+ # TODO: unclear why some ops are not in the registry with TF 2.0 like TensorListReserve
381
+ for non_reg_ops in ops_not_in_registry:
382
+ reg[non_reg_ops] = {'type': None, 'location': location_tag}
383
+ for n in op_handlers:
384
+ if n in reg:
385
+ self.orig_grads[n] = reg[n]["type"]
386
+ reg["shap_"+n] = {
387
+ "type": self.custom_grad,
388
+ "location": reg[n]["location"]
389
+ }
390
+ reg[n]["type"] = self.custom_grad
391
+
392
+ # In TensorFlow 1.10 they started pruning out nodes that they think can't be backpropped
393
+ # unfortunately that includes the index of embedding layers so we disable that check here
394
+ if hasattr(tf_gradients_impl, "_IsBackpropagatable"):
395
+ orig_IsBackpropagatable = tf_gradients_impl._IsBackpropagatable
396
+ tf_gradients_impl._IsBackpropagatable = lambda tensor: True
397
+
398
+ # define the computation graph for the attribution values using a custom gradient-like computation
399
+ try:
400
+ out = f()
401
+ finally:
402
+ # reinstate the backpropagatable check
403
+ if hasattr(tf_gradients_impl, "_IsBackpropagatable"):
404
+ tf_gradients_impl._IsBackpropagatable = orig_IsBackpropagatable
405
+
406
+ # restore the original gradient definitions
407
+ for n in op_handlers:
408
+ if n in reg:
409
+ del reg["shap_"+n]
410
+ reg[n]["type"] = self.orig_grads[n]
411
+ for non_reg_ops in ops_not_in_registry:
412
+ del reg[non_reg_ops]
413
+ if not tf.executing_eagerly():
414
+ return out
415
+ else:
416
+ return [v.numpy() for v in out]
417
+
418
+ def tensors_blocked_by_false(ops):
419
+ """ Follows a set of ops assuming their value is False and find blocked Switch paths.
420
+
421
+ This is used to prune away parts of the model graph that are only used during the training
422
+ phase (like dropout, batch norm, etc.).
423
+ """
424
+ blocked = []
425
+ def recurse(op):
426
+ if op.type == "Switch":
427
+ blocked.append(op.outputs[1]) # the true path is blocked since we assume the ops we trace are False
428
+ else:
429
+ for out in op.outputs:
430
+ for c in out.consumers():
431
+ recurse(c)
432
+ for op in ops:
433
+ recurse(op)
434
+
435
+ return blocked
436
+
437
+ def backward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist):
438
+ found_ops = []
439
+ op_stack = [op for op in start_ops]
440
+ while len(op_stack) > 0:
441
+ op = op_stack.pop()
442
+ if op.type not in op_type_blacklist and op not in found_ops:
443
+ found_ops.append(op)
444
+ for input in op.inputs:
445
+ if input not in tensor_blacklist:
446
+ op_stack.append(input.op)
447
+ return found_ops
448
+
449
+ def forward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist, within_ops):
450
+ found_ops = []
451
+ op_stack = [op for op in start_ops]
452
+ while len(op_stack) > 0:
453
+ op = op_stack.pop()
454
+ if op.type not in op_type_blacklist and op in within_ops and op not in found_ops:
455
+ found_ops.append(op)
456
+ for out in op.outputs:
457
+ if out not in tensor_blacklist:
458
+ for c in out.consumers():
459
+ op_stack.append(c)
460
+ return found_ops
461
+
462
+
463
+ def softmax(explainer, op, *grads):
464
+ """ Just decompose softmax into its components and recurse, we can handle all of them :)
465
+
466
+ We assume the 'axis' is the last dimension because the TF codebase swaps the 'axis' to
467
+ the last dimension before the softmax op if 'axis' is not already the last dimension.
468
+ We also don't subtract the max before tf.exp for numerical stability since that might
469
+ mess up the attributions and it seems like TensorFlow doesn't define softmax that way
470
+ (according to the docs)
471
+ """
472
+ in0 = op.inputs[0]
473
+ in0_max = tf.reduce_max(in0, axis=-1, keepdims=True, name="in0_max")
474
+ in0_centered = in0 - in0_max
475
+ evals = tf.exp(in0_centered, name="custom_exp")
476
+ rsum = tf.reduce_sum(evals, axis=-1, keepdims=True)
477
+ div = evals / rsum
478
+
479
+ # mark these as in-between the inputs and outputs
480
+ for op in [evals.op, rsum.op, div.op, in0_centered.op]:
481
+ for t in op.outputs:
482
+ if t.name not in explainer.between_tensors:
483
+ explainer.between_tensors[t.name] = False
484
+
485
+ out = tf.gradients(div, in0_centered, grad_ys=grads[0])[0]
486
+
487
+ # remove the names we just added
488
+ for op in [evals.op, rsum.op, div.op, in0_centered.op]:
489
+ for t in op.outputs:
490
+ if explainer.between_tensors[t.name] is False:
491
+ del explainer.between_tensors[t.name]
492
+
493
+ # rescale to account for our shift by in0_max (which we did for numerical stability)
494
+ xin0,rin0 = tf.split(in0, 2)
495
+ xin0_centered,rin0_centered = tf.split(in0_centered, 2)
496
+ delta_in0 = xin0 - rin0
497
+ dup0 = [2] + [1 for i in delta_in0.shape[1:]]
498
+ return tf.where(
499
+ tf.tile(tf.abs(delta_in0), dup0) < 1e-6,
500
+ out,
501
+ out * tf.tile((xin0_centered - rin0_centered) / delta_in0, dup0)
502
+ )
503
+
504
+ def maxpool(explainer, op, *grads):
505
+ xin0,rin0 = tf.split(op.inputs[0], 2)
506
+ xout,rout = tf.split(op.outputs[0], 2)
507
+ delta_in0 = xin0 - rin0
508
+ dup0 = [2] + [1 for i in delta_in0.shape[1:]]
509
+ cross_max = tf.maximum(xout, rout)
510
+ diffs = tf.concat([cross_max - rout, xout - cross_max], 0)
511
+ if op.type.startswith("shap_"):
512
+ op.type = op.type[5:]
513
+ xmax_pos,rmax_pos = tf.split(explainer.orig_grads[op.type](op, grads[0] * diffs), 2)
514
+ return tf.tile(tf.where(
515
+ tf.abs(delta_in0) < 1e-7,
516
+ tf.zeros_like(delta_in0),
517
+ (xmax_pos + rmax_pos) / delta_in0
518
+ ), dup0)
519
+
520
+ def gather(explainer, op, *grads):
521
+ #params = op.inputs[0]
522
+ indices = op.inputs[1]
523
+ #axis = op.inputs[2]
524
+ var = explainer._variable_inputs(op)
525
+ if var[1] and not var[0]:
526
+ assert len(indices.shape) == 2, "Only scalar indices supported right now in GatherV2!"
527
+
528
+ xin1,rin1 = tf.split(tf.cast(op.inputs[1], tf.float32), 2)
529
+ xout,rout = tf.split(op.outputs[0], 2)
530
+ dup_in1 = [2] + [1 for i in xin1.shape[1:]]
531
+ dup_out = [2] + [1 for i in xout.shape[1:]]
532
+ delta_in1_t = tf.tile(xin1 - rin1, dup_in1)
533
+ out_sum = tf.reduce_sum(grads[0] * tf.tile(xout - rout, dup_out), list(range(len(indices.shape), len(grads[0].shape))))
534
+ if op.type == "ResourceGather":
535
+ return [None, tf.where(
536
+ tf.abs(delta_in1_t) < 1e-6,
537
+ tf.zeros_like(delta_in1_t),
538
+ out_sum / delta_in1_t
539
+ )]
540
+ return [None, tf.where(
541
+ tf.abs(delta_in1_t) < 1e-6,
542
+ tf.zeros_like(delta_in1_t),
543
+ out_sum / delta_in1_t
544
+ ), None]
545
+ elif var[0] and not var[1]:
546
+ if op.type.startswith("shap_"):
547
+ op.type = op.type[5:]
548
+ return [explainer.orig_grads[op.type](op, grads[0]), None] # linear in this case
549
+ else:
550
+ raise ValueError("Axis not yet supported to be varying for gather op!")
551
+
552
+
553
+ def linearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func):
554
+ def handler(explainer, op, *grads):
555
+ var = explainer._variable_inputs(op)
556
+ if var[input_ind0] and not var[input_ind1]:
557
+ return linearity_1d_handler(input_ind0, explainer, op, *grads)
558
+ elif var[input_ind1] and not var[input_ind0]:
559
+ return linearity_1d_handler(input_ind1, explainer, op, *grads)
560
+ elif var[input_ind0] and var[input_ind1]:
561
+ return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads)
562
+ else:
563
+ return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function
564
+ return handler
565
+
566
+ def nonlinearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func):
567
+ def handler(explainer, op, *grads):
568
+ var = explainer._variable_inputs(op)
569
+ if var[input_ind0] and not var[input_ind1]:
570
+ return nonlinearity_1d_handler(input_ind0, explainer, op, *grads)
571
+ elif var[input_ind1] and not var[input_ind0]:
572
+ return nonlinearity_1d_handler(input_ind1, explainer, op, *grads)
573
+ elif var[input_ind0] and var[input_ind1]:
574
+ return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads)
575
+ else:
576
+ return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function
577
+ return handler
578
+
579
+ def nonlinearity_1d(input_ind):
580
+ def handler(explainer, op, *grads):
581
+ return nonlinearity_1d_handler(input_ind, explainer, op, *grads)
582
+ return handler
583
+
584
+ def nonlinearity_1d_handler(input_ind, explainer, op, *grads):
585
+ # make sure only the given input varies
586
+ op_inputs = op.inputs
587
+ if op_inputs is None:
588
+ op_inputs = op.outputs[0].op.inputs
589
+
590
+ for i in range(len(op_inputs)):
591
+ if i != input_ind:
592
+ assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
593
+
594
+ xin0, rin0 = tf.split(op_inputs[input_ind], 2)
595
+ xout, rout = tf.split(op.outputs[input_ind], 2)
596
+ delta_in0 = xin0 - rin0
597
+ if delta_in0.shape is None:
598
+ dup0 = [2, 1]
599
+ else:
600
+ dup0 = [2] + [1 for i in delta_in0.shape[1:]]
601
+ out = [None for _ in op_inputs]
602
+ if op.type.startswith("shap_"):
603
+ op.type = op.type[5:]
604
+ orig_grad = explainer.orig_grads[op.type](op, grads[0])
605
+ out[input_ind] = tf.where(
606
+ tf.tile(tf.abs(delta_in0), dup0) < 1e-6,
607
+ orig_grad[input_ind] if len(op_inputs) > 1 else orig_grad,
608
+ grads[0] * tf.tile((xout - rout) / delta_in0, dup0)
609
+ )
610
+ return out
611
+
612
+ def nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads):
613
+ if not (input_ind0 == 0 and input_ind1 == 1):
614
+ emsg = "TODO: Can't yet handle double inputs that are not first!"
615
+ raise Exception(emsg)
616
+ xout,rout = tf.split(op.outputs[0], 2)
617
+ in0 = op.inputs[input_ind0]
618
+ in1 = op.inputs[input_ind1]
619
+ xin0,rin0 = tf.split(in0, 2)
620
+ xin1,rin1 = tf.split(in1, 2)
621
+ delta_in0 = xin0 - rin0
622
+ delta_in1 = xin1 - rin1
623
+ dup0 = [2] + [1 for i in delta_in0.shape[1:]]
624
+ out10 = op_func(xin0, rin1)
625
+ out01 = op_func(rin0, xin1)
626
+ out11,out00 = xout,rout
627
+ out0 = 0.5 * (out11 - out01 + out10 - out00)
628
+ out0 = grads[0] * tf.tile(out0 / delta_in0, dup0)
629
+ out1 = 0.5 * (out11 - out10 + out01 - out00)
630
+ out1 = grads[0] * tf.tile(out1 / delta_in1, dup0)
631
+
632
+ # Avoid divide by zero nans
633
+ out0 = tf.where(tf.abs(tf.tile(delta_in0, dup0)) < 1e-7, tf.zeros_like(out0), out0)
634
+ out1 = tf.where(tf.abs(tf.tile(delta_in1, dup0)) < 1e-7, tf.zeros_like(out1), out1)
635
+
636
+ # see if due to broadcasting our gradient shapes don't match our input shapes
637
+ if (np.any(np.array(out1.shape) != np.array(in1.shape))):
638
+ broadcast_index = np.where(np.array(out1.shape) != np.array(in1.shape))[0][0]
639
+ out1 = tf.reduce_sum(out1, axis=broadcast_index, keepdims=True)
640
+ elif (np.any(np.array(out0.shape) != np.array(in0.shape))):
641
+ broadcast_index = np.where(np.array(out0.shape) != np.array(in0.shape))[0][0]
642
+ out0 = tf.reduce_sum(out0, axis=broadcast_index, keepdims=True)
643
+
644
+ return [out0, out1]
645
+
646
+ def linearity_1d(input_ind):
647
+ def handler(explainer, op, *grads):
648
+ return linearity_1d_handler(input_ind, explainer, op, *grads)
649
+ return handler
650
+
651
+ def linearity_1d_handler(input_ind, explainer, op, *grads):
652
+ # make sure only the given input varies (negative means only that input cannot vary, and is measured from the end of the list)
653
+ for i in range(len(op.inputs)):
654
+ if i != input_ind:
655
+ assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
656
+ if op.type.startswith("shap_"):
657
+ op.type = op.type[5:]
658
+ return explainer.orig_grads[op.type](op, *grads)
659
+
660
+ def linearity_with_excluded(input_inds):
661
+ def handler(explainer, op, *grads):
662
+ return linearity_with_excluded_handler(input_inds, explainer, op, *grads)
663
+ return handler
664
+
665
+ def linearity_with_excluded_handler(input_inds, explainer, op, *grads):
666
+ # make sure the given inputs don't vary (negative is measured from the end of the list)
667
+ for i in range(len(op.inputs)):
668
+ if i in input_inds or i - len(op.inputs) in input_inds:
669
+ assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
670
+ if op.type.startswith("shap_"):
671
+ op.type = op.type[5:]
672
+ return explainer.orig_grads[op.type](op, *grads)
673
+
674
+ def passthrough(explainer, op, *grads):
675
+ if op.type.startswith("shap_"):
676
+ op.type = op.type[5:]
677
+ return explainer.orig_grads[op.type](op, *grads)
678
+
679
+ def break_dependence(explainer, op, *grads):
680
+ """ This function name is used to break attribution dependence in the graph traversal.
681
+
682
+ These operation types may be connected above input data values in the graph but their outputs
683
+ don't depend on the input values (for example they just depend on the shape).
684
+ """
685
+ return [None for _ in op.inputs]
686
+
687
+
688
+ op_handlers = {}
689
+
690
+ # ops that are always linear
691
+ op_handlers["Identity"] = passthrough
692
+ op_handlers["StridedSlice"] = passthrough
693
+ op_handlers["Squeeze"] = passthrough
694
+ op_handlers["ExpandDims"] = passthrough
695
+ op_handlers["Pack"] = passthrough
696
+ op_handlers["BiasAdd"] = passthrough
697
+ op_handlers["Unpack"] = passthrough
698
+ op_handlers["Add"] = passthrough
699
+ op_handlers["Sub"] = passthrough
700
+ op_handlers["Merge"] = passthrough
701
+ op_handlers["Sum"] = passthrough
702
+ op_handlers["Mean"] = passthrough
703
+ op_handlers["Cast"] = passthrough
704
+ op_handlers["Transpose"] = passthrough
705
+ op_handlers["Enter"] = passthrough
706
+ op_handlers["Exit"] = passthrough
707
+ op_handlers["NextIteration"] = passthrough
708
+ op_handlers["Tile"] = passthrough
709
+ op_handlers["TensorArrayScatterV3"] = passthrough
710
+ op_handlers["TensorArrayReadV3"] = passthrough
711
+ op_handlers["TensorArrayWriteV3"] = passthrough
712
+
713
+
714
+ # ops that don't pass any attributions to their inputs
715
+ op_handlers["Shape"] = break_dependence
716
+ op_handlers["RandomUniform"] = break_dependence
717
+ op_handlers["ZerosLike"] = break_dependence
718
+ #op_handlers["StopGradient"] = break_dependence # this allows us to stop attributions when we want to (like softmax re-centering)
719
+
720
+ # ops that are linear and only allow a single input to vary
721
+ op_handlers["Reshape"] = linearity_1d(0)
722
+ op_handlers["Pad"] = linearity_1d(0)
723
+ op_handlers["ReverseV2"] = linearity_1d(0)
724
+ op_handlers["ConcatV2"] = linearity_with_excluded([-1])
725
+ op_handlers["Conv2D"] = linearity_1d(0)
726
+ op_handlers["Switch"] = linearity_1d(0)
727
+ op_handlers["AvgPool"] = linearity_1d(0)
728
+ op_handlers["FusedBatchNorm"] = linearity_1d(0)
729
+
730
+ # ops that are nonlinear and only allow a single input to vary
731
+ op_handlers["Relu"] = nonlinearity_1d(0)
732
+ op_handlers["Elu"] = nonlinearity_1d(0)
733
+ op_handlers["Sigmoid"] = nonlinearity_1d(0)
734
+ op_handlers["Tanh"] = nonlinearity_1d(0)
735
+ op_handlers["Softplus"] = nonlinearity_1d(0)
736
+ op_handlers["Exp"] = nonlinearity_1d(0)
737
+ op_handlers["ClipByValue"] = nonlinearity_1d(0)
738
+ op_handlers["Rsqrt"] = nonlinearity_1d(0)
739
+ op_handlers["Square"] = nonlinearity_1d(0)
740
+ op_handlers["Max"] = nonlinearity_1d(0)
741
+
742
+ # ops that are nonlinear and allow two inputs to vary
743
+ op_handlers["SquaredDifference"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: (x - y) * (x - y))
744
+ op_handlers["Minimum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.minimum(x, y))
745
+ op_handlers["Maximum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.maximum(x, y))
746
+
747
+ # ops that allow up to two inputs to vary are are linear when only one input varies
748
+ op_handlers["Mul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x * y)
749
+ op_handlers["RealDiv"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x / y)
750
+ op_handlers["MatMul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.matmul(x, y))
751
+
752
+ # ops that need their own custom attribution functions
753
+ op_handlers["GatherV2"] = gather
754
+ op_handlers["ResourceGather"] = gather
755
+ op_handlers["MaxPool"] = maxpool
756
+ op_handlers["Softmax"] = softmax
757
+
758
+
759
+ # TODO items
760
+ # TensorArrayGatherV3
761
+ # Max
762
+ # TensorArraySizeV3
763
+ # Range
lib/shap/explainers/_deep/deep_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def _check_additivity(explainer, model_output_values, output_phis):
5
+ TOLERANCE = 1e-2
6
+
7
+ assert len(explainer.expected_value) == model_output_values.shape[1], "Length of expected values and model outputs does not match."
8
+
9
+ for t in range(len(explainer.expected_value)):
10
+ if not explainer.multi_input:
11
+ diffs = model_output_values[:, t] - explainer.expected_value[t] - output_phis[t].sum(axis=tuple(range(1, output_phis[t].ndim)))
12
+ else:
13
+ diffs = model_output_values[:, t] - explainer.expected_value[t]
14
+
15
+ for i in range(len(output_phis[t])):
16
+ diffs -= output_phis[t][i].sum(axis=tuple(range(1, output_phis[t][i].ndim)))
17
+
18
+ maxdiff = np.abs(diffs).max()
19
+
20
+ assert maxdiff < TOLERANCE, "The SHAP explanations do not sum up to the model's output! This is either because of a " \
21
+ "rounding error or because an operator in your computation graph was not fully supported. If " \
22
+ "the sum difference of %f is significant compared to the scale of your model outputs, please post " \
23
+ f"as a github issue, with a reproducible example so we can debug it. Used framework: {explainer.framework} - Max. diff: {maxdiff} - Tolerance: {TOLERANCE}"
lib/shap/explainers/_exact.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ from numba import njit
5
+
6
+ from .. import links
7
+ from ..models import Model
8
+ from ..utils import (
9
+ MaskedModel,
10
+ delta_minimization_order,
11
+ make_masks,
12
+ shapley_coefficients,
13
+ )
14
+ from ._explainer import Explainer
15
+
16
+ log = logging.getLogger('shap')
17
+
18
+
19
+ class ExactExplainer(Explainer):
20
+ """ Computes SHAP values via an optimized exact enumeration.
21
+
22
+ This works well for standard Shapley value maskers for models with less than ~15 features that vary
23
+ from the background per sample. It also works well for Owen values from hclustering structured
24
+ maskers when there are less than ~100 features that vary from the background per sample. This
25
+ explainer minimizes the number of function evaluations needed by ordering the masking sets to
26
+ minimize sequential differences. This is done using gray codes for standard Shapley values
27
+ and a greedy sorting method for hclustering structured maskers.
28
+ """
29
+
30
+ def __init__(self, model, masker, link=links.identity, linearize_link=True, feature_names=None):
31
+ """ Build an explainers.Exact object for the given model using the given masker object.
32
+
33
+ Parameters
34
+ ----------
35
+ model : function
36
+ A callable python object that executes the model given a set of input data samples.
37
+
38
+ masker : function or numpy.array or pandas.DataFrame
39
+ A callable python object used to "mask" out hidden features of the form `masker(mask, *fargs)`.
40
+ It takes a single a binary mask and an input sample and returns a matrix of masked samples. These
41
+ masked samples are evaluated using the model function and the outputs are then averaged.
42
+ As a shortcut for the standard masking used by SHAP you can pass a background data matrix
43
+ instead of a function and that matrix will be used for masking. To use a clustering
44
+ game structure you can pass a shap.maskers.TabularPartitions(data) object.
45
+
46
+ link : function
47
+ The link function used to map between the output units of the model and the SHAP value units. By
48
+ default it is shap.links.identity, but shap.links.logit can be useful so that expectations are
49
+ computed in probability units while explanations remain in the (more naturally additive) log-odds
50
+ units. For more details on how link functions work see any overview of link functions for generalized
51
+ linear models.
52
+
53
+ linearize_link : bool
54
+ If we use a non-linear link function to take expectations then models that are additive with respect to that
55
+ link function for a single background sample will no longer be additive when using a background masker with
56
+ many samples. This for example means that a linear logistic regression model would have interaction effects
57
+ that arise from the non-linear changes in expectation averaging. To retain the additively of the model with
58
+ still respecting the link function we linearize the link function by default.
59
+ """ # TODO link to the link linearization paper when done
60
+ super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names)
61
+
62
+ self.model = Model(model)
63
+
64
+ if getattr(masker, "clustering", None) is not None:
65
+ self._partition_masks, self._partition_masks_inds = partition_masks(masker.clustering)
66
+ self._partition_delta_indexes = partition_delta_indexes(masker.clustering, self._partition_masks)
67
+
68
+ self._gray_code_cache = {} # used to avoid regenerating the same gray code patterns
69
+
70
+ def __call__(self, *args, max_evals=100000, main_effects=False, error_bounds=False, batch_size="auto", interactions=1, silent=False):
71
+ """ Explains the output of model(*args), where args represents one or more parallel iterators.
72
+ """
73
+
74
+ # we entirely rely on the general call implementation, we override just to remove **kwargs
75
+ # from the function signature
76
+ return super().__call__(
77
+ *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
78
+ batch_size=batch_size, interactions=interactions, silent=silent
79
+ )
80
+
81
+ def _cached_gray_codes(self, n):
82
+ if n not in self._gray_code_cache:
83
+ self._gray_code_cache[n] = gray_code_indexes(n)
84
+ return self._gray_code_cache[n]
85
+
86
+ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, interactions, silent):
87
+ """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
88
+ """
89
+
90
+ # build a masked version of the model for the current input sample
91
+ fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)
92
+
93
+ # do the standard Shapley values
94
+ inds = None
95
+ if getattr(self.masker, "clustering", None) is None:
96
+
97
+ # see which elements we actually need to perturb
98
+ inds = fm.varying_inputs()
99
+
100
+ # make sure we have enough evals
101
+ if max_evals is not None and max_evals != "auto" and max_evals < 2**len(inds):
102
+ raise ValueError(
103
+ f"It takes {2**len(inds)} masked evaluations to run the Exact explainer on this instance, but max_evals={max_evals}!"
104
+ )
105
+
106
+ # generate the masks in gray code order (so that we change the inputs as little
107
+ # as possible while we iterate to minimize the need to re-eval when the inputs
108
+ # don't vary from the background)
109
+ delta_indexes = self._cached_gray_codes(len(inds))
110
+
111
+ # map to a larger mask that includes the invariant entries
112
+ extended_delta_indexes = np.zeros(2**len(inds), dtype=int)
113
+ for i in range(2**len(inds)):
114
+ if delta_indexes[i] == MaskedModel.delta_mask_noop_value:
115
+ extended_delta_indexes[i] = delta_indexes[i]
116
+ else:
117
+ extended_delta_indexes[i] = inds[delta_indexes[i]]
118
+
119
+ # run the model
120
+ outputs = fm(extended_delta_indexes, zero_index=0, batch_size=batch_size)
121
+
122
+ # Shapley values
123
+ # Care: Need to distinguish between `True` and `1`
124
+ if interactions is False or (interactions == 1 and interactions is not True):
125
+
126
+ # loop over all the outputs to update the rows
127
+ coeff = shapley_coefficients(len(inds))
128
+ row_values = np.zeros((len(fm),) + outputs.shape[1:])
129
+ mask = np.zeros(len(fm), dtype=bool)
130
+ _compute_grey_code_row_values(row_values, mask, inds, outputs, coeff, extended_delta_indexes, MaskedModel.delta_mask_noop_value)
131
+
132
+ # Shapley-Taylor interaction values
133
+ elif interactions is True or interactions == 2:
134
+
135
+ # loop over all the outputs to update the rows
136
+ coeff = shapley_coefficients(len(inds))
137
+ row_values = np.zeros((len(fm), len(fm)) + outputs.shape[1:])
138
+ mask = np.zeros(len(fm), dtype=bool)
139
+ _compute_grey_code_row_values_st(row_values, mask, inds, outputs, coeff, extended_delta_indexes, MaskedModel.delta_mask_noop_value)
140
+
141
+ elif interactions > 2:
142
+ raise NotImplementedError("Currently the Exact explainer does not support interactions higher than order 2!")
143
+
144
+ # do a partition tree constrained version of Shapley values
145
+ else:
146
+
147
+ # make sure we have enough evals
148
+ if max_evals is not None and max_evals != "auto" and max_evals < len(fm)**2:
149
+ raise ValueError(
150
+ f"It takes {len(fm)**2} masked evaluations to run the Exact explainer on this instance, but max_evals={max_evals}!"
151
+ )
152
+
153
+ # generate the masks in a hclust order (so that we change the inputs as little
154
+ # as possible while we iterate to minimize the need to re-eval when the inputs
155
+ # don't vary from the background)
156
+ delta_indexes = self._partition_delta_indexes
157
+
158
+ # run the model
159
+ outputs = fm(delta_indexes, batch_size=batch_size)
160
+
161
+ # loop over each output feature
162
+ row_values = np.zeros((len(fm),) + outputs.shape[1:])
163
+ for i in range(len(fm)):
164
+ on_outputs = outputs[self._partition_masks_inds[i][1]]
165
+ off_outputs = outputs[self._partition_masks_inds[i][0]]
166
+ row_values[i] = (on_outputs - off_outputs).mean(0)
167
+
168
+ # compute the main effects if we need to
169
+ main_effect_values = None
170
+ if main_effects or interactions is True or interactions == 2:
171
+ if inds is None:
172
+ inds = np.arange(len(fm))
173
+ main_effect_values = fm.main_effects(inds)
174
+ if interactions is True or interactions == 2:
175
+ for i in range(len(fm)):
176
+ row_values[i, i] = main_effect_values[i]
177
+
178
+ return {
179
+ "values": row_values,
180
+ "expected_values": outputs[0],
181
+ "mask_shapes": fm.mask_shapes,
182
+ "main_effects": main_effect_values if main_effects else None,
183
+ "clustering": getattr(self.masker, "clustering", None)
184
+ }
185
+
186
+ @njit
187
+ def _compute_grey_code_row_values(row_values, mask, inds, outputs, shapley_coeff, extended_delta_indexes, noop_code):
188
+ set_size = 0
189
+ M = len(inds)
190
+ for i in range(2**M):
191
+
192
+ # update the mask
193
+ delta_ind = extended_delta_indexes[i]
194
+ if delta_ind != noop_code:
195
+ mask[delta_ind] = ~mask[delta_ind]
196
+ if mask[delta_ind]:
197
+ set_size += 1
198
+ else:
199
+ set_size -= 1
200
+
201
+ # update the output row values
202
+ on_coeff = shapley_coeff[set_size-1]
203
+ if set_size < M:
204
+ off_coeff = shapley_coeff[set_size]
205
+ out = outputs[i]
206
+ for j in inds:
207
+ if mask[j]:
208
+ row_values[j] += out * on_coeff
209
+ else:
210
+ row_values[j] -= out * off_coeff
211
+
212
+ @njit
213
+ def _compute_grey_code_row_values_st(row_values, mask, inds, outputs, shapley_coeff, extended_delta_indexes, noop_code):
214
+ set_size = 0
215
+ M = len(inds)
216
+ for i in range(2**M):
217
+
218
+ # update the mask
219
+ delta_ind = extended_delta_indexes[i]
220
+ if delta_ind != noop_code:
221
+ mask[delta_ind] = ~mask[delta_ind]
222
+ if mask[delta_ind]:
223
+ set_size += 1
224
+ else:
225
+ set_size -= 1
226
+
227
+ # distribute the effect of this mask set over all the terms it impacts
228
+ out = outputs[i]
229
+ for j in range(M):
230
+ for k in range(j+1, M):
231
+ if not mask[j] and not mask[k]:
232
+ delta = out * shapley_coeff[set_size] # * 2
233
+ elif (not mask[j] and mask[k]) or (mask[j] and not mask[k]):
234
+ delta = -out * shapley_coeff[set_size - 1] # * 2
235
+ else: # both true
236
+ delta = out * shapley_coeff[set_size - 2] # * 2
237
+ row_values[j,k] += delta
238
+ row_values[k,j] += delta
239
+
240
+ def partition_delta_indexes(partition_tree, all_masks):
241
+ """ Return an delta index encoded array of all the masks possible while following the given partition tree.
242
+ """
243
+
244
+ # convert the masks to delta index format
245
+ mask = np.zeros(all_masks.shape[1], dtype=bool)
246
+ delta_inds = []
247
+ for i in range(len(all_masks)):
248
+ inds = np.where(mask ^ all_masks[i,:])[0]
249
+
250
+ for j in inds[:-1]:
251
+ delta_inds.append(-j - 1) # negative + (-1) means we have more inds still to change...
252
+ if len(inds) == 0:
253
+ delta_inds.append(MaskedModel.delta_mask_noop_value)
254
+ else:
255
+ delta_inds.extend(inds[-1:])
256
+ mask = all_masks[i,:]
257
+
258
+ return np.array(delta_inds)
259
+
260
+ def partition_masks(partition_tree):
261
+ """ Return an array of all the masks possible while following the given partition tree.
262
+ """
263
+
264
+ M = partition_tree.shape[0] + 1
265
+ mask_matrix = make_masks(partition_tree)
266
+ all_masks = []
267
+ m00 = np.zeros(M, dtype=bool)
268
+ all_masks.append(m00)
269
+ all_masks.append(~m00)
270
+ #inds_stack = [0,1]
271
+ inds_lists = [[[], []] for i in range(M)]
272
+ _partition_masks_recurse(len(partition_tree)-1, m00, 0, 1, inds_lists, mask_matrix, partition_tree, M, all_masks)
273
+
274
+ all_masks = np.array(all_masks)
275
+
276
+ # we resort the clustering matrix to minimize the sequential difference between the masks
277
+ # this minimizes the number of model evaluations we need to run when the background sometimes
278
+ # matches the foreground. We seem to average about 1.5 feature changes per mask with this
279
+ # approach. This is not as clean as the grey code ordering, but a perfect 1 feature change
280
+ # ordering is not possible with a clustering tree
281
+ order = delta_minimization_order(all_masks)
282
+ inverse_order = np.arange(len(order))[np.argsort(order)]
283
+
284
+ for inds_list0,inds_list1 in inds_lists:
285
+ for i in range(len(inds_list0)):
286
+ inds_list0[i] = inverse_order[inds_list0[i]]
287
+ inds_list1[i] = inverse_order[inds_list1[i]]
288
+
289
+ # Care: inds_lists have different lengths, so partition_masks_inds is a "ragged" array. See GH #3063
290
+ partition_masks = all_masks[order]
291
+ partition_masks_inds = [[np.array(on), np.array(off)] for on, off in inds_lists]
292
+ return partition_masks, partition_masks_inds
293
+
294
+ # TODO: this should be a jit function... which would require preallocating the inds_lists (sizes are 2**depth of that ind)
295
+ # TODO: we could also probable avoid making the masks at all and just record the deltas if we want...
296
+ def _partition_masks_recurse(index, m00, ind00, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks):
297
+ if index < 0:
298
+ inds_lists[index + M][0].append(ind00)
299
+ inds_lists[index + M][1].append(ind11)
300
+ return
301
+
302
+ # get our children indexes
303
+ left_index = int(partition_tree[index,0] - M)
304
+ right_index = int(partition_tree[index,1] - M)
305
+
306
+ # build more refined masks
307
+ m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
308
+ m10[:] += mask_matrix[left_index+M, :]
309
+ m01 = m00.copy()
310
+ m01[:] += mask_matrix[right_index+M, :]
311
+
312
+ # record the new masks we made
313
+ ind01 = len(all_masks)
314
+ all_masks.append(m01)
315
+ ind10 = len(all_masks)
316
+ all_masks.append(m10)
317
+
318
+ # inds_stack.append(len(all_masks) - 2)
319
+ # inds_stack.append(len(all_masks) - 1)
320
+
321
+ # recurse left and right with both 1 (True) and 0 (False) contexts
322
+ _partition_masks_recurse(left_index, m00, ind00, ind10, inds_lists, mask_matrix, partition_tree, M, all_masks)
323
+ _partition_masks_recurse(right_index, m10, ind10, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks)
324
+ _partition_masks_recurse(left_index, m01, ind01, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks)
325
+ _partition_masks_recurse(right_index, m00, ind00, ind01, inds_lists, mask_matrix, partition_tree, M, all_masks)
326
+
327
+
328
+ def gray_code_masks(nbits):
329
+ """ Produces an array of all binary patterns of size nbits in gray code order.
330
+
331
+ This is based on code from: http://code.activestate.com/recipes/576592-gray-code-generatoriterator/
332
+ """
333
+ out = np.zeros((2**nbits, nbits), dtype=bool)
334
+ li = np.zeros(nbits, dtype=bool)
335
+
336
+ for term in range(2, (1<<nbits)+1):
337
+ if term % 2 == 1: # odd
338
+ for i in range(-1,-nbits,-1):
339
+ if li[i] == 1:
340
+ li[i-1] = li[i-1]^1
341
+ break
342
+ else: # even
343
+ li[-1] = li[-1]^1
344
+
345
+ out[term-1,:] = li
346
+ return out
347
+
348
+ def gray_code_indexes(nbits):
349
+ """ Produces an array of which bits flip at which position.
350
+
351
+ We assume the masks start at all zero and -1 means don't do a flip.
352
+ This is a more efficient representation of the gray_code_masks version.
353
+ """
354
+ out = np.ones(2**nbits, dtype=int) * MaskedModel.delta_mask_noop_value
355
+ li = np.zeros(nbits, dtype=bool)
356
+ for term in range((1<<nbits)-1):
357
+ if term % 2 == 1: # odd
358
+ for i in range(-1,-nbits,-1):
359
+ if li[i] == 1:
360
+ li[i-1] = li[i-1]^1
361
+ out[term+1] = nbits + (i-1)
362
+ break
363
+ else: # even
364
+ li[-1] = li[-1]^1
365
+ out[term+1] = nbits-1
366
+ return out
lib/shap/explainers/_explainer.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import time
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scipy.sparse
7
+
8
+ from .. import explainers, links, maskers, models
9
+ from .._explanation import Explanation
10
+ from .._serializable import Deserializer, Serializable, Serializer
11
+ from ..maskers import Masker
12
+ from ..models import Model
13
+ from ..utils import safe_isinstance, show_progress
14
+ from ..utils._exceptions import InvalidAlgorithmError
15
+ from ..utils.transformers import is_transformers_lm
16
+
17
+
18
+ class Explainer(Serializable):
19
+ """ Uses Shapley values to explain any machine learning model or python function.
20
+
21
+ This is the primary explainer interface for the SHAP library. It takes any combination
22
+ of a model and masker and returns a callable subclass object that implements
23
+ the particular estimation algorithm that was chosen.
24
+ """
25
+
26
+ def __init__(self, model, masker=None, link=links.identity, algorithm="auto", output_names=None, feature_names=None, linearize_link=True,
27
+ seed=None, **kwargs):
28
+ """ Build a new explainer for the passed model.
29
+
30
+ Parameters
31
+ ----------
32
+ model : object or function
33
+ User supplied function or model object that takes a dataset of samples and
34
+ computes the output of the model for those samples.
35
+
36
+ masker : function, numpy.array, pandas.DataFrame, tokenizer, None, or a list of these for each model input
37
+ The function used to "mask" out hidden features of the form `masked_args = masker(*model_args, mask=mask)`.
38
+ It takes input in the same form as the model, but for just a single sample with a binary
39
+ mask, then returns an iterable of masked samples. These
40
+ masked samples will then be evaluated using the model function and the outputs averaged.
41
+ As a shortcut for the standard masking using by SHAP you can pass a background data matrix
42
+ instead of a function and that matrix will be used for masking. Domain specific masking
43
+ functions are available in shap such as shap.ImageMasker for images and shap.TokenMasker
44
+ for text. In addition to determining how to replace hidden features, the masker can also
45
+ constrain the rules of the cooperative game used to explain the model. For example
46
+ shap.TabularMasker(data, hclustering="correlation") will enforce a hierarchical clustering
47
+ of coalitions for the game (in this special case the attributions are known as the Owen values).
48
+
49
+ link : function
50
+ The link function used to map between the output units of the model and the SHAP value units. By
51
+ default it is shap.links.identity, but shap.links.logit can be useful so that expectations are
52
+ computed in probability units while explanations remain in the (more naturally additive) log-odds
53
+ units. For more details on how link functions work see any overview of link functions for generalized
54
+ linear models.
55
+
56
+ algorithm : "auto", "permutation", "partition", "tree", or "linear"
57
+ The algorithm used to estimate the Shapley values. There are many different algorithms that
58
+ can be used to estimate the Shapley values (and the related value for constrained games), each
59
+ of these algorithms have various tradeoffs and are preferable in different situations. By
60
+ default the "auto" options attempts to make the best choice given the passed model and masker,
61
+ but this choice can always be overridden by passing the name of a specific algorithm. The type of
62
+ algorithm used will determine what type of subclass object is returned by this constructor, and
63
+ you can also build those subclasses directly if you prefer or need more fine grained control over
64
+ their options.
65
+
66
+ output_names : None or list of strings
67
+ The names of the model outputs. For example if the model is an image classifier, then output_names would
68
+ be the names of all the output classes. This parameter is optional. When output_names is None then
69
+ the Explanation objects produced by this explainer will not have any output_names, which could effect
70
+ downstream plots.
71
+
72
+ seed: None or int
73
+ seed for reproducibility
74
+
75
+ """
76
+
77
+ self.model = model
78
+ self.output_names = output_names
79
+ self.feature_names = feature_names
80
+
81
+ # wrap the incoming masker object as a shap.Masker object
82
+ if (
83
+ isinstance(masker, pd.DataFrame)
84
+ or ((isinstance(masker, np.ndarray) or scipy.sparse.issparse(masker)) and len(masker.shape) == 2)
85
+ ):
86
+ if algorithm == "partition":
87
+ self.masker = maskers.Partition(masker)
88
+ else:
89
+ self.masker = maskers.Independent(masker)
90
+ elif safe_isinstance(masker, ["transformers.PreTrainedTokenizer", "transformers.tokenization_utils_base.PreTrainedTokenizerBase"]):
91
+ if is_transformers_lm(self.model):
92
+ # auto assign text infilling if model is a transformer model with lm head
93
+ self.masker = maskers.Text(masker, mask_token="...", collapse_mask_token=True)
94
+ else:
95
+ self.masker = maskers.Text(masker)
96
+ elif (masker is list or masker is tuple) and masker[0] is not str:
97
+ self.masker = maskers.Composite(*masker)
98
+ elif (masker is dict) and ("mean" in masker):
99
+ self.masker = maskers.Independent(masker)
100
+ elif masker is None and isinstance(self.model, models.TransformersPipeline):
101
+ return self.__init__(
102
+ self.model, self.model.inner_model.tokenizer,
103
+ link=link, algorithm=algorithm, output_names=output_names, feature_names=feature_names, linearize_link=linearize_link, **kwargs
104
+ )
105
+ else:
106
+ self.masker = masker
107
+
108
+ # Check for transformer pipeline objects and wrap them
109
+ if safe_isinstance(self.model, "transformers.pipelines.Pipeline"):
110
+ if is_transformers_lm(self.model.model):
111
+ return self.__init__(
112
+ self.model.model, self.model.tokenizer if self.masker is None else self.masker,
113
+ link=link, algorithm=algorithm, output_names=output_names, feature_names=feature_names, linearize_link=linearize_link, **kwargs
114
+ )
115
+ else:
116
+ return self.__init__(
117
+ models.TransformersPipeline(self.model), self.masker,
118
+ link=link, algorithm=algorithm, output_names=output_names, feature_names=feature_names, linearize_link=linearize_link, **kwargs
119
+ )
120
+
121
+ # wrap self.masker and self.model for output text explanation algorithm
122
+ if is_transformers_lm(self.model):
123
+ self.model = models.TeacherForcing(self.model, self.masker.tokenizer)
124
+ self.masker = maskers.OutputComposite(self.masker, self.model.text_generate)
125
+ elif safe_isinstance(self.model, "shap.models.TeacherForcing") and safe_isinstance(self.masker, ["shap.maskers.Text", "shap.maskers.Image"]):
126
+ self.masker = maskers.OutputComposite(self.masker, self.model.text_generate)
127
+ elif safe_isinstance(self.model, "shap.models.TopKLM") and safe_isinstance(self.masker, "shap.maskers.Text"):
128
+ self.masker = maskers.FixedComposite(self.masker)
129
+
130
+ #self._brute_force_fallback = explainers.BruteForce(self.model, self.masker)
131
+
132
+ # validate and save the link function
133
+ if callable(link):
134
+ self.link = link
135
+ else:
136
+ raise TypeError("The passed link function needs to be callable!")
137
+ self.linearize_link = linearize_link
138
+
139
+ # if we are called directly (as opposed to through super()) then we convert ourselves to the subclass
140
+ # that implements the specific algorithm that was chosen
141
+ if self.__class__ is Explainer:
142
+
143
+ # do automatic algorithm selection
144
+ #from .. import explainers
145
+ if algorithm == "auto":
146
+
147
+ # use implementation-aware methods if possible
148
+ if explainers.LinearExplainer.supports_model_with_masker(model, self.masker):
149
+ algorithm = "linear"
150
+ elif explainers.TreeExplainer.supports_model_with_masker(model, self.masker): # TODO: check for Partition?
151
+ algorithm = "tree"
152
+ elif explainers.AdditiveExplainer.supports_model_with_masker(model, self.masker):
153
+ algorithm = "additive"
154
+
155
+ # otherwise use a model agnostic method
156
+ elif callable(self.model):
157
+ if issubclass(type(self.masker), maskers.Independent):
158
+ if self.masker.shape[1] <= 10:
159
+ algorithm = "exact"
160
+ else:
161
+ algorithm = "permutation"
162
+ elif issubclass(type(self.masker), maskers.Partition):
163
+ if self.masker.shape[1] <= 32:
164
+ algorithm = "exact"
165
+ else:
166
+ algorithm = "permutation"
167
+ elif (getattr(self.masker, "text_data", False) or getattr(self.masker, "image_data", False)) and hasattr(self.masker, "clustering"):
168
+ algorithm = "partition"
169
+ else:
170
+ algorithm = "permutation"
171
+
172
+ # if we get here then we don't know how to handle what was given to us
173
+ else:
174
+ raise TypeError("The passed model is not callable and cannot be analyzed directly with the given masker! Model: " + str(model))
175
+
176
+ # build the right subclass
177
+ if algorithm == "exact":
178
+ self.__class__ = explainers.ExactExplainer
179
+ explainers.ExactExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
180
+ elif algorithm == "permutation":
181
+ self.__class__ = explainers.PermutationExplainer
182
+ explainers.PermutationExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, seed=seed, **kwargs)
183
+ elif algorithm == "partition":
184
+ self.__class__ = explainers.PartitionExplainer
185
+ explainers.PartitionExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, output_names=self.output_names, **kwargs)
186
+ elif algorithm == "tree":
187
+ self.__class__ = explainers.TreeExplainer
188
+ explainers.TreeExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
189
+ elif algorithm == "additive":
190
+ self.__class__ = explainers.AdditiveExplainer
191
+ explainers.AdditiveExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
192
+ elif algorithm == "linear":
193
+ self.__class__ = explainers.LinearExplainer
194
+ explainers.LinearExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
195
+ elif algorithm == "deep":
196
+ self.__class__ = explainers.DeepExplainer
197
+ explainers.DeepExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
198
+ else:
199
+ raise InvalidAlgorithmError("Unknown algorithm type passed: %s!" % algorithm)
200
+
201
+
202
+ def __call__(self, *args, max_evals="auto", main_effects=False, error_bounds=False, batch_size="auto",
203
+ outputs=None, silent=False, **kwargs):
204
+ """ Explains the output of model(*args), where args is a list of parallel iterable datasets.
205
+
206
+ Note this default version could be an abstract method that is implemented by each algorithm-specific
207
+ subclass of Explainer. Descriptions of each subclasses' __call__ arguments
208
+ are available in their respective doc-strings.
209
+ """
210
+
211
+ # if max_evals == "auto":
212
+ # self._brute_force_fallback
213
+
214
+ start_time = time.time()
215
+
216
+ if issubclass(type(self.masker), maskers.OutputComposite) and len(args)==2:
217
+ self.masker.model = models.TextGeneration(target_sentences=args[1])
218
+ args = args[:1]
219
+ # parse our incoming arguments
220
+ num_rows = None
221
+ args = list(args)
222
+ if self.feature_names is None:
223
+ feature_names = [None for _ in range(len(args))]
224
+ elif issubclass(type(self.feature_names[0]), (list, tuple)):
225
+ feature_names = copy.deepcopy(self.feature_names)
226
+ else:
227
+ feature_names = [copy.deepcopy(self.feature_names)]
228
+ for i in range(len(args)):
229
+
230
+ # try and see if we can get a length from any of the for our progress bar
231
+ if num_rows is None:
232
+ try:
233
+ num_rows = len(args[i])
234
+ except Exception:
235
+ pass
236
+
237
+ # convert DataFrames to numpy arrays
238
+ if isinstance(args[i], pd.DataFrame):
239
+ feature_names[i] = list(args[i].columns)
240
+ args[i] = args[i].to_numpy()
241
+
242
+ # convert nlp Dataset objects to lists
243
+ if safe_isinstance(args[i], "nlp.arrow_dataset.Dataset"):
244
+ args[i] = args[i]["text"]
245
+ elif issubclass(type(args[i]), dict) and "text" in args[i]:
246
+ args[i] = args[i]["text"]
247
+
248
+ if batch_size == "auto":
249
+ if hasattr(self.masker, "default_batch_size"):
250
+ batch_size = self.masker.default_batch_size
251
+ else:
252
+ batch_size = 10
253
+
254
+ # loop over each sample, filling in the values array
255
+ values = []
256
+ output_indices = []
257
+ expected_values = []
258
+ mask_shapes = []
259
+ main_effects = []
260
+ hierarchical_values = []
261
+ clustering = []
262
+ output_names = []
263
+ error_std = []
264
+ if callable(getattr(self.masker, "feature_names", None)):
265
+ feature_names = [[] for _ in range(len(args))]
266
+ for row_args in show_progress(zip(*args), num_rows, self.__class__.__name__+" explainer", silent):
267
+ row_result = self.explain_row(
268
+ *row_args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
269
+ batch_size=batch_size, outputs=outputs, silent=silent, **kwargs
270
+ )
271
+ values.append(row_result.get("values", None))
272
+ output_indices.append(row_result.get("output_indices", None))
273
+ expected_values.append(row_result.get("expected_values", None))
274
+ mask_shapes.append(row_result["mask_shapes"])
275
+ main_effects.append(row_result.get("main_effects", None))
276
+ clustering.append(row_result.get("clustering", None))
277
+ hierarchical_values.append(row_result.get("hierarchical_values", None))
278
+ tmp = row_result.get("output_names", None)
279
+ output_names.append(tmp(*row_args) if callable(tmp) else tmp)
280
+ error_std.append(row_result.get("error_std", None))
281
+ if callable(getattr(self.masker, "feature_names", None)):
282
+ row_feature_names = self.masker.feature_names(*row_args)
283
+ for i in range(len(row_args)):
284
+ feature_names[i].append(row_feature_names[i])
285
+
286
+ # split the values up according to each input
287
+ arg_values = [[] for a in args]
288
+ for i, v in enumerate(values):
289
+ pos = 0
290
+ for j in range(len(args)):
291
+ mask_length = np.prod(mask_shapes[i][j])
292
+ arg_values[j].append(values[i][pos:pos+mask_length])
293
+ pos += mask_length
294
+
295
+ # collapse the arrays as possible
296
+ expected_values = pack_values(expected_values)
297
+ main_effects = pack_values(main_effects)
298
+ output_indices = pack_values(output_indices)
299
+ main_effects = pack_values(main_effects)
300
+ hierarchical_values = pack_values(hierarchical_values)
301
+ error_std = pack_values(error_std)
302
+ clustering = pack_values(clustering)
303
+
304
+ # getting output labels
305
+ ragged_outputs = False
306
+ if output_indices is not None:
307
+ ragged_outputs = not all(len(x) == len(output_indices[0]) for x in output_indices)
308
+ if self.output_names is None:
309
+ if None not in output_names:
310
+ if not ragged_outputs:
311
+ sliced_labels = np.array(output_names)
312
+ else:
313
+ sliced_labels = [np.array(output_names[i])[index_list] for i,index_list in enumerate(output_indices)]
314
+ else:
315
+ sliced_labels = None
316
+ else:
317
+ assert output_indices is not None, "You have passed a list for output_names but the model seems to not have multiple outputs!"
318
+ labels = np.array(self.output_names)
319
+ sliced_labels = [labels[index_list] for index_list in output_indices]
320
+ if not ragged_outputs:
321
+ sliced_labels = np.array(sliced_labels)
322
+
323
+ if isinstance(sliced_labels, np.ndarray) and len(sliced_labels.shape) == 2:
324
+ if np.all(sliced_labels[0,:] == sliced_labels):
325
+ sliced_labels = sliced_labels[0]
326
+
327
+ # allow the masker to transform the input data to better match the masking pattern
328
+ # (such as breaking text into token segments)
329
+ if hasattr(self.masker, "data_transform"):
330
+ new_args = []
331
+ for row_args in zip(*args):
332
+ new_args.append([pack_values(v) for v in self.masker.data_transform(*row_args)])
333
+ args = list(zip(*new_args))
334
+
335
+ # build the explanation objects
336
+ out = []
337
+ for j, data in enumerate(args):
338
+
339
+ # reshape the attribution values using the mask_shapes
340
+ tmp = []
341
+ for i, v in enumerate(arg_values[j]):
342
+ if np.prod(mask_shapes[i][j]) != np.prod(v.shape): # see if we have multiple outputs
343
+ tmp.append(v.reshape(*mask_shapes[i][j], -1))
344
+ else:
345
+ tmp.append(v.reshape(*mask_shapes[i][j]))
346
+ arg_values[j] = pack_values(tmp)
347
+
348
+ if feature_names[j] is None:
349
+ feature_names[j] = ["Feature " + str(i) for i in range(data.shape[1])]
350
+
351
+
352
+ # build an explanation object for this input argument
353
+ out.append(Explanation(
354
+ arg_values[j], expected_values, data,
355
+ feature_names=feature_names[j], main_effects=main_effects,
356
+ clustering=clustering,
357
+ hierarchical_values=hierarchical_values,
358
+ output_names=sliced_labels, # self.output_names
359
+ error_std=error_std,
360
+ compute_time=time.time() - start_time
361
+ # output_shape=output_shape,
362
+ #lower_bounds=v_min, upper_bounds=v_max
363
+ ))
364
+ return out[0] if len(out) == 1 else out
365
+
366
+ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, outputs, silent, **kwargs):
367
+ """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes, main_effects).
368
+
369
+ This is an abstract method meant to be implemented by each subclass.
370
+
371
+ Returns
372
+ -------
373
+ tuple
374
+ A tuple of (row_values, row_expected_values, row_mask_shapes), where row_values is an array of the
375
+ attribution values for each sample, row_expected_values is an array (or single value) representing
376
+ the expected value of the model for each sample (which is the same for all samples unless there
377
+ are fixed inputs present, like labels when explaining the loss), and row_mask_shapes is a list
378
+ of all the input shapes (since the row_values is always flattened),
379
+ """
380
+
381
+ return {}
382
+
383
+ @staticmethod
384
+ def supports_model_with_masker(model, masker):
385
+ """ Determines if this explainer can handle the given model.
386
+
387
+ This is an abstract static method meant to be implemented by each subclass.
388
+ """
389
+ return False
390
+
391
+ @staticmethod
392
+ def _compute_main_effects(fm, expected_value, inds):
393
+ """ A utility method to compute the main effects from a MaskedModel.
394
+ """
395
+
396
+ # mask each input on in isolation
397
+ masks = np.zeros(2*len(inds)-1, dtype=int)
398
+ last_ind = -1
399
+ for i in range(len(inds)):
400
+ if i > 0:
401
+ masks[2*i - 1] = -last_ind - 1 # turn off the last input
402
+ masks[2*i] = inds[i] # turn on this input
403
+ last_ind = inds[i]
404
+
405
+ # compute the main effects for the given indexes
406
+ main_effects = fm(masks) - expected_value
407
+
408
+ # expand the vector to the full input size
409
+ expanded_main_effects = np.zeros(len(fm))
410
+ for i, ind in enumerate(inds):
411
+ expanded_main_effects[ind] = main_effects[i]
412
+
413
+ return expanded_main_effects
414
+
415
+ def save(self, out_file, model_saver=".save", masker_saver=".save"):
416
+ """ Write the explainer to the given file stream.
417
+ """
418
+ super().save(out_file)
419
+ with Serializer(out_file, "shap.Explainer", version=0) as s:
420
+ s.save("model", self.model, model_saver)
421
+ s.save("masker", self.masker, masker_saver)
422
+ s.save("link", self.link)
423
+
424
+ @classmethod
425
+ def load(cls, in_file, model_loader=Model.load, masker_loader=Masker.load, instantiate=True):
426
+ """ Load an Explainer from the given file stream.
427
+
428
+ Parameters
429
+ ----------
430
+ in_file : The file stream to load objects from.
431
+ """
432
+ if instantiate:
433
+ return cls._instantiated_load(in_file, model_loader=model_loader, masker_loader=masker_loader)
434
+
435
+ kwargs = super().load(in_file, instantiate=False)
436
+ with Deserializer(in_file, "shap.Explainer", min_version=0, max_version=0) as s:
437
+ kwargs["model"] = s.load("model", model_loader)
438
+ kwargs["masker"] = s.load("masker", masker_loader)
439
+ kwargs["link"] = s.load("link")
440
+ return kwargs
441
+
442
+ def pack_values(values):
443
+ """ Used the clean up arrays before putting them into an Explanation object.
444
+ """
445
+
446
+ if not hasattr(values, "__len__"):
447
+ return values
448
+
449
+ # collapse the values if we didn't compute them
450
+ if values is None or values[0] is None:
451
+ return None
452
+
453
+ # convert to a single numpy matrix when the array is not ragged
454
+ elif np.issubdtype(type(values[0]), np.number) or len(np.unique([len(v) for v in values])) == 1:
455
+ return np.array(values)
456
+ else:
457
+ return np.array(values, dtype=object)
lib/shap/explainers/_gpu_tree.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GPU accelerated tree explanations"""
2
+ import numpy as np
3
+
4
+ from ..utils import assert_import, record_import_error
5
+ from ._tree import TreeExplainer, feature_perturbation_codes, output_transform_codes
6
+
7
+ try:
8
+ from .. import _cext_gpu
9
+ except ImportError as e:
10
+ record_import_error("cext_gpu", "cuda extension was not built during install!", e)
11
+
12
+
13
+ class GPUTreeExplainer(TreeExplainer):
14
+ """
15
+ Experimental GPU accelerated version of TreeExplainer. Currently requires source build with
16
+ cuda available and 'CUDA_PATH' environment variable defined.
17
+
18
+ Parameters
19
+ ----------
20
+ model : model object
21
+ The tree based machine learning model that we want to explain. XGBoost, LightGBM,
22
+ CatBoost, Pyspark and most tree-based scikit-learn models are supported.
23
+
24
+ data : numpy.array or pandas.DataFrame
25
+ The background dataset to use for integrating out features. This argument is optional when
26
+ feature_perturbation="tree_path_dependent", since in that case we can use the number of
27
+ training samples that went down each tree path as our background dataset (this is recorded
28
+ in the model object).
29
+
30
+ feature_perturbation : "interventional" (default) or "tree_path_dependent" (default when data=None)
31
+ Since SHAP values rely on conditional expectations we need to decide how to handle correlated
32
+ (or otherwise dependent) input features. The "interventional" approach breaks the dependencies
33
+ between features according to the rules dictated by casual inference (Janzing et al. 2019). Note
34
+ that the "interventional" option requires a background dataset and its runtime scales linearly
35
+ with the size of the background dataset you use. Anywhere from 100 to 1000 random background samples
36
+ are good sizes to use. The "tree_path_dependent" approach is to just follow the trees and use the
37
+ number of training examples that went down each leaf to represent the background distribution.
38
+ This approach does not require a background dataset and so is used by default when no background
39
+ dataset is provided.
40
+
41
+ model_output : "raw", "probability", "log_loss", or model method name
42
+ What output of the model should be explained. If "raw" then we explain the raw output of the
43
+ trees, which varies by model. For regression models "raw" is the standard output, for binary
44
+ classification in XGBoost this is the log odds ratio. If model_output is the name of a
45
+ supported prediction method on the model object then we explain the output of that model
46
+ method name. For example model_output="predict_proba" explains the result of calling
47
+ model.predict_proba. If "probability" then we explain the output of the model transformed into
48
+ probability space (note that this means the SHAP values now sum to the probability output of the
49
+ model). If "logloss" then we explain the log base e of the model loss function, so that the SHAP
50
+ values sum up to the log loss of the model for each sample. This is helpful for breaking
51
+ down model performance by feature. Currently the probability and logloss options are only
52
+ supported when
53
+ feature_dependence="independent".
54
+
55
+ Examples
56
+ --------
57
+ See `GPUTree explainer examples <https://shap.readthedocs.io/en/latest/api_examples/explainers/GPUTreeExplainer.html>`_
58
+ """
59
+
60
+ def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_additivity=True,
61
+ from_call=False):
62
+ """ Estimate the SHAP values for a set of samples.
63
+
64
+ Parameters
65
+ ----------
66
+ X : numpy.array, pandas.DataFrame or catboost.Pool (for catboost)
67
+ A matrix of samples (# samples x # features) on which to explain the model's output.
68
+
69
+ y : numpy.array
70
+ An array of label values for each sample. Used when explaining loss functions.
71
+
72
+ tree_limit : None (default) or int
73
+ Limit the number of trees used by the model. By default None means no use the limit
74
+ of the
75
+ original model, and -1 means no limit.
76
+
77
+ approximate : bool
78
+ Not supported.
79
+
80
+ check_additivity : bool
81
+ Run a validation check that the sum of the SHAP values equals the output of the
82
+ model. This
83
+ check takes only a small amount of time, and will catch potential unforeseen errors.
84
+ Note that this check only runs right now when explaining the margin of the model.
85
+
86
+ Returns
87
+ -------
88
+ array or list
89
+ For models with a single output this returns a matrix of SHAP values
90
+ (# samples x # features). Each row sums to the difference between the model output
91
+ for that
92
+ sample and the expected value of the model output (which is stored in the expected_value
93
+ attribute of the explainer when it is constant). For models with vector outputs this
94
+ returns
95
+ a list of such matrices, one for each output.
96
+ """
97
+ assert not approximate, "approximate not supported"
98
+
99
+ X, y, X_missing, flat_output, tree_limit, check_additivity = \
100
+ self._validate_inputs(X, y,
101
+ tree_limit,
102
+ check_additivity)
103
+ transform = self.model.get_transform()
104
+
105
+ # run the core algorithm using the C extension
106
+ assert_import("cext_gpu")
107
+ phi = np.zeros((X.shape[0], X.shape[1] + 1, self.model.num_outputs))
108
+ _cext_gpu.dense_tree_shap(
109
+ self.model.children_left, self.model.children_right, self.model.children_default,
110
+ self.model.features, self.model.thresholds, self.model.values,
111
+ self.model.node_sample_weight,
112
+ self.model.max_depth, X, X_missing, y, self.data, self.data_missing, tree_limit,
113
+ self.model.base_offset, phi, feature_perturbation_codes[self.feature_perturbation],
114
+ output_transform_codes[transform], False
115
+ )
116
+
117
+ out = self._get_shap_output(phi, flat_output)
118
+ if check_additivity and self.model.model_output == "raw":
119
+ self.assert_additivity(out, self.model.predict(X))
120
+
121
+ return out
122
+
123
+ def shap_interaction_values(self, X, y=None, tree_limit=None):
124
+ """ Estimate the SHAP interaction values for a set of samples.
125
+
126
+ Parameters
127
+ ----------
128
+ X : numpy.array, pandas.DataFrame or catboost.Pool (for catboost)
129
+ A matrix of samples (# samples x # features) on which to explain the model's output.
130
+
131
+ y : numpy.array
132
+ An array of label values for each sample. Used when explaining loss functions (not
133
+ yet supported).
134
+
135
+ tree_limit : None (default) or int
136
+ Limit the number of trees used by the model. By default None means no use the limit
137
+ of the
138
+ original model, and -1 means no limit.
139
+
140
+ Returns
141
+ -------
142
+ array or list
143
+ For models with a single output this returns a tensor of SHAP values
144
+ (# samples x # features x # features). The matrix (# features x # features) for each
145
+ sample sums
146
+ to the difference between the model output for that sample and the expected value of
147
+ the model output
148
+ (which is stored in the expected_value attribute of the explainer). Each row of this
149
+ matrix sums to the
150
+ SHAP value for that feature for that sample. The diagonal entries of the matrix
151
+ represent the
152
+ "main effect" of that feature on the prediction and the symmetric off-diagonal
153
+ entries represent the
154
+ interaction effects between all pairs of features for that sample. For models with
155
+ vector outputs
156
+ this returns a list of tensors, one for each output.
157
+ """
158
+
159
+ assert self.model.model_output == "raw", "Only model_output = \"raw\" is supported for " \
160
+ "SHAP interaction values right now!"
161
+ assert self.feature_perturbation != "interventional", 'feature_perturbation="interventional" is not yet supported for ' + \
162
+ 'interaction values. Use feature_perturbation="tree_path_dependent" instead.'
163
+ transform = "identity"
164
+
165
+ X, y, X_missing, flat_output, tree_limit, _ = self._validate_inputs(X, y, tree_limit,
166
+ False)
167
+ # run the core algorithm using the C extension
168
+ assert_import("cext_gpu")
169
+ phi = np.zeros((X.shape[0], X.shape[1] + 1, X.shape[1] + 1, self.model.num_outputs))
170
+ _cext_gpu.dense_tree_shap(
171
+ self.model.children_left, self.model.children_right, self.model.children_default,
172
+ self.model.features, self.model.thresholds, self.model.values,
173
+ self.model.node_sample_weight,
174
+ self.model.max_depth, X, X_missing, y, self.data, self.data_missing, tree_limit,
175
+ self.model.base_offset, phi, feature_perturbation_codes[self.feature_perturbation],
176
+ output_transform_codes[transform], True
177
+ )
178
+
179
+ return self._get_shap_interactions_output(phi, flat_output)
lib/shap/explainers/_gradient.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from packaging import version
6
+
7
+ from .._explanation import Explanation
8
+ from ..explainers._explainer import Explainer
9
+ from ..explainers.tf_utils import (
10
+ _get_graph,
11
+ _get_model_inputs,
12
+ _get_model_output,
13
+ _get_session,
14
+ )
15
+
16
+ keras = None
17
+ tf = None
18
+ torch = None
19
+
20
+
21
+ class GradientExplainer(Explainer):
22
+ """ Explains a model using expected gradients (an extension of integrated gradients).
23
+
24
+ Expected gradients an extension of the integrated gradients method (Sundararajan et al. 2017), a
25
+ feature attribution method designed for differentiable models based on an extension of Shapley
26
+ values to infinite player games (Aumann-Shapley values). Integrated gradients values are a bit
27
+ different from SHAP values, and require a single reference value to integrate from. As an adaptation
28
+ to make them approximate SHAP values, expected gradients reformulates the integral as an expectation
29
+ and combines that expectation with sampling reference values from the background dataset. This leads
30
+ to a single combined expectation of gradients that converges to attributions that sum to the
31
+ difference between the expected model output and the current output.
32
+
33
+ Examples
34
+ --------
35
+ See :ref:`Gradient Explainer Examples <gradient_explainer_examples>`
36
+ """
37
+
38
+ def __init__(self, model, data, session=None, batch_size=50, local_smoothing=0):
39
+ """ An explainer object for a differentiable model using a given background dataset.
40
+
41
+ Parameters
42
+ ----------
43
+ model : tf.keras.Model, (input : [tf.Tensor], output : tf.Tensor), torch.nn.Module, or a tuple
44
+ (model, layer), where both are torch.nn.Module objects
45
+
46
+ For TensorFlow this can be a model object, or a pair of TensorFlow tensors (or a list and
47
+ a tensor) that specifies the input and output of the model to be explained. Note that for
48
+ TensowFlow 2 you must pass a tensorflow function, not a tuple of input/output tensors).
49
+
50
+ For PyTorch this can be a nn.Module object (model), or a tuple (model, layer), where both
51
+ are nn.Module objects. The model is an nn.Module object which takes as input a tensor
52
+ (or list of tensors) of shape data, and returns a single dimensional output. If the input
53
+ is a tuple, the returned shap values will be for the input of the layer argument. layer must
54
+ be a layer in the model, i.e. model.conv2.
55
+
56
+ data : [numpy.array] or [pandas.DataFrame] or [torch.tensor]
57
+ The background dataset to use for integrating out features. Gradient explainer integrates
58
+ over these samples. The data passed here must match the input tensors given in the
59
+ first argument. Single element lists can be passed unwrapped.
60
+ """
61
+
62
+ # first, we need to find the framework
63
+ if type(model) is tuple:
64
+ a, b = model
65
+ try:
66
+ a.named_parameters()
67
+ framework = 'pytorch'
68
+ except Exception:
69
+ framework = 'tensorflow'
70
+ else:
71
+ try:
72
+ model.named_parameters()
73
+ framework = 'pytorch'
74
+ except Exception:
75
+ framework = 'tensorflow'
76
+
77
+ if isinstance(data, pd.DataFrame):
78
+ self.features = data.columns.values
79
+ else:
80
+ self.features = None
81
+
82
+ if framework == 'tensorflow':
83
+ self.explainer = _TFGradient(model, data, session, batch_size, local_smoothing)
84
+ elif framework == 'pytorch':
85
+ self.explainer = _PyTorchGradient(model, data, batch_size, local_smoothing)
86
+
87
+ def __call__(self, X, nsamples=200):
88
+ """ Return an explanation object for the model applied to X.
89
+
90
+ Parameters
91
+ ----------
92
+ X : list,
93
+ if framework == 'tensorflow': numpy.array, or pandas.DataFrame
94
+ if framework == 'pytorch': torch.tensor
95
+ A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to
96
+ explain the model's output.
97
+ nsamples : int
98
+ number of background samples
99
+ Returns
100
+ -------
101
+ shap.Explanation:
102
+ """
103
+ shap_values = self.shap_values(X, nsamples)
104
+ return Explanation(values=shap_values, data=X, feature_names=self.features)
105
+
106
+ def shap_values(self, X, nsamples=200, ranked_outputs=None, output_rank_order="max", rseed=None, return_variances=False):
107
+ """ Return the values for the model applied to X.
108
+
109
+ Parameters
110
+ ----------
111
+ X : list,
112
+ if framework == 'tensorflow': numpy.array, or pandas.DataFrame
113
+ if framework == 'pytorch': torch.tensor
114
+ A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to
115
+ explain the model's output.
116
+
117
+ ranked_outputs : None or int
118
+ If ranked_outputs is None then we explain all the outputs in a multi-output model. If
119
+ ranked_outputs is a positive integer then we only explain that many of the top model
120
+ outputs (where "top" is determined by output_rank_order). Note that this causes a pair
121
+ of values to be returned (shap_values, indexes), where shap_values is a list of numpy arrays
122
+ for each of the output ranks, and indexes is a matrix that tells for each sample which output
123
+ indexes were chosen as "top".
124
+
125
+ output_rank_order : "max", "min", "max_abs", or "custom"
126
+ How to order the model outputs when using ranked_outputs, either by maximum, minimum, or
127
+ maximum absolute value. If "custom" Then "ranked_outputs" contains a list of output nodes.
128
+
129
+ rseed : None or int
130
+ Seeding the randomness in shap value computation (background example choice,
131
+ interpolation between current and background example, smoothing).
132
+
133
+ Returns
134
+ -------
135
+ array or list
136
+ For a models with a single output this returns a tensor of SHAP values with the same shape
137
+ as X. For a model with multiple outputs this returns a list of SHAP value tensors, each of
138
+ which are the same shape as X. If ranked_outputs is None then this list of tensors matches
139
+ the number of model outputs. If ranked_outputs is a positive integer a pair is returned
140
+ (shap_values, indexes), where shap_values is a list of tensors with a length of
141
+ ranked_outputs, and indexes is a matrix that tells for each sample which output indexes
142
+ were chosen as "top".
143
+ """
144
+ return self.explainer.shap_values(X, nsamples, ranked_outputs, output_rank_order, rseed, return_variances)
145
+
146
+
147
+ class _TFGradient(Explainer):
148
+
149
+ def __init__(self, model, data, session=None, batch_size=50, local_smoothing=0):
150
+
151
+ # try and import keras and tensorflow
152
+ global tf, keras
153
+ if tf is None:
154
+ import tensorflow as tf
155
+ if version.parse(tf.__version__) < version.parse("1.4.0"):
156
+ warnings.warn("Your TensorFlow version is older than 1.4.0 and not supported.")
157
+ if keras is None:
158
+ try:
159
+ from tensorflow import keras
160
+ if version.parse(keras.__version__) < version.parse("2.1.0"):
161
+ warnings.warn("Your Keras version is older than 2.1.0 and not supported.")
162
+ except Exception:
163
+ pass
164
+
165
+ # determine the model inputs and outputs
166
+ self.model = model
167
+ self.model_inputs = _get_model_inputs(model)
168
+ self.model_output = _get_model_output(model)
169
+ assert not isinstance(self.model_output, list), "The model output to be explained must be a single tensor!"
170
+ assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!"
171
+ self.multi_output = True
172
+ if len(self.model_output.shape) == 1:
173
+ self.multi_output = False
174
+
175
+ # check if we have multiple inputs
176
+ self.multi_input = True
177
+ if not isinstance(self.model_inputs, list):
178
+ self.model_inputs = [self.model_inputs]
179
+ self.multi_input = len(self.model_inputs) > 1
180
+ if isinstance(data, pd.DataFrame):
181
+ data = [data.values]
182
+ if not isinstance(data, list):
183
+ data = [data]
184
+
185
+ self.data = data
186
+ self._num_vinputs = {}
187
+ self.batch_size = batch_size
188
+ self.local_smoothing = local_smoothing
189
+
190
+ if not tf.executing_eagerly():
191
+ self.session = _get_session(session)
192
+ self.graph = _get_graph(self)
193
+ # see if there is a keras operation we need to save
194
+ self.keras_phase_placeholder = None
195
+ for op in self.graph.get_operations():
196
+ if 'keras_learning_phase' in op.name:
197
+ self.keras_phase_placeholder = op.outputs[0]
198
+
199
+ # save the expected output of the model (commented out because self.data could be huge for GradientExpliner)
200
+ #self.expected_value = self.run(self.model_output, self.model_inputs, self.data).mean(0)
201
+
202
+ if not self.multi_output:
203
+ self.gradients = [None]
204
+ else:
205
+ self.gradients = [None for i in range(self.model_output.shape[1])]
206
+
207
+ def gradient(self, i):
208
+ global tf, keras
209
+
210
+ if self.gradients[i] is None:
211
+ if not tf.executing_eagerly():
212
+ out = self.model_output[:,i] if self.multi_output else self.model_output
213
+ self.gradients[i] = tf.gradients(out, self.model_inputs)
214
+ else:
215
+ @tf.function
216
+ def grad_graph(x):
217
+ phase = tf.keras.backend.learning_phase()
218
+ tf.keras.backend.set_learning_phase(0)
219
+
220
+ with tf.GradientTape(watch_accessed_variables=False) as tape:
221
+ tape.watch(x)
222
+ out = self.model(x)
223
+ if self.multi_output:
224
+ out = out[:,i]
225
+
226
+ x_grad = tape.gradient(out, x)
227
+
228
+ tf.keras.backend.set_learning_phase(phase)
229
+
230
+ return x_grad
231
+
232
+ self.gradients[i] = grad_graph
233
+
234
+ return self.gradients[i]
235
+
236
+ def shap_values(self, X, nsamples=200, ranked_outputs=None, output_rank_order="max", rseed=None, return_variances=False):
237
+ global tf, keras
238
+
239
+ import tensorflow as tf
240
+ import tensorflow.keras as keras
241
+
242
+ # check if we have multiple inputs
243
+ if not self.multi_input:
244
+ assert not isinstance(X, list), "Expected a single tensor model input!"
245
+ X = [X]
246
+ else:
247
+ assert isinstance(X, list), "Expected a list of model inputs!"
248
+ assert len(self.model_inputs) == len(X), "Number of model inputs does not match the number given!"
249
+
250
+ # rank and determine the model outputs that we will explain
251
+ if not tf.executing_eagerly():
252
+ model_output_values = self.run(self.model_output, self.model_inputs, X)
253
+ else:
254
+ model_output_values = self.run(self.model, self.model_inputs, X)
255
+ if ranked_outputs is not None and self.multi_output:
256
+ if output_rank_order == "max":
257
+ model_output_ranks = np.argsort(-model_output_values)
258
+ elif output_rank_order == "min":
259
+ model_output_ranks = np.argsort(model_output_values)
260
+ elif output_rank_order == "max_abs":
261
+ model_output_ranks = np.argsort(np.abs(model_output_values))
262
+ elif output_rank_order == "custom":
263
+ model_output_ranks = ranked_outputs
264
+ else:
265
+ emsg = "output_rank_order must be max, min, max_abs or custom!"
266
+ raise ValueError(emsg)
267
+
268
+ if output_rank_order in ["max", "min", "max_abs"]:
269
+ model_output_ranks = model_output_ranks[:,:ranked_outputs]
270
+ else:
271
+ model_output_ranks = np.tile(np.arange(len(self.gradients)), (X[0].shape[0], 1))
272
+
273
+ # compute the attributions
274
+ output_phis = []
275
+ output_phi_vars = []
276
+ samples_input = [np.zeros((nsamples,) + X[t].shape[1:], dtype=np.float32) for t in range(len(X))]
277
+ samples_delta = [np.zeros((nsamples,) + X[t].shape[1:], dtype=np.float32) for t in range(len(X))]
278
+ # use random seed if no argument given
279
+ if rseed is None:
280
+ rseed = np.random.randint(0, 1e6)
281
+
282
+ for i in range(model_output_ranks.shape[1]):
283
+ np.random.seed(rseed) # so we get the same noise patterns for each output class
284
+ phis = []
285
+ phi_vars = []
286
+ for k in range(len(X)):
287
+ phis.append(np.zeros(X[k].shape))
288
+ phi_vars.append(np.zeros(X[k].shape))
289
+ for j in range(X[0].shape[0]):
290
+
291
+ # fill in the samples arrays
292
+ for k in range(nsamples):
293
+ rind = np.random.choice(self.data[0].shape[0])
294
+ t = np.random.uniform()
295
+ for u in range(len(X)):
296
+ if self.local_smoothing > 0:
297
+ x = X[u][j] + np.random.randn(*X[u][j].shape) * self.local_smoothing
298
+ else:
299
+ x = X[u][j]
300
+ samples_input[u][k] = t * x + (1 - t) * self.data[u][rind]
301
+ samples_delta[u][k] = x - self.data[u][rind]
302
+
303
+ # compute the gradients at all the sample points
304
+ find = model_output_ranks[j,i]
305
+ grads = []
306
+ for b in range(0, nsamples, self.batch_size):
307
+ batch = [samples_input[a][b:min(b+self.batch_size,nsamples)] for a in range(len(X))]
308
+ grads.append(self.run(self.gradient(find), self.model_inputs, batch))
309
+ grad = [np.concatenate([g[a] for g in grads], 0) for a in range(len(X))]
310
+
311
+ # assign the attributions to the right part of the output arrays
312
+ for a in range(len(X)):
313
+ samples = grad[a] * samples_delta[a]
314
+ phis[a][j] = samples.mean(0)
315
+ phi_vars[a][j] = samples.var(0) / np.sqrt(samples.shape[0]) # estimate variance of means
316
+
317
+ # TODO: this could be avoided by integrating between endpoints if no local smoothing is used
318
+ # correct the sum of the values to equal the output of the model using a linear
319
+ # regression model with priors of the coefficients equal to the estimated variances for each
320
+ # value (note that 1e-6 is designed to increase the weight of the sample and so closely
321
+ # match the correct sum)
322
+ # if False and self.local_smoothing == 0: # disabled right now to make sure it doesn't mask problems
323
+ # phis_sum = np.sum([phis[l][j].sum() for l in range(len(X))])
324
+ # phi_vars_s = np.stack([phi_vars[l][j] for l in range(len(X))], 0).flatten()
325
+ # if self.multi_output:
326
+ # sum_error = model_output_values[j,find] - phis_sum - self.expected_value[find]
327
+ # else:
328
+ # sum_error = model_output_values[j] - phis_sum - self.expected_value
329
+
330
+ # # this is a ridge regression with one sample of all ones with sum_error as the label
331
+ # # and 1/v as the ridge penalties. This simplified (and stable) form comes from the
332
+ # # Sherman-Morrison formula
333
+ # v = (phi_vars_s / phi_vars_s.max()) * 1e6
334
+ # adj = sum_error * (v - (v * v.sum()) / (1 + v.sum()))
335
+
336
+ # # add the adjustment to the output so the sum matches
337
+ # offset = 0
338
+ # for l in range(len(X)):
339
+ # s = np.prod(phis[l][j].shape)
340
+ # phis[l][j] += adj[offset:offset+s].reshape(phis[l][j].shape)
341
+ # offset += s
342
+
343
+ output_phis.append(phis[0] if not self.multi_input else phis)
344
+ output_phi_vars.append(phi_vars[0] if not self.multi_input else phi_vars)
345
+ if not self.multi_output:
346
+ if return_variances:
347
+ return output_phis[0], output_phi_vars[0]
348
+ else:
349
+ return output_phis[0]
350
+ elif ranked_outputs is not None:
351
+ if return_variances:
352
+ return output_phis, output_phi_vars, model_output_ranks
353
+ else:
354
+ return output_phis, model_output_ranks
355
+ else:
356
+ if return_variances:
357
+ return output_phis, output_phi_vars
358
+ else:
359
+ return output_phis
360
+
361
+ def run(self, out, model_inputs, X):
362
+ global tf, keras
363
+
364
+ if not tf.executing_eagerly():
365
+ feed_dict = dict(zip(model_inputs, X))
366
+ if self.keras_phase_placeholder is not None:
367
+ feed_dict[self.keras_phase_placeholder] = 0
368
+ return self.session.run(out, feed_dict)
369
+ else:
370
+ # build inputs that are correctly shaped, typed, and tf-wrapped
371
+ inputs = []
372
+ for i in range(len(X)):
373
+ shape = list(self.model_inputs[i].shape)
374
+ shape[0] = -1
375
+ v = tf.constant(X[i].reshape(shape), dtype=self.model_inputs[i].dtype)
376
+ inputs.append(v)
377
+ return out(inputs)
378
+
379
+
380
+ class _PyTorchGradient(Explainer):
381
+
382
+ def __init__(self, model, data, batch_size=50, local_smoothing=0):
383
+
384
+ # try and import pytorch
385
+ global torch
386
+ if torch is None:
387
+ import torch
388
+ if version.parse(torch.__version__) < version.parse("0.4"):
389
+ warnings.warn("Your PyTorch version is older than 0.4 and not supported.")
390
+
391
+ # check if we have multiple inputs
392
+ self.multi_input = False
393
+ if isinstance(data, list):
394
+ self.multi_input = True
395
+ if not isinstance(data, list):
396
+ data = [data]
397
+
398
+ # for consistency, the method signature calls for data as the model input.
399
+ # However, within this class, self.model_inputs is the input (i.e. the data passed by the user)
400
+ # and self.data is the background data for the layer we want to assign importances to. If this layer is
401
+ # the input, then self.data = self.model_inputs
402
+ self.model_inputs = data
403
+ self.batch_size = batch_size
404
+ self.local_smoothing = local_smoothing
405
+
406
+ self.layer = None
407
+ self.input_handle = None
408
+ self.interim = False
409
+ if type(model) == tuple:
410
+ self.interim = True
411
+ model, layer = model
412
+ model = model.eval()
413
+ self.add_handles(layer)
414
+ self.layer = layer
415
+
416
+ # now, if we are taking an interim layer, the 'data' is going to be the input
417
+ # of the interim layer; we will capture this using a forward hook
418
+ with torch.no_grad():
419
+ _ = model(*data)
420
+ interim_inputs = self.layer.target_input
421
+ if type(interim_inputs) is tuple:
422
+ # this should always be true, but just to be safe
423
+ self.data = [i.clone().detach() for i in interim_inputs]
424
+ else:
425
+ self.data = [interim_inputs.clone().detach()]
426
+ else:
427
+ self.data = data
428
+ self.model = model.eval()
429
+
430
+ multi_output = False
431
+ outputs = self.model(*self.model_inputs)
432
+ if len(outputs.shape) > 1 and outputs.shape[1] > 1:
433
+ multi_output = True
434
+ self.multi_output = multi_output
435
+
436
+ if not self.multi_output:
437
+ self.gradients = [None]
438
+ else:
439
+ self.gradients = [None for i in range(outputs.shape[1])]
440
+
441
+ def gradient(self, idx, inputs):
442
+ self.model.zero_grad()
443
+ X = [x.requires_grad_() for x in inputs]
444
+ outputs = self.model(*X)
445
+ selected = [val for val in outputs[:, idx]]
446
+ if self.input_handle is not None:
447
+ interim_inputs = self.layer.target_input
448
+ grads = [torch.autograd.grad(selected, input,
449
+ retain_graph=True if idx + 1 < len(interim_inputs) else None)[0].cpu().numpy()
450
+ for idx, input in enumerate(interim_inputs)]
451
+ del self.layer.target_input
452
+ else:
453
+ grads = [torch.autograd.grad(selected, x,
454
+ retain_graph=True if idx + 1 < len(X) else None)[0].cpu().numpy()
455
+ for idx, x in enumerate(X)]
456
+ return grads
457
+
458
+ @staticmethod
459
+ def get_interim_input(self, input, output):
460
+ try:
461
+ del self.target_input
462
+ except AttributeError:
463
+ pass
464
+ setattr(self, 'target_input', input)
465
+
466
+ def add_handles(self, layer):
467
+ input_handle = layer.register_forward_hook(self.get_interim_input)
468
+ self.input_handle = input_handle
469
+
470
+ def shap_values(self, X, nsamples=200, ranked_outputs=None, output_rank_order="max", rseed=None, return_variances=False):
471
+
472
+ # X ~ self.model_input
473
+ # X_data ~ self.data
474
+
475
+ # check if we have multiple inputs
476
+ if not self.multi_input:
477
+ assert not isinstance(X, list), "Expected a single tensor model input!"
478
+ X = [X]
479
+ else:
480
+ assert isinstance(X, list), "Expected a list of model inputs!"
481
+
482
+ if ranked_outputs is not None and self.multi_output:
483
+ with torch.no_grad():
484
+ model_output_values = self.model(*X)
485
+ # rank and determine the model outputs that we will explain
486
+ if output_rank_order == "max":
487
+ _, model_output_ranks = torch.sort(model_output_values, descending=True)
488
+ elif output_rank_order == "min":
489
+ _, model_output_ranks = torch.sort(model_output_values, descending=False)
490
+ elif output_rank_order == "max_abs":
491
+ _, model_output_ranks = torch.sort(torch.abs(model_output_values), descending=True)
492
+ else:
493
+ emsg = "output_rank_order must be max, min, or max_abs!"
494
+ raise ValueError(emsg)
495
+ model_output_ranks = model_output_ranks[:, :ranked_outputs]
496
+ else:
497
+ model_output_ranks = (torch.ones((X[0].shape[0], len(self.gradients))).int() *
498
+ torch.arange(0, len(self.gradients)).int())
499
+
500
+ # if a cleanup happened, we need to add the handles back
501
+ # this allows shap_values to be called multiple times, but the model to be
502
+ # 'clean' at the end of each run for other uses
503
+ if self.input_handle is None and self.interim is True:
504
+ self.add_handles(self.layer)
505
+
506
+ # compute the attributions
507
+ X_batches = X[0].shape[0]
508
+ output_phis = []
509
+ output_phi_vars = []
510
+ # samples_input = input to the model
511
+ # samples_delta = (x - x') for the input being explained - may be an interim input
512
+ samples_input = [torch.zeros((nsamples,) + X[t].shape[1:], device=X[t].device) for t in range(len(X))]
513
+ samples_delta = [np.zeros((nsamples, ) + self.data[t].shape[1:]) for t in range(len(self.data))]
514
+
515
+ # use random seed if no argument given
516
+ if rseed is None:
517
+ rseed = np.random.randint(0, 1e6)
518
+
519
+ for i in range(model_output_ranks.shape[1]):
520
+ np.random.seed(rseed) # so we get the same noise patterns for each output class
521
+ phis = []
522
+ phi_vars = []
523
+ for k in range(len(self.data)):
524
+ # for each of the inputs being explained - may be an interim input
525
+ phis.append(np.zeros((X_batches,) + self.data[k].shape[1:]))
526
+ phi_vars.append(np.zeros((X_batches, ) + self.data[k].shape[1:]))
527
+ for j in range(X[0].shape[0]):
528
+ # fill in the samples arrays
529
+ for k in range(nsamples):
530
+ rind = np.random.choice(self.data[0].shape[0])
531
+ t = np.random.uniform()
532
+ for a in range(len(X)):
533
+ if self.local_smoothing > 0:
534
+ # local smoothing is added to the base input, unlike in the TF gradient explainer
535
+ x = X[a][j].clone().detach() + torch.empty(X[a][j].shape, device=X[a].device).normal_() \
536
+ * self.local_smoothing
537
+ else:
538
+ x = X[a][j].clone().detach()
539
+ samples_input[a][k] = (t * x + (1 - t) * (self.model_inputs[a][rind]).clone().detach()).\
540
+ clone().detach()
541
+ if self.input_handle is None:
542
+ samples_delta[a][k] = (x - (self.data[a][rind]).clone().detach()).cpu().numpy()
543
+
544
+ if self.interim is True:
545
+ with torch.no_grad():
546
+ _ = self.model(*[samples_input[a][k].unsqueeze(0) for a in range(len(X))])
547
+ interim_inputs = self.layer.target_input
548
+ del self.layer.target_input
549
+ if type(interim_inputs) is tuple:
550
+ if type(interim_inputs) is tuple:
551
+ # this should always be true, but just to be safe
552
+ for a in range(len(interim_inputs)):
553
+ samples_delta[a][k] = interim_inputs[a].cpu().numpy()
554
+ else:
555
+ samples_delta[0][k] = interim_inputs.cpu().numpy()
556
+
557
+ # compute the gradients at all the sample points
558
+ find = model_output_ranks[j, i]
559
+ grads = []
560
+ for b in range(0, nsamples, self.batch_size):
561
+ batch = [samples_input[c][b:min(b+self.batch_size,nsamples)].clone().detach() for c in range(len(X))]
562
+ grads.append(self.gradient(find, batch))
563
+ grad = [np.concatenate([g[z] for g in grads], 0) for z in range(len(self.data))]
564
+ # assign the attributions to the right part of the output arrays
565
+ for t in range(len(self.data)):
566
+ samples = grad[t] * samples_delta[t]
567
+ phis[t][j] = samples.mean(0)
568
+ phi_vars[t][j] = samples.var(0) / np.sqrt(samples.shape[0]) # estimate variance of means
569
+
570
+ output_phis.append(phis[0] if len(self.data) == 1 else phis)
571
+ output_phi_vars.append(phi_vars[0] if not self.multi_input else phi_vars)
572
+ # cleanup: remove the handles, if they were added
573
+ if self.input_handle is not None:
574
+ self.input_handle.remove()
575
+ self.input_handle = None
576
+ # note: the target input attribute is deleted in the loop
577
+
578
+ if not self.multi_output:
579
+ if return_variances:
580
+ return output_phis[0], output_phi_vars[0]
581
+ else:
582
+ return output_phis[0]
583
+ elif ranked_outputs is not None:
584
+ if return_variances:
585
+ return output_phis, output_phi_vars, model_output_ranks
586
+ else:
587
+ return output_phis, model_output_ranks
588
+ else:
589
+ if return_variances:
590
+ return output_phis, output_phi_vars
591
+ else:
592
+ return output_phis
lib/shap/explainers/_kernel.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import gc
3
+ import itertools
4
+ import logging
5
+ import time
6
+ import warnings
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import scipy.sparse
11
+ import sklearn
12
+ from packaging import version
13
+ from scipy.special import binom
14
+ from sklearn.linear_model import Lasso, LassoLarsIC, lars_path
15
+ from sklearn.pipeline import make_pipeline
16
+ from sklearn.preprocessing import StandardScaler
17
+ from tqdm.auto import tqdm
18
+
19
+ from .._explanation import Explanation
20
+ from ..utils import safe_isinstance
21
+ from ..utils._exceptions import DimensionError
22
+ from ..utils._legacy import (
23
+ DenseData,
24
+ SparseData,
25
+ convert_to_data,
26
+ convert_to_instance,
27
+ convert_to_instance_with_index,
28
+ convert_to_link,
29
+ convert_to_model,
30
+ match_instance_to_data,
31
+ match_model_to_data,
32
+ )
33
+ from ._explainer import Explainer
34
+
35
+ log = logging.getLogger('shap')
36
+
37
+
38
+ class KernelExplainer(Explainer):
39
+ """Uses the Kernel SHAP method to explain the output of any function.
40
+
41
+ Kernel SHAP is a method that uses a special weighted linear regression
42
+ to compute the importance of each feature. The computed importance values
43
+ are Shapley values from game theory and also coefficients from a local linear
44
+ regression.
45
+
46
+ Parameters
47
+ ----------
48
+ model : function or iml.Model
49
+ User supplied function that takes a matrix of samples (# samples x # features) and
50
+ computes the output of the model for those samples. The output can be a vector
51
+ (# samples) or a matrix (# samples x # model outputs).
52
+
53
+ data : numpy.array or pandas.DataFrame or shap.common.DenseData or any scipy.sparse matrix
54
+ The background dataset to use for integrating out features. To determine the impact
55
+ of a feature, that feature is set to "missing" and the change in the model output
56
+ is observed. Since most models aren't designed to handle arbitrary missing data at test
57
+ time, we simulate "missing" by replacing the feature with the values it takes in the
58
+ background dataset. So if the background dataset is a simple sample of all zeros, then
59
+ we would approximate a feature being missing by setting it to zero. For small problems,
60
+ this background dataset can be the whole training set, but for larger problems consider
61
+ using a single reference value or using the ``kmeans`` function to summarize the dataset.
62
+ Note: for the sparse case, we accept any sparse matrix but convert to lil format for
63
+ performance.
64
+
65
+ feature_names : list
66
+ The names of the features in the background dataset. If the background dataset is
67
+ supplied as a pandas.DataFrame, then ``feature_names`` can be set to ``None`` (default),
68
+ and the feature names will be taken as the column names of the dataframe.
69
+
70
+ link : "identity" or "logit"
71
+ A generalized linear model link to connect the feature importance values to the model
72
+ output. Since the feature importance values, phi, sum up to the model output, it often makes
73
+ sense to connect them to the output with a link function where link(output) = sum(phi).
74
+ Default is "identity" (a no-op).
75
+ If the model output is a probability, then "logit" can be used to transform the SHAP values
76
+ into log-odds units.
77
+
78
+ Examples
79
+ --------
80
+ See :ref:`Kernel Explainer Examples <kernel_explainer_examples>`.
81
+ """
82
+
83
+ def __init__(self, model, data, feature_names=None, link="identity", **kwargs):
84
+
85
+ if feature_names is not None:
86
+ self.data_feature_names=feature_names
87
+ elif isinstance(data, pd.DataFrame):
88
+ self.data_feature_names = list(data.columns)
89
+
90
+ # convert incoming inputs to standardized iml objects
91
+ self.link = convert_to_link(link)
92
+ self.keep_index = kwargs.get("keep_index", False)
93
+ self.keep_index_ordered = kwargs.get("keep_index_ordered", False)
94
+ self.model = convert_to_model(model, keep_index=self.keep_index)
95
+ self.data = convert_to_data(data, keep_index=self.keep_index)
96
+ model_null = match_model_to_data(self.model, self.data)
97
+
98
+ # enforce our current input type limitations
99
+ if not isinstance(self.data, (DenseData, SparseData)):
100
+ emsg = "Shap explainer only supports the DenseData and SparseData input currently."
101
+ raise TypeError(emsg)
102
+ if self.data.transposed:
103
+ emsg = "Shap explainer does not support transposed DenseData or SparseData currently."
104
+ raise DimensionError(emsg)
105
+
106
+ # warn users about large background data sets
107
+ if len(self.data.weights) > 100:
108
+ log.warning("Using " + str(len(self.data.weights)) + " background data samples could cause " +
109
+ "slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to " +
110
+ "summarize the background as K samples.")
111
+
112
+ # init our parameters
113
+ self.N = self.data.data.shape[0]
114
+ self.P = self.data.data.shape[1]
115
+ self.linkfv = np.vectorize(self.link.f)
116
+ self.nsamplesAdded = 0
117
+ self.nsamplesRun = 0
118
+
119
+ # find E_x[f(x)]
120
+ if isinstance(model_null, (pd.DataFrame, pd.Series)):
121
+ model_null = np.squeeze(model_null.values)
122
+ if safe_isinstance(model_null, "tensorflow.python.framework.ops.EagerTensor"):
123
+ model_null = model_null.numpy()
124
+ self.fnull = np.sum((model_null.T * self.data.weights).T, 0)
125
+ self.expected_value = self.linkfv(self.fnull)
126
+
127
+ # see if we have a vector output
128
+ self.vector_out = True
129
+ if len(self.fnull.shape) == 0:
130
+ self.vector_out = False
131
+ self.fnull = np.array([self.fnull])
132
+ self.D = 1
133
+ self.expected_value = float(self.expected_value)
134
+ else:
135
+ self.D = self.fnull.shape[0]
136
+
137
+ def __call__(self, X):
138
+
139
+ start_time = time.time()
140
+
141
+ if isinstance(X, pd.DataFrame):
142
+ feature_names = list(X.columns)
143
+ else:
144
+ feature_names = getattr(self, "data_feature_names", None)
145
+
146
+ v = self.shap_values(X)
147
+ if isinstance(v, list):
148
+ v = np.stack(v, axis=-1) # put outputs at the end
149
+
150
+ # the explanation object expects an expected value for each row
151
+ if hasattr(self.expected_value, "__len__"):
152
+ ev_tiled = np.tile(self.expected_value, (v.shape[0],1))
153
+ else:
154
+ ev_tiled = np.tile(self.expected_value, v.shape[0])
155
+
156
+ return Explanation(
157
+ v,
158
+ base_values=ev_tiled,
159
+ data=X.to_numpy() if isinstance(X, pd.DataFrame) else X,
160
+ feature_names=feature_names,
161
+ compute_time=time.time() - start_time,
162
+ )
163
+
164
+ def shap_values(self, X, **kwargs):
165
+ """ Estimate the SHAP values for a set of samples.
166
+
167
+ Parameters
168
+ ----------
169
+ X : numpy.array or pandas.DataFrame or any scipy.sparse matrix
170
+ A matrix of samples (# samples x # features) on which to explain the model's output.
171
+
172
+ nsamples : "auto" or int
173
+ Number of times to re-evaluate the model when explaining each prediction. More samples
174
+ lead to lower variance estimates of the SHAP values. The "auto" setting uses
175
+ `nsamples = 2 * X.shape[1] + 2048`.
176
+
177
+ l1_reg : "num_features(int)", "auto" (default for now, but deprecated), "aic", "bic", or float
178
+ The l1 regularization to use for feature selection (the estimation procedure is based on
179
+ a debiased lasso). The auto option currently uses "aic" when less that 20% of the possible sample
180
+ space is enumerated, otherwise it uses no regularization. THE BEHAVIOR OF "auto" WILL CHANGE
181
+ in a future version to be based on num_features instead of AIC.
182
+ The "aic" and "bic" options use the AIC and BIC rules for regularization.
183
+ Using "num_features(int)" selects a fix number of top features. Passing a float directly sets the
184
+ "alpha" parameter of the sklearn.linear_model.Lasso model used for feature selection.
185
+
186
+ gc_collect : bool
187
+ Run garbage collection after each explanation round. Sometime needed for memory intensive explanations (default False).
188
+
189
+ Returns
190
+ -------
191
+ array or list
192
+ For models with a single output this returns a matrix of SHAP values
193
+ (# samples x # features). Each row sums to the difference between the model output for that
194
+ sample and the expected value of the model output (which is stored as expected_value
195
+ attribute of the explainer). For models with vector outputs this returns a list
196
+ of such matrices, one for each output.
197
+ """
198
+
199
+ # convert dataframes
200
+ if isinstance(X, pd.Series):
201
+ X = X.values
202
+ elif isinstance(X, pd.DataFrame):
203
+ if self.keep_index:
204
+ index_value = X.index.values
205
+ index_name = X.index.name
206
+ column_name = list(X.columns)
207
+ X = X.values
208
+
209
+ x_type = str(type(X))
210
+ arr_type = "'numpy.ndarray'>"
211
+ # if sparse, convert to lil for performance
212
+ if scipy.sparse.issparse(X) and not scipy.sparse.isspmatrix_lil(X):
213
+ X = X.tolil()
214
+ assert x_type.endswith(arr_type) or scipy.sparse.isspmatrix_lil(X), "Unknown instance type: " + x_type
215
+
216
+ # single instance
217
+ if len(X.shape) == 1:
218
+ data = X.reshape((1, X.shape[0]))
219
+ if self.keep_index:
220
+ data = convert_to_instance_with_index(data, column_name, index_name, index_value)
221
+ explanation = self.explain(data, **kwargs)
222
+
223
+ # vector-output
224
+ s = explanation.shape
225
+ if len(s) == 2:
226
+ outs = [np.zeros(s[0]) for j in range(s[1])]
227
+ for j in range(s[1]):
228
+ outs[j] = explanation[:, j]
229
+ return outs
230
+
231
+ # single-output
232
+ else:
233
+ out = np.zeros(s[0])
234
+ out[:] = explanation
235
+ return out
236
+
237
+ # explain the whole dataset
238
+ elif len(X.shape) == 2:
239
+ explanations = []
240
+ for i in tqdm(range(X.shape[0]), disable=kwargs.get("silent", False)):
241
+ data = X[i:i + 1, :]
242
+ if self.keep_index:
243
+ data = convert_to_instance_with_index(data, column_name, index_value[i:i + 1], index_name)
244
+ explanations.append(self.explain(data, **kwargs))
245
+ if kwargs.get("gc_collect", False):
246
+ gc.collect()
247
+
248
+ # vector-output
249
+ s = explanations[0].shape
250
+ if len(s) == 2:
251
+ outs = [np.zeros((X.shape[0], s[0])) for j in range(s[1])]
252
+ for i in range(X.shape[0]):
253
+ for j in range(s[1]):
254
+ outs[j][i] = explanations[i][:, j]
255
+ return outs
256
+
257
+ # single-output
258
+ else:
259
+ out = np.zeros((X.shape[0], s[0]))
260
+ for i in range(X.shape[0]):
261
+ out[i] = explanations[i]
262
+ return out
263
+
264
+ else:
265
+ emsg = "Instance must have 1 or 2 dimensions!"
266
+ raise DimensionError(emsg)
267
+
268
+ def explain(self, incoming_instance, **kwargs):
269
+ # convert incoming input to a standardized iml object
270
+ instance = convert_to_instance(incoming_instance)
271
+ match_instance_to_data(instance, self.data)
272
+
273
+ # find the feature groups we will test. If a feature does not change from its
274
+ # current value then we know it doesn't impact the model
275
+ self.varyingInds = self.varying_groups(instance.x)
276
+ if self.data.groups is None:
277
+ self.varyingFeatureGroups = np.array([i for i in self.varyingInds])
278
+ self.M = self.varyingFeatureGroups.shape[0]
279
+ else:
280
+ self.varyingFeatureGroups = [self.data.groups[i] for i in self.varyingInds]
281
+ self.M = len(self.varyingFeatureGroups)
282
+ groups = self.data.groups
283
+ # convert to numpy array as it is much faster if not jagged array (all groups of same length)
284
+ if self.varyingFeatureGroups and all(len(groups[i]) == len(groups[0]) for i in self.varyingInds):
285
+ self.varyingFeatureGroups = np.array(self.varyingFeatureGroups)
286
+ # further performance optimization in case each group has a single value
287
+ if self.varyingFeatureGroups.shape[1] == 1:
288
+ self.varyingFeatureGroups = self.varyingFeatureGroups.flatten()
289
+
290
+ # find f(x)
291
+ if self.keep_index:
292
+ model_out = self.model.f(instance.convert_to_df())
293
+ else:
294
+ model_out = self.model.f(instance.x)
295
+ if isinstance(model_out, (pd.DataFrame, pd.Series)):
296
+ model_out = model_out.values
297
+ self.fx = model_out[0]
298
+
299
+ if not self.vector_out:
300
+ self.fx = np.array([self.fx])
301
+
302
+ # if no features vary then no feature has an effect
303
+ if self.M == 0:
304
+ phi = np.zeros((self.data.groups_size, self.D))
305
+ phi_var = np.zeros((self.data.groups_size, self.D))
306
+
307
+ # if only one feature varies then it has all the effect
308
+ elif self.M == 1:
309
+ phi = np.zeros((self.data.groups_size, self.D))
310
+ phi_var = np.zeros((self.data.groups_size, self.D))
311
+ diff = self.link.f(self.fx) - self.link.f(self.fnull)
312
+ for d in range(self.D):
313
+ phi[self.varyingInds[0],d] = diff[d]
314
+
315
+ # if more than one feature varies then we have to do real work
316
+ else:
317
+ self.l1_reg = kwargs.get("l1_reg", "auto")
318
+
319
+ # pick a reasonable number of samples if the user didn't specify how many they wanted
320
+ self.nsamples = kwargs.get("nsamples", "auto")
321
+ if self.nsamples == "auto":
322
+ self.nsamples = 2 * self.M + 2**11
323
+
324
+ # if we have enough samples to enumerate all subsets then ignore the unneeded samples
325
+ self.max_samples = 2 ** 30
326
+ if self.M <= 30:
327
+ self.max_samples = 2 ** self.M - 2
328
+ if self.nsamples > self.max_samples:
329
+ self.nsamples = self.max_samples
330
+
331
+ # reserve space for some of our computations
332
+ self.allocate()
333
+
334
+ # weight the different subset sizes
335
+ num_subset_sizes = int(np.ceil((self.M - 1) / 2.0))
336
+ num_paired_subset_sizes = int(np.floor((self.M - 1) / 2.0))
337
+ weight_vector = np.array([(self.M - 1.0) / (i * (self.M - i)) for i in range(1, num_subset_sizes + 1)])
338
+ weight_vector[:num_paired_subset_sizes] *= 2
339
+ weight_vector /= np.sum(weight_vector)
340
+ log.debug(f"{weight_vector = }")
341
+ log.debug(f"{num_subset_sizes = }")
342
+ log.debug(f"{num_paired_subset_sizes = }")
343
+ log.debug(f"{self.M = }")
344
+
345
+ # fill out all the subset sizes we can completely enumerate
346
+ # given nsamples*remaining_weight_vector[subset_size]
347
+ num_full_subsets = 0
348
+ num_samples_left = self.nsamples
349
+ group_inds = np.arange(self.M, dtype='int64')
350
+ mask = np.zeros(self.M)
351
+ remaining_weight_vector = copy.copy(weight_vector)
352
+ for subset_size in range(1, num_subset_sizes + 1):
353
+
354
+ # determine how many subsets (and their complements) are of the current size
355
+ nsubsets = binom(self.M, subset_size)
356
+ if subset_size <= num_paired_subset_sizes:
357
+ nsubsets *= 2
358
+ log.debug(f"{subset_size = }")
359
+ log.debug(f"{nsubsets = }")
360
+ log.debug("self.nsamples*weight_vector[subset_size-1] = {}".format(
361
+ num_samples_left * remaining_weight_vector[subset_size - 1]))
362
+ log.debug("self.nsamples*weight_vector[subset_size-1]/nsubsets = {}".format(
363
+ num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets))
364
+
365
+ # see if we have enough samples to enumerate all subsets of this size
366
+ if num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets >= 1.0 - 1e-8:
367
+ num_full_subsets += 1
368
+ num_samples_left -= nsubsets
369
+
370
+ # rescale what's left of the remaining weight vector to sum to 1
371
+ if remaining_weight_vector[subset_size - 1] < 1.0:
372
+ remaining_weight_vector /= (1 - remaining_weight_vector[subset_size - 1])
373
+
374
+ # add all the samples of the current subset size
375
+ w = weight_vector[subset_size - 1] / binom(self.M, subset_size)
376
+ if subset_size <= num_paired_subset_sizes:
377
+ w /= 2.0
378
+ for inds in itertools.combinations(group_inds, subset_size):
379
+ mask[:] = 0.0
380
+ mask[np.array(inds, dtype='int64')] = 1.0
381
+ self.addsample(instance.x, mask, w)
382
+ if subset_size <= num_paired_subset_sizes:
383
+ mask[:] = np.abs(mask - 1)
384
+ self.addsample(instance.x, mask, w)
385
+ else:
386
+ break
387
+ log.info(f"{num_full_subsets = }")
388
+
389
+ # add random samples from what is left of the subset space
390
+ nfixed_samples = self.nsamplesAdded
391
+ samples_left = self.nsamples - self.nsamplesAdded
392
+ log.debug(f"{samples_left = }")
393
+ if num_full_subsets != num_subset_sizes:
394
+ remaining_weight_vector = copy.copy(weight_vector)
395
+ remaining_weight_vector[:num_paired_subset_sizes] /= 2 # because we draw two samples each below
396
+ remaining_weight_vector = remaining_weight_vector[num_full_subsets:]
397
+ remaining_weight_vector /= np.sum(remaining_weight_vector)
398
+ log.info(f"{remaining_weight_vector = }")
399
+ log.info(f"{num_paired_subset_sizes = }")
400
+ ind_set = np.random.choice(len(remaining_weight_vector), 4 * samples_left, p=remaining_weight_vector)
401
+ ind_set_pos = 0
402
+ used_masks = {}
403
+ while samples_left > 0 and ind_set_pos < len(ind_set):
404
+ mask.fill(0.0)
405
+ ind = ind_set[ind_set_pos] # we call np.random.choice once to save time and then just read it here
406
+ ind_set_pos += 1
407
+ subset_size = ind + num_full_subsets + 1
408
+ mask[np.random.permutation(self.M)[:subset_size]] = 1.0
409
+
410
+ # only add the sample if we have not seen it before, otherwise just
411
+ # increment a previous sample's weight
412
+ mask_tuple = tuple(mask)
413
+ new_sample = False
414
+ if mask_tuple not in used_masks:
415
+ new_sample = True
416
+ used_masks[mask_tuple] = self.nsamplesAdded
417
+ samples_left -= 1
418
+ self.addsample(instance.x, mask, 1.0)
419
+ else:
420
+ self.kernelWeights[used_masks[mask_tuple]] += 1.0
421
+
422
+ # add the compliment sample
423
+ if samples_left > 0 and subset_size <= num_paired_subset_sizes:
424
+ mask[:] = np.abs(mask - 1)
425
+
426
+ # only add the sample if we have not seen it before, otherwise just
427
+ # increment a previous sample's weight
428
+ if new_sample:
429
+ samples_left -= 1
430
+ self.addsample(instance.x, mask, 1.0)
431
+ else:
432
+ # we know the compliment sample is the next one after the original sample, so + 1
433
+ self.kernelWeights[used_masks[mask_tuple] + 1] += 1.0
434
+
435
+ # normalize the kernel weights for the random samples to equal the weight left after
436
+ # the fixed enumerated samples have been already counted
437
+ weight_left = np.sum(weight_vector[num_full_subsets:])
438
+ log.info(f"{weight_left = }")
439
+ self.kernelWeights[nfixed_samples:] *= weight_left / self.kernelWeights[nfixed_samples:].sum()
440
+
441
+ # execute the model on the synthetic samples we have created
442
+ self.run()
443
+
444
+ # solve then expand the feature importance (Shapley value) vector to contain the non-varying features
445
+ phi = np.zeros((self.data.groups_size, self.D))
446
+ phi_var = np.zeros((self.data.groups_size, self.D))
447
+ for d in range(self.D):
448
+ vphi, vphi_var = self.solve(self.nsamples / self.max_samples, d)
449
+ phi[self.varyingInds, d] = vphi
450
+ phi_var[self.varyingInds, d] = vphi_var
451
+
452
+ if not self.vector_out:
453
+ phi = np.squeeze(phi, axis=1)
454
+ phi_var = np.squeeze(phi_var, axis=1)
455
+
456
+ return phi
457
+
458
+ @staticmethod
459
+ def not_equal(i, j):
460
+ number_types = (int, float, np.number)
461
+ if isinstance(i, number_types) and isinstance(j, number_types):
462
+ return 0 if np.isclose(i, j, equal_nan=True) else 1
463
+ else:
464
+ return 0 if i == j else 1
465
+
466
+ def varying_groups(self, x):
467
+ if not scipy.sparse.issparse(x):
468
+ varying = np.zeros(self.data.groups_size)
469
+ for i in range(0, self.data.groups_size):
470
+ inds = self.data.groups[i]
471
+ x_group = x[0, inds]
472
+ if scipy.sparse.issparse(x_group):
473
+ if all(j not in x.nonzero()[1] for j in inds):
474
+ varying[i] = False
475
+ continue
476
+ x_group = x_group.todense()
477
+ num_mismatches = np.sum(np.frompyfunc(self.not_equal, 2, 1)(x_group, self.data.data[:, inds]))
478
+ varying[i] = num_mismatches > 0
479
+ varying_indices = np.nonzero(varying)[0]
480
+ return varying_indices
481
+ else:
482
+ varying_indices = []
483
+ # go over all nonzero columns in background and evaluation data
484
+ # if both background and evaluation are zero, the column does not vary
485
+ varying_indices = np.unique(np.union1d(self.data.data.nonzero()[1], x.nonzero()[1]))
486
+ remove_unvarying_indices = []
487
+ for i in range(0, len(varying_indices)):
488
+ varying_index = varying_indices[i]
489
+ # now verify the nonzero values do vary
490
+ data_rows = self.data.data[:, [varying_index]]
491
+ nonzero_rows = data_rows.nonzero()[0]
492
+
493
+ if nonzero_rows.size > 0:
494
+ background_data_rows = data_rows[nonzero_rows]
495
+ if scipy.sparse.issparse(background_data_rows):
496
+ background_data_rows = background_data_rows.toarray()
497
+ num_mismatches = np.sum(np.abs(background_data_rows - x[0, varying_index]) > 1e-7)
498
+ # Note: If feature column non-zero but some background zero, can't remove index
499
+ if num_mismatches == 0 and not \
500
+ (np.abs(x[0, [varying_index]][0, 0]) > 1e-7 and len(nonzero_rows) < data_rows.shape[0]):
501
+ remove_unvarying_indices.append(i)
502
+ mask = np.ones(len(varying_indices), dtype=bool)
503
+ mask[remove_unvarying_indices] = False
504
+ varying_indices = varying_indices[mask]
505
+ return varying_indices
506
+
507
+ def allocate(self):
508
+ if scipy.sparse.issparse(self.data.data):
509
+ # We tile the sparse matrix in csr format but convert it to lil
510
+ # for performance when adding samples
511
+ shape = self.data.data.shape
512
+ nnz = self.data.data.nnz
513
+ data_rows, data_cols = shape
514
+ rows = data_rows * self.nsamples
515
+ shape = rows, data_cols
516
+ if nnz == 0:
517
+ self.synth_data = scipy.sparse.csr_matrix(shape, dtype=self.data.data.dtype).tolil()
518
+ else:
519
+ data = self.data.data.data
520
+ indices = self.data.data.indices
521
+ indptr = self.data.data.indptr
522
+ last_indptr_idx = indptr[len(indptr) - 1]
523
+ indptr_wo_last = indptr[:-1]
524
+ new_indptrs = []
525
+ for i in range(0, self.nsamples - 1):
526
+ new_indptrs.append(indptr_wo_last + (i * last_indptr_idx))
527
+ new_indptrs.append(indptr + ((self.nsamples - 1) * last_indptr_idx))
528
+ new_indptr = np.concatenate(new_indptrs)
529
+ new_data = np.tile(data, self.nsamples)
530
+ new_indices = np.tile(indices, self.nsamples)
531
+ self.synth_data = scipy.sparse.csr_matrix((new_data, new_indices, new_indptr), shape=shape).tolil()
532
+ else:
533
+ self.synth_data = np.tile(self.data.data, (self.nsamples, 1))
534
+
535
+ self.maskMatrix = np.zeros((self.nsamples, self.M))
536
+ self.kernelWeights = np.zeros(self.nsamples)
537
+ self.y = np.zeros((self.nsamples * self.N, self.D))
538
+ self.ey = np.zeros((self.nsamples, self.D))
539
+ self.lastMask = np.zeros(self.nsamples)
540
+ self.nsamplesAdded = 0
541
+ self.nsamplesRun = 0
542
+ if self.keep_index:
543
+ self.synth_data_index = np.tile(self.data.index_value, self.nsamples)
544
+
545
+ def addsample(self, x, m, w):
546
+ offset = self.nsamplesAdded * self.N
547
+ if isinstance(self.varyingFeatureGroups, (list,)):
548
+ for j in range(self.M):
549
+ for k in self.varyingFeatureGroups[j]:
550
+ if m[j] == 1.0:
551
+ self.synth_data[offset:offset+self.N, k] = x[0, k]
552
+ else:
553
+ # for non-jagged numpy array we can significantly boost performance
554
+ mask = m == 1.0
555
+ groups = self.varyingFeatureGroups[mask]
556
+ if len(groups.shape) == 2:
557
+ for group in groups:
558
+ self.synth_data[offset:offset+self.N, group] = x[0, group]
559
+ else:
560
+ # further performance optimization in case each group has a single feature
561
+ evaluation_data = x[0, groups]
562
+ # In edge case where background is all dense but evaluation data
563
+ # is all sparse, make evaluation data dense
564
+ if scipy.sparse.issparse(x) and not scipy.sparse.issparse(self.synth_data):
565
+ evaluation_data = evaluation_data.toarray()
566
+ self.synth_data[offset:offset+self.N, groups] = evaluation_data
567
+ self.maskMatrix[self.nsamplesAdded, :] = m
568
+ self.kernelWeights[self.nsamplesAdded] = w
569
+ self.nsamplesAdded += 1
570
+
571
+ def run(self):
572
+ num_to_run = self.nsamplesAdded * self.N - self.nsamplesRun * self.N
573
+ data = self.synth_data[self.nsamplesRun*self.N:self.nsamplesAdded*self.N,:]
574
+ if self.keep_index:
575
+ index = self.synth_data_index[self.nsamplesRun*self.N:self.nsamplesAdded*self.N]
576
+ index = pd.DataFrame(index, columns=[self.data.index_name])
577
+ data = pd.DataFrame(data, columns=self.data.group_names)
578
+ data = pd.concat([index, data], axis=1).set_index(self.data.index_name)
579
+ if self.keep_index_ordered:
580
+ data = data.sort_index()
581
+ modelOut = self.model.f(data)
582
+ if isinstance(modelOut, (pd.DataFrame, pd.Series)):
583
+ modelOut = modelOut.values
584
+ self.y[self.nsamplesRun * self.N:self.nsamplesAdded * self.N, :] = np.reshape(modelOut, (num_to_run, self.D))
585
+
586
+ # find the expected value of each output
587
+ for i in range(self.nsamplesRun, self.nsamplesAdded):
588
+ eyVal = np.zeros(self.D)
589
+ for j in range(0, self.N):
590
+ eyVal += self.y[i * self.N + j, :] * self.data.weights[j]
591
+
592
+ self.ey[i, :] = eyVal
593
+ self.nsamplesRun += 1
594
+
595
+ def solve(self, fraction_evaluated, dim):
596
+ eyAdj = self.linkfv(self.ey[:, dim]) - self.link.f(self.fnull[dim])
597
+ s = np.sum(self.maskMatrix, 1)
598
+
599
+ # do feature selection if we have not well enumerated the space
600
+ nonzero_inds = np.arange(self.M)
601
+ log.debug(f"{fraction_evaluated = }")
602
+ # if self.l1_reg == "auto":
603
+ # warnings.warn(
604
+ # "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
605
+ # "conditional use of AIC to simply \"num_features(10)\"!"
606
+ # )
607
+ if (self.l1_reg not in ["auto", False, 0]) or (fraction_evaluated < 0.2 and self.l1_reg == "auto"):
608
+ w_aug = np.hstack((self.kernelWeights * (self.M - s), self.kernelWeights * s))
609
+ log.info(f"{np.sum(w_aug) = }")
610
+ log.info(f"{np.sum(self.kernelWeights) = }")
611
+ w_sqrt_aug = np.sqrt(w_aug)
612
+ eyAdj_aug = np.hstack((eyAdj, eyAdj - (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))))
613
+ eyAdj_aug *= w_sqrt_aug
614
+ mask_aug = np.transpose(w_sqrt_aug * np.transpose(np.vstack((self.maskMatrix, self.maskMatrix - 1))))
615
+ #var_norms = np.array([np.linalg.norm(mask_aug[:, i]) for i in range(mask_aug.shape[1])])
616
+
617
+ # select a fixed number of top features
618
+ if isinstance(self.l1_reg, str) and self.l1_reg.startswith("num_features("):
619
+ r = int(self.l1_reg[len("num_features("):-1])
620
+ nonzero_inds = lars_path(mask_aug, eyAdj_aug, max_iter=r)[1]
621
+
622
+ # use an adaptive regularization method
623
+ elif self.l1_reg == "auto" or self.l1_reg == "bic" or self.l1_reg == "aic":
624
+ c = "aic" if self.l1_reg == "auto" else self.l1_reg
625
+
626
+ # "Normalize" parameter of LassoLarsIC was deprecated in sklearn version 1.2
627
+ if version.parse(sklearn.__version__) < version.parse("1.2.0"):
628
+ kwg = dict(normalize=False)
629
+ else:
630
+ kwg = {}
631
+ model = make_pipeline(StandardScaler(with_mean=False), LassoLarsIC(criterion=c, **kwg))
632
+ nonzero_inds = np.nonzero(model.fit(mask_aug, eyAdj_aug)[1].coef_)[0]
633
+
634
+ # use a fixed regularization coefficient
635
+ else:
636
+ nonzero_inds = np.nonzero(Lasso(alpha=self.l1_reg).fit(mask_aug, eyAdj_aug).coef_)[0]
637
+
638
+ if len(nonzero_inds) == 0:
639
+ return np.zeros(self.M), np.ones(self.M)
640
+
641
+ # eliminate one variable with the constraint that all features sum to the output
642
+ eyAdj2 = eyAdj - self.maskMatrix[:, nonzero_inds[-1]] * (
643
+ self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))
644
+ etmp = np.transpose(np.transpose(self.maskMatrix[:, nonzero_inds[:-1]]) - self.maskMatrix[:, nonzero_inds[-1]])
645
+ log.debug(f"{etmp[:4, :] = }")
646
+
647
+ # solve a weighted least squares equation to estimate phi
648
+ # least squares:
649
+ # phi = min_w ||W^(1/2) (y - X w)||^2
650
+ # the corresponding normal equation:
651
+ # (X' W X) phi = X' W y
652
+ # with
653
+ # X = etmp
654
+ # W = np.diag(self.kernelWeights)
655
+ # y = eyAdj2
656
+ #
657
+ # We could just rely on sciki-learn
658
+ # from sklearn.linear_model import LinearRegression
659
+ # lm = LinearRegression(fit_intercept=False).fit(etmp, eyAdj2, sample_weight=self.kernelWeights)
660
+ # Under the hood, as of scikit-learn version 1.3, LinearRegression still uses np.linalg.lstsq and
661
+ # there are more performant options. See https://github.com/scikit-learn/scikit-learn/issues/22855.
662
+ y = eyAdj2
663
+ X = etmp
664
+ WX = self.kernelWeights[:, None] * X
665
+ try:
666
+ w = np.linalg.solve(X.T @ WX, WX.T @ y)
667
+ except np.linalg.LinAlgError:
668
+ warnings.warn(
669
+ "Linear regression equation is singular, a least squares solutions is used instead.\n"
670
+ "To avoid this situation and get a regular matrix do one of the following:\n"
671
+ "1) turn up the number of samples,\n"
672
+ "2) turn up the L1 regularization with num_features(N) where N is less than the number of samples,\n"
673
+ "3) group features together to reduce the number of inputs that need to be explained."
674
+ )
675
+ # XWX = np.linalg.pinv(X.T @ WX)
676
+ # w = np.dot(XWX, np.dot(np.transpose(WX), y))
677
+ sqrt_W = np.sqrt(self.kernelWeights)
678
+ w = np.linalg.lstsq(sqrt_W[:, None] * X, sqrt_W * y, rcond=None)[0]
679
+ log.debug(f"{np.sum(w) = }")
680
+ log.debug("self.link(self.fx) - self.link(self.fnull) = {}".format(
681
+ self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim])))
682
+ log.debug(f"self.fx = {self.fx[dim]}")
683
+ log.debug(f"self.link(self.fx) = {self.link.f(self.fx[dim])}")
684
+ log.debug(f"self.fnull = {self.fnull[dim]}")
685
+ log.debug(f"self.link(self.fnull) = {self.link.f(self.fnull[dim])}")
686
+ phi = np.zeros(self.M)
687
+ phi[nonzero_inds[:-1]] = w
688
+ phi[nonzero_inds[-1]] = (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim])) - sum(w)
689
+ log.info(f"{phi = }")
690
+
691
+ # clean up any rounding errors
692
+ for i in range(self.M):
693
+ if np.abs(phi[i]) < 1e-10:
694
+ phi[i] = 0
695
+
696
+ return phi, np.ones(len(phi))
lib/shap/explainers/_linear.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from scipy.sparse import issparse
6
+ from tqdm.auto import tqdm
7
+
8
+ from .. import links, maskers
9
+ from ..utils._exceptions import (
10
+ DimensionError,
11
+ InvalidFeaturePerturbationError,
12
+ InvalidModelError,
13
+ )
14
+ from ._explainer import Explainer
15
+
16
+
17
+ class LinearExplainer(Explainer):
18
+ """ Computes SHAP values for a linear model, optionally accounting for inter-feature correlations.
19
+
20
+ This computes the SHAP values for a linear model and can account for the correlations among
21
+ the input features. Assuming features are independent leads to interventional SHAP values which
22
+ for a linear model are coef[i] * (x[i] - X.mean(0)[i]) for the ith feature. If instead we account
23
+ for correlations then we prevent any problems arising from collinearity and share credit among
24
+ correlated features. Accounting for correlations can be computationally challenging, but
25
+ LinearExplainer uses sampling to estimate a transform that can then be applied to explain
26
+ any prediction of the model.
27
+
28
+ Parameters
29
+ ----------
30
+ model : (coef, intercept) or sklearn.linear_model.*
31
+ User supplied linear model either as either a parameter pair or sklearn object.
32
+
33
+ data : (mean, cov), numpy.array, pandas.DataFrame, iml.DenseData or scipy.csr_matrix
34
+ The background dataset to use for computing conditional expectations. Note that only the
35
+ mean and covariance of the dataset are used. This means passing a raw data matrix is just
36
+ a convenient alternative to passing the mean and covariance directly.
37
+ nsamples : int
38
+ Number of samples to use when estimating the transformation matrix used to account for
39
+ feature correlations.
40
+ feature_perturbation : "interventional" (default) or "correlation_dependent"
41
+ There are two ways we might want to compute SHAP values, either the full conditional SHAP
42
+ values or the interventional SHAP values. For interventional SHAP values we break any
43
+ dependence structure between features in the model and so uncover how the model would behave if we
44
+ intervened and changed some of the inputs. For the full conditional SHAP values we respect
45
+ the correlations among the input features, so if the model depends on one input but that
46
+ input is correlated with another input, then both get some credit for the model's behavior. The
47
+ interventional option stays "true to the model" meaning it will only give credit to features that are
48
+ actually used by the model, while the correlation option stays "true to the data" in the sense that
49
+ it only considers how the model would behave when respecting the correlations in the input data.
50
+ For sparse case only interventional option is supported.
51
+
52
+ Examples
53
+ --------
54
+ See `Linear explainer examples <https://shap.readthedocs.io/en/latest/api_examples/explainers/LinearExplainer.html>`_
55
+ """
56
+
57
+ def __init__(self, model, masker, link=links.identity, nsamples=1000, feature_perturbation=None, **kwargs):
58
+ if 'feature_dependence' in kwargs:
59
+ warnings.warn('The option feature_dependence has been renamed to feature_perturbation!')
60
+ feature_perturbation = kwargs["feature_dependence"]
61
+ if feature_perturbation == "independent":
62
+ warnings.warn('The option feature_perturbation="independent" is has been renamed to feature_perturbation="interventional"!')
63
+ feature_perturbation = "interventional"
64
+ elif feature_perturbation == "correlation":
65
+ warnings.warn('The option feature_perturbation="correlation" is has been renamed to feature_perturbation="correlation_dependent"!')
66
+ feature_perturbation = "correlation_dependent"
67
+ if feature_perturbation is not None:
68
+ warnings.warn("The feature_perturbation option is now deprecated in favor of using the appropriate masker (maskers.Independent, or maskers.Impute)")
69
+ else:
70
+ feature_perturbation = "interventional"
71
+ self.feature_perturbation = feature_perturbation
72
+
73
+ # wrap the incoming masker object as a shap.Masker object before calling
74
+ # parent class constructor, which does the same but without respecting
75
+ # the user-provided feature_perturbation choice
76
+ if isinstance(masker, pd.DataFrame) or ((isinstance(masker, np.ndarray) or issparse(masker)) and len(masker.shape) == 2):
77
+ if self.feature_perturbation == "correlation_dependent":
78
+ masker = maskers.Impute(masker)
79
+ else:
80
+ masker = maskers.Independent(masker)
81
+ elif issubclass(type(masker), tuple) and len(masker) == 2:
82
+ if self.feature_perturbation == "correlation_dependent":
83
+ masker = maskers.Impute({"mean": masker[0], "cov": masker[1]}, method="linear")
84
+ else:
85
+ masker = maskers.Independent({"mean": masker[0], "cov": masker[1]})
86
+
87
+ super().__init__(model, masker, link=link, **kwargs)
88
+
89
+ self.nsamples = nsamples
90
+
91
+
92
+ # extract what we need from the given model object
93
+ self.coef, self.intercept = LinearExplainer._parse_model(model)
94
+
95
+ # extract the data
96
+ if issubclass(type(self.masker), (maskers.Independent, maskers.Partition)):
97
+ self.feature_perturbation = "interventional"
98
+ elif issubclass(type(self.masker), maskers.Impute):
99
+ self.feature_perturbation = "correlation_dependent"
100
+ else:
101
+ raise NotImplementedError("The Linear explainer only supports the Independent, Partition, and Impute maskers right now!")
102
+ data = getattr(self.masker, "data", None)
103
+
104
+ # convert DataFrame's to numpy arrays
105
+ if isinstance(data, pd.DataFrame):
106
+ data = data.values
107
+
108
+ # get the mean and covariance of the model
109
+ if getattr(self.masker, "mean", None) is not None:
110
+ self.mean = self.masker.mean
111
+ self.cov = self.masker.cov
112
+ elif isinstance(data, dict) and len(data) == 2:
113
+ self.mean = data["mean"]
114
+ if isinstance(self.mean, pd.Series):
115
+ self.mean = self.mean.values
116
+
117
+ self.cov = data["cov"]
118
+ if isinstance(self.cov, pd.DataFrame):
119
+ self.cov = self.cov.values
120
+ elif isinstance(data, tuple) and len(data) == 2:
121
+ self.mean = data[0]
122
+ if isinstance(self.mean, pd.Series):
123
+ self.mean = self.mean.values
124
+
125
+ self.cov = data[1]
126
+ if isinstance(self.cov, pd.DataFrame):
127
+ self.cov = self.cov.values
128
+ elif data is None:
129
+ raise ValueError("A background data distribution must be provided!")
130
+ else:
131
+ if issparse(data):
132
+ self.mean = np.array(np.mean(data, 0))[0]
133
+ if self.feature_perturbation != "interventional":
134
+ raise NotImplementedError("Only feature_perturbation = 'interventional' is supported for sparse data")
135
+ else:
136
+ self.mean = np.array(np.mean(data, 0)).flatten() # assumes it is an array
137
+ if self.feature_perturbation == "correlation_dependent":
138
+ self.cov = np.cov(data, rowvar=False)
139
+ #print(self.coef, self.mean.flatten(), self.intercept)
140
+ # Note: mean can be numpy.matrixlib.defmatrix.matrix or numpy.matrix type depending on numpy version
141
+ if issparse(self.mean) or str(type(self.mean)).endswith("matrix'>"):
142
+ # accept both sparse and dense coef
143
+ # if not issparse(self.coef):
144
+ # self.coef = np.asmatrix(self.coef)
145
+ self.expected_value = np.dot(self.coef, self.mean) + self.intercept
146
+
147
+ # unwrap the matrix form
148
+ if len(self.expected_value) == 1:
149
+ self.expected_value = self.expected_value[0,0]
150
+ else:
151
+ self.expected_value = np.array(self.expected_value)[0]
152
+ else:
153
+ self.expected_value = np.dot(self.coef, self.mean) + self.intercept
154
+
155
+ self.M = len(self.mean)
156
+
157
+ # if needed, estimate the transform matrices
158
+ if self.feature_perturbation == "correlation_dependent":
159
+ self.valid_inds = np.where(np.diag(self.cov) > 1e-8)[0]
160
+ self.mean = self.mean[self.valid_inds]
161
+ self.cov = self.cov[:,self.valid_inds][self.valid_inds,:]
162
+ self.coef = self.coef[self.valid_inds]
163
+
164
+ # group perfectly redundant variables together
165
+ self.avg_proj,sum_proj = duplicate_components(self.cov)
166
+ self.cov = np.matmul(np.matmul(self.avg_proj, self.cov), self.avg_proj.T)
167
+ self.mean = np.matmul(self.avg_proj, self.mean)
168
+ self.coef = np.matmul(sum_proj, self.coef)
169
+
170
+ # if we still have some multi-collinearity present then we just add regularization...
171
+ e,_ = np.linalg.eig(self.cov)
172
+ if e.min() < 1e-7:
173
+ self.cov = self.cov + np.eye(self.cov.shape[0]) * 1e-6
174
+
175
+ mean_transform, x_transform = self._estimate_transforms(nsamples)
176
+ self.mean_transformed = np.matmul(mean_transform, self.mean)
177
+ self.x_transform = x_transform
178
+ elif self.feature_perturbation == "interventional":
179
+ if nsamples != 1000:
180
+ warnings.warn("Setting nsamples has no effect when feature_perturbation = 'interventional'!")
181
+ else:
182
+ raise InvalidFeaturePerturbationError("Unknown type of feature_perturbation provided: " + self.feature_perturbation)
183
+
184
+ def _estimate_transforms(self, nsamples):
185
+ """ Uses block matrix inversion identities to quickly estimate transforms.
186
+
187
+ After a bit of matrix math we can isolate a transform matrix (# features x # features)
188
+ that is independent of any sample we are explaining. It is the result of averaging over
189
+ all feature permutations, but we just use a fixed number of samples to estimate the value.
190
+
191
+ TODO: Do a brute force enumeration when # feature subsets is less than nsamples. This could
192
+ happen through a recursive method that uses the same block matrix inversion as below.
193
+ """
194
+ M = len(self.coef)
195
+
196
+ mean_transform = np.zeros((M,M))
197
+ x_transform = np.zeros((M,M))
198
+ inds = np.arange(M, dtype=int)
199
+ for _ in tqdm(range(nsamples), "Estimating transforms"):
200
+ np.random.shuffle(inds)
201
+ cov_inv_SiSi = np.zeros((0,0))
202
+ cov_Si = np.zeros((M,0))
203
+ for j in range(M):
204
+ i = inds[j]
205
+
206
+ # use the last Si as the new S
207
+ cov_S = cov_Si
208
+ cov_inv_SS = cov_inv_SiSi
209
+
210
+ # get the new cov_Si
211
+ cov_Si = self.cov[:,inds[:j+1]]
212
+
213
+ # compute the new cov_inv_SiSi from cov_inv_SS
214
+ d = cov_Si[i,:-1].T
215
+ t = np.matmul(cov_inv_SS, d)
216
+ Z = self.cov[i, i]
217
+ u = Z - np.matmul(t.T, d)
218
+ cov_inv_SiSi = np.zeros((j+1, j+1))
219
+ if j > 0:
220
+ cov_inv_SiSi[:-1, :-1] = cov_inv_SS + np.outer(t, t) / u
221
+ cov_inv_SiSi[:-1, -1] = cov_inv_SiSi[-1,:-1] = -t / u
222
+ cov_inv_SiSi[-1, -1] = 1 / u
223
+
224
+ # + coef @ (Q(bar(Sui)) - Q(bar(S)))
225
+ mean_transform[i, i] += self.coef[i]
226
+
227
+ # + coef @ R(Sui)
228
+ coef_R_Si = np.matmul(self.coef[inds[j+1:]], np.matmul(cov_Si, cov_inv_SiSi)[inds[j+1:]])
229
+ mean_transform[i, inds[:j+1]] += coef_R_Si
230
+
231
+ # - coef @ R(S)
232
+ coef_R_S = np.matmul(self.coef[inds[j:]], np.matmul(cov_S, cov_inv_SS)[inds[j:]])
233
+ mean_transform[i, inds[:j]] -= coef_R_S
234
+
235
+ # - coef @ (Q(Sui) - Q(S))
236
+ x_transform[i, i] += self.coef[i]
237
+
238
+ # + coef @ R(Sui)
239
+ x_transform[i, inds[:j+1]] += coef_R_Si
240
+
241
+ # - coef @ R(S)
242
+ x_transform[i, inds[:j]] -= coef_R_S
243
+
244
+ mean_transform /= nsamples
245
+ x_transform /= nsamples
246
+ return mean_transform, x_transform
247
+
248
+ @staticmethod
249
+ def _parse_model(model):
250
+ """ Attempt to pull out the coefficients and intercept from the given model object.
251
+ """
252
+ # raw coefficients
253
+ if type(model) == tuple and len(model) == 2:
254
+ coef = model[0]
255
+ intercept = model[1]
256
+
257
+ # sklearn style model
258
+ elif hasattr(model, "coef_") and hasattr(model, "intercept_"):
259
+ # work around for multi-class with a single class
260
+ if len(model.coef_.shape) > 1 and model.coef_.shape[0] == 1:
261
+ coef = model.coef_[0]
262
+ try:
263
+ intercept = model.intercept_[0]
264
+ except TypeError:
265
+ intercept = model.intercept_
266
+ else:
267
+ coef = model.coef_
268
+ intercept = model.intercept_
269
+ else:
270
+ raise InvalidModelError("An unknown model type was passed: " + str(type(model)))
271
+
272
+ return coef,intercept
273
+
274
+ @staticmethod
275
+ def supports_model_with_masker(model, masker):
276
+ """ Determines if we can parse the given model.
277
+ """
278
+
279
+ if not isinstance(masker, (maskers.Independent, maskers.Partition, maskers.Impute)):
280
+ return False
281
+
282
+ try:
283
+ LinearExplainer._parse_model(model)
284
+ except Exception:
285
+ return False
286
+ return True
287
+
288
+ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent):
289
+ """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
290
+ """
291
+
292
+ assert len(row_args) == 1, "Only single-argument functions are supported by the Linear explainer!"
293
+
294
+ X = row_args[0]
295
+ if len(X.shape) == 1:
296
+ X = X.reshape(1, -1)
297
+
298
+ # convert dataframes
299
+ if isinstance(X, (pd.Series, pd.DataFrame)):
300
+ X = X.values
301
+
302
+ if len(X.shape) not in (1, 2):
303
+ raise DimensionError("Instance must have 1 or 2 dimensions! Not: %s" %len(X.shape))
304
+
305
+ if self.feature_perturbation == "correlation_dependent":
306
+ if issparse(X):
307
+ raise InvalidFeaturePerturbationError("Only feature_perturbation = 'interventional' is supported for sparse data")
308
+ phi = np.matmul(np.matmul(X[:,self.valid_inds], self.avg_proj.T), self.x_transform.T) - self.mean_transformed
309
+ phi = np.matmul(phi, self.avg_proj)
310
+
311
+ full_phi = np.zeros((phi.shape[0], self.M))
312
+ full_phi[:,self.valid_inds] = phi
313
+ phi = full_phi
314
+
315
+ elif self.feature_perturbation == "interventional":
316
+ if issparse(X):
317
+ phi = np.array(np.multiply(X - self.mean, self.coef))
318
+
319
+ # if len(self.coef.shape) == 1:
320
+ # return np.array(np.multiply(X - self.mean, self.coef))
321
+ # else:
322
+ # return [np.array(np.multiply(X - self.mean, self.coef[i])) for i in range(self.coef.shape[0])]
323
+ else:
324
+ phi = np.array(X - self.mean) * self.coef
325
+ # if len(self.coef.shape) == 1:
326
+ # phi = np.array(X - self.mean) * self.coef
327
+ # return np.array(X - self.mean) * self.coef
328
+ # else:
329
+ # return [np.array(X - self.mean) * self.coef[i] for i in range(self.coef.shape[0])]
330
+
331
+ return {
332
+ "values": phi.T,
333
+ "expected_values": self.expected_value,
334
+ "mask_shapes": (X.shape[1:],),
335
+ "main_effects": phi.T,
336
+ "clustering": None
337
+ }
338
+
339
+
340
+ def shap_values(self, X):
341
+ """ Estimate the SHAP values for a set of samples.
342
+
343
+ Parameters
344
+ ----------
345
+ X : numpy.array, pandas.DataFrame or scipy.csr_matrix
346
+ A matrix of samples (# samples x # features) on which to explain the model's output.
347
+
348
+ Returns
349
+ -------
350
+ array or list
351
+ For models with a single output this returns a matrix of SHAP values
352
+ (# samples x # features). Each row sums to the difference between the model output for that
353
+ sample and the expected value of the model output (which is stored as expected_value
354
+ attribute of the explainer).
355
+ """
356
+
357
+ # convert dataframes
358
+ if isinstance(X, (pd.Series, pd.DataFrame)):
359
+ X = X.values
360
+
361
+ # assert isinstance(X, np.ndarray), "Unknown instance type: " + str(type(X))
362
+ if len(X.shape) not in (1, 2):
363
+ raise DimensionError("Instance must have 1 or 2 dimensions! Not: %s" % len(X.shape))
364
+
365
+ if self.feature_perturbation == "correlation_dependent":
366
+ if issparse(X):
367
+ raise InvalidFeaturePerturbationError("Only feature_perturbation = 'interventional' is supported for sparse data")
368
+ phi = np.matmul(np.matmul(X[:,self.valid_inds], self.avg_proj.T), self.x_transform.T) - self.mean_transformed
369
+ phi = np.matmul(phi, self.avg_proj)
370
+
371
+ full_phi = np.zeros((phi.shape[0], self.M))
372
+ full_phi[:,self.valid_inds] = phi
373
+
374
+ return full_phi
375
+
376
+ elif self.feature_perturbation == "interventional":
377
+ if issparse(X):
378
+ if len(self.coef.shape) == 1:
379
+ return np.array(np.multiply(X - self.mean, self.coef))
380
+ else:
381
+ return [np.array(np.multiply(X - self.mean, self.coef[i])) for i in range(self.coef.shape[0])]
382
+ else:
383
+ if len(self.coef.shape) == 1:
384
+ return np.array(X - self.mean) * self.coef
385
+ else:
386
+ return [np.array(X - self.mean) * self.coef[i] for i in range(self.coef.shape[0])]
387
+
388
+ def duplicate_components(C):
389
+ D = np.diag(1/np.sqrt(np.diag(C)))
390
+ C = np.matmul(np.matmul(D, C), D)
391
+ components = -np.ones(C.shape[0], dtype=int)
392
+ count = -1
393
+ for i in range(C.shape[0]):
394
+ found_group = False
395
+ for j in range(C.shape[0]):
396
+ if components[j] < 0 and np.abs(2*C[i,j] - C[i,i] - C[j,j]) < 1e-8:
397
+ if not found_group:
398
+ count += 1
399
+ found_group = True
400
+ components[j] = count
401
+
402
+ proj = np.zeros((len(np.unique(components)), C.shape[0]))
403
+ proj[0, 0] = 1
404
+ for i in range(1,C.shape[0]):
405
+ proj[components[i], i] = 1
406
+ return (proj.T / proj.sum(1)).T, proj
lib/shap/explainers/_partition.py ADDED
@@ -0,0 +1,681 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue
2
+ import time
3
+
4
+ import numpy as np
5
+ from numba import njit
6
+ from tqdm.auto import tqdm
7
+
8
+ from .. import Explanation, links
9
+ from ..models import Model
10
+ from ..utils import MaskedModel, OpChain, make_masks, safe_isinstance
11
+ from ._explainer import Explainer
12
+
13
+
14
+ class PartitionExplainer(Explainer):
15
+ """Uses the Partition SHAP method to explain the output of any function.
16
+
17
+ Partition SHAP computes Shapley values recursively through a hierarchy of features, this
18
+ hierarchy defines feature coalitions and results in the Owen values from game theory.
19
+
20
+ The PartitionExplainer has two particularly nice properties:
21
+
22
+ 1) PartitionExplainer is model-agnostic but when using a balanced partition tree only has
23
+ quadratic exact runtime (in term of the number of input features). This is in contrast to the
24
+ exponential exact runtime of KernelExplainer or SamplingExplainer.
25
+ 2) PartitionExplainer always assigns to groups of correlated features the credit that set of features
26
+ would have had if treated as a group. This means if the hierarchical clustering given to
27
+ PartitionExplainer groups correlated features together, then feature correlations are
28
+ "accounted for" in the sense that the total credit assigned to a group of tightly dependent features
29
+ does not depend on how they behave if their correlation structure was broken during the explanation's
30
+ perturbation process.
31
+ Note that for linear models the Owen values that PartitionExplainer returns are the same as the standard
32
+ non-hierarchical Shapley values.
33
+ """
34
+
35
+ def __init__(self, model, masker, *, output_names=None, link=links.identity, linearize_link=True,
36
+ feature_names=None, **call_args):
37
+ """Build a PartitionExplainer for the given model with the given masker.
38
+
39
+ Parameters
40
+ ----------
41
+ model : function
42
+ User supplied function that takes a matrix of samples (# samples x # features) and
43
+ computes the output of the model for those samples.
44
+
45
+ masker : function or numpy.array or pandas.DataFrame or tokenizer
46
+ The function used to "mask" out hidden features of the form `masker(mask, x)`. It takes a
47
+ single input sample and a binary mask and returns a matrix of masked samples. These
48
+ masked samples will then be evaluated using the model function and the outputs averaged.
49
+ As a shortcut for the standard masking using by SHAP you can pass a background data matrix
50
+ instead of a function and that matrix will be used for masking. Domain specific masking
51
+ functions are available in shap such as shap.maksers.Image for images and shap.maskers.Text
52
+ for text.
53
+
54
+ partition_tree : None or function or numpy.array
55
+ A hierarchical clustering of the input features represented by a matrix that follows the format
56
+ used by scipy.cluster.hierarchy (see the notebooks_html/partition_explainer directory an example).
57
+ If this is a function then the function produces a clustering matrix when given a single input
58
+ example. If you are using a standard SHAP masker object then you can pass masker.clustering
59
+ to use that masker's built-in clustering of the features, or if partition_tree is None then
60
+ masker.clustering will be used by default.
61
+
62
+ Examples
63
+ --------
64
+ See `Partition explainer examples <https://shap.readthedocs.io/en/latest/api_examples/explainers/PartitionExplainer.html>`_
65
+ """
66
+
67
+ super().__init__(model, masker, link=link, linearize_link=linearize_link, algorithm="partition", \
68
+ output_names = output_names, feature_names=feature_names)
69
+
70
+ # convert dataframes
71
+ # if isinstance(masker, pd.DataFrame):
72
+ # masker = TabularMasker(masker)
73
+ # elif isinstance(masker, np.ndarray) and len(masker.shape) == 2:
74
+ # masker = TabularMasker(masker)
75
+ # elif safe_isinstance(masker, "transformers.PreTrainedTokenizer"):
76
+ # masker = TextMasker(masker)
77
+ # self.masker = masker
78
+
79
+ # TODO: maybe? if we have a tabular masker then we build a PermutationExplainer that we
80
+ # will use for sampling
81
+ self.input_shape = masker.shape[1:] if hasattr(masker, "shape") and not callable(masker.shape) else None
82
+ # self.output_names = output_names
83
+ if not safe_isinstance(self.model, "shap.models.Model"):
84
+ self.model = Model(self.model)#lambda *args: np.array(model(*args))
85
+ self.expected_value = None
86
+ self._curr_base_value = None
87
+ if getattr(self.masker, "clustering", None) is None:
88
+ raise ValueError("The passed masker must have a .clustering attribute defined! Try shap.maskers.Partition(data) for example.")
89
+ # if partition_tree is None:
90
+ # if not hasattr(masker, "partition_tree"):
91
+ # raise ValueError("The passed masker does not have masker.clustering, so the partition_tree must be passed!")
92
+ # self.partition_tree = masker.clustering
93
+ # else:
94
+ # self.partition_tree = partition_tree
95
+
96
+ # handle higher dimensional tensor inputs
97
+ if self.input_shape is not None and len(self.input_shape) > 1:
98
+ self._reshaped_model = lambda x: self.model(x.reshape(x.shape[0], *self.input_shape))
99
+ else:
100
+ self._reshaped_model = self.model
101
+
102
+ # if we don't have a dynamic clustering algorithm then can precowe mpute
103
+ # a lot of information
104
+ if not callable(self.masker.clustering):
105
+ self._clustering = self.masker.clustering
106
+ self._mask_matrix = make_masks(self._clustering)
107
+
108
+ # if we have gotten default arguments for the call function we need to wrap ourselves in a new class that
109
+ # has a call function with those new default arguments
110
+ if len(call_args) > 0:
111
+ class PartitionExplainer(self.__class__):
112
+ # this signature should match the __call__ signature of the class defined below
113
+ def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, error_bounds=False, batch_size="auto",
114
+ outputs=None, silent=False):
115
+ return super().__call__(
116
+ *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, error_bounds=error_bounds,
117
+ batch_size=batch_size, outputs=outputs, silent=silent
118
+ )
119
+ PartitionExplainer.__call__.__doc__ = self.__class__.__call__.__doc__
120
+ self.__class__ = PartitionExplainer
121
+ for k, v in call_args.items():
122
+ self.__call__.__kwdefaults__[k] = v
123
+
124
+ # note that changes to this function signature should be copied to the default call argument wrapper above
125
+ def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, error_bounds=False, batch_size="auto",
126
+ outputs=None, silent=False):
127
+ """ Explain the output of the model on the given arguments.
128
+ """
129
+ return super().__call__(
130
+ *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size,
131
+ outputs=outputs, silent=silent
132
+ )
133
+
134
+ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent, fixed_context = "auto"):
135
+ """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
136
+ """
137
+
138
+ if fixed_context == "auto":
139
+ # if isinstance(self.masker, maskers.Text):
140
+ # fixed_context = 1 # we err on the side of speed for text models
141
+ # else:
142
+ fixed_context = None
143
+ elif fixed_context not in [0, 1, None]:
144
+ raise ValueError("Unknown fixed_context value passed (must be 0, 1 or None): %s" %fixed_context)
145
+
146
+ # build a masked version of the model for the current input sample
147
+ fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)
148
+
149
+ # make sure we have the base value and current value outputs
150
+ M = len(fm)
151
+ m00 = np.zeros(M, dtype=bool)
152
+ # if not fixed background or no base value assigned then compute base value for a row
153
+ if self._curr_base_value is None or not getattr(self.masker, "fixed_background", False):
154
+ self._curr_base_value = fm(m00.reshape(1, -1), zero_index=0)[0] # the zero index param tells the masked model what the baseline is
155
+ f11 = fm(~m00.reshape(1, -1))[0]
156
+
157
+ if callable(self.masker.clustering):
158
+ self._clustering = self.masker.clustering(*row_args)
159
+ self._mask_matrix = make_masks(self._clustering)
160
+
161
+ if hasattr(self._curr_base_value, 'shape') and len(self._curr_base_value.shape) > 0:
162
+ if outputs is None:
163
+ outputs = np.arange(len(self._curr_base_value))
164
+ elif isinstance(outputs, OpChain):
165
+ outputs = outputs.apply(Explanation(f11)).values
166
+
167
+ out_shape = (2*self._clustering.shape[0]+1, len(outputs))
168
+ else:
169
+ out_shape = (2*self._clustering.shape[0]+1,)
170
+
171
+ if max_evals == "auto":
172
+ max_evals = 500
173
+
174
+ self.values = np.zeros(out_shape)
175
+ self.dvalues = np.zeros(out_shape)
176
+
177
+ self.owen(fm, self._curr_base_value, f11, max_evals - 2, outputs, fixed_context, batch_size, silent)
178
+
179
+ # if False:
180
+ # if self.multi_output:
181
+ # return [self.dvalues[:,i] for i in range(self.dvalues.shape[1])], oinds
182
+ # else:
183
+ # return self.dvalues.copy(), oinds
184
+ # else:
185
+ # drop the interaction terms down onto self.values
186
+ self.values[:] = self.dvalues
187
+
188
+ lower_credit(len(self.dvalues) - 1, 0, M, self.values, self._clustering)
189
+
190
+ return {
191
+ "values": self.values[:M].copy(),
192
+ "expected_values": self._curr_base_value if outputs is None else self._curr_base_value[outputs],
193
+ "mask_shapes": [s + out_shape[1:] for s in fm.mask_shapes],
194
+ "main_effects": None,
195
+ "hierarchical_values": self.dvalues.copy(),
196
+ "clustering": self._clustering,
197
+ "output_indices": outputs,
198
+ "output_names": getattr(self.model, "output_names", None)
199
+ }
200
+
201
+ def __str__(self):
202
+ return "shap.explainers.PartitionExplainer()"
203
+
204
+ def owen(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent):
205
+ """ Compute a nested set of recursive Owen values based on an ordering recursion.
206
+ """
207
+
208
+ #f = self._reshaped_model
209
+ #r = self.masker
210
+ #masks = np.zeros(2*len(inds)+1, dtype=int)
211
+ M = len(fm)
212
+ m00 = np.zeros(M, dtype=bool)
213
+ #f00 = fm(m00.reshape(1,-1))[0]
214
+ base_value = f00
215
+ #f11 = fm(~m00.reshape(1,-1))[0]
216
+ #f11 = self._reshaped_model(r(~m00, x)).mean(0)
217
+ ind = len(self.dvalues)-1
218
+
219
+ # make sure output_indexes is a list of indexes
220
+ if output_indexes is not None:
221
+ # assert self.multi_output, "output_indexes is only valid for multi-output models!"
222
+ # inds = output_indexes.apply(f11, 0)
223
+ # out_len = output_indexes_len(output_indexes)
224
+ # if output_indexes.startswith("max("):
225
+ # output_indexes = np.argsort(-f11)[:out_len]
226
+ # elif output_indexes.startswith("min("):
227
+ # output_indexes = np.argsort(f11)[:out_len]
228
+ # elif output_indexes.startswith("max(abs("):
229
+ # output_indexes = np.argsort(np.abs(f11))[:out_len]
230
+
231
+ f00 = f00[output_indexes]
232
+ f11 = f11[output_indexes]
233
+
234
+ q = queue.PriorityQueue()
235
+ q.put((0, 0, (m00, f00, f11, ind, 1.0)))
236
+ eval_count = 0
237
+ total_evals = min(max_evals, (M-1)*M) # TODO: (M-1)*M is only right for balanced clusterings, but this is just for plotting progress...
238
+ pbar = None
239
+ start_time = time.time()
240
+ while not q.empty():
241
+
242
+ # if we passed our execution limit then leave everything else on the internal nodes
243
+ if eval_count >= max_evals:
244
+ while not q.empty():
245
+ m00, f00, f11, ind, weight = q.get()[2]
246
+ self.dvalues[ind] += (f11 - f00) * weight
247
+ break
248
+
249
+ # create a batch of work to do
250
+ batch_args = []
251
+ batch_masks = []
252
+ while not q.empty() and len(batch_masks) < batch_size and eval_count + len(batch_masks) < max_evals:
253
+
254
+ # get our next set of arguments
255
+ m00, f00, f11, ind, weight = q.get()[2]
256
+
257
+ # get the left and right children of this cluster
258
+ lind = int(self._clustering[ind-M, 0]) if ind >= M else -1
259
+ rind = int(self._clustering[ind-M, 1]) if ind >= M else -1
260
+
261
+ # get the distance of this cluster's children
262
+ if ind < M:
263
+ distance = -1
264
+ else:
265
+ if self._clustering.shape[1] >= 3:
266
+ distance = self._clustering[ind-M, 2]
267
+ else:
268
+ distance = 1
269
+
270
+ # check if we are a leaf node (or other negative distance cluster) and so should terminate our decent
271
+ if distance < 0:
272
+ self.dvalues[ind] += (f11 - f00) * weight
273
+ continue
274
+
275
+ # build the masks
276
+ m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
277
+ m10[:] += self._mask_matrix[lind, :]
278
+ m01 = m00.copy()
279
+ m01[:] += self._mask_matrix[rind, :]
280
+
281
+ batch_args.append((m00, m10, m01, f00, f11, ind, lind, rind, weight))
282
+ batch_masks.append(m10)
283
+ batch_masks.append(m01)
284
+
285
+ batch_masks = np.array(batch_masks)
286
+
287
+ # run the batch
288
+ if len(batch_args) > 0:
289
+ fout = fm(batch_masks)
290
+ if output_indexes is not None:
291
+ fout = fout[:,output_indexes]
292
+
293
+ eval_count += len(batch_masks)
294
+
295
+ if pbar is None and time.time() - start_time > 5:
296
+ pbar = tqdm(total=total_evals, disable=silent, leave=False)
297
+ pbar.update(eval_count)
298
+ if pbar is not None:
299
+ pbar.update(len(batch_masks))
300
+
301
+ # use the results of the batch to add new nodes
302
+ for i in range(len(batch_args)):
303
+
304
+ m00, m10, m01, f00, f11, ind, lind, rind, weight = batch_args[i]
305
+
306
+ # get the evaluated model output on the two new masked inputs
307
+ f10 = fout[2*i]
308
+ f01 = fout[2*i+1]
309
+
310
+ new_weight = weight
311
+ if fixed_context is None:
312
+ new_weight /= 2
313
+ elif fixed_context == 0:
314
+ self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
315
+ elif fixed_context == 1:
316
+ self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
317
+
318
+ if fixed_context is None or fixed_context == 0:
319
+ # recurse on the left node with zero context
320
+ args = (m00, f00, f10, lind, new_weight)
321
+ q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
322
+
323
+ # recurse on the right node with zero context
324
+ args = (m00, f00, f01, rind, new_weight)
325
+ q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
326
+
327
+ if fixed_context is None or fixed_context == 1:
328
+ # recurse on the left node with one context
329
+ args = (m01, f01, f11, lind, new_weight)
330
+ q.put((-np.max(np.abs(f11 - f01)) * new_weight, np.random.randn(), args))
331
+
332
+ # recurse on the right node with one context
333
+ args = (m10, f10, f11, rind, new_weight)
334
+ q.put((-np.max(np.abs(f11 - f10)) * new_weight, np.random.randn(), args))
335
+
336
+ if pbar is not None:
337
+ pbar.close()
338
+
339
+ self.last_eval_count = eval_count
340
+
341
+ return output_indexes, base_value
342
+
343
+ def owen3(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent):
344
+ """ Compute a nested set of recursive Owen values based on an ordering recursion.
345
+ """
346
+
347
+ #f = self._reshaped_model
348
+ #r = self.masker
349
+ #masks = np.zeros(2*len(inds)+1, dtype=int)
350
+ M = len(fm)
351
+ m00 = np.zeros(M, dtype=bool)
352
+ #f00 = fm(m00.reshape(1,-1))[0]
353
+ base_value = f00
354
+ #f11 = fm(~m00.reshape(1,-1))[0]
355
+ #f11 = self._reshaped_model(r(~m00, x)).mean(0)
356
+ ind = len(self.dvalues)-1
357
+
358
+ # make sure output_indexes is a list of indexes
359
+ if output_indexes is not None:
360
+ # assert self.multi_output, "output_indexes is only valid for multi-output models!"
361
+ # inds = output_indexes.apply(f11, 0)
362
+ # out_len = output_indexes_len(output_indexes)
363
+ # if output_indexes.startswith("max("):
364
+ # output_indexes = np.argsort(-f11)[:out_len]
365
+ # elif output_indexes.startswith("min("):
366
+ # output_indexes = np.argsort(f11)[:out_len]
367
+ # elif output_indexes.startswith("max(abs("):
368
+ # output_indexes = np.argsort(np.abs(f11))[:out_len]
369
+
370
+ f00 = f00[output_indexes]
371
+ f11 = f11[output_indexes]
372
+
373
+ # our starting plan is to evaluate all the nodes with a fixed_context
374
+ evals_planned = M
375
+
376
+ q = queue.PriorityQueue()
377
+ q.put((0, 0, (m00, f00, f11, ind, 1.0, fixed_context))) # (m00, f00, f11, tree_index, weight)
378
+ eval_count = 0
379
+ total_evals = min(max_evals, (M-1)*M) # TODO: (M-1)*M is only right for balanced clusterings, but this is just for plotting progress...
380
+ pbar = None
381
+ start_time = time.time()
382
+ while not q.empty():
383
+
384
+ # if we passed our execution limit then leave everything else on the internal nodes
385
+ if eval_count >= max_evals:
386
+ while not q.empty():
387
+ m00, f00, f11, ind, weight, _ = q.get()[2]
388
+ self.dvalues[ind] += (f11 - f00) * weight
389
+ break
390
+
391
+ # create a batch of work to do
392
+ batch_args = []
393
+ batch_masks = []
394
+ while not q.empty() and len(batch_masks) < batch_size and eval_count < max_evals:
395
+
396
+ # get our next set of arguments
397
+ m00, f00, f11, ind, weight, context = q.get()[2]
398
+
399
+ # get the left and right children of this cluster
400
+ lind = int(self._clustering[ind-M, 0]) if ind >= M else -1
401
+ rind = int(self._clustering[ind-M, 1]) if ind >= M else -1
402
+
403
+ # get the distance of this cluster's children
404
+ if ind < M:
405
+ distance = -1
406
+ else:
407
+ distance = self._clustering[ind-M, 2]
408
+
409
+ # check if we are a leaf node (or other negative distance cluster) and so should terminate our decent
410
+ if distance < 0:
411
+ self.dvalues[ind] += (f11 - f00) * weight
412
+ continue
413
+
414
+ # build the masks
415
+ m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
416
+ m10[:] += self._mask_matrix[lind, :]
417
+ m01 = m00.copy()
418
+ m01[:] += self._mask_matrix[rind, :]
419
+
420
+ batch_args.append((m00, m10, m01, f00, f11, ind, lind, rind, weight, context))
421
+ batch_masks.append(m10)
422
+ batch_masks.append(m01)
423
+
424
+ batch_masks = np.array(batch_masks)
425
+
426
+ # run the batch
427
+ if len(batch_args) > 0:
428
+ fout = fm(batch_masks)
429
+ if output_indexes is not None:
430
+ fout = fout[:,output_indexes]
431
+
432
+ eval_count += len(batch_masks)
433
+
434
+ if pbar is None and time.time() - start_time > 5:
435
+ pbar = tqdm(total=total_evals, disable=silent, leave=False)
436
+ pbar.update(eval_count)
437
+ if pbar is not None:
438
+ pbar.update(len(batch_masks))
439
+
440
+ # use the results of the batch to add new nodes
441
+ for i in range(len(batch_args)):
442
+
443
+ m00, m10, m01, f00, f11, ind, lind, rind, weight, context = batch_args[i]
444
+
445
+ # get the the number of leaves in this cluster
446
+ if ind < M:
447
+ num_leaves = 0
448
+ else:
449
+ num_leaves = self._clustering[ind-M, 3]
450
+
451
+ # get the evaluated model output on the two new masked inputs
452
+ f10 = fout[2*i]
453
+ f01 = fout[2*i+1]
454
+
455
+ # see if we have enough evaluations left to get both sides of a fixed context
456
+ if max_evals - evals_planned > num_leaves:
457
+ evals_planned += num_leaves
458
+ ignore_context = True
459
+ else:
460
+ ignore_context = False
461
+
462
+ new_weight = weight
463
+ if context is None or ignore_context:
464
+ new_weight /= 2
465
+
466
+ if context is None or context == 0 or ignore_context:
467
+ self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
468
+
469
+ # recurse on the left node with zero context, flip the context for all descendents if we are ignoring it
470
+ args = (m00, f00, f10, lind, new_weight, 0 if context == 1 else context)
471
+ q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
472
+
473
+ # recurse on the right node with zero context, flip the context for all descendents if we are ignoring it
474
+ args = (m00, f00, f01, rind, new_weight, 0 if context == 1 else context)
475
+ q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
476
+
477
+ if context is None or context == 1 or ignore_context:
478
+ self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
479
+
480
+ # recurse on the left node with one context, flip the context for all descendents if we are ignoring it
481
+ args = (m01, f01, f11, lind, new_weight, 1 if context == 0 else context)
482
+ q.put((-np.max(np.abs(f11 - f01)) * new_weight, np.random.randn(), args))
483
+
484
+ # recurse on the right node with one context, flip the context for all descendents if we are ignoring it
485
+ args = (m10, f10, f11, rind, new_weight, 1 if context == 0 else context)
486
+ q.put((-np.max(np.abs(f11 - f10)) * new_weight, np.random.randn(), args))
487
+
488
+ if pbar is not None:
489
+ pbar.close()
490
+
491
+ self.last_eval_count = eval_count
492
+
493
+ return output_indexes, base_value
494
+
495
+
496
+
497
+ # def owen2(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent):
498
+ # """ Compute a nested set of recursive Owen values based on an ordering recursion.
499
+ # """
500
+
501
+ # #f = self._reshaped_model
502
+ # #r = self.masker
503
+ # #masks = np.zeros(2*len(inds)+1, dtype=int)
504
+ # M = len(fm)
505
+ # m00 = np.zeros(M, dtype=bool)
506
+ # #f00 = fm(m00.reshape(1,-1))[0]
507
+ # base_value = f00
508
+ # #f11 = fm(~m00.reshape(1,-1))[0]
509
+ # #f11 = self._reshaped_model(r(~m00, x)).mean(0)
510
+ # ind = len(self.dvalues)-1
511
+
512
+ # # make sure output_indexes is a list of indexes
513
+ # if output_indexes is not None:
514
+ # # assert self.multi_output, "output_indexes is only valid for multi-output models!"
515
+ # # inds = output_indexes.apply(f11, 0)
516
+ # # out_len = output_indexes_len(output_indexes)
517
+ # # if output_indexes.startswith("max("):
518
+ # # output_indexes = np.argsort(-f11)[:out_len]
519
+ # # elif output_indexes.startswith("min("):
520
+ # # output_indexes = np.argsort(f11)[:out_len]
521
+ # # elif output_indexes.startswith("max(abs("):
522
+ # # output_indexes = np.argsort(np.abs(f11))[:out_len]
523
+
524
+ # f00 = f00[output_indexes]
525
+ # f11 = f11[output_indexes]
526
+
527
+ # fc_owen(m00, m11, 1)
528
+ # fc_owen(m00, m11, 0)
529
+
530
+ # def fc_owen(m00, m11, context):
531
+
532
+ # # recurse on the left node with zero context
533
+ # args = (m00, f00, f10, lind, new_weight)
534
+ # q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
535
+
536
+ # # recurse on the right node with zero context
537
+ # args = (m00, f00, f01, rind, new_weight)
538
+ # q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
539
+ # fc_owen(m00, m11, 1)
540
+ # m00 m11
541
+ # owen(fc=1)
542
+ # owen(fc=0)
543
+
544
+ # q = queue.PriorityQueue()
545
+ # q.put((0, 0, (m00, f00, f11, ind, 1.0, 1)))
546
+ # eval_count = 0
547
+ # total_evals = min(max_evals, (M-1)*M) # TODO: (M-1)*M is only right for balanced clusterings, but this is just for plotting progress...
548
+ # pbar = None
549
+ # start_time = time.time()
550
+ # while not q.empty():
551
+
552
+ # # if we passed our execution limit then leave everything else on the internal nodes
553
+ # if eval_count >= max_evals:
554
+ # while not q.empty():
555
+ # m00, f00, f11, ind, weight, _ = q.get()[2]
556
+ # self.dvalues[ind] += (f11 - f00) * weight
557
+ # break
558
+
559
+ # # create a batch of work to do
560
+ # batch_args = []
561
+ # batch_masks = []
562
+ # while not q.empty() and len(batch_masks) < batch_size and eval_count < max_evals:
563
+
564
+ # # get our next set of arguments
565
+ # m00, f00, f11, ind, weight, context = q.get()[2]
566
+
567
+ # # get the left and right children of this cluster
568
+ # lind = int(self._clustering[ind-M, 0]) if ind >= M else -1
569
+ # rind = int(self._clustering[ind-M, 1]) if ind >= M else -1
570
+
571
+ # # get the distance of this cluster's children
572
+ # if ind < M:
573
+ # distance = -1
574
+ # else:
575
+ # if self._clustering.shape[1] >= 3:
576
+ # distance = self._clustering[ind-M, 2]
577
+ # else:
578
+ # distance = 1
579
+
580
+ # # check if we are a leaf node (or other negative distance cluster) and so should terminate our decent
581
+ # if distance < 0:
582
+ # self.dvalues[ind] += (f11 - f00) * weight
583
+ # continue
584
+
585
+ # # build the masks
586
+ # m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
587
+ # m10[:] += self._mask_matrix[lind, :]
588
+ # m01 = m00.copy()
589
+ # m01[:] += self._mask_matrix[rind, :]
590
+
591
+ # batch_args.append((m00, m10, m01, f00, f11, ind, lind, rind, weight, context))
592
+ # batch_masks.append(m10)
593
+ # batch_masks.append(m01)
594
+
595
+ # batch_masks = np.array(batch_masks)
596
+
597
+ # # run the batch
598
+ # if len(batch_args) > 0:
599
+ # fout = fm(batch_masks)
600
+ # if output_indexes is not None:
601
+ # fout = fout[:,output_indexes]
602
+
603
+ # eval_count += len(batch_masks)
604
+
605
+ # if pbar is None and time.time() - start_time > 5:
606
+ # pbar = tqdm(total=total_evals, disable=silent, leave=False)
607
+ # pbar.update(eval_count)
608
+ # if pbar is not None:
609
+ # pbar.update(len(batch_masks))
610
+
611
+ # # use the results of the batch to add new nodes
612
+ # for i in range(len(batch_args)):
613
+
614
+ # m00, m10, m01, f00, f11, ind, lind, rind, weight, context = batch_args[i]
615
+
616
+ # # get the evaluated model output on the two new masked inputs
617
+ # f10 = fout[2*i]
618
+ # f01 = fout[2*i+1]
619
+
620
+ # new_weight = weight
621
+ # if fixed_context is None:
622
+ # new_weight /= 2
623
+ # elif fixed_context == 0:
624
+ # self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
625
+ # elif fixed_context == 1:
626
+ # self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
627
+
628
+ # if fixed_context is None or fixed_context == 0:
629
+ # self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
630
+
631
+
632
+ # # recurse on the left node with zero context
633
+ # args = (m00, f00, f10, lind, new_weight)
634
+ # q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
635
+
636
+ # # recurse on the right node with zero context
637
+ # args = (m00, f00, f01, rind, new_weight)
638
+ # q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
639
+
640
+ # if fixed_context is None or fixed_context == 1:
641
+ # self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
642
+
643
+
644
+ # # recurse on the left node with one context
645
+ # args = (m01, f01, f11, lind, new_weight)
646
+ # q.put((-np.max(np.abs(f11 - f01)) * new_weight, np.random.randn(), args))
647
+
648
+ # # recurse on the right node with one context
649
+ # args = (m10, f10, f11, rind, new_weight)
650
+ # q.put((-np.max(np.abs(f11 - f10)) * new_weight, np.random.randn(), args))
651
+
652
+ # if pbar is not None:
653
+ # pbar.close()
654
+
655
+ # return output_indexes, base_value
656
+
657
+
658
+ def output_indexes_len(output_indexes):
659
+ if output_indexes.startswith("max("):
660
+ return int(output_indexes[4:-1])
661
+ elif output_indexes.startswith("min("):
662
+ return int(output_indexes[4:-1])
663
+ elif output_indexes.startswith("max(abs("):
664
+ return int(output_indexes[8:-2])
665
+ elif not isinstance(output_indexes, str):
666
+ return len(output_indexes)
667
+
668
+ @njit
669
+ def lower_credit(i, value, M, values, clustering):
670
+ if i < M:
671
+ values[i] += value
672
+ return
673
+ li = int(clustering[i-M,0])
674
+ ri = int(clustering[i-M,1])
675
+ group_size = int(clustering[i-M,3])
676
+ lsize = int(clustering[li-M,3]) if li >= M else 1
677
+ rsize = int(clustering[ri-M,3]) if ri >= M else 1
678
+ assert lsize+rsize == group_size
679
+ values[i] += value
680
+ lower_credit(li, values[i] * lsize / group_size, M, values, clustering)
681
+ lower_credit(ri, values[i] * rsize / group_size, M, values, clustering)
lib/shap/explainers/_permutation.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import numpy as np
4
+
5
+ from .. import links
6
+ from ..models import Model
7
+ from ..utils import MaskedModel, partition_tree_shuffle
8
+ from ._explainer import Explainer
9
+
10
+
11
+ class PermutationExplainer(Explainer):
12
+ """ This method approximates the Shapley values by iterating through permutations of the inputs.
13
+
14
+ This is a model agnostic explainer that guarantees local accuracy (additivity) by iterating completely
15
+ through an entire permutation of the features in both forward and reverse directions (antithetic sampling).
16
+ If we do this once, then we get the exact SHAP values for models with up to second order interaction effects.
17
+ We can iterate this many times over many random permutations to get better SHAP value estimates for models
18
+ with higher order interactions. This sequential ordering formulation also allows for easy reuse of
19
+ model evaluations and the ability to efficiently avoid evaluating the model when the background values
20
+ for a feature are the same as the current input value. We can also account for hierarchical data
21
+ structures with partition trees, something not currently implemented for KernalExplainer or SamplingExplainer.
22
+ """
23
+
24
+ def __init__(self, model, masker, link=links.identity, feature_names=None, linearize_link=True, seed=None, **call_args):
25
+ """ Build an explainers.Permutation object for the given model using the given masker object.
26
+
27
+ Parameters
28
+ ----------
29
+ model : function
30
+ A callable python object that executes the model given a set of input data samples.
31
+
32
+ masker : function or numpy.array or pandas.DataFrame
33
+ A callable python object used to "mask" out hidden features of the form `masker(binary_mask, x)`.
34
+ It takes a single input sample and a binary mask and returns a matrix of masked samples. These
35
+ masked samples are evaluated using the model function and the outputs are then averaged.
36
+ As a shortcut for the standard masking using by SHAP you can pass a background data matrix
37
+ instead of a function and that matrix will be used for masking. To use a clustering
38
+ game structure you can pass a shap.maskers.Tabular(data, clustering=\"correlation\") object.
39
+
40
+ seed: None or int
41
+ Seed for reproducibility
42
+
43
+ **call_args : valid argument to the __call__ method
44
+ These arguments are saved and passed to the __call__ method as the new default values for these arguments.
45
+ """
46
+
47
+ # setting seed for random generation: if seed is not None, then shap values computation should be reproducible
48
+ np.random.seed(seed)
49
+
50
+ if masker is None:
51
+ raise ValueError("masker cannot be None.")
52
+
53
+ super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names)
54
+
55
+ if not isinstance(self.model, Model):
56
+ self.model = Model(self.model)
57
+
58
+ # if we have gotten default arguments for the call function we need to wrap ourselves in a new class that
59
+ # has a call function with those new default arguments
60
+ if len(call_args) > 0:
61
+ # this signature should match the __call__ signature of the class defined below
62
+ class PermutationExplainer(self.__class__):
63
+ def __call__(self, *args, max_evals=500, main_effects=False, error_bounds=False, batch_size="auto",
64
+ outputs=None, silent=False):
65
+ return super().__call__(
66
+ *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
67
+ batch_size=batch_size, outputs=outputs, silent=silent
68
+ )
69
+ PermutationExplainer.__call__.__doc__ = self.__class__.__call__.__doc__
70
+ self.__class__ = PermutationExplainer
71
+ for k, v in call_args.items():
72
+ self.__call__.__kwdefaults__[k] = v
73
+
74
+ # note that changes to this function signature should be copied to the default call argument wrapper above
75
+ def __call__(self, *args, max_evals=500, main_effects=False, error_bounds=False, batch_size="auto",
76
+ outputs=None, silent=False):
77
+ """ Explain the output of the model on the given arguments.
78
+ """
79
+ return super().__call__(
80
+ *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size,
81
+ outputs=outputs, silent=silent
82
+ )
83
+
84
+ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent):
85
+ """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
86
+ """
87
+
88
+ # build a masked version of the model for the current input sample
89
+ fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)
90
+
91
+ # by default we run 10 permutations forward and backward
92
+ if max_evals == "auto":
93
+ max_evals = 10 * 2 * len(fm)
94
+
95
+ # compute any custom clustering for this row
96
+ row_clustering = None
97
+ if getattr(self.masker, "clustering", None) is not None:
98
+ if isinstance(self.masker.clustering, np.ndarray):
99
+ row_clustering = self.masker.clustering
100
+ elif callable(self.masker.clustering):
101
+ row_clustering = self.masker.clustering(*row_args)
102
+ else:
103
+ raise NotImplementedError("The masker passed has a .clustering attribute that is not yet supported by the Permutation explainer!")
104
+
105
+ # loop over many permutations
106
+ inds = fm.varying_inputs()
107
+ inds_mask = np.zeros(len(fm), dtype=bool)
108
+ inds_mask[inds] = True
109
+ masks = np.zeros(2*len(inds)+1, dtype=int)
110
+ masks[0] = MaskedModel.delta_mask_noop_value
111
+ npermutations = max_evals // (2*len(inds)+1)
112
+ row_values = None
113
+ row_values_history = None
114
+ history_pos = 0
115
+ main_effect_values = None
116
+ if len(inds) > 0:
117
+ for _ in range(npermutations):
118
+
119
+ # shuffle the indexes so we get a random permutation ordering
120
+ if row_clustering is not None:
121
+ # [TODO] This is shuffle does not work when inds is not a complete set of integers from 0 to M TODO: still true?
122
+ #assert len(inds) == len(fm), "Need to support partition shuffle when not all the inds vary!!"
123
+ partition_tree_shuffle(inds, inds_mask, row_clustering)
124
+ else:
125
+ np.random.shuffle(inds)
126
+
127
+ # create a large batch of masks to evaluate
128
+ i = 1
129
+ for ind in inds:
130
+ masks[i] = ind
131
+ i += 1
132
+ for ind in inds:
133
+ masks[i] = ind
134
+ i += 1
135
+
136
+ # evaluate the masked model
137
+ outputs = fm(masks, zero_index=0, batch_size=batch_size)
138
+
139
+ if row_values is None:
140
+ row_values = np.zeros((len(fm),) + outputs.shape[1:])
141
+
142
+ if error_bounds:
143
+ row_values_history = np.zeros((2 * npermutations, len(fm),) + outputs.shape[1:])
144
+
145
+ # update our SHAP value estimates
146
+ i = 0
147
+ for ind in inds: # forward
148
+ row_values[ind] += outputs[i + 1] - outputs[i]
149
+ if error_bounds:
150
+ row_values_history[history_pos][ind] = outputs[i + 1] - outputs[i]
151
+ i += 1
152
+ history_pos += 1
153
+ for ind in inds: # backward
154
+ row_values[ind] += outputs[i] - outputs[i + 1]
155
+ if error_bounds:
156
+ row_values_history[history_pos][ind] = outputs[i] - outputs[i + 1]
157
+ i += 1
158
+ history_pos += 1
159
+
160
+ if npermutations == 0:
161
+ raise ValueError(f"max_evals={max_evals} is too low for the Permutation explainer, it must be at least 2 * num_features + 1 = {2 * len(inds) + 1}!")
162
+
163
+ expected_value = outputs[0]
164
+
165
+ # compute the main effects if we need to
166
+ if main_effects:
167
+ main_effect_values = fm.main_effects(inds, batch_size=batch_size)
168
+ else:
169
+ masks = np.zeros(1, dtype=int)
170
+ outputs = fm(masks, zero_index=0, batch_size=1)
171
+ expected_value = outputs[0]
172
+ row_values = np.zeros((len(fm),) + outputs.shape[1:])
173
+ if error_bounds:
174
+ row_values_history = np.zeros((2 * npermutations, len(fm),) + outputs.shape[1:])
175
+
176
+ return {
177
+ "values": row_values / (2 * npermutations),
178
+ "expected_values": expected_value,
179
+ "mask_shapes": fm.mask_shapes,
180
+ "main_effects": main_effect_values,
181
+ "clustering": row_clustering,
182
+ "error_std": None if row_values_history is None else row_values_history.std(0),
183
+ "output_names": self.model.output_names if hasattr(self.model, "output_names") else None
184
+ }
185
+
186
+
187
+ def shap_values(self, X, npermutations=10, main_effects=False, error_bounds=False, batch_evals=True, silent=False):
188
+ """ Legacy interface to estimate the SHAP values for a set of samples.
189
+
190
+ Parameters
191
+ ----------
192
+ X : numpy.array or pandas.DataFrame or any scipy.sparse matrix
193
+ A matrix of samples (# samples x # features) on which to explain the model's output.
194
+
195
+ npermutations : int
196
+ Number of times to cycle through all the features, re-evaluating the model at each step.
197
+ Each cycle evaluates the model function 2 * (# features + 1) times on a data matrix of
198
+ (# background data samples) rows. An exception to this is when PermutationExplainer can
199
+ avoid evaluating the model because a feature's value is the same in X and the background
200
+ dataset (which is common for example with sparse features).
201
+
202
+ Returns
203
+ -------
204
+ array or list
205
+ For models with a single output this returns a matrix of SHAP values
206
+ (# samples x # features). Each row sums to the difference between the model output for that
207
+ sample and the expected value of the model output (which is stored as expected_value
208
+ attribute of the explainer). For models with vector outputs this returns a list
209
+ of such matrices, one for each output.
210
+ """
211
+ warnings.warn("shap_values() is deprecated; use __call__().", DeprecationWarning)
212
+
213
+ explanation = self(X, max_evals=npermutations * X.shape[1], main_effects=main_effects)
214
+ return explanation.values
215
+
216
+ def __str__(self):
217
+ return "shap.explainers.PermutationExplainer()"
lib/shap/explainers/_sampling.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+
6
+ from .._explanation import Explanation
7
+ from ..utils._exceptions import ExplainerError
8
+ from ..utils._legacy import convert_to_instance, match_instance_to_data
9
+ from ._kernel import KernelExplainer
10
+
11
+ log = logging.getLogger('shap')
12
+
13
+
14
+ class SamplingExplainer(KernelExplainer):
15
+ """Computes SHAP values using an extension of the Shapley sampling values explanation method
16
+ (also known as IME).
17
+
18
+ SamplingExplainer computes SHAP values under the assumption of feature independence and is an
19
+ extension of the algorithm proposed in "An Efficient Explanation of Individual Classifications
20
+ using Game Theory", Erik Strumbelj, Igor Kononenko, JMLR 2010. It is a good alternative to
21
+ KernelExplainer when you want to use a large background set (as opposed to a single reference
22
+ value for example).
23
+
24
+ Parameters
25
+ ----------
26
+ model : function
27
+ User supplied function that takes a matrix of samples (# samples x # features) and
28
+ computes the output of the model for those samples. The output can be a vector
29
+ (# samples) or a matrix (# samples x # model outputs).
30
+
31
+ data : numpy.array or pandas.DataFrame
32
+ The background dataset to use for integrating out features. To determine the impact
33
+ of a feature, that feature is set to "missing" and the change in the model output
34
+ is observed. Since most models aren't designed to handle arbitrary missing data at test
35
+ time, we simulate "missing" by replacing the feature with the values it takes in the
36
+ background dataset. So if the background dataset is a simple sample of all zeros, then
37
+ we would approximate a feature being missing by setting it to zero. Unlike the
38
+ KernelExplainer, this data can be the whole training set, even if that is a large set. This
39
+ is because SamplingExplainer only samples from this background dataset.
40
+ """
41
+
42
+ def __init__(self, model, data, **kwargs):
43
+ # silence warning about large datasets
44
+ level = log.level
45
+ log.setLevel(logging.ERROR)
46
+ super().__init__(model, data, **kwargs)
47
+ log.setLevel(level)
48
+
49
+ if str(self.link) != "identity":
50
+ emsg = f"SamplingExplainer only supports the identity link, not {self.link}"
51
+ raise ValueError(emsg)
52
+
53
+ def __call__(self, X, y=None, nsamples=2000):
54
+
55
+ if isinstance(X, pd.DataFrame):
56
+ feature_names = list(X.columns)
57
+ X = X.values
58
+ else:
59
+ feature_names = None # we can make self.feature_names from background data eventually if we have it
60
+
61
+ v = self.shap_values(X, nsamples=nsamples)
62
+ if isinstance(v, list):
63
+ v = np.stack(v, axis=-1) # put outputs at the end
64
+ e = Explanation(v, self.expected_value, X, feature_names=feature_names)
65
+ return e
66
+
67
+ def explain(self, incoming_instance, **kwargs):
68
+ # convert incoming input to a standardized iml object
69
+ instance = convert_to_instance(incoming_instance)
70
+ match_instance_to_data(instance, self.data)
71
+
72
+ if len(self.data.groups) != self.P:
73
+ emsg = "SamplingExplainer does not support feature groups!"
74
+ raise ExplainerError(emsg)
75
+
76
+ # find the feature groups we will test. If a feature does not change from its
77
+ # current value then we know it doesn't impact the model
78
+ self.varyingInds = self.varying_groups(instance.x)
79
+ #self.varyingFeatureGroups = [self.data.groups[i] for i in self.varyingInds]
80
+ self.M = len(self.varyingInds)
81
+
82
+ # find f(x)
83
+ if self.keep_index:
84
+ model_out = self.model.f(instance.convert_to_df())
85
+ else:
86
+ model_out = self.model.f(instance.x)
87
+ if isinstance(model_out, (pd.DataFrame, pd.Series)):
88
+ model_out = model_out.values[0]
89
+ self.fx = model_out[0]
90
+
91
+ if not self.vector_out:
92
+ self.fx = np.array([self.fx])
93
+
94
+ # if no features vary then there no feature has an effect
95
+ if self.M == 0:
96
+ phi = np.zeros((len(self.data.groups), self.D))
97
+ phi_var = np.zeros((len(self.data.groups), self.D))
98
+
99
+ # if only one feature varies then it has all the effect
100
+ elif self.M == 1:
101
+ phi = np.zeros((len(self.data.groups), self.D))
102
+ phi_var = np.zeros((len(self.data.groups), self.D))
103
+ diff = self.fx - self.fnull
104
+ for d in range(self.D):
105
+ phi[self.varyingInds[0],d] = diff[d]
106
+
107
+ # if more than one feature varies then we have to do real work
108
+ else:
109
+
110
+ # pick a reasonable number of samples if the user didn't specify how many they wanted
111
+ self.nsamples = kwargs.get("nsamples", "auto")
112
+ if self.nsamples == "auto":
113
+ self.nsamples = 1000 * self.M
114
+
115
+ min_samples_per_feature = kwargs.get("min_samples_per_feature", 100)
116
+ round1_samples = self.nsamples
117
+ round2_samples = 0
118
+ if round1_samples > self.M * min_samples_per_feature:
119
+ round2_samples = round1_samples - self.M * min_samples_per_feature
120
+ round1_samples -= round2_samples
121
+
122
+ # divide up the samples among the features for round 1
123
+ nsamples_each1 = np.ones(self.M, dtype=np.int64) * 2 * (round1_samples // (self.M * 2))
124
+ for i in range((round1_samples % (self.M * 2)) // 2):
125
+ nsamples_each1[i] += 2
126
+
127
+ # explain every feature in round 1
128
+ phi = np.zeros((self.P, self.D))
129
+ phi_var = np.zeros((self.P, self.D))
130
+ self.X_masked = np.zeros((nsamples_each1.max() * 2, self.data.data.shape[1]))
131
+ for i,ind in enumerate(self.varyingInds):
132
+ phi[ind,:],phi_var[ind,:] = self.sampling_estimate(ind, self.model.f, instance.x, self.data.data, nsamples=nsamples_each1[i])
133
+
134
+ # optimally allocate samples according to the variance
135
+ if phi_var.sum() == 0:
136
+ phi_var += 1 # spread samples uniformally if we found no variability
137
+ phi_var /= phi_var.sum(0)[np.newaxis, :]
138
+ nsamples_each2 = (phi_var[self.varyingInds,:].mean(1) * round2_samples).astype(int)
139
+ for i in range(len(nsamples_each2)):
140
+ if nsamples_each2[i] % 2 == 1:
141
+ nsamples_each2[i] += 1
142
+ for i in range(len(nsamples_each2)):
143
+ if nsamples_each2.sum() > round2_samples:
144
+ nsamples_each2[i] -= 2
145
+ elif nsamples_each2.sum() < round2_samples:
146
+ nsamples_each2[i] += 2
147
+ else:
148
+ break
149
+
150
+ self.X_masked = np.zeros((nsamples_each2.max() * 2, self.data.data.shape[1]))
151
+ for i,ind in enumerate(self.varyingInds):
152
+ if nsamples_each2[i] > 0:
153
+ val,var = self.sampling_estimate(ind, self.model.f, instance.x, self.data.data, nsamples=nsamples_each2[i])
154
+
155
+ total_samples = nsamples_each1[i] + nsamples_each2[i]
156
+ phi[ind,:] = (phi[ind,:] * nsamples_each1[i] + val * nsamples_each2[i]) / total_samples
157
+ phi_var[ind,:] = (phi_var[ind,:] * nsamples_each1[i] + var * nsamples_each2[i]) / total_samples
158
+
159
+ # convert from the variance of the differences to the variance of the mean (phi)
160
+ for i,ind in enumerate(self.varyingInds):
161
+ phi_var[ind,:] /= np.sqrt(nsamples_each1[i] + nsamples_each2[i])
162
+
163
+ # correct the sum of the SHAP values to equal the output of the model using a linear
164
+ # regression model with priors of the coefficients equal to the estimated variances for each
165
+ # SHAP value (note that 1e6 is designed to increase the weight of the sample and so closely
166
+ # match the correct sum)
167
+ sum_error = self.fx - phi.sum(0) - self.fnull
168
+ for i in range(self.D):
169
+ # this is a ridge regression with one sample of all ones with sum_error[i] as the label
170
+ # and 1/v as the ridge penalties. This simplified (and stable) form comes from the
171
+ # Sherman-Morrison formula
172
+ v = (phi_var[:,i] / phi_var[:,i].max()) * 1e6
173
+ adj = sum_error[i] * (v - (v * v.sum()) / (1 + v.sum()))
174
+ phi[:,i] += adj
175
+
176
+ if phi.shape[1] == 1:
177
+ phi = phi[:,0]
178
+
179
+ return phi
180
+
181
+ def sampling_estimate(self, j, f, x, X, nsamples=10):
182
+ X_masked = self.X_masked[:nsamples * 2,:]
183
+ inds = np.arange(X.shape[1])
184
+
185
+ for i in range(0, nsamples):
186
+ np.random.shuffle(inds)
187
+ pos = np.where(inds == j)[0][0]
188
+ rind = np.random.randint(X.shape[0])
189
+ X_masked[i, :] = x
190
+ X_masked[i, inds[pos+1:]] = X[rind, inds[pos+1:]]
191
+ X_masked[-(i+1), :] = x
192
+ X_masked[-(i+1), inds[pos:]] = X[rind, inds[pos:]]
193
+
194
+ evals = f(X_masked)
195
+ evals_on = evals[:nsamples]
196
+ evals_off = evals[nsamples:][::-1]
197
+ d = evals_on - evals_off
198
+
199
+ return np.mean(d, 0), np.var(d, 0)
lib/shap/explainers/_tree.py ADDED
The diff for this file is too large to render. See raw diff
 
lib/shap/explainers/other/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from ._coefficient import Coefficient
4
+ from ._lime import LimeTabular
5
+ from ._maple import Maple, TreeMaple
6
+ from ._random import Random
7
+ from ._treegain import TreeGain
8
+
9
+ __all__ = [
10
+ "Coefficient",
11
+ "LimeTabular",
12
+ "Maple",
13
+ "TreeMaple",
14
+ "Random",
15
+ "TreeGain",
16
+ ]
17
+
18
+
19
+ # Deprecated class alias with incorrect spelling
20
+ def Coefficent(*args, **kwargs): # noqa
21
+ warnings.warn(
22
+ "Coefficent has been renamed to Coefficient. "
23
+ "The former is deprecated and will be removed in shap 0.45.",
24
+ DeprecationWarning
25
+ )
26
+ return Coefficient(*args, **kwargs)
lib/shap/explainers/other/_coefficient.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from .._explainer import Explainer
4
+
5
+
6
+ class Coefficient(Explainer):
7
+ """ Simply returns the model coefficients as the feature attributions.
8
+
9
+ This is only for benchmark comparisons and does not approximate SHAP values in a
10
+ meaningful way.
11
+ """
12
+ def __init__(self, model):
13
+ assert hasattr(model, "coef_"), "The passed model does not have a coef_ attribute!"
14
+ self.model = model
15
+
16
+ def attributions(self, X):
17
+ return np.tile(self.model.coef_, (X.shape[0], 1))
lib/shap/explainers/other/_lime.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from .._explainer import Explainer
5
+
6
+ try:
7
+ import lime
8
+ import lime.lime_tabular
9
+ except ImportError:
10
+ pass
11
+
12
+ class LimeTabular(Explainer):
13
+ """ Simply wrap of lime.lime_tabular.LimeTabularExplainer into the common shap interface.
14
+
15
+ Parameters
16
+ ----------
17
+ model : function or iml.Model
18
+ User supplied function that takes a matrix of samples (# samples x # features) and
19
+ computes the output of the model for those samples. The output can be a vector
20
+ (# samples) or a matrix (# samples x # model outputs).
21
+
22
+ data : numpy.array
23
+ The background dataset.
24
+
25
+ mode : "classification" or "regression"
26
+ Control the mode of LIME tabular.
27
+ """
28
+
29
+ def __init__(self, model, data, mode="classification"):
30
+ self.model = model
31
+ if mode not in ["classification", "regression"]:
32
+ emsg = f"Invalid mode {mode!r}, must be one of 'classification' or 'regression'"
33
+ raise ValueError(emsg)
34
+ self.mode = mode
35
+
36
+ if isinstance(data, pd.DataFrame):
37
+ data = data.values
38
+ self.data = data
39
+ self.explainer = lime.lime_tabular.LimeTabularExplainer(data, mode=mode)
40
+
41
+ out = self.model(data[0:1])
42
+ if len(out.shape) == 1:
43
+ self.out_dim = 1
44
+ self.flat_out = True
45
+ if mode == "classification":
46
+ def pred(X): # assume that 1d outputs are probabilities
47
+ preds = self.model(X).reshape(-1, 1)
48
+ p0 = 1 - preds
49
+ return np.hstack((p0, preds))
50
+ self.model = pred
51
+ else:
52
+ self.out_dim = self.model(data[0:1]).shape[1]
53
+ self.flat_out = False
54
+
55
+ def attributions(self, X, nsamples=5000, num_features=None):
56
+ num_features = X.shape[1] if num_features is None else num_features
57
+
58
+ if isinstance(X, pd.DataFrame):
59
+ X = X.values
60
+
61
+ out = [np.zeros(X.shape) for j in range(self.out_dim)]
62
+ for i in range(X.shape[0]):
63
+ exp = self.explainer.explain_instance(X[i], self.model, labels=range(self.out_dim), num_features=num_features)
64
+ for j in range(self.out_dim):
65
+ for k,v in exp.local_exp[j]:
66
+ out[j][i,k] = v
67
+
68
+ # because it output two results even for only one model output, and they are negated from what we expect
69
+ if self.mode == "regression":
70
+ for i in range(len(out)):
71
+ out[i] = -out[i]
72
+
73
+ return out[0] if self.flat_out else out
lib/shap/explainers/other/_maple.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from sklearn.model_selection import train_test_split
4
+
5
+ from .._explainer import Explainer
6
+
7
+
8
+ class Maple(Explainer):
9
+ """ Simply wraps MAPLE into the common SHAP interface.
10
+
11
+ Parameters
12
+ ----------
13
+ model : function
14
+ User supplied function that takes a matrix of samples (# samples x # features) and
15
+ computes the output of the model for those samples. The output can be a vector
16
+ (# samples) or a matrix (# samples x # model outputs).
17
+
18
+ data : numpy.array
19
+ The background dataset.
20
+ """
21
+
22
+ def __init__(self, model, data):
23
+ self.model = model
24
+
25
+ if isinstance(data, pd.DataFrame):
26
+ data = data.values
27
+ self.data = data
28
+ self.data_mean = self.data.mean(0)
29
+
30
+ out = self.model(data)
31
+ if len(out.shape) == 1:
32
+ self.out_dim = 1
33
+ self.flat_out = True
34
+ else:
35
+ self.out_dim = out.shape[1]
36
+ self.flat_out = False
37
+
38
+ X_train, X_valid, y_train, y_valid = train_test_split(data, out, test_size=0.2, random_state=0)
39
+ self.explainer = MAPLE(X_train, y_train, X_valid, y_valid)
40
+
41
+ def attributions(self, X, multiply_by_input=False):
42
+ """ Compute the MAPLE coef attributions.
43
+
44
+ Parameters
45
+ ----------
46
+ multiply_by_input : bool
47
+ If true, this multiplies the learned coefficients by the mean-centered input. This makes these
48
+ values roughly comparable to SHAP values.
49
+ """
50
+ if isinstance(X, pd.DataFrame):
51
+ X = X.values
52
+
53
+ out = [np.zeros(X.shape) for j in range(self.out_dim)]
54
+ for i in range(X.shape[0]):
55
+ exp = self.explainer.explain(X[i])["coefs"]
56
+ out[0][i,:] = exp[1:]
57
+ if multiply_by_input:
58
+ out[0][i,:] = out[0][i,:] * (X[i] - self.data_mean)
59
+
60
+ return out[0] if self.flat_out else out
61
+
62
+
63
+ class TreeMaple(Explainer):
64
+ """ Simply tree MAPLE into the common SHAP interface.
65
+
66
+ Parameters
67
+ ----------
68
+ model : function
69
+ User supplied function that takes a matrix of samples (# samples x # features) and
70
+ computes the output of the model for those samples. The output can be a vector
71
+ (# samples) or a matrix (# samples x # model outputs).
72
+
73
+ data : numpy.array
74
+ The background dataset.
75
+ """
76
+
77
+ def __init__(self, model, data):
78
+ self.model = model
79
+
80
+ if str(type(model)).endswith("sklearn.ensemble.gradient_boosting.GradientBoostingRegressor'>"):
81
+ fe_type = "gbdt"
82
+ # elif str(type(model)).endswith("sklearn.tree.tree.DecisionTreeClassifier'>"):
83
+ # pass
84
+ elif str(type(model)).endswith("sklearn.ensemble.forest.RandomForestRegressor'>"):
85
+ fe_type = "rf"
86
+ # elif str(type(model)).endswith("sklearn.ensemble.forest.RandomForestClassifier'>"):
87
+ # pass
88
+ # elif str(type(model)).endswith("xgboost.sklearn.XGBRegressor'>"):
89
+ # pass
90
+ # elif str(type(model)).endswith("xgboost.sklearn.XGBClassifier'>"):
91
+ # pass
92
+ else:
93
+ raise NotImplementedError("The passed model is not yet supported by TreeMapleExplainer: " + str(type(model)))
94
+
95
+ if isinstance(data, pd.DataFrame):
96
+ data = data.values
97
+ self.data = data
98
+ self.data_mean = self.data.mean(0)
99
+
100
+ out = self.model.predict(data[0:1])
101
+ if len(out.shape) == 1:
102
+ self.out_dim = 1
103
+ self.flat_out = True
104
+ else:
105
+ self.out_dim = self.model.predict(data[0:1]).shape[1]
106
+ self.flat_out = False
107
+
108
+ #_, X_valid, _, y_valid = train_test_split(data, self.model.predict(data), test_size=0.2, random_state=0)
109
+ preds = self.model.predict(data)
110
+ self.explainer = MAPLE(data, preds, data, preds, fe=self.model, fe_type=fe_type)
111
+
112
+ def attributions(self, X, multiply_by_input=False):
113
+ """ Compute the MAPLE coef attributions.
114
+
115
+ Parameters
116
+ ----------
117
+ multiply_by_input : bool
118
+ If true, this multiplies the learned coefficients by the mean-centered input. This makes these
119
+ values roughly comparable to SHAP values.
120
+ """
121
+ if isinstance(X, pd.DataFrame):
122
+ X = X.values
123
+
124
+ out = [np.zeros(X.shape) for j in range(self.out_dim)]
125
+ for i in range(X.shape[0]):
126
+ exp = self.explainer.explain(X[i])["coefs"]
127
+ out[0][i,:] = exp[1:]
128
+ if multiply_by_input:
129
+ out[0][i,:] = out[0][i,:] * (X[i] - self.data_mean)
130
+
131
+ return out[0] if self.flat_out else out
132
+
133
+
134
+ #################################################
135
+ # The code below was authored by Gregory Plumb and is
136
+ # from: https://github.com/GDPlumb/MAPLE/blob/master/Code/MAPLE.py
137
+ # It has by copied here to allow for benchmark comparisons. Please see
138
+ # the original repo for the latest version, supporting material, and citations.
139
+ #################################################
140
+
141
+ # Notes:
142
+ # - Assumes any required data normalization has already been done
143
+ # - Can pass Y (desired response) instead of MR (model fit to Y) to make fitting MAPLE to datasets easy
144
+
145
+ import numpy as np
146
+ from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
147
+ from sklearn.linear_model import Ridge
148
+ from sklearn.metrics import mean_squared_error
149
+
150
+
151
+ class MAPLE:
152
+
153
+ def __init__(self, X_train, MR_train, X_val, MR_val, fe_type = "rf", fe=None, n_estimators = 200, max_features = 0.5, min_samples_leaf = 10, regularization = 0.001):
154
+
155
+ # Features and the target model response
156
+ self.X_train = X_train
157
+ self.MR_train = MR_train
158
+ self.X_val = X_val
159
+ self.MR_val = MR_val
160
+
161
+ # Forest Ensemble Parameters
162
+ self.n_estimators = n_estimators
163
+ self.max_features = max_features
164
+ self.min_samples_leaf = min_samples_leaf
165
+
166
+ # Local Linear Model Parameters
167
+ self.regularization = regularization
168
+
169
+ # Data parameters
170
+ num_features = X_train.shape[1]
171
+ self.num_features = num_features
172
+ num_train = X_train.shape[0]
173
+ self.num_train = num_train
174
+ num_val = X_val.shape[0]
175
+
176
+ # Fit a Forest Ensemble to the model response
177
+ if fe is None:
178
+ if fe_type == "rf":
179
+ fe = RandomForestRegressor(n_estimators = n_estimators, min_samples_leaf = min_samples_leaf, max_features = max_features)
180
+ elif fe_type == "gbrt":
181
+ fe = GradientBoostingRegressor(n_estimators = n_estimators, min_samples_leaf = min_samples_leaf, max_features = max_features, max_depth = None)
182
+ else:
183
+ print("Unknown FE type ", fe)
184
+ import sys
185
+ sys.exit(0)
186
+ fe.fit(X_train, MR_train)
187
+ else:
188
+ self.n_estimators = n_estimators = len(fe.estimators_)
189
+ self.fe = fe
190
+
191
+ train_leaf_ids = fe.apply(X_train)
192
+ self.train_leaf_ids = train_leaf_ids
193
+
194
+ val_leaf_ids_list = fe.apply(X_val)
195
+
196
+ # Compute the feature importances: Non-normalized @ Root
197
+ scores = np.zeros(num_features)
198
+ if fe_type == "rf":
199
+ for i in range(n_estimators):
200
+ splits = fe[i].tree_.feature #-2 indicates leaf, index 0 is root
201
+ if splits[0] != -2:
202
+ scores[splits[0]] += fe[i].tree_.impurity[0] #impurity reduction not normalized per tree
203
+ elif fe_type == "gbrt":
204
+ for i in range(n_estimators):
205
+ splits = fe[i, 0].tree_.feature #-2 indicates leaf, index 0 is root
206
+ if splits[0] != -2:
207
+ scores[splits[0]] += fe[i, 0].tree_.impurity[0] #impurity reduction not normalized per tree
208
+ self.feature_scores = scores
209
+ mostImpFeats = np.argsort(-scores)
210
+
211
+ # Find the number of features to use for MAPLE
212
+ retain_best = 0
213
+ rmse_best = np.inf
214
+ for retain in range(1, num_features + 1):
215
+
216
+ # Drop less important features for local regression
217
+ X_train_p = np.delete(X_train, mostImpFeats[retain:], axis = 1)
218
+ X_val_p = np.delete(X_val, mostImpFeats[retain:], axis = 1)
219
+
220
+ lr_predictions = np.empty([num_val], dtype=float)
221
+
222
+ for i in range(num_val):
223
+
224
+ weights = self.training_point_weights(val_leaf_ids_list[i])
225
+
226
+ # Local linear model
227
+ lr_model = Ridge(alpha=regularization)
228
+ lr_model.fit(X_train_p, MR_train, weights)
229
+ lr_predictions[i] = lr_model.predict(X_val_p[i].reshape(1, -1))
230
+
231
+ rmse_curr = np.sqrt(mean_squared_error(lr_predictions, MR_val))
232
+
233
+ if rmse_curr < rmse_best:
234
+ rmse_best = rmse_curr
235
+ retain_best = retain
236
+
237
+ self.retain = retain_best
238
+ self.X = np.delete(X_train, mostImpFeats[retain_best:], axis = 1)
239
+
240
+ def training_point_weights(self, instance_leaf_ids):
241
+ weights = np.zeros(self.num_train)
242
+ for i in range(self.n_estimators):
243
+ # Get the PNNs for each tree (ones with the same leaf_id)
244
+ PNNs_Leaf_Node = np.where(self.train_leaf_ids[:, i] == instance_leaf_ids[i])[0]
245
+ if len(PNNs_Leaf_Node) > 0: # SML: added this to fix degenerate cases
246
+ weights[PNNs_Leaf_Node] += 1.0 / len(PNNs_Leaf_Node)
247
+ return weights
248
+
249
+ def explain(self, x):
250
+
251
+ x = x.reshape(1, -1)
252
+
253
+ mostImpFeats = np.argsort(-self.feature_scores)
254
+ x_p = np.delete(x, mostImpFeats[self.retain:], axis = 1)
255
+
256
+ curr_leaf_ids = self.fe.apply(x)[0]
257
+ weights = self.training_point_weights(curr_leaf_ids)
258
+
259
+ # Local linear model
260
+ lr_model = Ridge(alpha = self.regularization)
261
+ lr_model.fit(self.X, self.MR_train, weights)
262
+
263
+ # Get the model coefficients
264
+ coefs = np.zeros(self.num_features + 1)
265
+ coefs[0] = lr_model.intercept_
266
+ coefs[np.sort(mostImpFeats[0:self.retain]) + 1] = lr_model.coef_
267
+
268
+ # Get the prediction at this point
269
+ prediction = lr_model.predict(x_p.reshape(1, -1))
270
+
271
+ out = {}
272
+ out["weights"] = weights
273
+ out["coefs"] = coefs
274
+ out["pred"] = prediction
275
+
276
+ return out
277
+
278
+ def predict(self, X):
279
+ n = X.shape[0]
280
+ pred = np.zeros(n)
281
+ for i in range(n):
282
+ exp = self.explain(X[i, :])
283
+ pred[i] = exp["pred"][0]
284
+ return pred
285
+
286
+ # Make the predictions based on the forest ensemble (either random forest or gradient boosted regression tree) instead of MAPLE
287
+ def predict_fe(self, X):
288
+ return self.fe.predict(X)
289
+
290
+ # Make the predictions based on SILO (no feature selection) instead of MAPLE
291
+ def predict_silo(self, X):
292
+ n = X.shape[0]
293
+ pred = np.zeros(n)
294
+ for i in range(n): #The contents of this inner loop are similar to explain(): doesn't use the features selected by MAPLE or return as much information
295
+ x = X[i, :].reshape(1, -1)
296
+
297
+ curr_leaf_ids = self.fe.apply(x)[0]
298
+ weights = self.training_point_weights(curr_leaf_ids)
299
+
300
+ # Local linear model
301
+ lr_model = Ridge(alpha = self.regularization)
302
+ lr_model.fit(self.X_train, self.MR_train, weights)
303
+
304
+ pred[i] = lr_model.predict(x)[0]
305
+
306
+ return pred
lib/shap/explainers/other/_random.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from shap import links
4
+ from shap.models import Model
5
+ from shap.utils import MaskedModel
6
+
7
+ from .._explainer import Explainer
8
+
9
+
10
+ class Random(Explainer):
11
+ """ Simply returns random (normally distributed) feature attributions.
12
+
13
+ This is only for benchmark comparisons. It supports both fully random attributions and random
14
+ attributions that are constant across all explanations.
15
+ """
16
+ def __init__(self, model, masker, link=links.identity, feature_names=None, linearize_link=True, constant=False, **call_args):
17
+ super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names)
18
+
19
+ if not isinstance(model, Model):
20
+ self.model = Model(model)
21
+
22
+ for arg in call_args:
23
+ self.__call__.__kwdefaults__[arg] = call_args[arg]
24
+
25
+ self.constant = constant
26
+ self.constant_attributions = None
27
+
28
+ def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent):
29
+ """ Explains a single row.
30
+ """
31
+
32
+ # build a masked version of the model for the current input sample
33
+ fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)
34
+
35
+ # compute any custom clustering for this row
36
+ row_clustering = None
37
+ if getattr(self.masker, "clustering", None) is not None:
38
+ if isinstance(self.masker.clustering, np.ndarray):
39
+ row_clustering = self.masker.clustering
40
+ elif callable(self.masker.clustering):
41
+ row_clustering = self.masker.clustering(*row_args)
42
+ else:
43
+ raise NotImplementedError("The masker passed has a .clustering attribute that is not yet supported by the Permutation explainer!")
44
+
45
+ # compute the correct expected value
46
+ masks = np.zeros(1, dtype=int)
47
+ outputs = fm(masks, zero_index=0, batch_size=1)
48
+ expected_value = outputs[0]
49
+
50
+ # generate random feature attributions
51
+ # we produce small values so our explanation errors are similar to a constant function
52
+ row_values = np.random.randn(*((len(fm),) + outputs.shape[1:])) * 0.001
53
+
54
+ return {
55
+ "values": row_values,
56
+ "expected_values": expected_value,
57
+ "mask_shapes": fm.mask_shapes,
58
+ "main_effects": None,
59
+ "clustering": row_clustering,
60
+ "error_std": None,
61
+ "output_names": self.model.output_names if hasattr(self.model, "output_names") else None
62
+ }
63
+
64
+ # def __call__(self, X):
65
+ # start_time = time.time()
66
+ # if self.constant:
67
+ # if self.constant_attributions is None:
68
+ # self.constant_attributions = np.random.randn(X.shape[1])
69
+ # return Explanation(np.tile(self.constant_attributions, (X.shape[0],1)), X, compute_time=time.time() - start_time)
70
+ # else:
71
+ # return Explanation(np.random.randn(*X.shape), X, compute_time=time.time() - start_time)
72
+
73
+ # def attributions(self, X):
74
+ # if self.constant:
75
+ # if self.constant_attributions is None:
76
+ # self.constant_attributions = np.random.randn(X.shape[1])
77
+ # return np.tile(self.constant_attributions, (X.shape[0],1))
78
+ # else:
79
+ # return np.random.randn(*X.shape)