LLH commited on
Commit
0136ac6
·
1 Parent(s): c95b9af

2024/02/14/12:17

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .idea/.gitignore +0 -8
  2. .idea/EasyMachineLearningDemo.iml +0 -12
  3. .idea/inspectionProfiles/Project_Default.xml +0 -12
  4. .idea/inspectionProfiles/profiles_settings.xml +0 -6
  5. .idea/modules.xml +0 -8
  6. .idea/vcs.xml +0 -6
  7. analysis/shap_model.py +4 -3
  8. app.py +1 -1
  9. lib/__init__.py +0 -0
  10. lib/shap/__init__.py +0 -144
  11. lib/shap/_cext.cp310-win_amd64.pyd +0 -0
  12. lib/shap/_explanation.py +0 -901
  13. lib/shap/_serializable.py +0 -204
  14. lib/shap/_version.py +0 -16
  15. lib/shap/actions/__init__.py +0 -3
  16. lib/shap/actions/_action.py +0 -8
  17. lib/shap/actions/_optimizer.py +0 -92
  18. lib/shap/benchmark/__init__.py +0 -9
  19. lib/shap/benchmark/_compute.py +0 -9
  20. lib/shap/benchmark/_explanation_error.py +0 -181
  21. lib/shap/benchmark/_result.py +0 -34
  22. lib/shap/benchmark/_sequential.py +0 -332
  23. lib/shap/benchmark/experiments.py +0 -414
  24. lib/shap/benchmark/framework.py +0 -113
  25. lib/shap/benchmark/measures.py +0 -424
  26. lib/shap/benchmark/methods.py +0 -148
  27. lib/shap/benchmark/metrics.py +0 -824
  28. lib/shap/benchmark/models.py +0 -230
  29. lib/shap/benchmark/plots.py +0 -566
  30. lib/shap/cext/_cext.cc +0 -560
  31. lib/shap/cext/_cext_gpu.cc +0 -187
  32. lib/shap/cext/_cext_gpu.cu +0 -353
  33. lib/shap/cext/gpu_treeshap.h +0 -1535
  34. lib/shap/cext/tree_shap.h +0 -1460
  35. lib/shap/datasets.py +0 -309
  36. lib/shap/explainers/__init__.py +0 -38
  37. lib/shap/explainers/_additive.py +0 -187
  38. lib/shap/explainers/_deep/__init__.py +0 -125
  39. lib/shap/explainers/_deep/deep_pytorch.py +0 -386
  40. lib/shap/explainers/_deep/deep_tf.py +0 -763
  41. lib/shap/explainers/_deep/deep_utils.py +0 -23
  42. lib/shap/explainers/_exact.py +0 -366
  43. lib/shap/explainers/_explainer.py +0 -457
  44. lib/shap/explainers/_gpu_tree.py +0 -179
  45. lib/shap/explainers/_gradient.py +0 -592
  46. lib/shap/explainers/_kernel.py +0 -696
  47. lib/shap/explainers/_linear.py +0 -406
  48. lib/shap/explainers/_partition.py +0 -681
  49. lib/shap/explainers/_permutation.py +0 -217
  50. lib/shap/explainers/_sampling.py +0 -199
.idea/.gitignore DELETED
@@ -1,8 +0,0 @@
1
- # Default ignored files
2
- /shelf/
3
- /workspace.xml
4
- # Editor-based HTTP Client requests
5
- /httpRequests/
6
- # Datasource local storage ignored files
7
- /dataSources/
8
- /dataSources.local.xml
 
 
 
 
 
 
 
 
 
.idea/EasyMachineLearningDemo.iml DELETED
@@ -1,12 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <module type="PYTHON_MODULE" version="4">
3
- <component name="NewModuleRootManager">
4
- <content url="file://$MODULE_DIR$" />
5
- <orderEntry type="inheritedJdk" />
6
- <orderEntry type="sourceFolder" forTests="false" />
7
- </component>
8
- <component name="PyDocumentationSettings">
9
- <option name="format" value="PLAIN" />
10
- <option name="myDocStringFormat" value="Plain" />
11
- </component>
12
- </module>
 
 
 
 
 
 
 
 
 
 
 
 
 
.idea/inspectionProfiles/Project_Default.xml DELETED
@@ -1,12 +0,0 @@
1
- <component name="InspectionProjectProfileManager">
2
- <profile version="1.0">
3
- <option name="myName" value="Project Default" />
4
- <inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
- <option name="ignoredIdentifiers">
6
- <list>
7
- <option value="object.pop" />
8
- </list>
9
- </option>
10
- </inspection_tool>
11
- </profile>
12
- </component>
 
 
 
 
 
 
 
 
 
 
 
 
 
.idea/inspectionProfiles/profiles_settings.xml DELETED
@@ -1,6 +0,0 @@
1
- <component name="InspectionProjectProfileManager">
2
- <settings>
3
- <option name="USE_PROJECT_PROFILE" value="false" />
4
- <version value="1.0" />
5
- </settings>
6
- </component>
 
 
 
 
 
 
 
.idea/modules.xml DELETED
@@ -1,8 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <project version="4">
3
- <component name="ProjectModuleManager">
4
- <modules>
5
- <module fileurl="file://$PROJECT_DIR$/.idea/EasyMachineLearningDemo.iml" filepath="$PROJECT_DIR$/.idea/EasyMachineLearningDemo.iml" />
6
- </modules>
7
- </component>
8
- </project>
 
 
 
 
 
 
 
 
 
.idea/vcs.xml DELETED
@@ -1,6 +0,0 @@
1
- <?xml version="1.0" encoding="UTF-8"?>
2
- <project version="4">
3
- <component name="VcsDirectoryMappings">
4
- <mapping directory="" vcs="Git" />
5
- </component>
6
- </project>
 
 
 
 
 
 
 
analysis/shap_model.py CHANGED
@@ -1,16 +1,17 @@
1
  import matplotlib.pyplot as plt
2
 
3
- import lib.shap as shap
4
 
5
 
6
  def shap_calculate(model, x, feature_names):
7
  explainer = shap.Explainer(model.predict, x)
8
  shap_values = explainer(x)
9
 
10
- return shap.summary_plot(shap_values, x, feature_names=feature_names)
 
 
11
 
12
  # title = "shap"
13
- # cur_plt.savefig("./diagram/{}.png".format(title), dpi=300)
14
 
15
 
16
 
 
1
  import matplotlib.pyplot as plt
2
 
3
+ import shap
4
 
5
 
6
  def shap_calculate(model, x, feature_names):
7
  explainer = shap.Explainer(model.predict, x)
8
  shap_values = explainer(x)
9
 
10
+ shap.summary_plot(shap_values, x, feature_names=feature_names, show=False)
11
+
12
+ return plt
13
 
14
  # title = "shap"
 
15
 
16
 
17
 
app.py CHANGED
@@ -69,7 +69,7 @@ class Container:
69
 
70
 
71
  class FilePath:
72
- base = "../diagram/{}.png"
73
  shap_beeswarm_plot = "shap_beeswarm_plot"
74
 
75
 
 
69
 
70
 
71
  class FilePath:
72
+ base = "./diagram/{}.png"
73
  shap_beeswarm_plot = "shap_beeswarm_plot"
74
 
75
 
lib/__init__.py DELETED
File without changes
lib/shap/__init__.py DELETED
@@ -1,144 +0,0 @@
1
- from ._explanation import Cohorts, Explanation
2
-
3
- # explainers
4
- from .explainers import other
5
- from .explainers._additive import AdditiveExplainer
6
- from .explainers._deep import DeepExplainer
7
- from .explainers._exact import ExactExplainer
8
- from .explainers._explainer import Explainer
9
- from .explainers._gpu_tree import GPUTreeExplainer
10
- from .explainers._gradient import GradientExplainer
11
- from .explainers._kernel import KernelExplainer
12
- from .explainers._linear import LinearExplainer
13
- from .explainers._partition import PartitionExplainer
14
- from .explainers._permutation import PermutationExplainer
15
- from .explainers._sampling import SamplingExplainer
16
- from .explainers._tree import TreeExplainer
17
-
18
- try:
19
- # Version from setuptools-scm
20
- from ._version import version as __version__
21
- except ImportError:
22
- # Expected when running locally without build
23
- __version__ = "0.0.0-not-built"
24
-
25
- _no_matplotlib_warning = "matplotlib is not installed so plotting is not available! Run `pip install matplotlib` " \
26
- "to fix this."
27
-
28
-
29
- # plotting (only loaded if matplotlib is present)
30
- def unsupported(*args, **kwargs):
31
- raise ImportError(_no_matplotlib_warning)
32
-
33
-
34
- class UnsupportedModule:
35
- def __getattribute__(self, item):
36
- raise ImportError(_no_matplotlib_warning)
37
-
38
-
39
- try:
40
- import matplotlib # noqa: F401
41
- have_matplotlib = True
42
- except ImportError:
43
- have_matplotlib = False
44
- if have_matplotlib:
45
- from . import plots
46
- from .plots._bar import bar_legacy as bar_plot
47
- from .plots._beeswarm import summary_legacy as summary_plot
48
- from .plots._decision import decision as decision_plot
49
- from .plots._decision import multioutput_decision as multioutput_decision_plot
50
- from .plots._embedding import embedding as embedding_plot
51
- from .plots._force import force as force_plot
52
- from .plots._force import getjs, initjs, save_html
53
- from .plots._group_difference import group_difference as group_difference_plot
54
- from .plots._heatmap import heatmap as heatmap_plot
55
- from .plots._image import image as image_plot
56
- from .plots._monitoring import monitoring as monitoring_plot
57
- from .plots._partial_dependence import partial_dependence as partial_dependence_plot
58
- from .plots._scatter import dependence_legacy as dependence_plot
59
- from .plots._text import text as text_plot
60
- from .plots._violin import violin as violin_plot
61
- from .plots._waterfall import waterfall as waterfall_plot
62
- else:
63
- bar_plot = unsupported
64
- summary_plot = unsupported
65
- decision_plot = unsupported
66
- multioutput_decision_plot = unsupported
67
- embedding_plot = unsupported
68
- force_plot = unsupported
69
- getjs = unsupported
70
- initjs = unsupported
71
- save_html = unsupported
72
- group_difference_plot = unsupported
73
- heatmap_plot = unsupported
74
- image_plot = unsupported
75
- monitoring_plot = unsupported
76
- partial_dependence_plot = unsupported
77
- dependence_plot = unsupported
78
- text_plot = unsupported
79
- violin_plot = unsupported
80
- waterfall_plot = unsupported
81
- # If matplotlib is available, then the plots submodule will be directly available.
82
- # If not, we need to define something that will issue a meaningful warning message
83
- # (rather than ModuleNotFound).
84
- plots = UnsupportedModule()
85
-
86
-
87
- # other stuff :)
88
- from . import datasets, links, utils # noqa: E402
89
- from .actions._optimizer import ActionOptimizer # noqa: E402
90
- from .utils import approximate_interactions, sample # noqa: E402
91
-
92
- #from . import benchmark
93
- from .utils._legacy import kmeans # noqa: E402
94
-
95
- # Use __all__ to let type checkers know what is part of the public API.
96
- __all__ = [
97
- "Cohorts",
98
- "Explanation",
99
-
100
- # Explainers
101
- "other",
102
- "AdditiveExplainer",
103
- "DeepExplainer",
104
- "ExactExplainer",
105
- "Explainer",
106
- "GPUTreeExplainer",
107
- "GradientExplainer",
108
- "KernelExplainer",
109
- "LinearExplainer",
110
- "PartitionExplainer",
111
- "PermutationExplainer",
112
- "SamplingExplainer",
113
- "TreeExplainer",
114
-
115
- # Plots
116
- "plots",
117
- "bar_plot",
118
- "summary_plot",
119
- "decision_plot",
120
- "multioutput_decision_plot",
121
- "embedding_plot",
122
- "force_plot",
123
- "getjs",
124
- "initjs",
125
- "save_html",
126
- "group_difference_plot",
127
- "heatmap_plot",
128
- "image_plot",
129
- "monitoring_plot",
130
- "partial_dependence_plot",
131
- "dependence_plot",
132
- "text_plot",
133
- "violin_plot",
134
- "waterfall_plot",
135
-
136
- # Other stuff
137
- "datasets",
138
- "links",
139
- "utils",
140
- "ActionOptimizer",
141
- "approximate_interactions",
142
- "sample",
143
- "kmeans",
144
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/_cext.cp310-win_amd64.pyd DELETED
Binary file (44 kB)
 
lib/shap/_explanation.py DELETED
@@ -1,901 +0,0 @@
1
-
2
- import copy
3
- import operator
4
-
5
- import numpy as np
6
- import pandas as pd
7
- import scipy.cluster
8
- import scipy.sparse
9
- import scipy.spatial
10
- import sklearn
11
- from slicer import Alias, Obj, Slicer
12
-
13
- from .utils._exceptions import DimensionError
14
- from .utils._general import OpChain
15
-
16
- op_chain_root = OpChain("shap.Explanation")
17
- class MetaExplanation(type):
18
- """ This metaclass exposes the Explanation object's methods for creating template op chains.
19
- """
20
-
21
- def __getitem__(cls, item):
22
- return op_chain_root.__getitem__(item)
23
-
24
- @property
25
- def abs(cls):
26
- """ Element-wise absolute value op.
27
- """
28
- return op_chain_root.abs
29
-
30
- @property
31
- def identity(cls):
32
- """ A no-op.
33
- """
34
- return op_chain_root.identity
35
-
36
- @property
37
- def argsort(cls):
38
- """ Numpy style argsort.
39
- """
40
- return op_chain_root.argsort
41
-
42
- @property
43
- def sum(cls):
44
- """ Numpy style sum.
45
- """
46
- return op_chain_root.sum
47
-
48
- @property
49
- def max(cls):
50
- """ Numpy style max.
51
- """
52
- return op_chain_root.max
53
-
54
- @property
55
- def min(cls):
56
- """ Numpy style min.
57
- """
58
- return op_chain_root.min
59
-
60
- @property
61
- def mean(cls):
62
- """ Numpy style mean.
63
- """
64
- return op_chain_root.mean
65
-
66
- @property
67
- def sample(cls):
68
- """ Numpy style sample.
69
- """
70
- return op_chain_root.sample
71
-
72
- @property
73
- def hclust(cls):
74
- """ Hierarchical clustering op.
75
- """
76
- return op_chain_root.hclust
77
-
78
-
79
- class Explanation(metaclass=MetaExplanation):
80
- """ A sliceable set of parallel arrays representing a SHAP explanation.
81
- """
82
- def __init__(
83
- self,
84
- values,
85
- base_values=None,
86
- data=None,
87
- display_data=None,
88
- instance_names=None,
89
- feature_names=None,
90
- output_names=None,
91
- output_indexes=None,
92
- lower_bounds=None,
93
- upper_bounds=None,
94
- error_std=None,
95
- main_effects=None,
96
- hierarchical_values=None,
97
- clustering=None,
98
- compute_time=None
99
- ):
100
- self.op_history = []
101
-
102
- self.compute_time = compute_time
103
-
104
- # cloning. TODOsomeday: better cloning :)
105
- if issubclass(type(values), Explanation):
106
- e = values
107
- values = e.values
108
- base_values = e.base_values
109
- data = e.data
110
-
111
- self.output_dims = compute_output_dims(values, base_values, data, output_names)
112
- values_shape = _compute_shape(values)
113
-
114
- if output_names is None and len(self.output_dims) == 1:
115
- output_names = [f"Output {i}" for i in range(values_shape[self.output_dims[0]])]
116
-
117
- if len(_compute_shape(feature_names)) == 1: # TODO: should always be an alias once slicer supports per-row aliases
118
- if len(values_shape) >= 2 and len(feature_names) == values_shape[1]:
119
- feature_names = Alias(list(feature_names), 1)
120
- elif len(values_shape) >= 1 and len(feature_names) == values_shape[0]:
121
- feature_names = Alias(list(feature_names), 0)
122
-
123
- if len(_compute_shape(output_names)) == 1: # TODO: should always be an alias once slicer supports per-row aliases
124
- output_names = Alias(list(output_names), self.output_dims[0])
125
- # if len(values_shape) >= 1 and len(output_names) == values_shape[0]:
126
- # output_names = Alias(list(output_names), 0)
127
- # elif len(values_shape) >= 2 and len(output_names) == values_shape[1]:
128
- # output_names = Alias(list(output_names), 1)
129
-
130
- if output_names is not None and not isinstance(output_names, Alias):
131
- output_names_order = len(_compute_shape(output_names))
132
- if output_names_order == 0:
133
- pass
134
- elif output_names_order == 1:
135
- output_names = Obj(output_names, self.output_dims)
136
- elif output_names_order == 2:
137
- output_names = Obj(output_names, [0] + list(self.output_dims))
138
- else:
139
- raise ValueError("shap.Explanation does not yet support output_names of order greater than 3!")
140
-
141
- if not hasattr(base_values, "__len__") or len(base_values) == 0:
142
- pass
143
- elif len(_compute_shape(base_values)) == len(self.output_dims):
144
- base_values = Obj(base_values, list(self.output_dims))
145
- else:
146
- base_values = Obj(base_values, [0] + list(self.output_dims))
147
-
148
- self._s = Slicer(
149
- values=values,
150
- base_values=base_values,
151
- data=list_wrap(data),
152
- display_data=list_wrap(display_data),
153
- instance_names=None if instance_names is None else Alias(instance_names, 0),
154
- feature_names=feature_names,
155
- output_names=output_names,
156
- output_indexes=None if output_indexes is None else (self.output_dims, output_indexes),
157
- lower_bounds=list_wrap(lower_bounds),
158
- upper_bounds=list_wrap(upper_bounds),
159
- error_std=list_wrap(error_std),
160
- main_effects=list_wrap(main_effects),
161
- hierarchical_values=list_wrap(hierarchical_values),
162
- clustering=None if clustering is None else Obj(clustering, [0])
163
- )
164
-
165
- @property
166
- def shape(self):
167
- """ Compute the shape over potentially complex data nesting.
168
- """
169
- return _compute_shape(self._s.values)
170
-
171
- @property
172
- def values(self):
173
- """ Pass-through from the underlying slicer object.
174
- """
175
- return self._s.values
176
- @values.setter
177
- def values(self, new_values):
178
- self._s.values = new_values
179
-
180
- @property
181
- def base_values(self):
182
- """ Pass-through from the underlying slicer object.
183
- """
184
- return self._s.base_values
185
- @base_values.setter
186
- def base_values(self, new_base_values):
187
- self._s.base_values = new_base_values
188
-
189
- @property
190
- def data(self):
191
- """ Pass-through from the underlying slicer object.
192
- """
193
- return self._s.data
194
- @data.setter
195
- def data(self, new_data):
196
- self._s.data = new_data
197
-
198
- @property
199
- def display_data(self):
200
- """ Pass-through from the underlying slicer object.
201
- """
202
- return self._s.display_data
203
- @display_data.setter
204
- def display_data(self, new_display_data):
205
- if issubclass(type(new_display_data), pd.DataFrame):
206
- new_display_data = new_display_data.values
207
- self._s.display_data = new_display_data
208
-
209
- @property
210
- def instance_names(self):
211
- """ Pass-through from the underlying slicer object.
212
- """
213
- return self._s.instance_names
214
-
215
- @property
216
- def output_names(self):
217
- """ Pass-through from the underlying slicer object.
218
- """
219
- return self._s.output_names
220
- @output_names.setter
221
- def output_names(self, new_output_names):
222
- self._s.output_names = new_output_names
223
-
224
- @property
225
- def output_indexes(self):
226
- """ Pass-through from the underlying slicer object.
227
- """
228
- return self._s.output_indexes
229
-
230
- @property
231
- def feature_names(self):
232
- """ Pass-through from the underlying slicer object.
233
- """
234
- return self._s.feature_names
235
- @feature_names.setter
236
- def feature_names(self, new_feature_names):
237
- self._s.feature_names = new_feature_names
238
-
239
- @property
240
- def lower_bounds(self):
241
- """ Pass-through from the underlying slicer object.
242
- """
243
- return self._s.lower_bounds
244
-
245
- @property
246
- def upper_bounds(self):
247
- """ Pass-through from the underlying slicer object.
248
- """
249
- return self._s.upper_bounds
250
-
251
- @property
252
- def error_std(self):
253
- """ Pass-through from the underlying slicer object.
254
- """
255
- return self._s.error_std
256
-
257
- @property
258
- def main_effects(self):
259
- """ Pass-through from the underlying slicer object.
260
- """
261
- return self._s.main_effects
262
- @main_effects.setter
263
- def main_effects(self, new_main_effects):
264
- self._s.main_effects = new_main_effects
265
-
266
- @property
267
- def hierarchical_values(self):
268
- """ Pass-through from the underlying slicer object.
269
- """
270
- return self._s.hierarchical_values
271
- @hierarchical_values.setter
272
- def hierarchical_values(self, new_hierarchical_values):
273
- self._s.hierarchical_values = new_hierarchical_values
274
-
275
- @property
276
- def clustering(self):
277
- """ Pass-through from the underlying slicer object.
278
- """
279
- return self._s.clustering
280
- @clustering.setter
281
- def clustering(self, new_clustering):
282
- self._s.clustering = new_clustering
283
-
284
- def cohorts(self, cohorts):
285
- """ Split this explanation into several cohorts.
286
-
287
- Parameters
288
- ----------
289
- cohorts : int or array
290
- If this is an integer then we auto build that many cohorts using a decision tree. If this is
291
- an array then we treat that as an array of cohort names/ids for each instance.
292
- """
293
-
294
- if isinstance(cohorts, int):
295
- return _auto_cohorts(self, max_cohorts=cohorts)
296
- if isinstance(cohorts, (list, tuple, np.ndarray)):
297
- cohorts = np.array(cohorts)
298
- return Cohorts(**{name: self[cohorts == name] for name in np.unique(cohorts)})
299
- raise TypeError("The given set of cohort indicators is not recognized! Please give an array or int.")
300
-
301
- def __repr__(self):
302
- """ Display some basic printable info, but not everything.
303
- """
304
- out = ".values =\n"+self.values.__repr__()
305
- if self.base_values is not None:
306
- out += "\n\n.base_values =\n"+self.base_values.__repr__()
307
- if self.data is not None:
308
- out += "\n\n.data =\n"+self.data.__repr__()
309
- return out
310
-
311
- def __getitem__(self, item):
312
- """ This adds support for OpChain indexing.
313
- """
314
- new_self = None
315
- if not isinstance(item, tuple):
316
- item = (item,)
317
-
318
- # convert any OpChains or magic strings
319
- pos = -1
320
- for t in item:
321
- pos += 1
322
-
323
- # skip over Ellipsis
324
- if t is Ellipsis:
325
- pos += len(self.shape) - len(item)
326
- continue
327
-
328
- orig_t = t
329
- if issubclass(type(t), OpChain):
330
- t = t.apply(self)
331
- if issubclass(type(t), (np.int64, np.int32)): # because slicer does not like numpy indexes
332
- t = int(t)
333
- elif issubclass(type(t), np.ndarray):
334
- t = [int(v) for v in t] # slicer wants lists not numpy arrays for indexing
335
- elif issubclass(type(t), Explanation):
336
- t = t.values
337
- elif isinstance(t, str):
338
-
339
- # work around for 2D output_names since they are not yet slicer supported
340
- output_names_dims = []
341
- if "output_names" in self._s._objects:
342
- output_names_dims = self._s._objects["output_names"].dim
343
- elif "output_names" in self._s._aliases:
344
- output_names_dims = self._s._aliases["output_names"].dim
345
- if pos != 0 and pos in output_names_dims:
346
- if len(output_names_dims) == 1:
347
- t = np.argwhere(np.array(self.output_names) == t)[0][0]
348
- elif len(output_names_dims) == 2:
349
- new_values = []
350
- new_base_values = []
351
- new_data = []
352
- new_self = copy.deepcopy(self)
353
- for i, v in enumerate(self.values):
354
- for j, s in enumerate(self.output_names[i]):
355
- if s == t:
356
- new_values.append(np.array(v[:,j]))
357
- new_data.append(np.array(self.data[i]))
358
- new_base_values.append(self.base_values[i][j])
359
-
360
- new_self = Explanation(
361
- np.array(new_values),
362
- np.array(new_base_values),
363
- np.array(new_data),
364
- self.display_data,
365
- self.instance_names,
366
- np.array(new_data),
367
- t, # output_names
368
- self.output_indexes,
369
- self.lower_bounds,
370
- self.upper_bounds,
371
- self.error_std,
372
- self.main_effects,
373
- self.hierarchical_values,
374
- self.clustering
375
- )
376
- new_self.op_history = copy.copy(self.op_history)
377
- # new_self = copy.deepcopy(self)
378
- # new_self.values = np.array(new_values)
379
- # new_self.base_values = np.array(new_base_values)
380
- # new_self.data = np.array(new_data)
381
- # new_self.output_names = t
382
- # new_self.feature_names = np.array(new_data)
383
- # new_self.clustering = None
384
-
385
- # work around for 2D feature_names since they are not yet slicer supported
386
- feature_names_dims = []
387
- if "feature_names" in self._s._objects:
388
- feature_names_dims = self._s._objects["feature_names"].dim
389
- if pos != 0 and pos in feature_names_dims and len(feature_names_dims) == 2:
390
- new_values = []
391
- new_data = []
392
- for i, val_i in enumerate(self.values):
393
- for s,v,d in zip(self.feature_names[i], val_i, self.data[i]):
394
- if s == t:
395
- new_values.append(v)
396
- new_data.append(d)
397
- new_self = copy.deepcopy(self)
398
- new_self.values = new_values
399
- new_self.data = new_data
400
- new_self.feature_names = t
401
- new_self.clustering = None
402
- # return new_self
403
-
404
- if issubclass(type(t), (np.int8, np.int16, np.int32, np.int64)):
405
- t = int(t)
406
-
407
- if t is not orig_t:
408
- tmp = list(item)
409
- tmp[pos] = t
410
- item = tuple(tmp)
411
-
412
- # call slicer for the real work
413
- item = tuple(v for v in item) # SML I cut out: `if not isinstance(v, str)`
414
- if len(item) == 0:
415
- return new_self
416
- if new_self is None:
417
- new_self = copy.copy(self)
418
- new_self._s = new_self._s.__getitem__(item)
419
- new_self.op_history.append({
420
- "name": "__getitem__",
421
- "args": (item,),
422
- "prev_shape": self.shape
423
- })
424
-
425
- return new_self
426
-
427
- def __len__(self):
428
- return self.shape[0]
429
-
430
- def __copy__(self):
431
- new_exp = Explanation(
432
- self.values,
433
- self.base_values,
434
- self.data,
435
- self.display_data,
436
- self.instance_names,
437
- self.feature_names,
438
- self.output_names,
439
- self.output_indexes,
440
- self.lower_bounds,
441
- self.upper_bounds,
442
- self.error_std,
443
- self.main_effects,
444
- self.hierarchical_values,
445
- self.clustering
446
- )
447
- new_exp.op_history = copy.copy(self.op_history)
448
- return new_exp
449
-
450
- def _apply_binary_operator(self, other, binary_op, op_name):
451
- new_exp = self.__copy__()
452
- new_exp.op_history = copy.copy(self.op_history)
453
- new_exp.op_history.append({
454
- "name": op_name,
455
- "args": (other,),
456
- "prev_shape": self.shape
457
- })
458
- if isinstance(other, Explanation):
459
- new_exp.values = binary_op(new_exp.values, other.values)
460
- if new_exp.data is not None:
461
- new_exp.data = binary_op(new_exp.data, other.data)
462
- if new_exp.base_values is not None:
463
- new_exp.base_values = binary_op(new_exp.base_values, other.base_values)
464
- else:
465
- new_exp.values = binary_op(new_exp.values, other)
466
- if new_exp.data is not None:
467
- new_exp.data = binary_op(new_exp.data, other)
468
- if new_exp.base_values is not None:
469
- new_exp.base_values = binary_op(new_exp.base_values, other)
470
- return new_exp
471
-
472
- def __add__(self, other):
473
- return self._apply_binary_operator(other, operator.add, "__add__")
474
-
475
- def __radd__(self, other):
476
- return self._apply_binary_operator(other, operator.add, "__add__")
477
-
478
- def __sub__(self, other):
479
- return self._apply_binary_operator(other, operator.sub, "__sub__")
480
-
481
- def __rsub__(self, other):
482
- return self._apply_binary_operator(other, operator.sub, "__sub__")
483
-
484
- def __mul__(self, other):
485
- return self._apply_binary_operator(other, operator.mul, "__mul__")
486
-
487
- def __rmul__(self, other):
488
- return self._apply_binary_operator(other, operator.mul, "__mul__")
489
-
490
- def __truediv__(self, other):
491
- return self._apply_binary_operator(other, operator.truediv, "__truediv__")
492
-
493
- # @property
494
- # def abs(self):
495
- # """ Element-size absolute value operator.
496
- # """
497
- # new_self = copy.copy(self)
498
- # new_self.values = np.abs(new_self.values)
499
- # new_self.op_history.append({
500
- # "name": "abs",
501
- # "prev_shape": self.shape
502
- # })
503
- # return new_self
504
-
505
- def _numpy_func(self, fname, **kwargs):
506
- """ Apply a numpy-style function to this Explanation.
507
- """
508
- new_self = copy.copy(self)
509
- axis = kwargs.get("axis", None)
510
-
511
- # collapse the slicer to right shape
512
- if axis == 0:
513
- new_self = new_self[0]
514
- elif axis == 1:
515
- new_self = new_self[1]
516
- elif axis == 2:
517
- new_self = new_self[2]
518
- if axis in [0,1,2]:
519
- new_self.op_history = new_self.op_history[:-1] # pop off the slicing operation we just used
520
-
521
- if self.feature_names is not None and not is_1d(self.feature_names) and axis == 0:
522
- new_values = self._flatten_feature_names()
523
- new_self.feature_names = np.array(list(new_values.keys()))
524
- new_self.values = np.array([getattr(np, fname)(v,0) for v in new_values.values()])
525
- new_self.clustering = None
526
- else:
527
- new_self.values = getattr(np, fname)(np.array(self.values), **kwargs)
528
- if new_self.data is not None:
529
- try:
530
- new_self.data = getattr(np, fname)(np.array(self.data), **kwargs)
531
- except Exception:
532
- new_self.data = None
533
- if new_self.base_values is not None and issubclass(type(axis), int) and len(self.base_values.shape) > axis:
534
- new_self.base_values = getattr(np, fname)(self.base_values, **kwargs)
535
- elif issubclass(type(axis), int):
536
- new_self.base_values = None
537
-
538
- if axis == 0 and self.clustering is not None and len(self.clustering.shape) == 3:
539
- if self.clustering.std(0).sum() < 1e-8:
540
- new_self.clustering = self.clustering[0]
541
- else:
542
- new_self.clustering = None
543
-
544
- new_self.op_history.append({
545
- "name": fname,
546
- "kwargs": kwargs,
547
- "prev_shape": self.shape,
548
- "collapsed_instances": axis == 0
549
- })
550
-
551
- return new_self
552
-
553
- def mean(self, axis):
554
- """ Numpy-style mean function.
555
- """
556
- return self._numpy_func("mean", axis=axis)
557
-
558
- def max(self, axis):
559
- """ Numpy-style mean function.
560
- """
561
- return self._numpy_func("max", axis=axis)
562
-
563
- def min(self, axis):
564
- """ Numpy-style mean function.
565
- """
566
- return self._numpy_func("min", axis=axis)
567
-
568
- def sum(self, axis=None, grouping=None):
569
- """ Numpy-style mean function.
570
- """
571
- if grouping is None:
572
- return self._numpy_func("sum", axis=axis)
573
- elif axis == 1 or len(self.shape) == 1:
574
- return group_features(self, grouping)
575
- else:
576
- raise DimensionError("Only axis = 1 is supported for grouping right now...")
577
-
578
- def hstack(self, other):
579
- """ Stack two explanations column-wise.
580
- """
581
- assert self.shape[0] == other.shape[0], "Can't hstack explanations with different numbers of rows!"
582
- assert np.max(np.abs(self.base_values - other.base_values)) < 1e-6, "Can't hstack explanations with different base values!"
583
-
584
- new_exp = Explanation(
585
- values=np.hstack([self.values, other.values]),
586
- base_values=self.base_values,
587
- data=self.data,
588
- display_data=self.display_data,
589
- instance_names=self.instance_names,
590
- feature_names=self.feature_names,
591
- output_names=self.output_names,
592
- output_indexes=self.output_indexes,
593
- lower_bounds=self.lower_bounds,
594
- upper_bounds=self.upper_bounds,
595
- error_std=self.error_std,
596
- main_effects=self.main_effects,
597
- hierarchical_values=self.hierarchical_values,
598
- clustering=self.clustering,
599
- )
600
- return new_exp
601
-
602
- # def reshape(self, *args):
603
- # return self._numpy_func("reshape", newshape=args)
604
-
605
- @property
606
- def abs(self):
607
- return self._numpy_func("abs")
608
-
609
- @property
610
- def identity(self):
611
- return self
612
-
613
- @property
614
- def argsort(self):
615
- return self._numpy_func("argsort")
616
-
617
- @property
618
- def flip(self):
619
- return self._numpy_func("flip")
620
-
621
-
622
- def hclust(self, metric="sqeuclidean", axis=0):
623
- """ Computes an optimal leaf ordering sort order using hclustering.
624
-
625
- hclust(metric="sqeuclidean")
626
-
627
- Parameters
628
- ----------
629
- metric : string
630
- A metric supported by scipy clustering.
631
-
632
- axis : int
633
- The axis to cluster along.
634
- """
635
- values = self.values
636
-
637
- if len(values.shape) != 2:
638
- raise DimensionError("The hclust order only supports 2D arrays right now!")
639
-
640
- if axis == 1:
641
- values = values.T
642
-
643
- # compute a hierarchical clustering and return the optimal leaf ordering
644
- D = scipy.spatial.distance.pdist(values, metric)
645
- cluster_matrix = scipy.cluster.hierarchy.complete(D)
646
- inds = scipy.cluster.hierarchy.leaves_list(scipy.cluster.hierarchy.optimal_leaf_ordering(cluster_matrix, D))
647
- return inds
648
-
649
- def sample(self, max_samples, replace=False, random_state=0):
650
- """ Randomly samples the instances (rows) of the Explanation object.
651
-
652
- Parameters
653
- ----------
654
- max_samples : int
655
- The number of rows to sample. Note that if replace=False then less than
656
- fewer than max_samples will be drawn if explanation.shape[0] < max_samples.
657
-
658
- replace : bool
659
- Sample with or without replacement.
660
- """
661
- prev_seed = np.random.seed(random_state)
662
- inds = np.random.choice(self.shape[0], min(max_samples, self.shape[0]), replace=replace)
663
- np.random.seed(prev_seed)
664
- return self[list(inds)]
665
-
666
- def _flatten_feature_names(self):
667
- new_values = {}
668
- for i in range(len(self.values)):
669
- for s,v in zip(self.feature_names[i], self.values[i]):
670
- if s not in new_values:
671
- new_values[s] = []
672
- new_values[s].append(v)
673
- return new_values
674
-
675
- def _use_data_as_feature_names(self):
676
- new_values = {}
677
- for i in range(len(self.values)):
678
- for s,v in zip(self.data[i], self.values[i]):
679
- if s not in new_values:
680
- new_values[s] = []
681
- new_values[s].append(v)
682
- return new_values
683
-
684
- def percentile(self, q, axis=None):
685
- new_self = copy.deepcopy(self)
686
- if self.feature_names is not None and not is_1d(self.feature_names) and axis == 0:
687
- new_values = self._flatten_feature_names()
688
- new_self.feature_names = np.array(list(new_values.keys()))
689
- new_self.values = np.array([np.percentile(v, q) for v in new_values.values()])
690
- new_self.clustering = None
691
- else:
692
- new_self.values = np.percentile(new_self.values, q, axis)
693
- new_self.data = np.percentile(new_self.data, q, axis)
694
- #new_self.data = None
695
- new_self.op_history.append({
696
- "name": "percentile",
697
- "args": (axis,),
698
- "prev_shape": self.shape,
699
- "collapsed_instances": axis == 0
700
- })
701
- return new_self
702
-
703
- def group_features(shap_values, feature_map):
704
- # TODOsomeday: support and deal with clusterings
705
- reverse_map = {}
706
- for name in feature_map:
707
- reverse_map[feature_map[name]] = reverse_map.get(feature_map[name], []) + [name]
708
-
709
- curr_names = shap_values.feature_names
710
- sv_new = copy.deepcopy(shap_values)
711
- found = {}
712
- i = 0
713
- rank1 = len(shap_values.shape) == 1
714
- for name in curr_names:
715
- new_name = feature_map.get(name, name)
716
- if new_name in found:
717
- continue
718
- found[new_name] = True
719
-
720
- new_name = feature_map.get(name, name)
721
- cols_to_sum = reverse_map.get(new_name, [new_name])
722
- old_inds = [curr_names.index(v) for v in cols_to_sum]
723
-
724
- if rank1:
725
- sv_new.values[i] = shap_values.values[old_inds].sum()
726
- sv_new.data[i] = shap_values.data[old_inds].sum()
727
- else:
728
- sv_new.values[:,i] = shap_values.values[:,old_inds].sum(1)
729
- sv_new.data[:,i] = shap_values.data[:,old_inds].sum(1)
730
- sv_new.feature_names[i] = new_name
731
- i += 1
732
-
733
- return Explanation(
734
- sv_new.values[:i] if rank1 else sv_new.values[:,:i],
735
- base_values = sv_new.base_values,
736
- data = sv_new.data[:i] if rank1 else sv_new.data[:,:i],
737
- display_data = None if sv_new.display_data is None else (sv_new.display_data[:,:i] if rank1 else sv_new.display_data[:,:i]),
738
- instance_names = None,
739
- feature_names = None if sv_new.feature_names is None else sv_new.feature_names[:i],
740
- output_names = None,
741
- output_indexes = None,
742
- lower_bounds = None,
743
- upper_bounds = None,
744
- error_std = None,
745
- main_effects = None,
746
- hierarchical_values = None,
747
- clustering = None
748
- )
749
-
750
- def compute_output_dims(values, base_values, data, output_names):
751
- """ Uses the passed data to infer which dimensions correspond to the model's output.
752
- """
753
- values_shape = _compute_shape(values)
754
-
755
- # input shape matches the data shape
756
- if data is not None:
757
- data_shape = _compute_shape(data)
758
-
759
- # if we are not given any data we assume it would be the same shape as the given values
760
- else:
761
- data_shape = values_shape
762
-
763
- # output shape is known from the base values or output names
764
- if output_names is not None:
765
- output_shape = _compute_shape(output_names)
766
-
767
- # if our output_names are per sample then we need to drop the sample dimension here
768
- if values_shape[-len(output_shape):] != output_shape and \
769
- values_shape[-len(output_shape)+1:] == output_shape[1:] and values_shape[0] == output_shape[0]:
770
- output_shape = output_shape[1:]
771
-
772
- elif base_values is not None:
773
- output_shape = _compute_shape(base_values)[1:]
774
- else:
775
- output_shape = tuple()
776
-
777
- interaction_order = len(values_shape) - len(data_shape) - len(output_shape)
778
- output_dims = range(len(data_shape) + interaction_order, len(values_shape))
779
- return tuple(output_dims)
780
-
781
- def is_1d(val):
782
- return not (isinstance(val[0], list) or isinstance(val[0], np.ndarray))
783
-
784
- class Op:
785
- pass
786
-
787
- class Percentile(Op):
788
- def __init__(self, percentile):
789
- self.percentile = percentile
790
-
791
- def add_repr(self, s, verbose=False):
792
- return "percentile("+s+", "+str(self.percentile)+")"
793
-
794
- def _first_item(x):
795
- for item in x:
796
- return item
797
- return None
798
-
799
- def _compute_shape(x):
800
- if not hasattr(x, "__len__") or isinstance(x, str):
801
- return tuple()
802
- elif not scipy.sparse.issparse(x) and len(x) > 0 and isinstance(_first_item(x), str):
803
- return (None,)
804
- else:
805
- if isinstance(x, dict):
806
- return (len(x),) + _compute_shape(x[next(iter(x))])
807
-
808
- # 2D arrays we just take their shape as-is
809
- if len(getattr(x, "shape", tuple())) > 1:
810
- return x.shape
811
-
812
- # 1D arrays we need to look inside
813
- if len(x) == 0:
814
- return (0,)
815
- elif len(x) == 1:
816
- return (1,) + _compute_shape(_first_item(x))
817
- else:
818
- first_shape = _compute_shape(_first_item(x))
819
- if first_shape == tuple():
820
- return (len(x),)
821
- else: # we have an array of arrays...
822
- matches = np.ones(len(first_shape), dtype=bool)
823
- for i in range(1, len(x)):
824
- shape = _compute_shape(x[i])
825
- assert len(shape) == len(first_shape), "Arrays in Explanation objects must have consistent inner dimensions!"
826
- for j in range(0, len(shape)):
827
- matches[j] &= shape[j] == first_shape[j]
828
- return (len(x),) + tuple(first_shape[j] if match else None for j, match in enumerate(matches))
829
-
830
- class Cohorts:
831
- def __init__(self, **kwargs):
832
- self.cohorts = kwargs
833
- for k in self.cohorts:
834
- assert isinstance(self.cohorts[k], Explanation), "All the arguments to a Cohorts set must be Explanation objects!"
835
-
836
- def __getitem__(self, item):
837
- new_cohorts = Cohorts()
838
- for k in self.cohorts:
839
- new_cohorts.cohorts[k] = self.cohorts[k].__getitem__(item)
840
- return new_cohorts
841
-
842
- def __getattr__(self, name):
843
- new_cohorts = Cohorts()
844
- for k in self.cohorts:
845
- new_cohorts.cohorts[k] = getattr(self.cohorts[k], name)
846
- return new_cohorts
847
-
848
- def __call__(self, *args, **kwargs):
849
- new_cohorts = Cohorts()
850
- for k in self.cohorts:
851
- new_cohorts.cohorts[k] = self.cohorts[k].__call__(*args, **kwargs)
852
- return new_cohorts
853
-
854
- def __repr__(self):
855
- return f"<shap._explanation.Cohorts object with {len(self.cohorts)} cohorts of sizes: {[v.shape for v in self.cohorts.values()]}>"
856
-
857
-
858
- def _auto_cohorts(shap_values, max_cohorts):
859
- """ This uses a DecisionTreeRegressor to build a group of cohorts with similar SHAP values.
860
- """
861
-
862
- # fit a decision tree that well separates the SHAP values
863
- m = sklearn.tree.DecisionTreeRegressor(max_leaf_nodes=max_cohorts)
864
- m.fit(shap_values.data, shap_values.values)
865
-
866
- # group instances by their decision paths
867
- paths = m.decision_path(shap_values.data).toarray()
868
- path_names = []
869
-
870
- # mark each instance with a path name
871
- for i in range(shap_values.shape[0]):
872
- name = ""
873
- for j in range(len(paths[i])):
874
- if paths[i,j] > 0:
875
- feature = m.tree_.feature[j]
876
- threshold = m.tree_.threshold[j]
877
- val = shap_values.data[i,feature]
878
- if feature >= 0:
879
- name += str(shap_values.feature_names[feature])
880
- if val < threshold:
881
- name += " < "
882
- else:
883
- name += " >= "
884
- name += str(threshold) + " & "
885
- path_names.append(name[:-3]) # the -3 strips off the last unneeded ' & '
886
- path_names = np.array(path_names)
887
-
888
- # split the instances into cohorts by their path names
889
- cohorts = {}
890
- for name in np.unique(path_names):
891
- cohorts[name] = shap_values[path_names == name]
892
-
893
- return Cohorts(**cohorts)
894
-
895
- def list_wrap(x):
896
- """ A helper to patch things since slicer doesn't handle arrays of arrays (it does handle lists of arrays)
897
- """
898
- if isinstance(x, np.ndarray) and len(x.shape) == 1 and isinstance(x[0], np.ndarray):
899
- return [v for v in x]
900
- else:
901
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/_serializable.py DELETED
@@ -1,204 +0,0 @@
1
-
2
- import inspect
3
- import logging
4
- import pickle
5
-
6
- import cloudpickle
7
- import numpy as np
8
-
9
- log = logging.getLogger('shap')
10
-
11
- class Serializable:
12
- """ This is the superclass of all serializable objects.
13
- """
14
-
15
- def save(self, out_file):
16
- """ Save the model to the given file stream.
17
- """
18
- pickle.dump(type(self), out_file)
19
-
20
- @classmethod
21
- def load(cls, in_file, instantiate=True):
22
- """ This is meant to be overridden by subclasses and called with super.
23
-
24
- We return constructor argument values when not being instantiated. Since there are no
25
- constructor arguments for the Serializable class we just return an empty dictionary.
26
- """
27
- if instantiate:
28
- return cls._instantiated_load(in_file)
29
- return {}
30
-
31
- @classmethod
32
- def _instantiated_load(cls, in_file, **kwargs):
33
- """ This is meant to be overridden by subclasses and called with super.
34
-
35
- We return constructor argument values (we have no values to load in this abstract class).
36
- """
37
- obj_type = pickle.load(in_file)
38
- if obj_type is None:
39
- return None
40
-
41
- if not inspect.isclass(obj_type) or (not issubclass(obj_type, cls) and (obj_type is not cls)):
42
- raise Exception(f"Invalid object type loaded from file. {obj_type} is not a subclass of {cls}.")
43
-
44
- # here we call the constructor with all the arguments we have loaded
45
- constructor_args = obj_type.load(in_file, instantiate=False, **kwargs)
46
- used_args = inspect.getfullargspec(obj_type.__init__)[0]
47
- return obj_type(**{k: constructor_args[k] for k in constructor_args if k in used_args})
48
-
49
-
50
- class Serializer:
51
- """ Save data items to an input stream.
52
- """
53
- def __init__(self, out_stream, block_name, version):
54
- self.out_stream = out_stream
55
- self.block_name = block_name
56
- self.block_version = version
57
- self.serializer_version = 0 # update this when the serializer changes
58
-
59
- def __enter__(self):
60
- log.debug("serializer_version = %d", self.serializer_version)
61
- pickle.dump(self.serializer_version, self.out_stream)
62
- log.debug("block_name = %s", self.block_name)
63
- pickle.dump(self.block_name, self.out_stream)
64
- log.debug("block_version = %d", self.block_version)
65
- pickle.dump(self.block_version, self.out_stream)
66
- return self
67
-
68
- def __exit__(self, exception_type, exception_value, traceback):
69
- log.debug("END_BLOCK___")
70
- pickle.dump("END_BLOCK___", self.out_stream)
71
-
72
- def save(self, name, value, encoder="auto"):
73
- """ Dump a data item to the current input stream.
74
- """
75
- log.debug("name = %s", name)
76
- pickle.dump(name, self.out_stream)
77
- if encoder is None or encoder is False:
78
- log.debug("encoder_name = %s", "no_encoder")
79
- pickle.dump("no_encoder", self.out_stream)
80
- elif callable(encoder):
81
- log.debug("encoder_name = %s", "custom_encoder")
82
- pickle.dump("custom_encoder", self.out_stream)
83
- encoder(value, self.out_stream)
84
- elif encoder == ".save" or (isinstance(value, Serializable) and encoder == "auto"):
85
- log.debug("encoder_name = %s", "serializable.save")
86
- pickle.dump("serializable.save", self.out_stream)
87
- if len(inspect.getfullargspec(value.save)[0]) == 3: # backward compat for MLflow, can remove 4/1/2021
88
- value.save(self.out_stream, value)
89
- else:
90
- value.save(self.out_stream)
91
- elif encoder == "auto":
92
- if isinstance(value, (int, float, str)):
93
- log.debug("encoder_name = %s", "pickle.dump")
94
- pickle.dump("pickle.dump", self.out_stream)
95
- pickle.dump(value, self.out_stream)
96
- else:
97
- log.debug("encoder_name = %s", "cloudpickle.dump")
98
- pickle.dump("cloudpickle.dump", self.out_stream)
99
- cloudpickle.dump(value, self.out_stream)
100
- else:
101
- raise ValueError(f"Unknown encoder type '{encoder}' given for serialization!")
102
- log.debug("value = %s", str(value))
103
-
104
- class Deserializer:
105
- """ Load data items from an input stream.
106
- """
107
-
108
- def __init__(self, in_stream, block_name, min_version, max_version):
109
- self.in_stream = in_stream
110
- self.block_name = block_name
111
- self.block_min_version = min_version
112
- self.block_max_version = max_version
113
-
114
- # update these when the serializer changes
115
- self.serializer_min_version = 0
116
- self.serializer_max_version = 0
117
-
118
- def __enter__(self):
119
-
120
- # confirm the serializer version
121
- serializer_version = pickle.load(self.in_stream)
122
- log.debug("serializer_version = %d", serializer_version)
123
- if serializer_version < self.serializer_min_version:
124
- raise ValueError(
125
- f"The file being loaded was saved with a serializer version of {serializer_version}, " + \
126
- f"but the current deserializer in SHAP requires at least version {self.serializer_min_version}."
127
- )
128
- if serializer_version > self.serializer_max_version:
129
- raise ValueError(
130
- f"The file being loaded was saved with a serializer version of {serializer_version}, " + \
131
- f"but the current deserializer in SHAP only support up to version {self.serializer_max_version}."
132
- )
133
-
134
- # confirm the block name
135
- block_name = pickle.load(self.in_stream)
136
- log.debug("block_name = %s", block_name)
137
- if block_name != self.block_name:
138
- raise ValueError(
139
- f"The next data block in the file being loaded was supposed to be {self.block_name}, " + \
140
- f"but the next block found was {block_name}."
141
- )
142
-
143
- # confirm the block version
144
- block_version = pickle.load(self.in_stream)
145
- log.debug("block_version = %d", block_version)
146
- if block_version < self.block_min_version:
147
- raise ValueError(
148
- f"The file being loaded was saved with a block version of {block_version}, " + \
149
- f"but the current deserializer in SHAP requires at least version {self.block_min_version}."
150
- )
151
- if block_version > self.block_max_version:
152
- raise ValueError(
153
- f"The file being loaded was saved with a block version of {block_version}, " + \
154
- f"but the current deserializer in SHAP only support up to version {self.block_max_version}."
155
- )
156
- return self
157
-
158
- def __exit__(self, exception_type, exception_value, traceback):
159
- # confirm the block end token
160
- for _ in range(100):
161
- end_token = pickle.load(self.in_stream)
162
- log.debug("end_token = %s", end_token)
163
- if end_token == "END_BLOCK___":
164
- return
165
- self._load_data_value()
166
- raise ValueError(
167
- f"The data block end token wsa not found for the block {self.block_name}."
168
- )
169
-
170
- def load(self, name, decoder=None):
171
- """ Load a data item from the current input stream.
172
- """
173
- # confirm the block name
174
- loaded_name = pickle.load(self.in_stream)
175
- log.debug("loaded_name = %s", loaded_name)
176
- print("loaded_name", loaded_name)
177
- if loaded_name != name:
178
- raise ValueError(
179
- f"The next data item in the file being loaded was supposed to be {name}, " + \
180
- f"but the next block found was {loaded_name}."
181
- ) # We should eventually add support for skipping over unused data items in old formats...
182
-
183
- value = self._load_data_value(decoder)
184
- log.debug("value = %s", str(value))
185
- return value
186
-
187
- def _load_data_value(self, decoder=None):
188
- encoder_name = pickle.load(self.in_stream)
189
- log.debug("encoder_name = %s", encoder_name)
190
- if encoder_name == "custom_encoder" or callable(decoder):
191
- assert callable(decoder), "You must provide a callable custom decoder for the data item {name}!"
192
- return decoder(self.in_stream)
193
- if encoder_name == "no_encoder":
194
- return None
195
- if encoder_name == "serializable.save":
196
- return Serializable.load(self.in_stream)
197
- if encoder_name == "numpy.save":
198
- return np.load(self.in_stream)
199
- if encoder_name == "pickle.dump":
200
- return pickle.load(self.in_stream)
201
- if encoder_name == "cloudpickle.dump":
202
- return cloudpickle.load(self.in_stream)
203
-
204
- raise ValueError(f"Unsupported encoder type found: {encoder_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/_version.py DELETED
@@ -1,16 +0,0 @@
1
- # file generated by setuptools_scm
2
- # don't change, don't track in version control
3
- TYPE_CHECKING = False
4
- if TYPE_CHECKING:
5
- from typing import Tuple, Union
6
- VERSION_TUPLE = Tuple[Union[int, str], ...]
7
- else:
8
- VERSION_TUPLE = object
9
-
10
- version: str
11
- __version__: str
12
- __version_tuple__: VERSION_TUPLE
13
- version_tuple: VERSION_TUPLE
14
-
15
- __version__ = version = '0.44.1'
16
- __version_tuple__ = version_tuple = (0, 44, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/actions/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from ._action import Action
2
-
3
- __all__ = ["Action"]
 
 
 
 
lib/shap/actions/_action.py DELETED
@@ -1,8 +0,0 @@
1
- class Action:
2
- """ Abstract action class.
3
- """
4
- def __lt__(self, other_action):
5
- return self.cost < other_action.cost
6
-
7
- def __repr__(self):
8
- return f"<Action '{self.__str__()}'>"
 
 
 
 
 
 
 
 
 
lib/shap/actions/_optimizer.py DELETED
@@ -1,92 +0,0 @@
1
- import copy
2
- import queue
3
- import warnings
4
-
5
- from ..utils._exceptions import ConvergenceError, InvalidAction
6
- from ._action import Action
7
-
8
-
9
- class ActionOptimizer:
10
- def __init__(self, model, actions):
11
- self.model = model
12
- warnings.warn(
13
- "Note that ActionOptimizer is still in an alpha state and is subjust to API changes."
14
- )
15
- # actions go into mutually exclusive groups
16
- self.action_groups = []
17
- for group in actions:
18
-
19
- if issubclass(type(group), Action):
20
- group._group_index = len(self.action_groups)
21
- group._grouped_index = 0
22
- self.action_groups.append([copy.copy(group)])
23
- elif issubclass(type(group), list):
24
- group = sorted([copy.copy(v) for v in group], key=lambda a: a.cost)
25
- for i, v in enumerate(group):
26
- v._group_index = len(self.action_groups)
27
- v._grouped_index = i
28
- self.action_groups.append(group)
29
- else:
30
- raise InvalidAction(
31
- "A passed action was not an Action or list of actions!"
32
- )
33
-
34
- def __call__(self, *args, max_evals=10000):
35
-
36
- # init our queue with all the least costly actions
37
- q = queue.PriorityQueue()
38
- for i in range(len(self.action_groups)):
39
- group = self.action_groups[i]
40
- q.put((group[0].cost, [group[0]]))
41
-
42
- nevals = 0
43
- while not q.empty():
44
-
45
- # see if we have exceeded our runtime budget
46
- nevals += 1
47
- if nevals > max_evals:
48
- raise ConvergenceError(
49
- f"Failed to find a solution with max_evals={max_evals}! Try reducing the number of actions or increasing max_evals."
50
- )
51
-
52
- # get the next cheapest set of actions we can do
53
- cost, actions = q.get()
54
-
55
- # apply those actions
56
- args_tmp = copy.deepcopy(args)
57
- for a in actions:
58
- a(*args_tmp)
59
-
60
- # if the model is now satisfied we are done!!
61
- v = self.model(*args_tmp)
62
- if v:
63
- return actions
64
-
65
- # if not then we add all possible follow-on actions to our queue
66
- else:
67
- for i in range(len(self.action_groups)):
68
- group = self.action_groups[i]
69
-
70
- # look to to see if we already have a action from this group, if so we need to
71
- # move to a more expensive action in the same group
72
- next_ind = 0
73
- prev_in_group = -1
74
- for j, a in enumerate(actions):
75
- if a._group_index == i:
76
- next_ind = max(next_ind, a._grouped_index + 1)
77
- prev_in_group = j
78
-
79
- # we are adding a new action type
80
- if prev_in_group == -1:
81
- new_actions = actions + [group[next_ind]]
82
- # we are moving from one action to a more expensive one in the same group
83
- elif next_ind < len(group):
84
- new_actions = copy.copy(actions)
85
- new_actions[prev_in_group] = group[next_ind]
86
- # we don't have a more expensive action left in this group
87
- else:
88
- new_actions = None
89
-
90
- # add the new option to our queue
91
- if new_actions is not None:
92
- q.put((sum([a.cost for a in new_actions]), new_actions))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- from ._compute import ComputeTime
2
- from ._explanation_error import ExplanationError
3
- from ._result import BenchmarkResult
4
- from ._sequential import SequentialMasker
5
-
6
- # from . import framework
7
- # from .. import datasets
8
-
9
- __all__ = ["ComputeTime", "ExplanationError", "BenchmarkResult", "SequentialMasker"]
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/_compute.py DELETED
@@ -1,9 +0,0 @@
1
- from ._result import BenchmarkResult
2
-
3
-
4
- class ComputeTime:
5
- """ Extracts a runtime benchmark result from the passed Explanation.
6
- """
7
-
8
- def __call__(self, explanation, name):
9
- return BenchmarkResult("compute time", name, value=explanation.compute_time / explanation.shape[0])
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/_explanation_error.py DELETED
@@ -1,181 +0,0 @@
1
- import time
2
-
3
- import numpy as np
4
- from tqdm.auto import tqdm
5
-
6
- from shap import Explanation, links
7
- from shap.maskers import FixedComposite, Image, Text
8
- from shap.utils import MaskedModel, partition_tree_shuffle
9
- from shap.utils._exceptions import DimensionError
10
-
11
- from ._result import BenchmarkResult
12
-
13
-
14
- class ExplanationError:
15
- """ A measure of the explanation error relative to a model's actual output.
16
-
17
- This benchmark metric measures the discrepancy between the output of the model predicted by an
18
- attribution explanation vs. the actual output of the model. This discrepancy is measured over
19
- many masking patterns drawn from permutations of the input features.
20
-
21
- For explanations (like Shapley values) that explain the difference between one alternative and another
22
- (for example a current sample and typical background feature values) there is possible explanation error
23
- for every pattern of mixing foreground and background, or other words every possible masking pattern.
24
- In this class we compute the standard deviation over these explanation errors where masking patterns
25
- are drawn from prefixes of random feature permutations. This seems natural, and aligns with Shapley value
26
- computations, but of course you could choose to summarize explanation errors in others ways as well.
27
- """
28
-
29
- def __init__(self, masker, model, *model_args, batch_size=500, num_permutations=10, link=links.identity, linearize_link=True, seed=38923):
30
- """ Build a new explanation error benchmarker with the given masker, model, and model args.
31
-
32
- Parameters
33
- ----------
34
- masker : function or shap.Masker
35
- The masker defines how we hide features during the perturbation process.
36
-
37
- model : function or shap.Model
38
- The model we want to evaluate explanations against.
39
-
40
- model_args : ...
41
- The list of arguments we will give to the model that we will have explained. When we later call this benchmark
42
- object we should pass explanations that have been computed on this same data.
43
-
44
- batch_size : int
45
- The maximum batch size we should use when calling the model. For some large NLP models this needs to be set
46
- lower (at say 1) to avoid running out of GPU memory.
47
-
48
- num_permutations : int
49
- How many permutations we will use to estimate the average explanation error for each sample. If you are running
50
- this benchmark on a large dataset with many samples then you can reduce this value since the final result is
51
- averaged over samples as well and the averages of both directly combine to reduce variance. So for 10k samples
52
- num_permutations=1 is appropreiate.
53
-
54
- link : function
55
- Allows for a non-linear link function to be used to bringe between the model output space and the explanation
56
- space.
57
-
58
- linearize_link : bool
59
- Non-linear links can destroy additive separation in generalized linear models, so by linearizing the link we can
60
- retain additive separation. See upcoming paper/doc for details.
61
- """
62
-
63
- self.masker = masker
64
- self.model = model
65
- self.model_args = model_args
66
- self.num_permutations = num_permutations
67
- self.link = link
68
- self.linearize_link = linearize_link
69
- self.model_args = model_args
70
- self.batch_size = batch_size
71
- self.seed = seed
72
-
73
- # user must give valid masker
74
- underlying_masker = masker.masker if isinstance(masker, FixedComposite) else masker
75
- if isinstance(underlying_masker, Text):
76
- self.data_type = "text"
77
- elif isinstance(underlying_masker, Image):
78
- self.data_type = "image"
79
- else:
80
- self.data_type = "tabular"
81
-
82
- def __call__(self, explanation, name, step_fraction=0.01, indices=[], silent=False):
83
- """ Run this benchmark on the given explanation.
84
- """
85
-
86
- if isinstance(explanation, np.ndarray):
87
- attributions = explanation
88
- elif isinstance(explanation, Explanation):
89
- attributions = explanation.values
90
- else:
91
- raise ValueError("The passed explanation must be either of type numpy.ndarray or shap.Explanation!")
92
-
93
- if len(attributions) != len(self.model_args[0]):
94
- emsg = (
95
- "The explanation passed must have the same number of rows as "
96
- "the self.model_args that were passed!"
97
- )
98
- raise DimensionError(emsg)
99
-
100
- # it is important that we choose the same permutations for the different explanations we are comparing
101
- # so as to avoid needless noise
102
- old_seed = np.random.seed()
103
- np.random.seed(self.seed)
104
-
105
- pbar = None
106
- start_time = time.time()
107
- svals = []
108
- mask_vals = []
109
-
110
- for i, args in enumerate(zip(*self.model_args)):
111
-
112
- if len(args[0].shape) != len(attributions[i].shape):
113
- raise ValueError("The passed explanation must have the same dim as the model_args and must not have a vector output!")
114
-
115
- feature_size = np.prod(attributions[i].shape)
116
- sample_attributions = attributions[i].flatten()
117
-
118
- # compute any custom clustering for this row
119
- row_clustering = None
120
- if getattr(self.masker, "clustering", None) is not None:
121
- if isinstance(self.masker.clustering, np.ndarray):
122
- row_clustering = self.masker.clustering
123
- elif callable(self.masker.clustering):
124
- row_clustering = self.masker.clustering(*args)
125
- else:
126
- raise NotImplementedError("The masker passed has a .clustering attribute that is not yet supported by the ExplanationError benchmark!")
127
-
128
- masked_model = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *args)
129
-
130
- total_values = None
131
- for _ in range(self.num_permutations):
132
- masks = []
133
- mask = np.zeros(feature_size, dtype=bool)
134
- masks.append(mask.copy())
135
- ordered_inds = np.arange(feature_size)
136
-
137
- # shuffle the indexes so we get a random permutation ordering
138
- if row_clustering is not None:
139
- inds_mask = np.ones(feature_size, dtype=bool)
140
- partition_tree_shuffle(ordered_inds, inds_mask, row_clustering)
141
- else:
142
- np.random.shuffle(ordered_inds)
143
-
144
- increment = max(1, int(feature_size * step_fraction))
145
- for j in range(0, feature_size, increment):
146
- mask[ordered_inds[np.arange(j, min(feature_size, j+increment))]] = True
147
- masks.append(mask.copy())
148
- mask_vals.append(masks)
149
-
150
- values = []
151
- masks_arr = np.array(masks)
152
- for j in range(0, len(masks_arr), self.batch_size):
153
- values.append(masked_model(masks_arr[j:j + self.batch_size]))
154
- values = np.concatenate(values)
155
- base_value = values[0]
156
- for j, v in enumerate(values):
157
- values[j] = (v - (base_value + np.sum(sample_attributions[masks_arr[j]])))**2
158
-
159
- if total_values is None:
160
- total_values = values
161
- else:
162
- total_values += values
163
- total_values /= self.num_permutations
164
-
165
- svals.append(total_values)
166
-
167
- if pbar is None and time.time() - start_time > 5:
168
- pbar = tqdm(total=len(self.model_args[0]), disable=silent, leave=False, desc=f"ExplanationError for {name}")
169
- pbar.update(i+1)
170
- if pbar is not None:
171
- pbar.update(1)
172
-
173
- if pbar is not None:
174
- pbar.close()
175
-
176
- svals = np.array(svals)
177
-
178
- # reset the random seed so we don't mess up the caller
179
- np.random.seed(old_seed)
180
-
181
- return BenchmarkResult("explanation error", name, value=np.sqrt(np.sum(total_values)/len(total_values)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/_result.py DELETED
@@ -1,34 +0,0 @@
1
- import numpy as np
2
- import sklearn
3
-
4
- sign_defaults = {
5
- "keep positive": 1,
6
- "keep negative": -1,
7
- "remove positive": -1,
8
- "remove negative": 1,
9
- "compute time": -1,
10
- "keep absolute": -1, # the absolute signs are defaults that make sense when scoring losses
11
- "remove absolute": 1,
12
- "explanation error": -1
13
- }
14
-
15
- class BenchmarkResult:
16
- """ The result of a benchmark run.
17
- """
18
-
19
- def __init__(self, metric, method, value=None, curve_x=None, curve_y=None, curve_y_std=None, value_sign=None):
20
- self.metric = metric
21
- self.method = method
22
- self.value = value
23
- self.curve_x = curve_x
24
- self.curve_y = curve_y
25
- self.curve_y_std = curve_y_std
26
- self.value_sign = value_sign
27
- if self.value_sign is None and self.metric in sign_defaults:
28
- self.value_sign = sign_defaults[self.metric]
29
- if self.value is None:
30
- self.value = sklearn.metrics.auc(curve_x, (np.array(curve_y) - curve_y[0]))
31
-
32
- @property
33
- def full_name(self):
34
- return self.method + " " + self.metric
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/_sequential.py DELETED
@@ -1,332 +0,0 @@
1
- import time
2
-
3
- import matplotlib.pyplot as pl
4
- import numpy as np
5
- import pandas as pd
6
- import sklearn
7
- from tqdm.auto import tqdm
8
-
9
- from shap import Explanation, links
10
- from shap.maskers import FixedComposite, Image, Text
11
- from shap.utils import MaskedModel
12
-
13
- from ._result import BenchmarkResult
14
-
15
-
16
- class SequentialMasker:
17
- def __init__(self, mask_type, sort_order, masker, model, *model_args, batch_size=500):
18
-
19
- for arg in model_args:
20
- if isinstance(arg, pd.DataFrame):
21
- raise TypeError("DataFrame arguments dont iterate correctly, pass numpy arrays instead!")
22
-
23
- # convert any DataFrames to numpy arrays
24
- # self.model_arg_cols = []
25
- # self.model_args = []
26
- # self.has_df = False
27
- # for arg in model_args:
28
- # if isinstance(arg, pd.DataFrame):
29
- # self.model_arg_cols.append(arg.columns)
30
- # self.model_args.append(arg.values)
31
- # self.has_df = True
32
- # else:
33
- # self.model_arg_cols.append(None)
34
- # self.model_args.append(arg)
35
-
36
- # if self.has_df:
37
- # given_model = model
38
- # def new_model(*args):
39
- # df_args = []
40
- # for i, arg in enumerate(args):
41
- # if self.model_arg_cols[i] is not None:
42
- # df_args.append(pd.DataFrame(arg, columns=self.model_arg_cols[i]))
43
- # else:
44
- # df_args.append(arg)
45
- # return given_model(*df_args)
46
- # model = new_model
47
-
48
- self.inner = SequentialPerturbation(
49
- model, masker, sort_order, mask_type
50
- )
51
- self.model_args = model_args
52
- self.batch_size = batch_size
53
-
54
- def __call__(self, explanation, name, **kwargs):
55
- return self.inner(name, explanation, *self.model_args, batch_size=self.batch_size, **kwargs)
56
-
57
- class SequentialPerturbation:
58
- def __init__(self, model, masker, sort_order, perturbation, linearize_link=False):
59
- # self.f = lambda masked, x, index: model.predict(masked)
60
- self.model = model if callable(model) else model.predict
61
- self.masker = masker
62
- self.sort_order = sort_order
63
- self.perturbation = perturbation
64
- self.linearize_link = linearize_link
65
-
66
- # define our sort order
67
- if self.sort_order == "positive":
68
- self.sort_order_map = lambda x: np.argsort(-x)
69
- elif self.sort_order == "negative":
70
- self.sort_order_map = lambda x: np.argsort(x)
71
- elif self.sort_order == "absolute":
72
- self.sort_order_map = lambda x: np.argsort(-abs(x))
73
- else:
74
- raise ValueError("sort_order must be either \"positive\", \"negative\", or \"absolute\"!")
75
-
76
- # user must give valid masker
77
- underlying_masker = masker.masker if isinstance(masker, FixedComposite) else masker
78
- if isinstance(underlying_masker, Text):
79
- self.data_type = "text"
80
- elif isinstance(underlying_masker, Image):
81
- self.data_type = "image"
82
- else:
83
- self.data_type = "tabular"
84
- #raise ValueError("masker must be for \"tabular\", \"text\", or \"image\"!")
85
-
86
- self.score_values = []
87
- self.score_aucs = []
88
- self.labels = []
89
-
90
- def __call__(self, name, explanation, *model_args, percent=0.01, indices=[], y=None, label=None, silent=False, debug_mode=False, batch_size=10):
91
- # if explainer is already the attributions
92
- if isinstance(explanation, np.ndarray):
93
- attributions = explanation
94
- elif isinstance(explanation, Explanation):
95
- attributions = explanation.values
96
- else:
97
- raise ValueError("The passed explanation must be either of type numpy.ndarray or shap.Explanation!")
98
-
99
- assert len(attributions) == len(model_args[0]), "The explanation passed must have the same number of rows as the model_args that were passed!"
100
-
101
- if label is None:
102
- label = "Score %d" % len(self.score_values)
103
-
104
- # convert dataframes
105
- # if isinstance(X, (pd.Series, pd.DataFrame)):
106
- # X = X.values
107
-
108
- # convert all single-sample vectors to matrices
109
- # if not hasattr(attributions[0], "__len__"):
110
- # attributions = np.array([attributions])
111
- # if not hasattr(X[0], "__len__") and self.data_type == "tabular":
112
- # X = np.array([X])
113
-
114
- pbar = None
115
- start_time = time.time()
116
- svals = []
117
- mask_vals = []
118
-
119
- for i, args in enumerate(zip(*model_args)):
120
- # if self.data_type == "image":
121
- # x_shape, y_shape = attributions[i].shape[0], attributions[i].shape[1]
122
- # feature_size = np.prod([x_shape, y_shape])
123
- # sample_attributions = attributions[i].mean(2).reshape(feature_size, -1)
124
- # data = X[i].flatten()
125
- # mask_shape = X[i].shape
126
- # else:
127
- feature_size = np.prod(attributions[i].shape)
128
- sample_attributions = attributions[i].flatten()
129
- # data = X[i]
130
- # mask_shape = feature_size
131
-
132
- self.masked_model = MaskedModel(self.model, self.masker, links.identity, self.linearize_link, *args)
133
-
134
- masks = []
135
-
136
- mask = np.ones(feature_size, dtype=bool) * (self.perturbation == "remove")
137
- masks.append(mask.copy())
138
-
139
- ordered_inds = self.sort_order_map(sample_attributions)
140
- increment = max(1,int(feature_size*percent))
141
- for j in range(0, feature_size, increment):
142
- oind_list = [ordered_inds[t] for t in range(j, min(feature_size, j+increment))]
143
-
144
- for oind in oind_list:
145
- if not ((self.sort_order == "positive" and sample_attributions[oind] <= 0) or \
146
- (self.sort_order == "negative" and sample_attributions[oind] >= 0)):
147
- mask[oind] = self.perturbation == "keep"
148
-
149
- masks.append(mask.copy())
150
-
151
- mask_vals.append(masks)
152
-
153
- # mask_size = len(range(0, feature_size, increment)) + 1
154
- values = []
155
- masks_arr = np.array(masks)
156
- for j in range(0, len(masks_arr), batch_size):
157
- values.append(self.masked_model(masks_arr[j:j + batch_size]))
158
- values = np.concatenate(values)
159
-
160
- svals.append(values)
161
-
162
- if pbar is None and time.time() - start_time > 5:
163
- pbar = tqdm(total=len(model_args[0]), disable=silent, leave=False, desc="SequentialMasker")
164
- pbar.update(i+1)
165
- if pbar is not None:
166
- pbar.update(1)
167
-
168
- if pbar is not None:
169
- pbar.close()
170
-
171
- self.score_values.append(np.array(svals))
172
-
173
- # if self.sort_order == "negative":
174
- # curve_sign = -1
175
- # else:
176
- curve_sign = 1
177
-
178
- self.labels.append(label)
179
-
180
- xs = np.linspace(0, 1, 100)
181
- curves = np.zeros((len(self.score_values[-1]), len(xs)))
182
- for j in range(len(self.score_values[-1])):
183
- xp = np.linspace(0, 1, len(self.score_values[-1][j]))
184
- yp = self.score_values[-1][j]
185
- curves[j,:] = np.interp(xs, xp, yp)
186
- ys = curves.mean(0)
187
- std = curves.std(0) / np.sqrt(curves.shape[0])
188
- auc = sklearn.metrics.auc(np.linspace(0, 1, len(ys)), curve_sign*(ys-ys[0]))
189
-
190
- if not debug_mode:
191
- return BenchmarkResult(self.perturbation + " " + self.sort_order, name, curve_x=xs, curve_y=ys, curve_y_std=std)
192
- else:
193
- aucs = []
194
- for j in range(len(self.score_values[-1])):
195
- curve = curves[j,:]
196
- auc = sklearn.metrics.auc(np.linspace(0, 1, len(curve)), curve_sign*(curve-curve[0]))
197
- aucs.append(auc)
198
- return mask_vals, curves, aucs
199
-
200
- def score(self, explanation, X, percent=0.01, y=None, label=None, silent=False, debug_mode=False):
201
- '''
202
- Will be deprecated once MaskedModel is in complete support
203
- '''
204
- # if explainer is already the attributions
205
- if isinstance(explanation, np.ndarray):
206
- attributions = explanation
207
- elif isinstance(explanation, Explanation):
208
- attributions = explanation.values
209
-
210
- if label is None:
211
- label = "Score %d" % len(self.score_values)
212
-
213
- # convert dataframes
214
- if isinstance(X, (pd.Series, pd.DataFrame)):
215
- X = X.values
216
-
217
- # convert all single-sample vectors to matrices
218
- if not hasattr(attributions[0], "__len__"):
219
- attributions = np.array([attributions])
220
- if not hasattr(X[0], "__len__") and self.data_type == "tabular":
221
- X = np.array([X])
222
-
223
- pbar = None
224
- start_time = time.time()
225
- svals = []
226
- mask_vals = []
227
-
228
- for i in range(len(X)):
229
- if self.data_type == "image":
230
- x_shape, y_shape = attributions[i].shape[0], attributions[i].shape[1]
231
- feature_size = np.prod([x_shape, y_shape])
232
- sample_attributions = attributions[i].mean(2).reshape(feature_size, -1)
233
- else:
234
- feature_size = attributions[i].shape[0]
235
- sample_attributions = attributions[i]
236
-
237
- if len(attributions[i].shape) == 1 or self.data_type == "tabular":
238
- output_size = 1
239
- else:
240
- output_size = attributions[i].shape[-1]
241
-
242
- for k in range(output_size):
243
- if self.data_type == "image":
244
- mask_shape = X[i].shape
245
- else:
246
- mask_shape = feature_size
247
-
248
- mask = np.ones(mask_shape, dtype=bool) * (self.perturbation == "remove")
249
- masks = [mask.copy()]
250
-
251
- values = np.zeros(feature_size+1)
252
- # masked, data = self.masker(mask, X[i])
253
- masked = self.masker(mask, X[i])
254
- data = None
255
- curr_val = self.f(masked, data, k).mean(0)
256
-
257
- values[0] = curr_val
258
-
259
- if output_size != 1:
260
- test_attributions = sample_attributions[:,k]
261
- else:
262
- test_attributions = sample_attributions
263
-
264
- ordered_inds = self.sort_order_map(test_attributions)
265
- increment = max(1,int(feature_size*percent))
266
- for j in range(0, feature_size, increment):
267
- oind_list = [ordered_inds[t] for t in range(j, min(feature_size, j+increment))]
268
-
269
- for oind in oind_list:
270
- if not ((self.sort_order == "positive" and test_attributions[oind] <= 0) or \
271
- (self.sort_order == "negative" and test_attributions[oind] >= 0)):
272
- if self.data_type == "image":
273
- xoind, yoind = oind // attributions[i].shape[1], oind % attributions[i].shape[1]
274
- mask[xoind][yoind] = self.perturbation == "keep"
275
- else:
276
- mask[oind] = self.perturbation == "keep"
277
-
278
- masks.append(mask.copy())
279
- # masked, data = self.masker(mask, X[i])
280
- masked = self.masker(mask, X[i])
281
- curr_val = self.f(masked, data, k).mean(0)
282
-
283
- for t in range(j, min(feature_size, j+increment)):
284
- values[t+1] = curr_val
285
-
286
- svals.append(values)
287
- mask_vals.append(masks)
288
-
289
- if pbar is None and time.time() - start_time > 5:
290
- pbar = tqdm(total=len(X), disable=silent, leave=False)
291
- pbar.update(i+1)
292
- if pbar is not None:
293
- pbar.update(1)
294
-
295
- if pbar is not None:
296
- pbar.close()
297
-
298
- self.score_values.append(np.array(svals))
299
-
300
- if self.sort_order == "negative":
301
- curve_sign = -1
302
- else:
303
- curve_sign = 1
304
-
305
- self.labels.append(label)
306
-
307
- xs = np.linspace(0, 1, 100)
308
- curves = np.zeros((len(self.score_values[-1]), len(xs)))
309
- for j in range(len(self.score_values[-1])):
310
- xp = np.linspace(0, 1, len(self.score_values[-1][j]))
311
- yp = self.score_values[-1][j]
312
- curves[j,:] = np.interp(xs, xp, yp)
313
- ys = curves.mean(0)
314
-
315
- if debug_mode:
316
- aucs = []
317
- for j in range(len(self.score_values[-1])):
318
- curve = curves[j,:]
319
- auc = sklearn.metrics.auc(np.linspace(0, 1, len(curve)), curve_sign*(curve-curve[0]))
320
- aucs.append(auc)
321
- return mask_vals, curves, aucs
322
- else:
323
- auc = sklearn.metrics.auc(np.linspace(0, 1, len(ys)), curve_sign*(ys-ys[0]))
324
- return xs, ys, auc
325
-
326
- def plot(self, xs, ys, auc):
327
- pl.plot(xs, ys, label="AUC %0.4f" % auc)
328
- pl.legend()
329
- xlabel = "Percent Unmasked" if self.perturbation == "keep" else "Percent Masked"
330
- pl.xlabel(xlabel)
331
- pl.ylabel("Model Output")
332
- pl.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/experiments.py DELETED
@@ -1,414 +0,0 @@
1
- import copy
2
- import itertools
3
- import os
4
- import pickle
5
- import random
6
- import subprocess
7
- import sys
8
- import time
9
- from multiprocessing import Pool
10
-
11
- from .. import __version__, datasets
12
- from . import metrics, models
13
-
14
- try:
15
- from queue import Queue
16
- except ImportError:
17
- from Queue import Queue
18
- from threading import Lock, Thread
19
-
20
- regression_metrics = [
21
- "local_accuracy",
22
- "consistency_guarantees",
23
- "keep_positive_mask",
24
- "keep_positive_resample",
25
- #"keep_positive_impute",
26
- "keep_negative_mask",
27
- "keep_negative_resample",
28
- #"keep_negative_impute",
29
- "keep_absolute_mask__r2",
30
- "keep_absolute_resample__r2",
31
- #"keep_absolute_impute__r2",
32
- "remove_positive_mask",
33
- "remove_positive_resample",
34
- #"remove_positive_impute",
35
- "remove_negative_mask",
36
- "remove_negative_resample",
37
- #"remove_negative_impute",
38
- "remove_absolute_mask__r2",
39
- "remove_absolute_resample__r2",
40
- #"remove_absolute_impute__r2"
41
- "runtime",
42
- ]
43
-
44
- binary_classification_metrics = [
45
- "local_accuracy",
46
- "consistency_guarantees",
47
- "keep_positive_mask",
48
- "keep_positive_resample",
49
- #"keep_positive_impute",
50
- "keep_negative_mask",
51
- "keep_negative_resample",
52
- #"keep_negative_impute",
53
- "keep_absolute_mask__roc_auc",
54
- "keep_absolute_resample__roc_auc",
55
- #"keep_absolute_impute__roc_auc",
56
- "remove_positive_mask",
57
- "remove_positive_resample",
58
- #"remove_positive_impute",
59
- "remove_negative_mask",
60
- "remove_negative_resample",
61
- #"remove_negative_impute",
62
- "remove_absolute_mask__roc_auc",
63
- "remove_absolute_resample__roc_auc",
64
- #"remove_absolute_impute__roc_auc"
65
- "runtime",
66
- ]
67
-
68
- human_metrics = [
69
- "human_and_00",
70
- "human_and_01",
71
- "human_and_11",
72
- "human_or_00",
73
- "human_or_01",
74
- "human_or_11",
75
- "human_xor_00",
76
- "human_xor_01",
77
- "human_xor_11",
78
- "human_sum_00",
79
- "human_sum_01",
80
- "human_sum_11"
81
- ]
82
-
83
- linear_regress_methods = [
84
- "linear_shap_corr",
85
- "linear_shap_ind",
86
- "coef",
87
- "random",
88
- "kernel_shap_1000_meanref",
89
- #"kernel_shap_100_meanref",
90
- #"sampling_shap_10000",
91
- "sampling_shap_1000",
92
- "lime_tabular_regression_1000"
93
- #"sampling_shap_100"
94
- ]
95
-
96
- linear_classify_methods = [
97
- # NEED LIME
98
- "linear_shap_corr",
99
- "linear_shap_ind",
100
- "coef",
101
- "random",
102
- "kernel_shap_1000_meanref",
103
- #"kernel_shap_100_meanref",
104
- #"sampling_shap_10000",
105
- "sampling_shap_1000",
106
- #"lime_tabular_regression_1000"
107
- #"sampling_shap_100"
108
- ]
109
-
110
- tree_regress_methods = [
111
- # NEED tree_shap_ind
112
- # NEED split_count?
113
- "tree_shap_tree_path_dependent",
114
- "tree_shap_independent_200",
115
- "saabas",
116
- "random",
117
- "tree_gain",
118
- "kernel_shap_1000_meanref",
119
- "mean_abs_tree_shap",
120
- #"kernel_shap_100_meanref",
121
- #"sampling_shap_10000",
122
- "sampling_shap_1000",
123
- "lime_tabular_regression_1000",
124
- "maple"
125
- #"sampling_shap_100"
126
- ]
127
-
128
- rf_regress_methods = [ # methods that only support random forest models
129
- "tree_maple"
130
- ]
131
-
132
- tree_classify_methods = [
133
- # NEED tree_shap_ind
134
- # NEED split_count?
135
- "tree_shap_tree_path_dependent",
136
- "tree_shap_independent_200",
137
- "saabas",
138
- "random",
139
- "tree_gain",
140
- "kernel_shap_1000_meanref",
141
- "mean_abs_tree_shap",
142
- #"kernel_shap_100_meanref",
143
- #"sampling_shap_10000",
144
- "sampling_shap_1000",
145
- "lime_tabular_classification_1000",
146
- "maple"
147
- #"sampling_shap_100"
148
- ]
149
-
150
- deep_regress_methods = [
151
- "deep_shap",
152
- "expected_gradients",
153
- "random",
154
- "kernel_shap_1000_meanref",
155
- "sampling_shap_1000",
156
- #"lime_tabular_regression_1000"
157
- ]
158
-
159
- deep_classify_methods = [
160
- "deep_shap",
161
- "expected_gradients",
162
- "random",
163
- "kernel_shap_1000_meanref",
164
- "sampling_shap_1000",
165
- #"lime_tabular_regression_1000"
166
- ]
167
-
168
- _experiments = []
169
- _experiments += [["corrgroups60", "lasso", m, s] for s in regression_metrics for m in linear_regress_methods]
170
- _experiments += [["corrgroups60", "ridge", m, s] for s in regression_metrics for m in linear_regress_methods]
171
- _experiments += [["corrgroups60", "decision_tree", m, s] for s in regression_metrics for m in tree_regress_methods]
172
- _experiments += [["corrgroups60", "random_forest", m, s] for s in regression_metrics for m in (tree_regress_methods + rf_regress_methods)]
173
- _experiments += [["corrgroups60", "gbm", m, s] for s in regression_metrics for m in tree_regress_methods]
174
- _experiments += [["corrgroups60", "ffnn", m, s] for s in regression_metrics for m in deep_regress_methods]
175
-
176
- _experiments += [["independentlinear60", "lasso", m, s] for s in regression_metrics for m in linear_regress_methods]
177
- _experiments += [["independentlinear60", "ridge", m, s] for s in regression_metrics for m in linear_regress_methods]
178
- _experiments += [["independentlinear60", "decision_tree", m, s] for s in regression_metrics for m in tree_regress_methods]
179
- _experiments += [["independentlinear60", "random_forest", m, s] for s in regression_metrics for m in (tree_regress_methods + rf_regress_methods)]
180
- _experiments += [["independentlinear60", "gbm", m, s] for s in regression_metrics for m in tree_regress_methods]
181
- _experiments += [["independentlinear60", "ffnn", m, s] for s in regression_metrics for m in deep_regress_methods]
182
-
183
- _experiments += [["cric", "lasso", m, s] for s in binary_classification_metrics for m in linear_classify_methods]
184
- _experiments += [["cric", "ridge", m, s] for s in binary_classification_metrics for m in linear_classify_methods]
185
- _experiments += [["cric", "decision_tree", m, s] for s in binary_classification_metrics for m in tree_classify_methods]
186
- _experiments += [["cric", "random_forest", m, s] for s in binary_classification_metrics for m in tree_classify_methods]
187
- _experiments += [["cric", "gbm", m, s] for s in binary_classification_metrics for m in tree_classify_methods]
188
- _experiments += [["cric", "ffnn", m, s] for s in binary_classification_metrics for m in deep_classify_methods]
189
-
190
- _experiments += [["human", "decision_tree", m, s] for s in human_metrics for m in tree_regress_methods]
191
-
192
-
193
- def experiments(dataset=None, model=None, method=None, metric=None):
194
- for experiment in _experiments:
195
- if dataset is not None and dataset != experiment[0]:
196
- continue
197
- if model is not None and model != experiment[1]:
198
- continue
199
- if method is not None and method != experiment[2]:
200
- continue
201
- if metric is not None and metric != experiment[3]:
202
- continue
203
- yield experiment
204
-
205
- def run_experiment(experiment, use_cache=True, cache_dir="/tmp"):
206
- dataset_name, model_name, method_name, metric_name = experiment
207
-
208
- # see if we have a cached version
209
- cache_id = __gen_cache_id(experiment)
210
- cache_file = os.path.join(cache_dir, cache_id + ".pickle")
211
- if use_cache and os.path.isfile(cache_file):
212
- with open(cache_file, "rb") as f:
213
- #print(cache_id.replace("__", " ") + " ...loaded from cache.")
214
- return pickle.load(f)
215
-
216
- # compute the scores
217
- print(cache_id.replace("__", " ", 4) + " ...")
218
- sys.stdout.flush()
219
- start = time.time()
220
- X,y = getattr(datasets, dataset_name)()
221
- score = getattr(metrics, metric_name)(
222
- X, y,
223
- getattr(models, dataset_name+"__"+model_name),
224
- method_name
225
- )
226
- print("...took %f seconds.\n" % (time.time() - start))
227
-
228
- # cache the scores
229
- with open(cache_file, "wb") as f:
230
- pickle.dump(score, f)
231
-
232
- return score
233
-
234
-
235
- def run_experiments_helper(args):
236
- experiment, cache_dir = args
237
- return run_experiment(experiment, cache_dir=cache_dir)
238
-
239
- def run_experiments(dataset=None, model=None, method=None, metric=None, cache_dir="/tmp", nworkers=1):
240
- experiments_arr = list(experiments(dataset=dataset, model=model, method=method, metric=metric))
241
- if nworkers == 1:
242
- out = list(map(run_experiments_helper, zip(experiments_arr, itertools.repeat(cache_dir))))
243
- else:
244
- with Pool(nworkers) as pool:
245
- out = pool.map(run_experiments_helper, zip(experiments_arr, itertools.repeat(cache_dir)))
246
- return list(zip(experiments_arr, out))
247
-
248
-
249
- nexperiments = 0
250
- total_sent = 0
251
- total_done = 0
252
- total_failed = 0
253
- host_records = {}
254
- worker_lock = Lock()
255
- ssh_conn_per_min_limit = 0 # set as an argument to run_remote_experiments
256
- def __thread_worker(q, host):
257
- global total_sent, total_done
258
- hostname, python_binary = host.split(":")
259
- while True:
260
-
261
- # make sure we are not sending too many ssh connections to the host
262
- # (if we send too many connections ssh thottling will lock us out)
263
- while True:
264
- all_clear = False
265
-
266
- worker_lock.acquire()
267
- try:
268
- if hostname not in host_records:
269
- host_records[hostname] = []
270
-
271
- if len(host_records[hostname]) < ssh_conn_per_min_limit:
272
- all_clear = True
273
- elif time.time() - host_records[hostname][-ssh_conn_per_min_limit] > 61:
274
- all_clear = True
275
- finally:
276
- worker_lock.release()
277
-
278
- # if we are clear to send a new ssh connection then break
279
- if all_clear:
280
- break
281
-
282
- # if we are not clear then we sleep and try again
283
- time.sleep(5)
284
-
285
- experiment = q.get()
286
-
287
- # if we are not loading from the cache then we note that we have called the host
288
- cache_dir = "/tmp"
289
- cache_file = os.path.join(cache_dir, __gen_cache_id(experiment) + ".pickle")
290
- if not os.path.isfile(cache_file):
291
- worker_lock.acquire()
292
- try:
293
- host_records[hostname].append(time.time())
294
- finally:
295
- worker_lock.release()
296
-
297
- # record how many we have sent off for execution
298
- worker_lock.acquire()
299
- try:
300
- total_sent += 1
301
- __print_status()
302
- finally:
303
- worker_lock.release()
304
-
305
- __run_remote_experiment(experiment, hostname, cache_dir=cache_dir, python_binary=python_binary)
306
-
307
- # record how many are finished
308
- worker_lock.acquire()
309
- try:
310
- total_done += 1
311
- __print_status()
312
- finally:
313
- worker_lock.release()
314
-
315
- q.task_done()
316
-
317
- def __print_status():
318
- print("Benchmark task %d of %d done (%d failed, %d running)" % (total_done, nexperiments, total_failed, total_sent - total_done), end="\r")
319
- sys.stdout.flush()
320
-
321
-
322
- def run_remote_experiments(experiments, thread_hosts, rate_limit=10):
323
- """ Use ssh to run the experiments on remote machines in parallel.
324
-
325
- Parameters
326
- ----------
327
- experiments : iterable
328
- Output of shap.benchmark.experiments(...).
329
-
330
- thread_hosts : list of strings
331
- Each host has the format "host_name:path_to_python_binary" and can appear multiple times
332
- in the list (one for each parallel execution you want on that machine).
333
-
334
- rate_limit : int
335
- How many ssh connections we make per minute to each host (to avoid throttling issues).
336
- """
337
-
338
- global ssh_conn_per_min_limit
339
- ssh_conn_per_min_limit = rate_limit
340
-
341
- # first we kill any remaining workers from previous runs
342
- # note we don't check_call because pkill kills our ssh call as well
343
- thread_hosts = copy.copy(thread_hosts)
344
- random.shuffle(thread_hosts)
345
- for host in set(thread_hosts):
346
- hostname,_ = host.split(":")
347
- try:
348
- subprocess.run(["ssh", hostname, "pkill -f shap.benchmark.run_experiment"], timeout=15)
349
- except subprocess.TimeoutExpired:
350
- print("Failed to connect to", hostname, "after 15 seconds! Exiting.")
351
- return
352
-
353
- experiments = copy.copy(list(experiments))
354
- random.shuffle(experiments) # this way all the hard experiments don't get put on one machine
355
- global nexperiments, total_sent, total_done, total_failed, host_records
356
- nexperiments = len(experiments)
357
- total_sent = 0
358
- total_done = 0
359
- total_failed = 0
360
- host_records = {}
361
-
362
- q = Queue()
363
-
364
- for host in thread_hosts:
365
- worker = Thread(target=__thread_worker, args=(q, host))
366
- worker.setDaemon(True)
367
- worker.start()
368
-
369
- for experiment in experiments:
370
- q.put(experiment)
371
-
372
- q.join()
373
-
374
- def __run_remote_experiment(experiment, remote, cache_dir="/tmp", python_binary="python"):
375
- global total_failed
376
- dataset_name, model_name, method_name, metric_name = experiment
377
-
378
- # see if we have a cached version
379
- cache_id = __gen_cache_id(experiment)
380
- cache_file = os.path.join(cache_dir, cache_id + ".pickle")
381
- if os.path.isfile(cache_file):
382
- with open(cache_file, "rb") as f:
383
- return pickle.load(f)
384
-
385
- # this is just so we don't dump everything at once on a machine
386
- time.sleep(random.uniform(0,5))
387
-
388
- # run the benchmark on the remote machine
389
- #start = time.time()
390
- cmd = "CUDA_VISIBLE_DEVICES=\"\" "+python_binary+" -c \"import shap; shap.benchmark.run_experiment(['{}', '{}', '{}', '{}'], cache_dir='{}')\" &> {}/{}.output".format(
391
- dataset_name, model_name, method_name, metric_name, cache_dir, cache_dir, cache_id
392
- )
393
- try:
394
- subprocess.check_output(["ssh", remote, cmd])
395
- except subprocess.CalledProcessError as e:
396
- print("The following command failed on %s:" % remote, file=sys.stderr)
397
- print(cmd, file=sys.stderr)
398
- total_failed += 1
399
- print(e)
400
- return
401
-
402
- # copy the results back
403
- subprocess.check_output(["scp", remote+":"+cache_file, cache_file])
404
-
405
- if os.path.isfile(cache_file):
406
- with open(cache_file, "rb") as f:
407
- #print(cache_id.replace("__", " ") + " ...loaded from remote after %f seconds" % (time.time() - start))
408
- return pickle.load(f)
409
- else:
410
- raise FileNotFoundError("Remote benchmark call finished but no local file was found!")
411
-
412
- def __gen_cache_id(experiment):
413
- dataset_name, model_name, method_name, metric_name = experiment
414
- return "v" + "__".join([__version__, dataset_name, model_name, method_name, metric_name])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/framework.py DELETED
@@ -1,113 +0,0 @@
1
- import itertools as it
2
-
3
- import matplotlib.pyplot as plt
4
- import numpy as np
5
- import pandas as pd
6
-
7
- from . import perturbation
8
-
9
-
10
- def update(model, attributions, X, y, masker, sort_order, perturbation_method, scores):
11
- metric = perturbation_method + ' ' + sort_order
12
- sp = perturbation.SequentialPerturbation(model, masker, sort_order, perturbation_method)
13
- xs, ys, auc = sp.model_score(attributions, X, y=y)
14
- scores['metrics'].append(metric)
15
- scores['values'][metric] = [xs, ys, auc]
16
-
17
- def get_benchmark(model, attributions, X, y, masker, metrics):
18
- # convert dataframes
19
- if isinstance(X, (pd.Series, pd.DataFrame)):
20
- X = X.values
21
- if isinstance(masker, (pd.Series, pd.DataFrame)):
22
- masker = masker.values
23
-
24
- # record scores per metric
25
- scores = {'metrics': list(), 'values': dict()}
26
- for sort_order, perturbation_method in list(it.product(metrics['sort_order'], metrics['perturbation'])):
27
- update(model, attributions, X, y, masker, sort_order, perturbation_method, scores)
28
-
29
- return scores
30
-
31
- def get_metrics(benchmarks, selection):
32
- # select metrics to plot using selection function
33
- explainer_metrics = set()
34
- for explainer in benchmarks:
35
- scores = benchmarks[explainer]
36
- if len(explainer_metrics) == 0:
37
- explainer_metrics = set(scores['metrics'])
38
- else:
39
- explainer_metrics = selection(explainer_metrics, set(scores['metrics']))
40
-
41
- return list(explainer_metrics)
42
-
43
- def trend_plot(benchmarks):
44
- explainer_metrics = get_metrics(benchmarks, lambda x, y: x.union(y))
45
-
46
- # plot all curves if metric exists
47
- for metric in explainer_metrics:
48
- plt.clf()
49
-
50
- for explainer in benchmarks:
51
- scores = benchmarks[explainer]
52
- if metric in scores['values']:
53
- x, y, auc = scores['values'][metric]
54
- plt.plot(x, y, label=f'{round(auc, 3)} - {explainer}')
55
-
56
- if 'keep' in metric:
57
- xlabel = 'Percent Unmasked'
58
- if 'remove' in metric:
59
- xlabel = 'Percent Masked'
60
-
61
- plt.ylabel('Model Output')
62
- plt.xlabel(xlabel)
63
- plt.title(metric)
64
- plt.legend()
65
- plt.show()
66
-
67
- def compare_plot(benchmarks):
68
- explainer_metrics = get_metrics(benchmarks, lambda x, y: x.intersection(y))
69
- explainers = list(benchmarks.keys())
70
- num_explainers = len(explainers)
71
- num_metrics = len(explainer_metrics)
72
-
73
- # dummy start to evenly distribute explainers on the left
74
- # can later be replaced by boolean metrics
75
- aucs = dict()
76
- for i in range(num_explainers):
77
- explainer = explainers[i]
78
- aucs[explainer] = [i/(num_explainers-1)]
79
-
80
- # normalize per metric
81
- for metric in explainer_metrics:
82
- max_auc, min_auc = -float('inf'), float('inf')
83
-
84
- for explainer in explainers:
85
- scores = benchmarks[explainer]
86
- _, _, auc = scores['values'][metric]
87
- min_auc = min(auc, min_auc)
88
- max_auc = max(auc, max_auc)
89
-
90
- for explainer in explainers:
91
- scores = benchmarks[explainer]
92
- _, _, auc = scores['values'][metric]
93
- aucs[explainer].append((auc-min_auc)/(max_auc-min_auc))
94
-
95
- # plot common curves
96
- ax = plt.gca()
97
- for explainer in explainers:
98
- plt.plot(np.linspace(0, 1, len(explainer_metrics)+1), aucs[explainer], '--o')
99
-
100
- ax.tick_params(which='major', axis='both', labelsize=8)
101
-
102
- ax.set_yticks([i/(num_explainers-1) for i in range(0, num_explainers)])
103
- ax.set_yticklabels(explainers, rotation=0)
104
-
105
- ax.set_xticks(np.linspace(0, 1, num_metrics+1))
106
- ax.set_xticklabels([' '] + explainer_metrics, rotation=45, ha='right')
107
-
108
- plt.grid(which='major', axis='x', linestyle='--')
109
- plt.tight_layout()
110
- plt.ylabel('Relative Performance of Each Explanation Method')
111
- plt.xlabel('Evaluation Metrics')
112
- plt.title('Explanation Method Performance Across Metrics')
113
- plt.show()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/measures.py DELETED
@@ -1,424 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- import pandas as pd
5
- import sklearn.utils
6
- from tqdm.auto import tqdm
7
-
8
- _remove_cache = {}
9
- def remove_retrain(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
10
- """ The model is retrained for each test sample with the important features set to a constant.
11
-
12
- If you want to know how important a set of features is you can ask how the model would be
13
- different if those features had never existed. To determine this we can mask those features
14
- across the entire training and test datasets, then retrain the model. If we apply compare the
15
- output of this retrained model to the original model we can see the effect produced by knowning
16
- the features we masked. Since for individualized explanation methods each test sample has a
17
- different set of most important features we need to retrain the model for every test sample
18
- to get the change in model performance when a specified fraction of the most important features
19
- are withheld.
20
- """
21
-
22
- warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
23
-
24
- # see if we match the last cached call
25
- global _remove_cache
26
- args = (X_train, y_train, X_test, y_test, model_generator, metric)
27
- cache_match = False
28
- if "args" in _remove_cache:
29
- if all(a is b for a,b in zip(_remove_cache["args"], args)) and np.all(_remove_cache["attr_test"] == attr_test):
30
- cache_match = True
31
-
32
- X_train, X_test = to_array(X_train, X_test)
33
-
34
- # how many features to mask
35
- assert X_train.shape[1] == X_test.shape[1]
36
-
37
- # this is the model we will retrain many times
38
- model_masked = model_generator()
39
-
40
- # mask nmask top features and re-train the model for each test explanation
41
- X_train_tmp = np.zeros(X_train.shape)
42
- X_test_tmp = np.zeros(X_test.shape)
43
- yp_masked_test = np.zeros(y_test.shape)
44
- tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
45
- last_nmask = _remove_cache.get("nmask", None)
46
- last_yp_masked_test = _remove_cache.get("yp_masked_test", None)
47
- for i in tqdm(range(len(y_test)), "Retraining for the 'remove' metric"):
48
- if cache_match and last_nmask[i] == nmask[i]:
49
- yp_masked_test[i] = last_yp_masked_test[i]
50
- elif nmask[i] == 0:
51
- yp_masked_test[i] = trained_model.predict(X_test[i:i+1])[0]
52
- else:
53
- # mask out the most important features for this test instance
54
- X_train_tmp[:] = X_train
55
- X_test_tmp[:] = X_test
56
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
57
- X_train_tmp[:,ordering[:nmask[i]]] = X_train[:,ordering[:nmask[i]]].mean()
58
- X_test_tmp[i,ordering[:nmask[i]]] = X_train[:,ordering[:nmask[i]]].mean()
59
-
60
- # retrain the model and make a prediction
61
- model_masked.fit(X_train_tmp, y_train)
62
- yp_masked_test[i] = model_masked.predict(X_test_tmp[i:i+1])[0]
63
-
64
- # save our results so the next call to us can be faster when there is redundancy
65
- _remove_cache["nmask"] = nmask
66
- _remove_cache["yp_masked_test"] = yp_masked_test
67
- _remove_cache["attr_test"] = attr_test
68
- _remove_cache["args"] = args
69
-
70
- return metric(y_test, yp_masked_test)
71
-
72
- def remove_mask(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
73
- """ Each test sample is masked by setting the important features to a constant.
74
- """
75
-
76
- X_train, X_test = to_array(X_train, X_test)
77
-
78
- # how many features to mask
79
- assert X_train.shape[1] == X_test.shape[1]
80
-
81
- # mask nmask top features for each test explanation
82
- X_test_tmp = X_test.copy()
83
- tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
84
- mean_vals = X_train.mean(0)
85
- for i in range(len(y_test)):
86
- if nmask[i] > 0:
87
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
88
- X_test_tmp[i,ordering[:nmask[i]]] = mean_vals[ordering[:nmask[i]]]
89
-
90
- yp_masked_test = trained_model.predict(X_test_tmp)
91
-
92
- return metric(y_test, yp_masked_test)
93
-
94
- def remove_impute(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
95
- """ The model is reevaluated for each test sample with the important features set to an imputed value.
96
-
97
- Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on
98
- being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should
99
- be significantly bigger than X_train.shape[1].
100
- """
101
-
102
- X_train, X_test = to_array(X_train, X_test)
103
-
104
- # how many features to mask
105
- assert X_train.shape[1] == X_test.shape[1]
106
-
107
- # keep nkeep top features for each test explanation
108
- C = np.cov(X_train.T)
109
- C += np.eye(C.shape[0]) * 1e-6
110
- X_test_tmp = X_test.copy()
111
- yp_masked_test = np.zeros(y_test.shape)
112
- tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
113
- mean_vals = X_train.mean(0)
114
- for i in range(len(y_test)):
115
- if nmask[i] > 0:
116
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
117
- observe_inds = ordering[nmask[i]:]
118
- impute_inds = ordering[:nmask[i]]
119
-
120
- # impute missing data assuming it follows a multivariate normal distribution
121
- Coo_inv = np.linalg.inv(C[observe_inds,:][:,observe_inds])
122
- Cio = C[impute_inds,:][:,observe_inds]
123
- impute = mean_vals[impute_inds] + Cio @ Coo_inv @ (X_test[i, observe_inds] - mean_vals[observe_inds])
124
-
125
- X_test_tmp[i, impute_inds] = impute
126
-
127
- yp_masked_test = trained_model.predict(X_test_tmp)
128
-
129
- return metric(y_test, yp_masked_test)
130
-
131
- def remove_resample(nmask, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
132
- """ The model is reevaluated for each test sample with the important features set to resample background values.
133
- """
134
-
135
- X_train, X_test = to_array(X_train, X_test)
136
-
137
- # how many features to mask
138
- assert X_train.shape[1] == X_test.shape[1]
139
-
140
- # how many samples to take
141
- nsamples = 100
142
-
143
- # keep nkeep top features for each test explanation
144
- N,M = X_test.shape
145
- X_test_tmp = np.tile(X_test, [1, nsamples]).reshape(nsamples * N, M)
146
- tie_breaking_noise = const_rand(M) * 1e-6
147
- inds = sklearn.utils.resample(np.arange(N), n_samples=nsamples, random_state=random_state)
148
- for i in range(N):
149
- if nmask[i] > 0:
150
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
151
- X_test_tmp[i*nsamples:(i+1)*nsamples, ordering[:nmask[i]]] = X_train[inds, :][:, ordering[:nmask[i]]]
152
-
153
- yp_masked_test = trained_model.predict(X_test_tmp)
154
- yp_masked_test = np.reshape(yp_masked_test, (N, nsamples)).mean(1) # take the mean output over all samples
155
-
156
- return metric(y_test, yp_masked_test)
157
-
158
- def batch_remove_retrain(nmask_train, nmask_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric):
159
- """ An approximation of holdout that only retraines the model once.
160
-
161
- This is also called ROAR (RemOve And Retrain) in work by Google. It is much more computationally
162
- efficient that the holdout method because it masks the most important features in every sample
163
- and then retrains the model once, instead of retraining the model for every test sample like
164
- the holdout metric.
165
- """
166
-
167
- warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
168
-
169
- X_train, X_test = to_array(X_train, X_test)
170
-
171
- # how many features to mask
172
- assert X_train.shape[1] == X_test.shape[1]
173
-
174
- # mask nmask top features for each explanation
175
- X_train_tmp = X_train.copy()
176
- X_train_mean = X_train.mean(0)
177
- tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
178
- for i in range(len(y_train)):
179
- if nmask_train[i] > 0:
180
- ordering = np.argsort(-attr_train[i, :] + tie_breaking_noise)
181
- X_train_tmp[i, ordering[:nmask_train[i]]] = X_train_mean[ordering[:nmask_train[i]]]
182
- X_test_tmp = X_test.copy()
183
- for i in range(len(y_test)):
184
- if nmask_test[i] > 0:
185
- ordering = np.argsort(-attr_test[i, :] + tie_breaking_noise)
186
- X_test_tmp[i, ordering[:nmask_test[i]]] = X_train_mean[ordering[:nmask_test[i]]]
187
-
188
- # train the model with all the given features masked
189
- model_masked = model_generator()
190
- model_masked.fit(X_train_tmp, y_train)
191
- yp_test_masked = model_masked.predict(X_test_tmp)
192
-
193
- return metric(y_test, yp_test_masked)
194
-
195
- _keep_cache = {}
196
- def keep_retrain(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
197
- """ The model is retrained for each test sample with the non-important features set to a constant.
198
-
199
- If you want to know how important a set of features is you can ask how the model would be
200
- different if only those features had existed. To determine this we can mask the other features
201
- across the entire training and test datasets, then retrain the model. If we apply compare the
202
- output of this retrained model to the original model we can see the effect produced by only
203
- knowning the important features. Since for individualized explanation methods each test sample
204
- has a different set of most important features we need to retrain the model for every test sample
205
- to get the change in model performance when a specified fraction of the most important features
206
- are retained.
207
- """
208
-
209
- warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
210
-
211
- # see if we match the last cached call
212
- global _keep_cache
213
- args = (X_train, y_train, X_test, y_test, model_generator, metric)
214
- cache_match = False
215
- if "args" in _keep_cache:
216
- if all(a is b for a,b in zip(_keep_cache["args"], args)) and np.all(_keep_cache["attr_test"] == attr_test):
217
- cache_match = True
218
-
219
- X_train, X_test = to_array(X_train, X_test)
220
-
221
- # how many features to mask
222
- assert X_train.shape[1] == X_test.shape[1]
223
-
224
- # this is the model we will retrain many times
225
- model_masked = model_generator()
226
-
227
- # keep nkeep top features and re-train the model for each test explanation
228
- X_train_tmp = np.zeros(X_train.shape)
229
- X_test_tmp = np.zeros(X_test.shape)
230
- yp_masked_test = np.zeros(y_test.shape)
231
- tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
232
- last_nkeep = _keep_cache.get("nkeep", None)
233
- last_yp_masked_test = _keep_cache.get("yp_masked_test", None)
234
- for i in tqdm(range(len(y_test)), "Retraining for the 'keep' metric"):
235
- if cache_match and last_nkeep[i] == nkeep[i]:
236
- yp_masked_test[i] = last_yp_masked_test[i]
237
- elif nkeep[i] == attr_test.shape[1]:
238
- yp_masked_test[i] = trained_model.predict(X_test[i:i+1])[0]
239
- else:
240
-
241
- # mask out the most important features for this test instance
242
- X_train_tmp[:] = X_train
243
- X_test_tmp[:] = X_test
244
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
245
- X_train_tmp[:,ordering[nkeep[i]:]] = X_train[:,ordering[nkeep[i]:]].mean()
246
- X_test_tmp[i,ordering[nkeep[i]:]] = X_train[:,ordering[nkeep[i]:]].mean()
247
-
248
- # retrain the model and make a prediction
249
- model_masked.fit(X_train_tmp, y_train)
250
- yp_masked_test[i] = model_masked.predict(X_test_tmp[i:i+1])[0]
251
-
252
- # save our results so the next call to us can be faster when there is redundancy
253
- _keep_cache["nkeep"] = nkeep
254
- _keep_cache["yp_masked_test"] = yp_masked_test
255
- _keep_cache["attr_test"] = attr_test
256
- _keep_cache["args"] = args
257
-
258
- return metric(y_test, yp_masked_test)
259
-
260
- def keep_mask(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
261
- """ The model is reevaluated for each test sample with the non-important features set to their mean.
262
- """
263
-
264
- X_train, X_test = to_array(X_train, X_test)
265
-
266
- # how many features to mask
267
- assert X_train.shape[1] == X_test.shape[1]
268
-
269
- # keep nkeep top features for each test explanation
270
- X_test_tmp = X_test.copy()
271
- yp_masked_test = np.zeros(y_test.shape)
272
- tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
273
- mean_vals = X_train.mean(0)
274
- for i in range(len(y_test)):
275
- if nkeep[i] < X_test.shape[1]:
276
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
277
- X_test_tmp[i,ordering[nkeep[i]:]] = mean_vals[ordering[nkeep[i]:]]
278
-
279
- yp_masked_test = trained_model.predict(X_test_tmp)
280
-
281
- return metric(y_test, yp_masked_test)
282
-
283
- def keep_impute(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
284
- """ The model is reevaluated for each test sample with the non-important features set to an imputed value.
285
-
286
- Note that the imputation is done using a multivariate normality assumption on the dataset. This depends on
287
- being able to estimate the full data covariance matrix (and inverse) accuractly. So X_train.shape[0] should
288
- be significantly bigger than X_train.shape[1].
289
- """
290
-
291
- X_train, X_test = to_array(X_train, X_test)
292
-
293
- # how many features to mask
294
- assert X_train.shape[1] == X_test.shape[1]
295
-
296
- # keep nkeep top features for each test explanation
297
- C = np.cov(X_train.T)
298
- C += np.eye(C.shape[0]) * 1e-6
299
- X_test_tmp = X_test.copy()
300
- yp_masked_test = np.zeros(y_test.shape)
301
- tie_breaking_noise = const_rand(X_train.shape[1], random_state) * 1e-6
302
- mean_vals = X_train.mean(0)
303
- for i in range(len(y_test)):
304
- if nkeep[i] < X_test.shape[1]:
305
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
306
- observe_inds = ordering[:nkeep[i]]
307
- impute_inds = ordering[nkeep[i]:]
308
-
309
- # impute missing data assuming it follows a multivariate normal distribution
310
- Coo_inv = np.linalg.inv(C[observe_inds,:][:,observe_inds])
311
- Cio = C[impute_inds,:][:,observe_inds]
312
- impute = mean_vals[impute_inds] + Cio @ Coo_inv @ (X_test[i, observe_inds] - mean_vals[observe_inds])
313
-
314
- X_test_tmp[i, impute_inds] = impute
315
-
316
- yp_masked_test = trained_model.predict(X_test_tmp)
317
-
318
- return metric(y_test, yp_masked_test)
319
-
320
- def keep_resample(nkeep, X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model, random_state):
321
- """ The model is reevaluated for each test sample with the non-important features set to resample background values.
322
- """ # why broken? overwriting?
323
-
324
- X_train, X_test = to_array(X_train, X_test)
325
-
326
- # how many features to mask
327
- assert X_train.shape[1] == X_test.shape[1]
328
-
329
- # how many samples to take
330
- nsamples = 100
331
-
332
- # keep nkeep top features for each test explanation
333
- N,M = X_test.shape
334
- X_test_tmp = np.tile(X_test, [1, nsamples]).reshape(nsamples * N, M)
335
- tie_breaking_noise = const_rand(M) * 1e-6
336
- inds = sklearn.utils.resample(np.arange(N), n_samples=nsamples, random_state=random_state)
337
- for i in range(N):
338
- if nkeep[i] < M:
339
- ordering = np.argsort(-attr_test[i,:] + tie_breaking_noise)
340
- X_test_tmp[i*nsamples:(i+1)*nsamples, ordering[nkeep[i]:]] = X_train[inds, :][:, ordering[nkeep[i]:]]
341
-
342
- yp_masked_test = trained_model.predict(X_test_tmp)
343
- yp_masked_test = np.reshape(yp_masked_test, (N, nsamples)).mean(1) # take the mean output over all samples
344
-
345
- return metric(y_test, yp_masked_test)
346
-
347
- def batch_keep_retrain(nkeep_train, nkeep_test, X_train, y_train, X_test, y_test, attr_train, attr_test, model_generator, metric):
348
- """ An approximation of keep that only retraines the model once.
349
-
350
- This is also called KAR (Keep And Retrain) in work by Google. It is much more computationally
351
- efficient that the keep method because it masks the unimportant features in every sample
352
- and then retrains the model once, instead of retraining the model for every test sample like
353
- the keep metric.
354
- """
355
-
356
- warnings.warn("The retrain based measures can incorrectly evaluate models in some cases!")
357
-
358
- X_train, X_test = to_array(X_train, X_test)
359
-
360
- # how many features to mask
361
- assert X_train.shape[1] == X_test.shape[1]
362
-
363
- # mask nkeep top features for each explanation
364
- X_train_tmp = X_train.copy()
365
- X_train_mean = X_train.mean(0)
366
- tie_breaking_noise = const_rand(X_train.shape[1]) * 1e-6
367
- for i in range(len(y_train)):
368
- if nkeep_train[i] < X_train.shape[1]:
369
- ordering = np.argsort(-attr_train[i, :] + tie_breaking_noise)
370
- X_train_tmp[i, ordering[nkeep_train[i]:]] = X_train_mean[ordering[nkeep_train[i]:]]
371
- X_test_tmp = X_test.copy()
372
- for i in range(len(y_test)):
373
- if nkeep_test[i] < X_test.shape[1]:
374
- ordering = np.argsort(-attr_test[i, :] + tie_breaking_noise)
375
- X_test_tmp[i, ordering[nkeep_test[i]:]] = X_train_mean[ordering[nkeep_test[i]:]]
376
-
377
- # train the model with all the features not given masked
378
- model_masked = model_generator()
379
- model_masked.fit(X_train_tmp, y_train)
380
- yp_test_masked = model_masked.predict(X_test_tmp)
381
-
382
- return metric(y_test, yp_test_masked)
383
-
384
- def local_accuracy(X_train, y_train, X_test, y_test, attr_test, model_generator, metric, trained_model):
385
- """ The how well do the features plus a constant base rate sum up to the model output.
386
- """
387
-
388
- X_train, X_test = to_array(X_train, X_test)
389
-
390
- # how many features to mask
391
- assert X_train.shape[1] == X_test.shape[1]
392
-
393
- # keep nkeep top features and re-train the model for each test explanation
394
- yp_test = trained_model.predict(X_test)
395
-
396
- return metric(yp_test, strip_list(attr_test).sum(1))
397
-
398
- def to_array(*args):
399
- return [a.values if isinstance(a, pd.DataFrame) else a for a in args]
400
-
401
- def const_rand(size, seed=23980):
402
- """ Generate a random array with a fixed seed.
403
- """
404
- old_seed = np.random.seed()
405
- np.random.seed(seed)
406
- out = np.random.rand(size)
407
- np.random.seed(old_seed)
408
- return out
409
-
410
- def const_shuffle(arr, seed=23980):
411
- """ Shuffle an array in-place with a fixed seed.
412
- """
413
- old_seed = np.random.seed()
414
- np.random.seed(seed)
415
- np.random.shuffle(arr)
416
- np.random.seed(old_seed)
417
-
418
- def strip_list(attrs):
419
- """ This assumes that if you have a list of outputs you just want the second one (the second class is the '1' class).
420
- """
421
- if isinstance(attrs, list):
422
- return attrs[1]
423
- else:
424
- return attrs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/methods.py DELETED
@@ -1,148 +0,0 @@
1
- import numpy as np
2
- import sklearn
3
-
4
- from .. import (
5
- DeepExplainer,
6
- GradientExplainer,
7
- KernelExplainer,
8
- LinearExplainer,
9
- SamplingExplainer,
10
- TreeExplainer,
11
- kmeans,
12
- )
13
- from ..explainers import other
14
- from .models import KerasWrap
15
-
16
-
17
- def linear_shap_corr(model, data):
18
- """ Linear SHAP (corr 1000)
19
- """
20
- return LinearExplainer(model, data, feature_dependence="correlation", nsamples=1000).shap_values
21
-
22
- def linear_shap_ind(model, data):
23
- """ Linear SHAP (ind)
24
- """
25
- return LinearExplainer(model, data, feature_dependence="independent").shap_values
26
-
27
- def coef(model, data):
28
- """ Coefficients
29
- """
30
- return other.CoefficentExplainer(model).attributions
31
-
32
- def random(model, data):
33
- """ Random
34
- color = #777777
35
- linestyle = solid
36
- """
37
- return other.RandomExplainer().attributions
38
-
39
- def kernel_shap_1000_meanref(model, data):
40
- """ Kernel SHAP 1000 mean ref.
41
- color = red_blue_circle(0.5)
42
- linestyle = solid
43
- """
44
- return lambda X: KernelExplainer(model.predict, kmeans(data, 1)).shap_values(X, nsamples=1000, l1_reg=0)
45
-
46
- def sampling_shap_1000(model, data):
47
- """ IME 1000
48
- color = red_blue_circle(0.5)
49
- linestyle = dashed
50
- """
51
- return lambda X: SamplingExplainer(model.predict, data).shap_values(X, nsamples=1000)
52
-
53
- def tree_shap_tree_path_dependent(model, data):
54
- """ TreeExplainer
55
- color = red_blue_circle(0)
56
- linestyle = solid
57
- """
58
- return TreeExplainer(model, feature_dependence="tree_path_dependent").shap_values
59
-
60
- def tree_shap_independent_200(model, data):
61
- """ TreeExplainer (independent)
62
- color = red_blue_circle(0)
63
- linestyle = dashed
64
- """
65
- data_subsample = sklearn.utils.resample(data, replace=False, n_samples=min(200, data.shape[0]), random_state=0)
66
- return TreeExplainer(model, data_subsample, feature_dependence="independent").shap_values
67
-
68
- def mean_abs_tree_shap(model, data):
69
- """ mean(|TreeExplainer|)
70
- color = red_blue_circle(0.25)
71
- linestyle = solid
72
- """
73
- def f(X):
74
- v = TreeExplainer(model).shap_values(X)
75
- if isinstance(v, list):
76
- return [np.tile(np.abs(sv).mean(0), (X.shape[0], 1)) for sv in v]
77
- else:
78
- return np.tile(np.abs(v).mean(0), (X.shape[0], 1))
79
- return f
80
-
81
- def saabas(model, data):
82
- """ Saabas
83
- color = red_blue_circle(0)
84
- linestyle = dotted
85
- """
86
- return lambda X: TreeExplainer(model).shap_values(X, approximate=True)
87
-
88
- def tree_gain(model, data):
89
- """ Gain/Gini Importance
90
- color = red_blue_circle(0.25)
91
- linestyle = dotted
92
- """
93
- return other.TreeGainExplainer(model).attributions
94
-
95
- def lime_tabular_regression_1000(model, data):
96
- """ LIME Tabular 1000
97
- color = red_blue_circle(0.75)
98
- """
99
- return lambda X: other.LimeTabularExplainer(model.predict, data, mode="regression").attributions(X, nsamples=1000)
100
-
101
- def lime_tabular_classification_1000(model, data):
102
- """ LIME Tabular 1000
103
- color = red_blue_circle(0.75)
104
- """
105
- return lambda X: other.LimeTabularExplainer(model.predict_proba, data, mode="classification").attributions(X, nsamples=1000)[1]
106
-
107
- def maple(model, data):
108
- """ MAPLE
109
- color = red_blue_circle(0.6)
110
- """
111
- return lambda X: other.MapleExplainer(model.predict, data).attributions(X, multiply_by_input=False)
112
-
113
- def tree_maple(model, data):
114
- """ Tree MAPLE
115
- color = red_blue_circle(0.6)
116
- linestyle = dashed
117
- """
118
- return lambda X: other.TreeMapleExplainer(model, data).attributions(X, multiply_by_input=False)
119
-
120
- def deep_shap(model, data):
121
- """ Deep SHAP (DeepLIFT)
122
- """
123
- if isinstance(model, KerasWrap):
124
- model = model.model
125
- explainer = DeepExplainer(model, kmeans(data, 1).data)
126
- def f(X):
127
- phi = explainer.shap_values(X)
128
- if isinstance(phi, list) and len(phi) == 1:
129
- return phi[0]
130
- else:
131
- return phi
132
-
133
- return f
134
-
135
- def expected_gradients(model, data):
136
- """ Expected Gradients
137
- """
138
- if isinstance(model, KerasWrap):
139
- model = model.model
140
- explainer = GradientExplainer(model, data)
141
- def f(X):
142
- phi = explainer.shap_values(X)
143
- if isinstance(phi, list) and len(phi) == 1:
144
- return phi[0]
145
- else:
146
- return phi
147
-
148
- return f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/metrics.py DELETED
@@ -1,824 +0,0 @@
1
- import hashlib
2
- import os
3
- import time
4
-
5
- import numpy as np
6
- import sklearn
7
-
8
- from .. import __version__
9
- from . import measures, methods
10
-
11
- try:
12
- import dill as pickle
13
- except Exception:
14
- pass
15
-
16
- try:
17
- from sklearn.model_selection import train_test_split
18
- except Exception:
19
- from sklearn.cross_validation import train_test_split
20
-
21
-
22
- def runtime(X, y, model_generator, method_name):
23
- """ Runtime (sec / 1k samples)
24
- transform = "negate_log"
25
- sort_order = 2
26
- """
27
-
28
- old_seed = np.random.seed()
29
- np.random.seed(3293)
30
-
31
- # average the method scores over several train/test splits
32
- method_reps = []
33
- for i in range(3):
34
- X_train, X_test, y_train, _ = train_test_split(__toarray(X), y, test_size=100, random_state=i)
35
-
36
- # define the model we are going to explain
37
- model = model_generator()
38
- model.fit(X_train, y_train)
39
-
40
- # evaluate each method
41
- start = time.time()
42
- explainer = getattr(methods, method_name)(model, X_train)
43
- build_time = time.time() - start
44
-
45
- start = time.time()
46
- explainer(X_test)
47
- explain_time = time.time() - start
48
-
49
- # we always normalize the explain time as though we were explaining 1000 samples
50
- # even if to reduce the runtime of the benchmark we do less (like just 100)
51
- method_reps.append(build_time + explain_time * 1000.0 / X_test.shape[0])
52
- np.random.seed(old_seed)
53
-
54
- return None, np.mean(method_reps)
55
-
56
- def local_accuracy(X, y, model_generator, method_name):
57
- """ Local Accuracy
58
- transform = "identity"
59
- sort_order = 0
60
- """
61
-
62
- def score_map(true, pred):
63
- """ Computes local accuracy as the normalized standard deviation of numerical scores.
64
- """
65
- return np.std(pred - true) / (np.std(true) + 1e-6)
66
-
67
- def score_function(X_train, X_test, y_train, y_test, attr_function, trained_model, random_state):
68
- return measures.local_accuracy(
69
- X_train, y_train, X_test, y_test, attr_function(X_test),
70
- model_generator, score_map, trained_model
71
- )
72
- return None, __score_method(X, y, None, model_generator, score_function, method_name)
73
-
74
- def consistency_guarantees(X, y, model_generator, method_name):
75
- """ Consistency Guarantees
76
- transform = "identity"
77
- sort_order = 1
78
- """
79
-
80
- # 1.0 - perfect consistency
81
- # 0.8 - guarantees depend on sampling
82
- # 0.6 - guarantees depend on approximation
83
- # 0.0 - no garuntees
84
- guarantees = {
85
- "linear_shap_corr": 1.0,
86
- "linear_shap_ind": 1.0,
87
- "coef": 0.0,
88
- "kernel_shap_1000_meanref": 0.8,
89
- "sampling_shap_1000": 0.8,
90
- "random": 0.0,
91
- "saabas": 0.0,
92
- "tree_gain": 0.0,
93
- "tree_shap_tree_path_dependent": 1.0,
94
- "tree_shap_independent_200": 1.0,
95
- "mean_abs_tree_shap": 1.0,
96
- "lime_tabular_regression_1000": 0.8,
97
- "lime_tabular_classification_1000": 0.8,
98
- "maple": 0.8,
99
- "tree_maple": 0.8,
100
- "deep_shap": 0.6,
101
- "expected_gradients": 0.6
102
- }
103
-
104
- return None, guarantees[method_name]
105
-
106
- def __mean_pred(true, pred):
107
- """ A trivial metric that is just is the output of the model.
108
- """
109
- return np.mean(pred)
110
-
111
- def keep_positive_mask(X, y, model_generator, method_name, num_fcounts=11):
112
- """ Keep Positive (mask)
113
- xlabel = "Max fraction of features kept"
114
- ylabel = "Mean model output"
115
- transform = "identity"
116
- sort_order = 4
117
- """
118
- return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
119
-
120
- def keep_negative_mask(X, y, model_generator, method_name, num_fcounts=11):
121
- """ Keep Negative (mask)
122
- xlabel = "Max fraction of features kept"
123
- ylabel = "Negative mean model output"
124
- transform = "negate"
125
- sort_order = 5
126
- """
127
- return __run_measure(measures.keep_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
128
-
129
- def keep_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11):
130
- """ Keep Absolute (mask)
131
- xlabel = "Max fraction of features kept"
132
- ylabel = "R^2"
133
- transform = "identity"
134
- sort_order = 6
135
- """
136
- return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
137
-
138
- def keep_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
139
- """ Keep Absolute (mask)
140
- xlabel = "Max fraction of features kept"
141
- ylabel = "ROC AUC"
142
- transform = "identity"
143
- sort_order = 6
144
- """
145
- return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
146
-
147
- def remove_positive_mask(X, y, model_generator, method_name, num_fcounts=11):
148
- """ Remove Positive (mask)
149
- xlabel = "Max fraction of features removed"
150
- ylabel = "Negative mean model output"
151
- transform = "negate"
152
- sort_order = 7
153
- """
154
- return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
155
-
156
- def remove_negative_mask(X, y, model_generator, method_name, num_fcounts=11):
157
- """ Remove Negative (mask)
158
- xlabel = "Max fraction of features removed"
159
- ylabel = "Mean model output"
160
- transform = "identity"
161
- sort_order = 8
162
- """
163
- return __run_measure(measures.remove_mask, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
164
-
165
- def remove_absolute_mask__r2(X, y, model_generator, method_name, num_fcounts=11):
166
- """ Remove Absolute (mask)
167
- xlabel = "Max fraction of features removed"
168
- ylabel = "1 - R^2"
169
- transform = "one_minus"
170
- sort_order = 9
171
- """
172
- return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
173
-
174
- def remove_absolute_mask__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
175
- """ Remove Absolute (mask)
176
- xlabel = "Max fraction of features removed"
177
- ylabel = "1 - ROC AUC"
178
- transform = "one_minus"
179
- sort_order = 9
180
- """
181
- return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
182
-
183
- def keep_positive_resample(X, y, model_generator, method_name, num_fcounts=11):
184
- """ Keep Positive (resample)
185
- xlabel = "Max fraction of features kept"
186
- ylabel = "Mean model output"
187
- transform = "identity"
188
- sort_order = 10
189
- """
190
- return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
191
-
192
- def keep_negative_resample(X, y, model_generator, method_name, num_fcounts=11):
193
- """ Keep Negative (resample)
194
- xlabel = "Max fraction of features kept"
195
- ylabel = "Negative mean model output"
196
- transform = "negate"
197
- sort_order = 11
198
- """
199
- return __run_measure(measures.keep_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
200
-
201
- def keep_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11):
202
- """ Keep Absolute (resample)
203
- xlabel = "Max fraction of features kept"
204
- ylabel = "R^2"
205
- transform = "identity"
206
- sort_order = 12
207
- """
208
- return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
209
-
210
- def keep_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
211
- """ Keep Absolute (resample)
212
- xlabel = "Max fraction of features kept"
213
- ylabel = "ROC AUC"
214
- transform = "identity"
215
- sort_order = 12
216
- """
217
- return __run_measure(measures.keep_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
218
-
219
- def remove_positive_resample(X, y, model_generator, method_name, num_fcounts=11):
220
- """ Remove Positive (resample)
221
- xlabel = "Max fraction of features removed"
222
- ylabel = "Negative mean model output"
223
- transform = "negate"
224
- sort_order = 13
225
- """
226
- return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
227
-
228
- def remove_negative_resample(X, y, model_generator, method_name, num_fcounts=11):
229
- """ Remove Negative (resample)
230
- xlabel = "Max fraction of features removed"
231
- ylabel = "Mean model output"
232
- transform = "identity"
233
- sort_order = 14
234
- """
235
- return __run_measure(measures.remove_resample, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
236
-
237
- def remove_absolute_resample__r2(X, y, model_generator, method_name, num_fcounts=11):
238
- """ Remove Absolute (resample)
239
- xlabel = "Max fraction of features removed"
240
- ylabel = "1 - R^2"
241
- transform = "one_minus"
242
- sort_order = 15
243
- """
244
- return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
245
-
246
- def remove_absolute_resample__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
247
- """ Remove Absolute (resample)
248
- xlabel = "Max fraction of features removed"
249
- ylabel = "1 - ROC AUC"
250
- transform = "one_minus"
251
- sort_order = 15
252
- """
253
- return __run_measure(measures.remove_resample, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
254
-
255
- def keep_positive_impute(X, y, model_generator, method_name, num_fcounts=11):
256
- """ Keep Positive (impute)
257
- xlabel = "Max fraction of features kept"
258
- ylabel = "Mean model output"
259
- transform = "identity"
260
- sort_order = 16
261
- """
262
- return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
263
-
264
- def keep_negative_impute(X, y, model_generator, method_name, num_fcounts=11):
265
- """ Keep Negative (impute)
266
- xlabel = "Max fraction of features kept"
267
- ylabel = "Negative mean model output"
268
- transform = "negate"
269
- sort_order = 17
270
- """
271
- return __run_measure(measures.keep_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
272
-
273
- def keep_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11):
274
- """ Keep Absolute (impute)
275
- xlabel = "Max fraction of features kept"
276
- ylabel = "R^2"
277
- transform = "identity"
278
- sort_order = 18
279
- """
280
- return __run_measure(measures.keep_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
281
-
282
- def keep_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
283
- """ Keep Absolute (impute)
284
- xlabel = "Max fraction of features kept"
285
- ylabel = "ROC AUC"
286
- transform = "identity"
287
- sort_order = 19
288
- """
289
- return __run_measure(measures.keep_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
290
-
291
- def remove_positive_impute(X, y, model_generator, method_name, num_fcounts=11):
292
- """ Remove Positive (impute)
293
- xlabel = "Max fraction of features removed"
294
- ylabel = "Negative mean model output"
295
- transform = "negate"
296
- sort_order = 7
297
- """
298
- return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
299
-
300
- def remove_negative_impute(X, y, model_generator, method_name, num_fcounts=11):
301
- """ Remove Negative (impute)
302
- xlabel = "Max fraction of features removed"
303
- ylabel = "Mean model output"
304
- transform = "identity"
305
- sort_order = 8
306
- """
307
- return __run_measure(measures.remove_impute, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
308
-
309
- def remove_absolute_impute__r2(X, y, model_generator, method_name, num_fcounts=11):
310
- """ Remove Absolute (impute)
311
- xlabel = "Max fraction of features removed"
312
- ylabel = "1 - R^2"
313
- transform = "one_minus"
314
- sort_order = 9
315
- """
316
- return __run_measure(measures.remove_impute, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.r2_score)
317
-
318
- def remove_absolute_impute__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
319
- """ Remove Absolute (impute)
320
- xlabel = "Max fraction of features removed"
321
- ylabel = "1 - ROC AUC"
322
- transform = "one_minus"
323
- sort_order = 9
324
- """
325
- return __run_measure(measures.remove_mask, X, y, model_generator, method_name, 0, num_fcounts, sklearn.metrics.roc_auc_score)
326
-
327
- def keep_positive_retrain(X, y, model_generator, method_name, num_fcounts=11):
328
- """ Keep Positive (retrain)
329
- xlabel = "Max fraction of features kept"
330
- ylabel = "Mean model output"
331
- transform = "identity"
332
- sort_order = 6
333
- """
334
- return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
335
-
336
- def keep_negative_retrain(X, y, model_generator, method_name, num_fcounts=11):
337
- """ Keep Negative (retrain)
338
- xlabel = "Max fraction of features kept"
339
- ylabel = "Negative mean model output"
340
- transform = "negate"
341
- sort_order = 7
342
- """
343
- return __run_measure(measures.keep_retrain, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
344
-
345
- def remove_positive_retrain(X, y, model_generator, method_name, num_fcounts=11):
346
- """ Remove Positive (retrain)
347
- xlabel = "Max fraction of features removed"
348
- ylabel = "Negative mean model output"
349
- transform = "negate"
350
- sort_order = 11
351
- """
352
- return __run_measure(measures.remove_retrain, X, y, model_generator, method_name, 1, num_fcounts, __mean_pred)
353
-
354
- def remove_negative_retrain(X, y, model_generator, method_name, num_fcounts=11):
355
- """ Remove Negative (retrain)
356
- xlabel = "Max fraction of features removed"
357
- ylabel = "Mean model output"
358
- transform = "identity"
359
- sort_order = 12
360
- """
361
- return __run_measure(measures.remove_retrain, X, y, model_generator, method_name, -1, num_fcounts, __mean_pred)
362
-
363
- def __run_measure(measure, X, y, model_generator, method_name, attribution_sign, num_fcounts, summary_function):
364
-
365
- def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trained_model, random_state):
366
- if attribution_sign == 0:
367
- A = np.abs(__strip_list(attr_function(X_test)))
368
- else:
369
- A = attribution_sign * __strip_list(attr_function(X_test))
370
- nmask = np.ones(len(y_test)) * fcount
371
- nmask = np.minimum(nmask, np.array(A >= 0).sum(1)).astype(int)
372
- return measure(
373
- nmask, X_train, y_train, X_test, y_test, A,
374
- model_generator, summary_function, trained_model, random_state
375
- )
376
- fcounts = __intlogspace(0, X.shape[1], num_fcounts)
377
- return fcounts, __score_method(X, y, fcounts, model_generator, score_function, method_name)
378
-
379
- def batch_remove_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11):
380
- """ Batch Remove Absolute (retrain)
381
- xlabel = "Fraction of features removed"
382
- ylabel = "1 - R^2"
383
- transform = "one_minus"
384
- sort_order = 13
385
- """
386
- return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts)
387
-
388
- def batch_keep_absolute_retrain__r2(X, y, model_generator, method_name, num_fcounts=11):
389
- """ Batch Keep Absolute (retrain)
390
- xlabel = "Fraction of features kept"
391
- ylabel = "R^2"
392
- transform = "identity"
393
- sort_order = 13
394
- """
395
- return __run_batch_abs_metric(measures.batch_keep_retrain, X, y, model_generator, method_name, sklearn.metrics.r2_score, num_fcounts)
396
-
397
- def batch_remove_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
398
- """ Batch Remove Absolute (retrain)
399
- xlabel = "Fraction of features removed"
400
- ylabel = "1 - ROC AUC"
401
- transform = "one_minus"
402
- sort_order = 13
403
- """
404
- return __run_batch_abs_metric(measures.batch_remove_retrain, X, y, model_generator, method_name, sklearn.metrics.roc_auc_score, num_fcounts)
405
-
406
- def batch_keep_absolute_retrain__roc_auc(X, y, model_generator, method_name, num_fcounts=11):
407
- """ Batch Keep Absolute (retrain)
408
- xlabel = "Fraction of features kept"
409
- ylabel = "ROC AUC"
410
- transform = "identity"
411
- sort_order = 13
412
- """
413
- return __run_batch_abs_metric(measures.batch_keep_retrain, X, y, model_generator, method_name, sklearn.metrics.roc_auc_score, num_fcounts)
414
-
415
- def __run_batch_abs_metric(metric, X, y, model_generator, method_name, loss, num_fcounts):
416
- def score_function(fcount, X_train, X_test, y_train, y_test, attr_function, trained_model):
417
- A_train = np.abs(__strip_list(attr_function(X_train)))
418
- nkeep_train = (np.ones(len(y_train)) * fcount).astype(int)
419
- #nkeep_train = np.minimum(nkeep_train, np.array(A_train > 0).sum(1)).astype(int)
420
- A_test = np.abs(__strip_list(attr_function(X_test)))
421
- nkeep_test = (np.ones(len(y_test)) * fcount).astype(int)
422
- #nkeep_test = np.minimum(nkeep_test, np.array(A_test >= 0).sum(1)).astype(int)
423
- return metric(
424
- nkeep_train, nkeep_test, X_train, y_train, X_test, y_test, A_train, A_test,
425
- model_generator, loss
426
- )
427
- fcounts = __intlogspace(0, X.shape[1], num_fcounts)
428
- return fcounts, __score_method(X, y, fcounts, model_generator, score_function, method_name)
429
-
430
- _attribution_cache = {}
431
- def __score_method(X, y, fcounts, model_generator, score_function, method_name, nreps=10, test_size=100, cache_dir="/tmp"):
432
- """ Test an explanation method.
433
- """
434
-
435
- try:
436
- pickle
437
- except NameError:
438
- raise ImportError("The 'dill' package could not be loaded and is needed for the benchmark!")
439
-
440
- old_seed = np.random.seed()
441
- np.random.seed(3293)
442
-
443
- # average the method scores over several train/test splits
444
- method_reps = []
445
-
446
- data_hash = hashlib.sha256(__toarray(X).flatten()).hexdigest() + hashlib.sha256(__toarray(y)).hexdigest()
447
- for i in range(nreps):
448
- X_train, X_test, y_train, y_test = train_test_split(__toarray(X), y, test_size=test_size, random_state=i)
449
-
450
- # define the model we are going to explain, caching so we onlu build it once
451
- model_id = "model_cache__v" + "__".join([__version__, data_hash, model_generator.__name__])+".pickle"
452
- cache_file = os.path.join(cache_dir, model_id + ".pickle")
453
- if os.path.isfile(cache_file):
454
- with open(cache_file, "rb") as f:
455
- model = pickle.load(f)
456
- else:
457
- model = model_generator()
458
- model.fit(X_train, y_train)
459
- with open(cache_file, "wb") as f:
460
- pickle.dump(model, f)
461
-
462
- attr_key = "_".join([model_generator.__name__, method_name, str(test_size), str(nreps), str(i), data_hash])
463
- def score(attr_function):
464
- def cached_attr_function(X_inner):
465
- if attr_key not in _attribution_cache:
466
- _attribution_cache[attr_key] = attr_function(X_inner)
467
- return _attribution_cache[attr_key]
468
-
469
- #cached_attr_function = lambda X: __check_cache(attr_function, X)
470
- if fcounts is None:
471
- return score_function(X_train, X_test, y_train, y_test, cached_attr_function, model, i)
472
- else:
473
- scores = []
474
- for f in fcounts:
475
- scores.append(score_function(f, X_train, X_test, y_train, y_test, cached_attr_function, model, i))
476
- return np.array(scores)
477
-
478
- # evaluate the method (only building the attribution function if we need to)
479
- if attr_key not in _attribution_cache:
480
- method_reps.append(score(getattr(methods, method_name)(model, X_train)))
481
- else:
482
- method_reps.append(score(None))
483
-
484
- np.random.seed(old_seed)
485
- return np.array(method_reps).mean(0)
486
-
487
-
488
- # used to memoize explainer functions so we don't waste time re-explaining the same object
489
- __cache0 = None
490
- __cache_X0 = None
491
- __cache_f0 = None
492
- __cache1 = None
493
- __cache_X1 = None
494
- __cache_f1 = None
495
- def __check_cache(f, X):
496
- global __cache0, __cache_X0, __cache_f0
497
- global __cache1, __cache_X1, __cache_f1
498
- if X is __cache_X0 and f is __cache_f0:
499
- return __cache0
500
- elif X is __cache_X1 and f is __cache_f1:
501
- return __cache1
502
- else:
503
- __cache_f1 = __cache_f0
504
- __cache_X1 = __cache_X0
505
- __cache1 = __cache0
506
- __cache_f0 = f
507
- __cache_X0 = X
508
- __cache0 = f(X)
509
- return __cache0
510
-
511
- def __intlogspace(start, end, count):
512
- return np.unique(np.round(start + (end-start) * (np.logspace(0, 1, count, endpoint=True) - 1) / 9).astype(int))
513
-
514
- def __toarray(X):
515
- """ Converts DataFrames to numpy arrays.
516
- """
517
- if hasattr(X, "values"):
518
- X = X.values
519
- return X
520
-
521
- def __strip_list(attrs):
522
- """ This assumes that if you have a list of outputs you just want the second one (the second class).
523
- """
524
- if isinstance(attrs, list):
525
- return attrs[1]
526
- else:
527
- return attrs
528
-
529
- def _fit_human(model_generator, val00, val01, val11):
530
- # force the model to fit a function with almost entirely zero background
531
- N = 1000000
532
- M = 3
533
- X = np.zeros((N,M))
534
- X.shape
535
- y = np.ones(N) * val00
536
- X[0:1000, 0] = 1
537
- y[0:1000] = val01
538
- for i in range(0,1000000,1000):
539
- X[i, 1] = 1
540
- y[i] = val01
541
- y[0] = val11
542
- model = model_generator()
543
- model.fit(X, y)
544
- return model
545
-
546
- def _human_and(X, model_generator, method_name, fever, cough):
547
- assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
548
-
549
- # these are from the sickness_score mturk user study experiment
550
- X_test = np.zeros((100,3))
551
- if not fever and not cough:
552
- human_consensus = np.array([0., 0., 0.])
553
- X_test[0,:] = np.array([[0., 0., 1.]])
554
- elif not fever and cough:
555
- human_consensus = np.array([0., 2., 0.])
556
- X_test[0,:] = np.array([[0., 1., 1.]])
557
- elif fever and cough:
558
- human_consensus = np.array([5., 5., 0.])
559
- X_test[0,:] = np.array([[1., 1., 1.]])
560
-
561
- # force the model to fit an XOR function with almost entirely zero background
562
- model = _fit_human(model_generator, 0, 2, 10)
563
-
564
- attr_function = getattr(methods, method_name)(model, X)
565
- methods_attrs = attr_function(X_test)
566
- return "human", (human_consensus, methods_attrs[0,:])
567
-
568
- def human_and_00(X, y, model_generator, method_name):
569
- """ AND (false/false)
570
-
571
- This tests how well a feature attribution method agrees with human intuition
572
- for an AND operation combined with linear effects. This metric deals
573
- specifically with the question of credit allocation for the following function
574
- when all three inputs are true:
575
- if fever: +2 points
576
- if cough: +2 points
577
- if fever and cough: +6 points
578
-
579
- transform = "identity"
580
- sort_order = 0
581
- """
582
- return _human_and(X, model_generator, method_name, False, False)
583
-
584
- def human_and_01(X, y, model_generator, method_name):
585
- """ AND (false/true)
586
-
587
- This tests how well a feature attribution method agrees with human intuition
588
- for an AND operation combined with linear effects. This metric deals
589
- specifically with the question of credit allocation for the following function
590
- when all three inputs are true:
591
- if fever: +2 points
592
- if cough: +2 points
593
- if fever and cough: +6 points
594
-
595
- transform = "identity"
596
- sort_order = 1
597
- """
598
- return _human_and(X, model_generator, method_name, False, True)
599
-
600
- def human_and_11(X, y, model_generator, method_name):
601
- """ AND (true/true)
602
-
603
- This tests how well a feature attribution method agrees with human intuition
604
- for an AND operation combined with linear effects. This metric deals
605
- specifically with the question of credit allocation for the following function
606
- when all three inputs are true:
607
- if fever: +2 points
608
- if cough: +2 points
609
- if fever and cough: +6 points
610
-
611
- transform = "identity"
612
- sort_order = 2
613
- """
614
- return _human_and(X, model_generator, method_name, True, True)
615
-
616
-
617
- def _human_or(X, model_generator, method_name, fever, cough):
618
- assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
619
-
620
- # these are from the sickness_score mturk user study experiment
621
- X_test = np.zeros((100,3))
622
- if not fever and not cough:
623
- human_consensus = np.array([0., 0., 0.])
624
- X_test[0,:] = np.array([[0., 0., 1.]])
625
- elif not fever and cough:
626
- human_consensus = np.array([0., 8., 0.])
627
- X_test[0,:] = np.array([[0., 1., 1.]])
628
- elif fever and cough:
629
- human_consensus = np.array([5., 5., 0.])
630
- X_test[0,:] = np.array([[1., 1., 1.]])
631
-
632
- # force the model to fit an XOR function with almost entirely zero background
633
- model = _fit_human(model_generator, 0, 8, 10)
634
-
635
- attr_function = getattr(methods, method_name)(model, X)
636
- methods_attrs = attr_function(X_test)
637
- return "human", (human_consensus, methods_attrs[0,:])
638
-
639
- def human_or_00(X, y, model_generator, method_name):
640
- """ OR (false/false)
641
-
642
- This tests how well a feature attribution method agrees with human intuition
643
- for an OR operation combined with linear effects. This metric deals
644
- specifically with the question of credit allocation for the following function
645
- when all three inputs are true:
646
- if fever: +2 points
647
- if cough: +2 points
648
- if fever or cough: +6 points
649
-
650
- transform = "identity"
651
- sort_order = 0
652
- """
653
- return _human_or(X, model_generator, method_name, False, False)
654
-
655
- def human_or_01(X, y, model_generator, method_name):
656
- """ OR (false/true)
657
-
658
- This tests how well a feature attribution method agrees with human intuition
659
- for an OR operation combined with linear effects. This metric deals
660
- specifically with the question of credit allocation for the following function
661
- when all three inputs are true:
662
- if fever: +2 points
663
- if cough: +2 points
664
- if fever or cough: +6 points
665
-
666
- transform = "identity"
667
- sort_order = 1
668
- """
669
- return _human_or(X, model_generator, method_name, False, True)
670
-
671
- def human_or_11(X, y, model_generator, method_name):
672
- """ OR (true/true)
673
-
674
- This tests how well a feature attribution method agrees with human intuition
675
- for an OR operation combined with linear effects. This metric deals
676
- specifically with the question of credit allocation for the following function
677
- when all three inputs are true:
678
- if fever: +2 points
679
- if cough: +2 points
680
- if fever or cough: +6 points
681
-
682
- transform = "identity"
683
- sort_order = 2
684
- """
685
- return _human_or(X, model_generator, method_name, True, True)
686
-
687
-
688
- def _human_xor(X, model_generator, method_name, fever, cough):
689
- assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
690
-
691
- # these are from the sickness_score mturk user study experiment
692
- X_test = np.zeros((100,3))
693
- if not fever and not cough:
694
- human_consensus = np.array([0., 0., 0.])
695
- X_test[0,:] = np.array([[0., 0., 1.]])
696
- elif not fever and cough:
697
- human_consensus = np.array([0., 8., 0.])
698
- X_test[0,:] = np.array([[0., 1., 1.]])
699
- elif fever and cough:
700
- human_consensus = np.array([2., 2., 0.])
701
- X_test[0,:] = np.array([[1., 1., 1.]])
702
-
703
- # force the model to fit an XOR function with almost entirely zero background
704
- model = _fit_human(model_generator, 0, 8, 4)
705
-
706
- attr_function = getattr(methods, method_name)(model, X)
707
- methods_attrs = attr_function(X_test)
708
- return "human", (human_consensus, methods_attrs[0,:])
709
-
710
- def human_xor_00(X, y, model_generator, method_name):
711
- """ XOR (false/false)
712
-
713
- This tests how well a feature attribution method agrees with human intuition
714
- for an eXclusive OR operation combined with linear effects. This metric deals
715
- specifically with the question of credit allocation for the following function
716
- when all three inputs are true:
717
- if fever: +2 points
718
- if cough: +2 points
719
- if fever or cough but not both: +6 points
720
-
721
- transform = "identity"
722
- sort_order = 3
723
- """
724
- return _human_xor(X, model_generator, method_name, False, False)
725
-
726
- def human_xor_01(X, y, model_generator, method_name):
727
- """ XOR (false/true)
728
-
729
- This tests how well a feature attribution method agrees with human intuition
730
- for an eXclusive OR operation combined with linear effects. This metric deals
731
- specifically with the question of credit allocation for the following function
732
- when all three inputs are true:
733
- if fever: +2 points
734
- if cough: +2 points
735
- if fever or cough but not both: +6 points
736
-
737
- transform = "identity"
738
- sort_order = 4
739
- """
740
- return _human_xor(X, model_generator, method_name, False, True)
741
-
742
- def human_xor_11(X, y, model_generator, method_name):
743
- """ XOR (true/true)
744
-
745
- This tests how well a feature attribution method agrees with human intuition
746
- for an eXclusive OR operation combined with linear effects. This metric deals
747
- specifically with the question of credit allocation for the following function
748
- when all three inputs are true:
749
- if fever: +2 points
750
- if cough: +2 points
751
- if fever or cough but not both: +6 points
752
-
753
- transform = "identity"
754
- sort_order = 5
755
- """
756
- return _human_xor(X, model_generator, method_name, True, True)
757
-
758
-
759
- def _human_sum(X, model_generator, method_name, fever, cough):
760
- assert np.abs(X).max() == 0, "Human agreement metrics are only for use with the human_agreement dataset!"
761
-
762
- # these are from the sickness_score mturk user study experiment
763
- X_test = np.zeros((100,3))
764
- if not fever and not cough:
765
- human_consensus = np.array([0., 0., 0.])
766
- X_test[0,:] = np.array([[0., 0., 1.]])
767
- elif not fever and cough:
768
- human_consensus = np.array([0., 2., 0.])
769
- X_test[0,:] = np.array([[0., 1., 1.]])
770
- elif fever and cough:
771
- human_consensus = np.array([2., 2., 0.])
772
- X_test[0,:] = np.array([[1., 1., 1.]])
773
-
774
- # force the model to fit an XOR function with almost entirely zero background
775
- model = _fit_human(model_generator, 0, 2, 4)
776
-
777
- attr_function = getattr(methods, method_name)(model, X)
778
- methods_attrs = attr_function(X_test)
779
- return "human", (human_consensus, methods_attrs[0,:])
780
-
781
- def human_sum_00(X, y, model_generator, method_name):
782
- """ SUM (false/false)
783
-
784
- This tests how well a feature attribution method agrees with human intuition
785
- for a SUM operation. This metric deals
786
- specifically with the question of credit allocation for the following function
787
- when all three inputs are true:
788
- if fever: +2 points
789
- if cough: +2 points
790
-
791
- transform = "identity"
792
- sort_order = 0
793
- """
794
- return _human_sum(X, model_generator, method_name, False, False)
795
-
796
- def human_sum_01(X, y, model_generator, method_name):
797
- """ SUM (false/true)
798
-
799
- This tests how well a feature attribution method agrees with human intuition
800
- for a SUM operation. This metric deals
801
- specifically with the question of credit allocation for the following function
802
- when all three inputs are true:
803
- if fever: +2 points
804
- if cough: +2 points
805
-
806
- transform = "identity"
807
- sort_order = 1
808
- """
809
- return _human_sum(X, model_generator, method_name, False, True)
810
-
811
- def human_sum_11(X, y, model_generator, method_name):
812
- """ SUM (true/true)
813
-
814
- This tests how well a feature attribution method agrees with human intuition
815
- for a SUM operation. This metric deals
816
- specifically with the question of credit allocation for the following function
817
- when all three inputs are true:
818
- if fever: +2 points
819
- if cough: +2 points
820
-
821
- transform = "identity"
822
- sort_order = 2
823
- """
824
- return _human_sum(X, model_generator, method_name, True, True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/models.py DELETED
@@ -1,230 +0,0 @@
1
- import numpy as np
2
- import sklearn
3
- import sklearn.ensemble
4
- from sklearn.preprocessing import StandardScaler
5
-
6
-
7
- class KerasWrap:
8
- """ A wrapper that allows us to set parameters in the constructor and do a reset before fitting.
9
- """
10
- def __init__(self, model, epochs, flatten_output=False):
11
- self.model = model
12
- self.epochs = epochs
13
- self.flatten_output = flatten_output
14
- self.init_weights = None
15
- self.scaler = StandardScaler()
16
-
17
- def fit(self, X, y, verbose=0):
18
- if self.init_weights is None:
19
- self.init_weights = self.model.get_weights()
20
- else:
21
- self.model.set_weights(self.init_weights)
22
- self.scaler.fit(X)
23
- return self.model.fit(X, y, epochs=self.epochs, verbose=verbose)
24
-
25
- def predict(self, X):
26
- X = self.scaler.transform(X)
27
- if self.flatten_output:
28
- return self.model.predict(X).flatten()
29
- else:
30
- return self.model.predict(X)
31
-
32
-
33
- # This models are all tuned for the corrgroups60 dataset
34
-
35
- def corrgroups60__lasso():
36
- """ Lasso Regression
37
- """
38
- return sklearn.linear_model.Lasso(alpha=0.1)
39
-
40
- def corrgroups60__ridge():
41
- """ Ridge Regression
42
- """
43
- return sklearn.linear_model.Ridge(alpha=1.0)
44
-
45
- def corrgroups60__decision_tree():
46
- """ Decision Tree
47
- """
48
-
49
- # max_depth was chosen to minimise test error
50
- return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=6)
51
-
52
- def corrgroups60__random_forest():
53
- """ Random Forest
54
- """
55
- return sklearn.ensemble.RandomForestRegressor(100, random_state=0)
56
-
57
- def corrgroups60__gbm():
58
- """ Gradient Boosted Trees
59
- """
60
- import xgboost
61
-
62
- # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split
63
- return xgboost.XGBRegressor(max_depth=6, n_estimators=50, learning_rate=0.1, n_jobs=8, random_state=0)
64
-
65
- def corrgroups60__ffnn():
66
- """ 4-Layer Neural Network
67
- """
68
- from tensorflow.keras.layers import Dense
69
- from tensorflow.keras.models import Sequential
70
-
71
- model = Sequential()
72
- model.add(Dense(32, activation='relu', input_dim=60))
73
- model.add(Dense(20, activation='relu'))
74
- model.add(Dense(20, activation='relu'))
75
- model.add(Dense(1))
76
-
77
- model.compile(optimizer='adam',
78
- loss='mean_squared_error',
79
- metrics=['mean_squared_error'])
80
-
81
- return KerasWrap(model, 30, flatten_output=True)
82
-
83
-
84
- def independentlinear60__lasso():
85
- """ Lasso Regression
86
- """
87
- return sklearn.linear_model.Lasso(alpha=0.1)
88
-
89
- def independentlinear60__ridge():
90
- """ Ridge Regression
91
- """
92
- return sklearn.linear_model.Ridge(alpha=1.0)
93
-
94
- def independentlinear60__decision_tree():
95
- """ Decision Tree
96
- """
97
-
98
- # max_depth was chosen to minimise test error
99
- return sklearn.tree.DecisionTreeRegressor(random_state=0, max_depth=4)
100
-
101
- def independentlinear60__random_forest():
102
- """ Random Forest
103
- """
104
- return sklearn.ensemble.RandomForestRegressor(100, random_state=0)
105
-
106
- def independentlinear60__gbm():
107
- """ Gradient Boosted Trees
108
- """
109
- import xgboost
110
-
111
- # max_depth and learning_rate were fixed then n_estimators was chosen using a train/test split
112
- return xgboost.XGBRegressor(max_depth=6, n_estimators=100, learning_rate=0.1, n_jobs=8, random_state=0)
113
-
114
- def independentlinear60__ffnn():
115
- """ 4-Layer Neural Network
116
- """
117
- from tensorflow.keras.layers import Dense
118
- from tensorflow.keras.models import Sequential
119
-
120
- model = Sequential()
121
- model.add(Dense(32, activation='relu', input_dim=60))
122
- model.add(Dense(20, activation='relu'))
123
- model.add(Dense(20, activation='relu'))
124
- model.add(Dense(1))
125
-
126
- model.compile(optimizer='adam',
127
- loss='mean_squared_error',
128
- metrics=['mean_squared_error'])
129
-
130
- return KerasWrap(model, 30, flatten_output=True)
131
-
132
-
133
- def cric__lasso():
134
- """ Lasso Regression
135
- """
136
- model = sklearn.linear_model.LogisticRegression(penalty="l1", C=0.002)
137
-
138
- # we want to explain the raw probability outputs of the trees
139
- model.predict = lambda X: model.predict_proba(X)[:,1]
140
-
141
- return model
142
-
143
- def cric__ridge():
144
- """ Ridge Regression
145
- """
146
- model = sklearn.linear_model.LogisticRegression(penalty="l2")
147
-
148
- # we want to explain the raw probability outputs of the trees
149
- model.predict = lambda X: model.predict_proba(X)[:,1]
150
-
151
- return model
152
-
153
- def cric__decision_tree():
154
- """ Decision Tree
155
- """
156
- model = sklearn.tree.DecisionTreeClassifier(random_state=0, max_depth=4)
157
-
158
- # we want to explain the raw probability outputs of the trees
159
- model.predict = lambda X: model.predict_proba(X)[:,1]
160
-
161
- return model
162
-
163
- def cric__random_forest():
164
- """ Random Forest
165
- """
166
- model = sklearn.ensemble.RandomForestClassifier(100, random_state=0)
167
-
168
- # we want to explain the raw probability outputs of the trees
169
- model.predict = lambda X: model.predict_proba(X)[:,1]
170
-
171
- return model
172
-
173
- def cric__gbm():
174
- """ Gradient Boosted Trees
175
- """
176
- import xgboost
177
-
178
- # max_depth and subsample match the params used for the full cric data in the paper
179
- # learning_rate was set a bit higher to allow for faster runtimes
180
- # n_estimators was chosen based on a train/test split of the data
181
- model = xgboost.XGBClassifier(max_depth=5, n_estimators=400, learning_rate=0.01, subsample=0.2, n_jobs=8, random_state=0)
182
-
183
- # we want to explain the margin, not the transformed probability outputs
184
- model.__orig_predict = model.predict
185
- model.predict = lambda X: model.__orig_predict(X, output_margin=True)
186
-
187
- return model
188
-
189
- def cric__ffnn():
190
- """ 4-Layer Neural Network
191
- """
192
- from tensorflow.keras.layers import Dense, Dropout
193
- from tensorflow.keras.models import Sequential
194
-
195
- model = Sequential()
196
- model.add(Dense(10, activation='relu', input_dim=336))
197
- model.add(Dropout(0.5))
198
- model.add(Dense(10, activation='relu'))
199
- model.add(Dropout(0.5))
200
- model.add(Dense(1, activation='sigmoid'))
201
-
202
- model.compile(optimizer='adam',
203
- loss='binary_crossentropy',
204
- metrics=['accuracy'])
205
-
206
- return KerasWrap(model, 30, flatten_output=True)
207
-
208
-
209
- def human__decision_tree():
210
- """ Decision Tree
211
- """
212
-
213
- # build data
214
- N = 1000000
215
- M = 3
216
- X = np.zeros((N,M))
217
- X.shape
218
- y = np.zeros(N)
219
- X[0, 0] = 1
220
- y[0] = 8
221
- X[1, 1] = 1
222
- y[1] = 8
223
- X[2, 0:2] = 1
224
- y[2] = 4
225
-
226
- # fit model
227
- xor_model = sklearn.tree.DecisionTreeRegressor(max_depth=2)
228
- xor_model.fit(X, y)
229
-
230
- return xor_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/benchmark/plots.py DELETED
@@ -1,566 +0,0 @@
1
- import base64
2
- import io
3
- import os
4
-
5
- import numpy as np
6
- import sklearn
7
- from matplotlib.colors import LinearSegmentedColormap
8
-
9
- from .. import __version__
10
- from ..plots import colors
11
- from . import methods, metrics, models
12
- from .experiments import run_experiments
13
-
14
- try:
15
- import matplotlib
16
- import matplotlib.pyplot as pl
17
- from IPython.display import HTML
18
- except ImportError:
19
- pass
20
-
21
-
22
- metadata = {
23
- # "runtime": {
24
- # "title": "Runtime",
25
- # "sort_order": 1
26
- # },
27
- # "local_accuracy": {
28
- # "title": "Local Accuracy",
29
- # "sort_order": 2
30
- # },
31
- # "consistency_guarantees": {
32
- # "title": "Consistency Guarantees",
33
- # "sort_order": 3
34
- # },
35
- # "keep_positive_mask": {
36
- # "title": "Keep Positive (mask)",
37
- # "xlabel": "Max fraction of features kept",
38
- # "ylabel": "Mean model output",
39
- # "sort_order": 4
40
- # },
41
- # "keep_negative_mask": {
42
- # "title": "Keep Negative (mask)",
43
- # "xlabel": "Max fraction of features kept",
44
- # "ylabel": "Negative mean model output",
45
- # "sort_order": 5
46
- # },
47
- # "keep_absolute_mask__r2": {
48
- # "title": "Keep Absolute (mask)",
49
- # "xlabel": "Max fraction of features kept",
50
- # "ylabel": "R^2",
51
- # "sort_order": 6
52
- # },
53
- # "keep_absolute_mask__roc_auc": {
54
- # "title": "Keep Absolute (mask)",
55
- # "xlabel": "Max fraction of features kept",
56
- # "ylabel": "ROC AUC",
57
- # "sort_order": 6
58
- # },
59
- # "remove_positive_mask": {
60
- # "title": "Remove Positive (mask)",
61
- # "xlabel": "Max fraction of features removed",
62
- # "ylabel": "Negative mean model output",
63
- # "sort_order": 7
64
- # },
65
- # "remove_negative_mask": {
66
- # "title": "Remove Negative (mask)",
67
- # "xlabel": "Max fraction of features removed",
68
- # "ylabel": "Mean model output",
69
- # "sort_order": 8
70
- # },
71
- # "remove_absolute_mask__r2": {
72
- # "title": "Remove Absolute (mask)",
73
- # "xlabel": "Max fraction of features removed",
74
- # "ylabel": "1 - R^2",
75
- # "sort_order": 9
76
- # },
77
- # "remove_absolute_mask__roc_auc": {
78
- # "title": "Remove Absolute (mask)",
79
- # "xlabel": "Max fraction of features removed",
80
- # "ylabel": "1 - ROC AUC",
81
- # "sort_order": 9
82
- # },
83
- # "keep_positive_resample": {
84
- # "title": "Keep Positive (resample)",
85
- # "xlabel": "Max fraction of features kept",
86
- # "ylabel": "Mean model output",
87
- # "sort_order": 10
88
- # },
89
- # "keep_negative_resample": {
90
- # "title": "Keep Negative (resample)",
91
- # "xlabel": "Max fraction of features kept",
92
- # "ylabel": "Negative mean model output",
93
- # "sort_order": 11
94
- # },
95
- # "keep_absolute_resample__r2": {
96
- # "title": "Keep Absolute (resample)",
97
- # "xlabel": "Max fraction of features kept",
98
- # "ylabel": "R^2",
99
- # "sort_order": 12
100
- # },
101
- # "keep_absolute_resample__roc_auc": {
102
- # "title": "Keep Absolute (resample)",
103
- # "xlabel": "Max fraction of features kept",
104
- # "ylabel": "ROC AUC",
105
- # "sort_order": 12
106
- # },
107
- # "remove_positive_resample": {
108
- # "title": "Remove Positive (resample)",
109
- # "xlabel": "Max fraction of features removed",
110
- # "ylabel": "Negative mean model output",
111
- # "sort_order": 13
112
- # },
113
- # "remove_negative_resample": {
114
- # "title": "Remove Negative (resample)",
115
- # "xlabel": "Max fraction of features removed",
116
- # "ylabel": "Mean model output",
117
- # "sort_order": 14
118
- # },
119
- # "remove_absolute_resample__r2": {
120
- # "title": "Remove Absolute (resample)",
121
- # "xlabel": "Max fraction of features removed",
122
- # "ylabel": "1 - R^2",
123
- # "sort_order": 15
124
- # },
125
- # "remove_absolute_resample__roc_auc": {
126
- # "title": "Remove Absolute (resample)",
127
- # "xlabel": "Max fraction of features removed",
128
- # "ylabel": "1 - ROC AUC",
129
- # "sort_order": 15
130
- # },
131
- # "remove_positive_retrain": {
132
- # "title": "Remove Positive (retrain)",
133
- # "xlabel": "Max fraction of features removed",
134
- # "ylabel": "Negative mean model output",
135
- # "sort_order": 11
136
- # },
137
- # "remove_negative_retrain": {
138
- # "title": "Remove Negative (retrain)",
139
- # "xlabel": "Max fraction of features removed",
140
- # "ylabel": "Mean model output",
141
- # "sort_order": 12
142
- # },
143
- # "keep_positive_retrain": {
144
- # "title": "Keep Positive (retrain)",
145
- # "xlabel": "Max fraction of features kept",
146
- # "ylabel": "Mean model output",
147
- # "sort_order": 6
148
- # },
149
- # "keep_negative_retrain": {
150
- # "title": "Keep Negative (retrain)",
151
- # "xlabel": "Max fraction of features kept",
152
- # "ylabel": "Negative mean model output",
153
- # "sort_order": 7
154
- # },
155
- # "batch_remove_absolute__r2": {
156
- # "title": "Batch Remove Absolute",
157
- # "xlabel": "Fraction of features removed",
158
- # "ylabel": "1 - R^2",
159
- # "sort_order": 13
160
- # },
161
- # "batch_keep_absolute__r2": {
162
- # "title": "Batch Keep Absolute",
163
- # "xlabel": "Fraction of features kept",
164
- # "ylabel": "R^2",
165
- # "sort_order": 8
166
- # },
167
- # "batch_remove_absolute__roc_auc": {
168
- # "title": "Batch Remove Absolute",
169
- # "xlabel": "Fraction of features removed",
170
- # "ylabel": "1 - ROC AUC",
171
- # "sort_order": 13
172
- # },
173
- # "batch_keep_absolute__roc_auc": {
174
- # "title": "Batch Keep Absolute",
175
- # "xlabel": "Fraction of features kept",
176
- # "ylabel": "ROC AUC",
177
- # "sort_order": 8
178
- # },
179
-
180
- # "linear_shap_corr": {
181
- # "title": "Linear SHAP (corr)"
182
- # },
183
- # "linear_shap_ind": {
184
- # "title": "Linear SHAP (ind)"
185
- # },
186
- # "coef": {
187
- # "title": "Coefficients"
188
- # },
189
- # "random": {
190
- # "title": "Random"
191
- # },
192
- # "kernel_shap_1000_meanref": {
193
- # "title": "Kernel SHAP 1000 mean ref."
194
- # },
195
- # "sampling_shap_1000": {
196
- # "title": "Sampling SHAP 1000"
197
- # },
198
- # "tree_shap_tree_path_dependent": {
199
- # "title": "Tree SHAP"
200
- # },
201
- # "saabas": {
202
- # "title": "Saabas"
203
- # },
204
- # "tree_gain": {
205
- # "title": "Gain/Gini Importance"
206
- # },
207
- # "mean_abs_tree_shap": {
208
- # "title": "mean(|Tree SHAP|)"
209
- # },
210
- # "lasso_regression": {
211
- # "title": "Lasso Regression"
212
- # },
213
- # "ridge_regression": {
214
- # "title": "Ridge Regression"
215
- # },
216
- # "gbm_regression": {
217
- # "title": "Gradient Boosting Regression"
218
- # }
219
- }
220
-
221
- benchmark_color_map = {
222
- "tree_shap": "#1E88E5",
223
- "deep_shap": "#1E88E5",
224
- "linear_shap_corr": "#1E88E5",
225
- "linear_shap_ind": "#ff0d57",
226
- "coef": "#13B755",
227
- "random": "#999999",
228
- "const_random": "#666666",
229
- "kernel_shap_1000_meanref": "#7C52FF"
230
- }
231
-
232
- # negated_metrics = [
233
- # "runtime",
234
- # "remove_positive_retrain",
235
- # "remove_positive_mask",
236
- # "remove_positive_resample",
237
- # "keep_negative_retrain",
238
- # "keep_negative_mask",
239
- # "keep_negative_resample"
240
- # ]
241
-
242
- # one_minus_metrics = [
243
- # "remove_absolute_mask__r2",
244
- # "remove_absolute_mask__roc_auc",
245
- # "remove_absolute_resample__r2",
246
- # "remove_absolute_resample__roc_auc"
247
- # ]
248
-
249
- def get_method_color(method):
250
- for line in getattr(methods, method).__doc__.split("\n"):
251
- line = line.strip()
252
- if line.startswith("color = "):
253
- v = line.split("=")[1].strip()
254
- if v.startswith("red_blue_circle("):
255
- return colors.red_blue_circle(float(v[16:-1]))
256
- else:
257
- return v
258
- return "#000000"
259
-
260
- def get_method_linestyle(method):
261
- for line in getattr(methods, method).__doc__.split("\n"):
262
- line = line.strip()
263
- if line.startswith("linestyle = "):
264
- return line.split("=")[1].strip()
265
- return "solid"
266
-
267
- def get_metric_attr(metric, attr):
268
- for line in getattr(metrics, metric).__doc__.split("\n"):
269
- line = line.strip()
270
-
271
- # string
272
- prefix = attr+" = \""
273
- suffix = "\""
274
- if line.startswith(prefix) and line.endswith(suffix):
275
- return line[len(prefix):-len(suffix)]
276
-
277
- # number
278
- prefix = attr+" = "
279
- if line.startswith(prefix):
280
- return float(line[len(prefix):])
281
- return ""
282
-
283
- def plot_curve(dataset, model, metric, cmap=benchmark_color_map):
284
- experiments = run_experiments(dataset=dataset, model=model, metric=metric)
285
- pl.figure()
286
- method_arr = []
287
- for (name,(fcounts,scores)) in experiments:
288
- _,_,method,_ = name
289
- transform = get_metric_attr(metric, "transform")
290
- if transform == "negate":
291
- scores = -scores
292
- elif transform == "one_minus":
293
- scores = 1 - scores
294
- auc = sklearn.metrics.auc(fcounts, scores) / fcounts[-1]
295
- method_arr.append((auc, method, scores))
296
- for (auc,method,scores) in sorted(method_arr):
297
- method_title = getattr(methods, method).__doc__.split("\n")[0].strip()
298
- label = f"{auc:6.3f} - " + method_title
299
- pl.plot(
300
- fcounts / fcounts[-1], scores, label=label,
301
- color=get_method_color(method), linewidth=2,
302
- linestyle=get_method_linestyle(method)
303
- )
304
- metric_title = getattr(metrics, metric).__doc__.split("\n")[0].strip()
305
- pl.xlabel(get_metric_attr(metric, "xlabel"))
306
- pl.ylabel(get_metric_attr(metric, "ylabel"))
307
- model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip()
308
- pl.title(metric_title + " - " + model_title)
309
- pl.gca().xaxis.set_ticks_position('bottom')
310
- pl.gca().yaxis.set_ticks_position('left')
311
- pl.gca().spines['right'].set_visible(False)
312
- pl.gca().spines['top'].set_visible(False)
313
- ahandles, alabels = pl.gca().get_legend_handles_labels()
314
- pl.legend(reversed(ahandles), reversed(alabels))
315
- return pl.gcf()
316
-
317
- def plot_human(dataset, model, metric, cmap=benchmark_color_map):
318
- experiments = run_experiments(dataset=dataset, model=model, metric=metric)
319
- pl.figure()
320
- method_arr = []
321
- for (name,(fcounts,scores)) in experiments:
322
- _,_,method,_ = name
323
- diff_sum = np.sum(np.abs(scores[1] - scores[0]))
324
- method_arr.append((diff_sum, method, scores[0], scores[1]))
325
-
326
- inds = np.arange(3) # the x locations for the groups
327
- inc_width = (1.0 / len(method_arr)) * 0.8
328
- width = inc_width * 0.9
329
- pl.bar(inds, method_arr[0][2], width, label="Human Consensus", color="black", edgecolor="white")
330
- i = 1
331
- line_style_to_hatch = {
332
- "dashed": "///",
333
- "dotted": "..."
334
- }
335
- for (diff_sum, method, _, methods_attrs) in sorted(method_arr):
336
- method_title = getattr(methods, method).__doc__.split("\n")[0].strip()
337
- label = f"{diff_sum:.2f} - " + method_title
338
- pl.bar(
339
- inds + inc_width * i, methods_attrs.flatten(), width, label=label, edgecolor="white",
340
- color=get_method_color(method), hatch=line_style_to_hatch.get(get_method_linestyle(method), None)
341
- )
342
- i += 1
343
- metric_title = getattr(metrics, metric).__doc__.split("\n")[0].strip()
344
- pl.xlabel("Features in the model")
345
- pl.ylabel("Feature attribution value")
346
- model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip()
347
- pl.title(metric_title + " - " + model_title)
348
- pl.gca().xaxis.set_ticks_position('bottom')
349
- pl.gca().yaxis.set_ticks_position('left')
350
- pl.gca().spines['right'].set_visible(False)
351
- pl.gca().spines['top'].set_visible(False)
352
- ahandles, alabels = pl.gca().get_legend_handles_labels()
353
- #pl.legend(ahandles, alabels)
354
- pl.xticks(np.array([0, 1, 2, 3]) - (inc_width + width)/2, ["", "", "", ""])
355
-
356
- pl.gca().xaxis.set_minor_locator(matplotlib.ticker.FixedLocator([0.4, 1.4, 2.4]))
357
- pl.gca().xaxis.set_minor_formatter(matplotlib.ticker.FixedFormatter(["Fever", "Cough", "Headache"]))
358
- pl.gca().tick_params(which='minor', length=0)
359
-
360
- pl.axhline(0, color="#aaaaaa", linewidth=0.5)
361
-
362
- box = pl.gca().get_position()
363
- pl.gca().set_position([
364
- box.x0, box.y0 + box.height * 0.3,
365
- box.width, box.height * 0.7
366
- ])
367
-
368
- # Put a legend below current axis
369
- pl.gca().legend(ahandles, alabels, loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=2)
370
-
371
- return pl.gcf()
372
-
373
- def _human_score_map(human_consensus, methods_attrs):
374
- """ Converts human agreement differences to numerical scores for coloring.
375
- """
376
-
377
- v = 1 - min(np.sum(np.abs(methods_attrs - human_consensus)) / (np.abs(human_consensus).sum() + 1), 1.0)
378
- return v
379
-
380
- def make_grid(scores, dataset, model, normalize=True, transform=True):
381
- color_vals = {}
382
- metric_sort_order = {}
383
- for (_,_,method,metric),(fcounts,score) in filter(lambda x: x[0][0] == dataset and x[0][1] == model, scores):
384
- metric_sort_order[metric] = get_metric_attr(metric, "sort_order")
385
- if metric not in color_vals:
386
- color_vals[metric] = {}
387
-
388
- if transform:
389
- transform_type = get_metric_attr(metric, "transform")
390
- if transform_type == "negate":
391
- score = -score
392
- elif transform_type == "one_minus":
393
- score = 1 - score
394
- elif transform_type == "negate_log":
395
- score = -np.log10(score)
396
-
397
- if fcounts is None:
398
- color_vals[metric][method] = score
399
- elif fcounts == "human":
400
- color_vals[metric][method] = _human_score_map(*score)
401
- else:
402
- auc = sklearn.metrics.auc(fcounts, score) / fcounts[-1]
403
- color_vals[metric][method] = auc
404
- # print(metric_sort_order)
405
- # col_keys = sorted(list(color_vals.keys()), key=lambda v: metric_sort_order[v])
406
- # print(col_keys)
407
- col_keys = list(color_vals.keys())
408
- row_keys = list({v for k in col_keys for v in color_vals[k].keys()})
409
-
410
- data = -28567 * np.ones((len(row_keys), len(col_keys)))
411
-
412
- for i in range(len(row_keys)):
413
- for j in range(len(col_keys)):
414
- data[i,j] = color_vals[col_keys[j]][row_keys[i]]
415
-
416
- assert np.sum(data == -28567) == 0, "There are missing data values!"
417
-
418
- if normalize:
419
- data = (data - data.min(0)) / (data.max(0) - data.min(0) + 1e-8)
420
-
421
- # sort by performans
422
- inds = np.argsort(-data.mean(1))
423
- row_keys = [row_keys[i] for i in inds]
424
- data = data[inds,:]
425
-
426
- return row_keys, col_keys, data
427
-
428
-
429
-
430
- red_blue_solid = LinearSegmentedColormap('red_blue_solid', {
431
- 'red': ((0.0, 198./255, 198./255),
432
- (1.0, 5./255, 5./255)),
433
-
434
- 'green': ((0.0, 34./255, 34./255),
435
- (1.0, 198./255, 198./255)),
436
-
437
- 'blue': ((0.0, 5./255, 5./255),
438
- (1.0, 24./255, 24./255)),
439
-
440
- 'alpha': ((0.0, 1, 1),
441
- (1.0, 1, 1))
442
- })
443
- def plot_grids(dataset, model_names, out_dir=None):
444
-
445
- if out_dir is not None:
446
- os.mkdir(out_dir)
447
-
448
- scores = []
449
- for model in model_names:
450
- scores.extend(run_experiments(dataset=dataset, model=model))
451
-
452
- prefix = "<style type='text/css'> .shap_benchmark__select:focus { outline-width: 0 }</style>"
453
- out = "" # background: rgb(30, 136, 229)
454
-
455
- # out += "<div style='font-weight: regular; font-size: 24px; text-align: center; background: #f8f8f8; color: #000; padding: 20px;'>SHAP Benchmark</div>\n"
456
- # out += "<div style='height: 1px; background: #ddd;'></div>\n"
457
- #out += "<div style='height: 7px; background-image: linear-gradient(to right, rgb(30, 136, 229), rgb(255, 13, 87));'></div>"
458
-
459
- out += "<div style='position: fixed; left: 0px; top: 0px; right: 0px; height: 230px; background: #fff;'>\n" # box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2), 0 6px 20px 0 rgba(0, 0, 0, 0.19);
460
- out += "<div style='position: absolute; bottom: 0px; left: 0px; right: 0px;' align='center'><table style='border-width: 1px; margin-right: 100px'>\n"
461
- for ind,model in enumerate(model_names):
462
- row_keys, col_keys, data = make_grid(scores, dataset, model)
463
- # print(data)
464
- # print(colors.red_blue_solid(0.))
465
- # print(colors.red_blue_solid(1.))
466
- # return
467
- for metric in col_keys:
468
- save_plot = False
469
- if metric.startswith("human_"):
470
- plot_human(dataset, model, metric)
471
- save_plot = True
472
- elif metric not in ["local_accuracy", "runtime", "consistency_guarantees"]:
473
- plot_curve(dataset, model, metric)
474
- save_plot = True
475
-
476
- if save_plot:
477
- buf = io.BytesIO()
478
- pl.gcf().set_size_inches(1200.0/175,1000.0/175)
479
- pl.savefig(buf, format='png', dpi=175)
480
- if out_dir is not None:
481
- pl.savefig(f"{out_dir}/plot_{dataset}_{model}_{metric}.pdf", format='pdf')
482
- pl.close()
483
- buf.seek(0)
484
- data_uri = base64.b64encode(buf.read()).decode('utf-8').replace('\n', '')
485
- plot_id = "plot__"+dataset+"__"+model+"__"+metric
486
- prefix += f"<div onclick='document.getElementById(\"{plot_id}\").style.display = \"none\"' style='display: none; position: fixed; z-index: 10000; left: 0px; right: 0px; top: 0px; bottom: 0px; background: rgba(255,255,255,0.9);' id='{plot_id}'>"
487
- prefix += "<img width='600' height='500' style='margin-left: auto; margin-right: auto; margin-top: 230px; box-shadow: 0 4px 8px 0 rgba(0, 0, 0, 0.2), 0 6px 20px 0 rgba(0, 0, 0, 0.19);' src='data:image/png;base64,%s'>" % data_uri
488
- prefix += "</div>"
489
-
490
- model_title = getattr(models, dataset+"__"+model).__doc__.split("\n")[0].strip()
491
-
492
- if ind == 0:
493
- out += "<tr><td style='background: #fff; width: 250px'></td></td>"
494
- for j in range(data.shape[1]):
495
- metric_title = getattr(metrics, col_keys[j]).__doc__.split("\n")[0].strip()
496
- out += "<td style='width: 40px; min-width: 40px; background: #fff; text-align: right;'><div style='margin-left: 10px; margin-bottom: -5px; white-space: nowrap; transform: rotate(-45deg); transform-origin: left top 0; width: 1.5em; margin-top: 8em'>" + metric_title + "</div></td>"
497
- out += "</tr>\n"
498
- out += "</table></div></div>\n"
499
- out += "<table style='border-width: 1px; margin-right: 100px; margin-top: 230px;'>\n"
500
- out += "<tr><td style='background: #fff'></td><td colspan='%d' style='background: #fff; font-weight: bold; text-align: center; margin-top: 10px;'>%s</td></tr>\n" % (data.shape[1], model_title)
501
- for i in range(data.shape[0]):
502
- out += "<tr>"
503
- # if i == 0:
504
- # out += "<td rowspan='%d' style='background: #fff; text-align: center; white-space: nowrap; vertical-align: middle; '><div style='font-weight: bold; transform: rotate(-90deg); transform-origin: left top 0; width: 1.5em; margin-top: 8em'>%s</div></td>" % (data.shape[0], model_name)
505
- method_title = getattr(methods, row_keys[i]).__doc__.split("\n")[0].strip()
506
- out += "<td style='background: #ffffff; text-align: right; width: 250px' title='shap.LinearExplainer(model)'>" + method_title + "</td>\n"
507
- for j in range(data.shape[1]):
508
- plot_id = "plot__"+dataset+"__"+model+"__"+col_keys[j]
509
- out += "<td onclick='document.getElementById(\"%s\").style.display = \"block\"' style='padding: 0px; padding-left: 0px; padding-right: 0px; border-left: 0px solid #999; width: 42px; min-width: 42px; height: 34px; background-color: #fff'>" % plot_id
510
- #out += "<div style='opacity: "+str(2*(max(1-data[i,j], data[i,j])-0.5))+"; background-color: rgb" + str(tuple(v*255 for v in colors.red_blue_solid(0. if data[i,j] < 0.5 else 1.)[:-1])) + "; height: "+str((30*max(1-data[i,j], data[i,j])))+"px; margin-left: auto; margin-right: auto; width:"+str((30*max(1-data[i,j], data[i,j])))+"px'></div>"
511
- out += "<div style='opacity: "+str(1)+"; background-color: rgb" + str(tuple(int(v*255) for v in colors.red_blue_no_bounds(5*(data[i,j]-0.8))[:-1])) + "; height: "+str(30*data[i,j])+"px; margin-left: auto; margin-right: auto; width:"+str(30*data[i,j])+"px'></div>"
512
- #out += "<div style='float: left; background-color: #eee; height: 10px; width: "+str((40*(1-data[i,j])))+"px'></div>"
513
- out += "</td>\n"
514
- out += "</tr>\n" #
515
-
516
- out += "<tr><td colspan='%d' style='background: #fff'></td></tr>" % (data.shape[1] + 1)
517
- out += "</table>"
518
-
519
- out += "<div style='position: fixed; left: 0px; top: 0px; right: 0px; text-align: left; padding: 20px; text-align: right'>\n"
520
- out += "<div style='float: left; font-weight: regular; font-size: 24px; color: #000;'>SHAP Benchmark <span style='font-size: 14px; color: #777777;'>v"+__version__+"</span></div>\n"
521
- # select {
522
- # margin: 50px;
523
- # width: 150px;
524
- # padding: 5px 35px 5px 5px;
525
- # font-size: 16px;
526
- # border: 1px solid #ccc;
527
- # height: 34px;
528
- # -webkit-appearance: none;
529
- # -moz-appearance: none;
530
- # appearance: none;
531
- # background: url(http://www.stackoverflow.com/favicon.ico) 96% / 15% no-repeat #eee;
532
- # }
533
- #out += "<div style='display: inline-block; margin-right: 20px; font-weight: normal; text-decoration: none; font-size: 18px; color: #000;'>Dataset:</div>\n"
534
-
535
- out += "<select id='shap_benchmark__select' onchange=\"document.location = '../' + this.value + '/index.html'\"dir='rtl' class='shap_benchmark__select' style='font-weight: normal; font-size: 20px; color: #000; padding: 10px; background: #fff; border: 1px solid #fff; -webkit-appearance: none; appearance: none;'>\n"
536
- out += "<option value='human' "+("selected" if dataset == "human" else "")+">Agreement with Human Intuition</option>\n"
537
- out += "<option value='corrgroups60' "+("selected" if dataset == "corrgroups60" else "")+">Correlated Groups 60 Dataset</option>\n"
538
- out += "<option value='independentlinear60' "+("selected" if dataset == "independentlinear60" else "")+">Independent Linear 60 Dataset</option>\n"
539
- #out += "<option>CRIC</option>\n"
540
- out += "</select>\n"
541
- #out += "<script> document.onload = function() { document.getElementById('shap_benchmark__select').value = '"+dataset+"'; }</script>"
542
- #out += "<div style='display: inline-block; margin-left: 20px; font-weight: normal; text-decoration: none; font-size: 18px; color: #000;'>CRIC</div>\n"
543
- out += "</div>\n"
544
-
545
- # output the legend
546
- out += "<table style='border-width: 0px; width: 100px; position: fixed; right: 50px; top: 200px; background: rgba(255, 255, 255, 0.9)'>\n"
547
- out += "<tr><td style='background: #fff; font-weight: normal; text-align: center'>Higher score</td></tr>\n"
548
- legend_size = 21
549
- for i in range(legend_size-9):
550
- out += "<tr>"
551
- out += "<td style='padding: 0px; padding-left: 0px; padding-right: 0px; border-left: 0px solid #999; height: 34px'>"
552
- val = (legend_size-i-1) / (legend_size-1)
553
- out += "<div style='opacity: 1; background-color: rgb" + str(tuple(int(v*255) for v in colors.red_blue_no_bounds(5*(val-0.8)))[:-1]) + "; height: "+str(30*val)+"px; margin-left: auto; margin-right: auto; width:"+str(30*val)+"px'></div>"
554
- out += "</td>"
555
- out += "</tr>\n" #
556
- out += "<tr><td style='background: #fff; font-weight: normal; text-align: center'>Lower score</td></tr>\n"
557
- out += "</table>\n"
558
-
559
- if out_dir is not None:
560
- with open(out_dir + "/index.html", "w") as f:
561
- f.write("<html><body style='margin: 0px; font-size: 16px; font-family: \"Myriad Pro\", Arial, sans-serif;'><center>")
562
- f.write(prefix)
563
- f.write(out)
564
- f.write("</center></body></html>")
565
- else:
566
- return HTML(prefix + out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/cext/_cext.cc DELETED
@@ -1,560 +0,0 @@
1
- #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
2
-
3
- #include <Python.h>
4
- #include <numpy/arrayobject.h>
5
- #include "tree_shap.h"
6
- #include <iostream>
7
-
8
- static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args);
9
- static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args);
10
- static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args);
11
- static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args);
12
- static PyObject *_cext_compute_expectations(PyObject *self, PyObject *args);
13
-
14
- static PyMethodDef module_methods[] = {
15
- {"dense_tree_shap", _cext_dense_tree_shap, METH_VARARGS, "C implementation of Tree SHAP for dense."},
16
- {"dense_tree_predict", _cext_dense_tree_predict, METH_VARARGS, "C implementation of tree predictions."},
17
- {"dense_tree_update_weights", _cext_dense_tree_update_weights, METH_VARARGS, "C implementation of tree node weight compuatations."},
18
- {"dense_tree_saabas", _cext_dense_tree_saabas, METH_VARARGS, "C implementation of Saabas (rough fast approximation to Tree SHAP)."},
19
- {"compute_expectations", _cext_compute_expectations, METH_VARARGS, "Compute expectations of internal nodes."},
20
- {NULL, NULL, 0, NULL}
21
- };
22
-
23
- #if PY_MAJOR_VERSION >= 3
24
- static struct PyModuleDef moduledef = {
25
- PyModuleDef_HEAD_INIT,
26
- "_cext",
27
- "This module provides an interface for a fast Tree SHAP implementation.",
28
- -1,
29
- module_methods,
30
- NULL,
31
- NULL,
32
- NULL,
33
- NULL
34
- };
35
- #endif
36
-
37
- #if PY_MAJOR_VERSION >= 3
38
- PyMODINIT_FUNC PyInit__cext(void)
39
- #else
40
- PyMODINIT_FUNC init_cext(void)
41
- #endif
42
- {
43
- #if PY_MAJOR_VERSION >= 3
44
- PyObject *module = PyModule_Create(&moduledef);
45
- if (!module) return NULL;
46
- #else
47
- PyObject *module = Py_InitModule("_cext", module_methods);
48
- if (!module) return;
49
- #endif
50
-
51
- /* Load `numpy` functionality. */
52
- import_array();
53
-
54
- #if PY_MAJOR_VERSION >= 3
55
- return module;
56
- #endif
57
- }
58
-
59
- static PyObject *_cext_compute_expectations(PyObject *self, PyObject *args)
60
- {
61
- PyObject *children_left_obj;
62
- PyObject *children_right_obj;
63
- PyObject *node_sample_weight_obj;
64
- PyObject *values_obj;
65
-
66
- /* Parse the input tuple */
67
- if (!PyArg_ParseTuple(
68
- args, "OOOO", &children_left_obj, &children_right_obj, &node_sample_weight_obj, &values_obj
69
- )) return NULL;
70
-
71
- /* Interpret the input objects as numpy arrays. */
72
- PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
73
- PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
74
- PyArrayObject *node_sample_weight_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weight_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
75
- PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
76
-
77
- /* If that didn't work, throw an exception. */
78
- if (children_left_array == NULL || children_right_array == NULL ||
79
- values_array == NULL || node_sample_weight_array == NULL) {
80
- Py_XDECREF(children_left_array);
81
- Py_XDECREF(children_right_array);
82
- //PyArray_ResolveWritebackIfCopy(values_array);
83
- Py_XDECREF(values_array);
84
- Py_XDECREF(node_sample_weight_array);
85
- return NULL;
86
- }
87
-
88
- TreeEnsemble tree;
89
-
90
- // number of outputs
91
- tree.num_outputs = PyArray_DIM(values_array, 1);
92
-
93
- /* Get pointers to the data as C-types. */
94
- tree.children_left = (int*)PyArray_DATA(children_left_array);
95
- tree.children_right = (int*)PyArray_DATA(children_right_array);
96
- tree.values = (tfloat*)PyArray_DATA(values_array);
97
- tree.node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weight_array);
98
-
99
- const int max_depth = compute_expectations(tree);
100
-
101
- // clean up the created python objects
102
- Py_XDECREF(children_left_array);
103
- Py_XDECREF(children_right_array);
104
- //PyArray_ResolveWritebackIfCopy(values_array);
105
- Py_XDECREF(values_array);
106
- Py_XDECREF(node_sample_weight_array);
107
-
108
- PyObject *ret = Py_BuildValue("i", max_depth);
109
- return ret;
110
- }
111
-
112
-
113
- static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args)
114
- {
115
- PyObject *children_left_obj;
116
- PyObject *children_right_obj;
117
- PyObject *children_default_obj;
118
- PyObject *features_obj;
119
- PyObject *thresholds_obj;
120
- PyObject *values_obj;
121
- PyObject *node_sample_weights_obj;
122
- int max_depth;
123
- PyObject *X_obj;
124
- PyObject *X_missing_obj;
125
- PyObject *y_obj;
126
- PyObject *R_obj;
127
- PyObject *R_missing_obj;
128
- int tree_limit;
129
- PyObject *out_contribs_obj;
130
- int feature_dependence;
131
- int model_output;
132
- PyObject *base_offset_obj;
133
- bool interactions;
134
-
135
- /* Parse the input tuple */
136
- if (!PyArg_ParseTuple(
137
- args, "OOOOOOOiOOOOOiOOiib", &children_left_obj, &children_right_obj, &children_default_obj,
138
- &features_obj, &thresholds_obj, &values_obj, &node_sample_weights_obj,
139
- &max_depth, &X_obj, &X_missing_obj, &y_obj, &R_obj, &R_missing_obj, &tree_limit, &base_offset_obj,
140
- &out_contribs_obj, &feature_dependence, &model_output, &interactions
141
- )) return NULL;
142
-
143
- /* Interpret the input objects as numpy arrays. */
144
- PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
145
- PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
146
- PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
147
- PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
148
- PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
149
- PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
150
- PyArrayObject *node_sample_weights_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weights_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
151
- PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
152
- PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
153
- PyArrayObject *y_array = NULL;
154
- if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
155
- PyArrayObject *R_array = NULL;
156
- if (R_obj != Py_None) R_array = (PyArrayObject*)PyArray_FROM_OTF(R_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
157
- PyArrayObject *R_missing_array = NULL;
158
- if (R_missing_obj != Py_None) R_missing_array = (PyArrayObject*)PyArray_FROM_OTF(R_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
159
- PyArrayObject *out_contribs_array = (PyArrayObject*)PyArray_FROM_OTF(out_contribs_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
160
- PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
161
-
162
- /* If that didn't work, throw an exception. Note that R and y are optional. */
163
- if (children_left_array == NULL || children_right_array == NULL ||
164
- children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
165
- values_array == NULL || node_sample_weights_array == NULL || X_array == NULL ||
166
- X_missing_array == NULL || out_contribs_array == NULL) {
167
- Py_XDECREF(children_left_array);
168
- Py_XDECREF(children_right_array);
169
- Py_XDECREF(children_default_array);
170
- Py_XDECREF(features_array);
171
- Py_XDECREF(thresholds_array);
172
- Py_XDECREF(values_array);
173
- Py_XDECREF(node_sample_weights_array);
174
- Py_XDECREF(X_array);
175
- Py_XDECREF(X_missing_array);
176
- if (y_array != NULL) Py_XDECREF(y_array);
177
- if (R_array != NULL) Py_XDECREF(R_array);
178
- if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
179
- //PyArray_ResolveWritebackIfCopy(out_contribs_array);
180
- Py_XDECREF(out_contribs_array);
181
- Py_XDECREF(base_offset_array);
182
- return NULL;
183
- }
184
-
185
- const unsigned num_X = PyArray_DIM(X_array, 0);
186
- const unsigned M = PyArray_DIM(X_array, 1);
187
- const unsigned max_nodes = PyArray_DIM(values_array, 1);
188
- const unsigned num_outputs = PyArray_DIM(values_array, 2);
189
- unsigned num_R = 0;
190
- if (R_array != NULL) num_R = PyArray_DIM(R_array, 0);
191
-
192
- // Get pointers to the data as C-types
193
- int *children_left = (int*)PyArray_DATA(children_left_array);
194
- int *children_right = (int*)PyArray_DATA(children_right_array);
195
- int *children_default = (int*)PyArray_DATA(children_default_array);
196
- int *features = (int*)PyArray_DATA(features_array);
197
- tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
198
- tfloat *values = (tfloat*)PyArray_DATA(values_array);
199
- tfloat *node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weights_array);
200
- tfloat *X = (tfloat*)PyArray_DATA(X_array);
201
- bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
202
- tfloat *y = NULL;
203
- if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
204
- tfloat *R = NULL;
205
- if (R_array != NULL) R = (tfloat*)PyArray_DATA(R_array);
206
- bool *R_missing = NULL;
207
- if (R_missing_array != NULL) R_missing = (bool*)PyArray_DATA(R_missing_array);
208
- tfloat *out_contribs = (tfloat*)PyArray_DATA(out_contribs_array);
209
- tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
210
-
211
- // these are just a wrapper objects for all the pointers and numbers associated with
212
- // the ensemble tree model and the dataset we are explaining
213
- TreeEnsemble trees = TreeEnsemble(
214
- children_left, children_right, children_default, features, thresholds, values,
215
- node_sample_weights, max_depth, tree_limit, base_offset,
216
- max_nodes, num_outputs
217
- );
218
- ExplanationDataset data = ExplanationDataset(X, X_missing, y, R, R_missing, num_X, M, num_R);
219
-
220
- dense_tree_shap(trees, data, out_contribs, feature_dependence, model_output, interactions);
221
-
222
- // retrieve return value before python cleanup of objects
223
- tfloat ret_value = (double)values[0];
224
-
225
- // clean up the created python objects
226
- Py_XDECREF(children_left_array);
227
- Py_XDECREF(children_right_array);
228
- Py_XDECREF(children_default_array);
229
- Py_XDECREF(features_array);
230
- Py_XDECREF(thresholds_array);
231
- Py_XDECREF(values_array);
232
- Py_XDECREF(node_sample_weights_array);
233
- Py_XDECREF(X_array);
234
- Py_XDECREF(X_missing_array);
235
- if (y_array != NULL) Py_XDECREF(y_array);
236
- if (R_array != NULL) Py_XDECREF(R_array);
237
- if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
238
- //PyArray_ResolveWritebackIfCopy(out_contribs_array);
239
- Py_XDECREF(out_contribs_array);
240
- Py_XDECREF(base_offset_array);
241
-
242
- /* Build the output tuple */
243
- PyObject *ret = Py_BuildValue("d", ret_value);
244
- return ret;
245
- }
246
-
247
-
248
- static PyObject *_cext_dense_tree_predict(PyObject *self, PyObject *args)
249
- {
250
- PyObject *children_left_obj;
251
- PyObject *children_right_obj;
252
- PyObject *children_default_obj;
253
- PyObject *features_obj;
254
- PyObject *thresholds_obj;
255
- PyObject *values_obj;
256
- int max_depth;
257
- int tree_limit;
258
- PyObject *base_offset_obj;
259
- int model_output;
260
- PyObject *X_obj;
261
- PyObject *X_missing_obj;
262
- PyObject *y_obj;
263
- PyObject *out_pred_obj;
264
-
265
- /* Parse the input tuple */
266
- if (!PyArg_ParseTuple(
267
- args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj,
268
- &features_obj, &thresholds_obj, &values_obj, &max_depth, &tree_limit, &base_offset_obj, &model_output,
269
- &X_obj, &X_missing_obj, &y_obj, &out_pred_obj
270
- )) return NULL;
271
-
272
- /* Interpret the input objects as numpy arrays. */
273
- PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
274
- PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
275
- PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
276
- PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
277
- PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
278
- PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
279
- PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
280
- PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
281
- PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
282
- PyArrayObject *y_array = NULL;
283
- if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
284
- PyArrayObject *out_pred_array = (PyArrayObject*)PyArray_FROM_OTF(out_pred_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
285
-
286
- /* If that didn't work, throw an exception. Note that R and y are optional. */
287
- if (children_left_array == NULL || children_right_array == NULL ||
288
- children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
289
- values_array == NULL || X_array == NULL ||
290
- X_missing_array == NULL || out_pred_array == NULL) {
291
- Py_XDECREF(children_left_array);
292
- Py_XDECREF(children_right_array);
293
- Py_XDECREF(children_default_array);
294
- Py_XDECREF(features_array);
295
- Py_XDECREF(thresholds_array);
296
- Py_XDECREF(values_array);
297
- Py_XDECREF(base_offset_array);
298
- Py_XDECREF(X_array);
299
- Py_XDECREF(X_missing_array);
300
- if (y_array != NULL) Py_XDECREF(y_array);
301
- //PyArray_ResolveWritebackIfCopy(out_pred_array);
302
- Py_XDECREF(out_pred_array);
303
- return NULL;
304
- }
305
-
306
- const unsigned num_X = PyArray_DIM(X_array, 0);
307
- const unsigned M = PyArray_DIM(X_array, 1);
308
- const unsigned max_nodes = PyArray_DIM(values_array, 1);
309
- const unsigned num_outputs = PyArray_DIM(values_array, 2);
310
-
311
- const unsigned num_offsets = PyArray_DIM(base_offset_array, 0);
312
- if (num_offsets != num_outputs) {
313
- std::cerr << "The passed base_offset array does that have the same number of outputs as the values array: " << num_offsets << " vs. " << num_outputs << std::endl;
314
- return NULL;
315
- }
316
-
317
- // Get pointers to the data as C-types
318
- int *children_left = (int*)PyArray_DATA(children_left_array);
319
- int *children_right = (int*)PyArray_DATA(children_right_array);
320
- int *children_default = (int*)PyArray_DATA(children_default_array);
321
- int *features = (int*)PyArray_DATA(features_array);
322
- tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
323
- tfloat *values = (tfloat*)PyArray_DATA(values_array);
324
- tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
325
- tfloat *X = (tfloat*)PyArray_DATA(X_array);
326
- bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
327
- tfloat *y = NULL;
328
- if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
329
- tfloat *out_pred = (tfloat*)PyArray_DATA(out_pred_array);
330
-
331
- // these are just wrapper objects for all the pointers and numbers associated with
332
- // the ensemble tree model and the dataset we are explaining
333
- TreeEnsemble trees = TreeEnsemble(
334
- children_left, children_right, children_default, features, thresholds, values,
335
- NULL, max_depth, tree_limit, base_offset,
336
- max_nodes, num_outputs
337
- );
338
- ExplanationDataset data = ExplanationDataset(X, X_missing, y, NULL, NULL, num_X, M, 0);
339
-
340
- dense_tree_predict(out_pred, trees, data, model_output);
341
-
342
- // clean up the created python objects
343
- Py_XDECREF(children_left_array);
344
- Py_XDECREF(children_right_array);
345
- Py_XDECREF(children_default_array);
346
- Py_XDECREF(features_array);
347
- Py_XDECREF(thresholds_array);
348
- Py_XDECREF(values_array);
349
- Py_XDECREF(base_offset_array);
350
- Py_XDECREF(X_array);
351
- Py_XDECREF(X_missing_array);
352
- if (y_array != NULL) Py_XDECREF(y_array);
353
- //PyArray_ResolveWritebackIfCopy(out_pred_array);
354
- Py_XDECREF(out_pred_array);
355
-
356
- /* Build the output tuple */
357
- PyObject *ret = Py_BuildValue("d", (double)values[0]);
358
- return ret;
359
- }
360
-
361
-
362
- static PyObject *_cext_dense_tree_update_weights(PyObject *self, PyObject *args)
363
- {
364
- PyObject *children_left_obj;
365
- PyObject *children_right_obj;
366
- PyObject *children_default_obj;
367
- PyObject *features_obj;
368
- PyObject *thresholds_obj;
369
- PyObject *values_obj;
370
- int tree_limit;
371
- PyObject *node_sample_weight_obj;
372
- PyObject *X_obj;
373
- PyObject *X_missing_obj;
374
-
375
- /* Parse the input tuple */
376
- if (!PyArg_ParseTuple(
377
- args, "OOOOOOiOOO", &children_left_obj, &children_right_obj, &children_default_obj,
378
- &features_obj, &thresholds_obj, &values_obj, &tree_limit, &node_sample_weight_obj, &X_obj, &X_missing_obj
379
- )) return NULL;
380
-
381
- /* Interpret the input objects as numpy arrays. */
382
- PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
383
- PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
384
- PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
385
- PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
386
- PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
387
- PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
388
- PyArrayObject *node_sample_weight_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weight_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
389
- PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
390
- PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
391
-
392
- /* If that didn't work, throw an exception. */
393
- if (children_left_array == NULL || children_right_array == NULL ||
394
- children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
395
- values_array == NULL || node_sample_weight_array == NULL || X_array == NULL ||
396
- X_missing_array == NULL) {
397
- Py_XDECREF(children_left_array);
398
- Py_XDECREF(children_right_array);
399
- Py_XDECREF(children_default_array);
400
- Py_XDECREF(features_array);
401
- Py_XDECREF(thresholds_array);
402
- Py_XDECREF(values_array);
403
- //PyArray_ResolveWritebackIfCopy(node_sample_weight_array);
404
- Py_XDECREF(node_sample_weight_array);
405
- Py_XDECREF(X_array);
406
- Py_XDECREF(X_missing_array);
407
- std::cerr << "Found a NULL input array in _cext_dense_tree_update_weights!\n";
408
- return NULL;
409
- }
410
-
411
- const unsigned num_X = PyArray_DIM(X_array, 0);
412
- const unsigned M = PyArray_DIM(X_array, 1);
413
- const unsigned max_nodes = PyArray_DIM(values_array, 1);
414
-
415
- // Get pointers to the data as C-types
416
- int *children_left = (int*)PyArray_DATA(children_left_array);
417
- int *children_right = (int*)PyArray_DATA(children_right_array);
418
- int *children_default = (int*)PyArray_DATA(children_default_array);
419
- int *features = (int*)PyArray_DATA(features_array);
420
- tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
421
- tfloat *values = (tfloat*)PyArray_DATA(values_array);
422
- tfloat *node_sample_weight = (tfloat*)PyArray_DATA(node_sample_weight_array);
423
- tfloat *X = (tfloat*)PyArray_DATA(X_array);
424
- bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
425
-
426
- // these are just wrapper objects for all the pointers and numbers associated with
427
- // the ensemble tree model and the dataset we are explaining
428
- TreeEnsemble trees = TreeEnsemble(
429
- children_left, children_right, children_default, features, thresholds, values,
430
- node_sample_weight, 0, tree_limit, 0, max_nodes, 0
431
- );
432
- ExplanationDataset data = ExplanationDataset(X, X_missing, NULL, NULL, NULL, num_X, M, 0);
433
-
434
- dense_tree_update_weights(trees, data);
435
-
436
- // clean up the created python objects
437
- Py_XDECREF(children_left_array);
438
- Py_XDECREF(children_right_array);
439
- Py_XDECREF(children_default_array);
440
- Py_XDECREF(features_array);
441
- Py_XDECREF(thresholds_array);
442
- Py_XDECREF(values_array);
443
- // PyArray_ResolveWritebackIfCopy(node_sample_weight_array);
444
- Py_XDECREF(node_sample_weight_array);
445
- Py_XDECREF(X_array);
446
- Py_XDECREF(X_missing_array);
447
-
448
- /* Build the output tuple */
449
- PyObject *ret = Py_BuildValue("d", 1);
450
- return ret;
451
- }
452
-
453
-
454
- static PyObject *_cext_dense_tree_saabas(PyObject *self, PyObject *args)
455
- {
456
- PyObject *children_left_obj;
457
- PyObject *children_right_obj;
458
- PyObject *children_default_obj;
459
- PyObject *features_obj;
460
- PyObject *thresholds_obj;
461
- PyObject *values_obj;
462
- int max_depth;
463
- int tree_limit;
464
- PyObject *base_offset_obj;
465
- int model_output;
466
- PyObject *X_obj;
467
- PyObject *X_missing_obj;
468
- PyObject *y_obj;
469
- PyObject *out_pred_obj;
470
-
471
-
472
- /* Parse the input tuple */
473
- if (!PyArg_ParseTuple(
474
- args, "OOOOOOiiOiOOOO", &children_left_obj, &children_right_obj, &children_default_obj,
475
- &features_obj, &thresholds_obj, &values_obj, &max_depth, &tree_limit, &base_offset_obj, &model_output,
476
- &X_obj, &X_missing_obj, &y_obj, &out_pred_obj
477
- )) return NULL;
478
-
479
- /* Interpret the input objects as numpy arrays. */
480
- PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
481
- PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
482
- PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
483
- PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
484
- PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
485
- PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
486
- PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
487
- PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
488
- PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
489
- PyArrayObject *y_array = NULL;
490
- if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
491
- PyArrayObject *out_pred_array = (PyArrayObject*)PyArray_FROM_OTF(out_pred_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
492
-
493
- /* If that didn't work, throw an exception. Note that R and y are optional. */
494
- if (children_left_array == NULL || children_right_array == NULL ||
495
- children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
496
- values_array == NULL || X_array == NULL ||
497
- X_missing_array == NULL || out_pred_array == NULL) {
498
- Py_XDECREF(children_left_array);
499
- Py_XDECREF(children_right_array);
500
- Py_XDECREF(children_default_array);
501
- Py_XDECREF(features_array);
502
- Py_XDECREF(thresholds_array);
503
- Py_XDECREF(values_array);
504
- Py_XDECREF(base_offset_array);
505
- Py_XDECREF(X_array);
506
- Py_XDECREF(X_missing_array);
507
- if (y_array != NULL) Py_XDECREF(y_array);
508
- //PyArray_ResolveWritebackIfCopy(out_pred_array);
509
- Py_XDECREF(out_pred_array);
510
- return NULL;
511
- }
512
-
513
- const unsigned num_X = PyArray_DIM(X_array, 0);
514
- const unsigned M = PyArray_DIM(X_array, 1);
515
- const unsigned max_nodes = PyArray_DIM(values_array, 1);
516
- const unsigned num_outputs = PyArray_DIM(values_array, 2);
517
-
518
- // Get pointers to the data as C-types
519
- int *children_left = (int*)PyArray_DATA(children_left_array);
520
- int *children_right = (int*)PyArray_DATA(children_right_array);
521
- int *children_default = (int*)PyArray_DATA(children_default_array);
522
- int *features = (int*)PyArray_DATA(features_array);
523
- tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
524
- tfloat *values = (tfloat*)PyArray_DATA(values_array);
525
- tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
526
- tfloat *X = (tfloat*)PyArray_DATA(X_array);
527
- bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
528
- tfloat *y = NULL;
529
- if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
530
- tfloat *out_pred = (tfloat*)PyArray_DATA(out_pred_array);
531
-
532
- // these are just wrapper objects for all the pointers and numbers associated with
533
- // the ensemble tree model and the dataset we are explaining
534
- TreeEnsemble trees = TreeEnsemble(
535
- children_left, children_right, children_default, features, thresholds, values,
536
- NULL, max_depth, tree_limit, base_offset,
537
- max_nodes, num_outputs
538
- );
539
- ExplanationDataset data = ExplanationDataset(X, X_missing, y, NULL, NULL, num_X, M, 0);
540
-
541
- dense_tree_saabas(out_pred, trees, data);
542
-
543
- // clean up the created python objects
544
- Py_XDECREF(children_left_array);
545
- Py_XDECREF(children_right_array);
546
- Py_XDECREF(children_default_array);
547
- Py_XDECREF(features_array);
548
- Py_XDECREF(thresholds_array);
549
- Py_XDECREF(values_array);
550
- Py_XDECREF(base_offset_array);
551
- Py_XDECREF(X_array);
552
- Py_XDECREF(X_missing_array);
553
- if (y_array != NULL) Py_XDECREF(y_array);
554
- //PyArray_ResolveWritebackIfCopy(out_pred_array);
555
- Py_XDECREF(out_pred_array);
556
-
557
- /* Build the output tuple */
558
- PyObject *ret = Py_BuildValue("d", (double)values[0]);
559
- return ret;
560
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/cext/_cext_gpu.cc DELETED
@@ -1,187 +0,0 @@
1
- #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
2
-
3
- #include <Python.h>
4
- #include <numpy/arrayobject.h>
5
- #include "tree_shap.h"
6
- #include <iostream>
7
-
8
- static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args);
9
-
10
- static PyMethodDef module_methods[] = {
11
- {"dense_tree_shap", _cext_dense_tree_shap, METH_VARARGS, "C implementation of Tree SHAP for dense."},
12
- {NULL, NULL, 0, NULL}
13
- };
14
-
15
- #if PY_MAJOR_VERSION >= 3
16
- static struct PyModuleDef moduledef = {
17
- PyModuleDef_HEAD_INIT,
18
- "_cext_gpu",
19
- "This module provides an interface for a fast Tree SHAP implementation.",
20
- -1,
21
- module_methods,
22
- NULL,
23
- NULL,
24
- NULL,
25
- NULL
26
- };
27
- #endif
28
-
29
- #if PY_MAJOR_VERSION >= 3
30
- PyMODINIT_FUNC PyInit__cext_gpu(void)
31
- #else
32
- PyMODINIT_FUNC init_cext(void)
33
- #endif
34
- {
35
- #if PY_MAJOR_VERSION >= 3
36
- PyObject *module = PyModule_Create(&moduledef);
37
- if (!module) return NULL;
38
- #else
39
- PyObject *module = Py_InitModule("_cext", module_methods);
40
- if (!module) return;
41
- #endif
42
-
43
- /* Load `numpy` functionality. */
44
- import_array();
45
-
46
- #if PY_MAJOR_VERSION >= 3
47
- return module;
48
- #endif
49
- }
50
-
51
- void dense_tree_shap_gpu(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs,
52
- const int feature_dependence, unsigned model_transform, bool interactions);
53
-
54
- static PyObject *_cext_dense_tree_shap(PyObject *self, PyObject *args)
55
- {
56
- PyObject *children_left_obj;
57
- PyObject *children_right_obj;
58
- PyObject *children_default_obj;
59
- PyObject *features_obj;
60
- PyObject *thresholds_obj;
61
- PyObject *values_obj;
62
- PyObject *node_sample_weights_obj;
63
- int max_depth;
64
- PyObject *X_obj;
65
- PyObject *X_missing_obj;
66
- PyObject *y_obj;
67
- PyObject *R_obj;
68
- PyObject *R_missing_obj;
69
- int tree_limit;
70
- PyObject *out_contribs_obj;
71
- int feature_dependence;
72
- int model_output;
73
- PyObject *base_offset_obj;
74
- bool interactions;
75
-
76
- /* Parse the input tuple */
77
- if (!PyArg_ParseTuple(
78
- args, "OOOOOOOiOOOOOiOOiib", &children_left_obj, &children_right_obj, &children_default_obj,
79
- &features_obj, &thresholds_obj, &values_obj, &node_sample_weights_obj,
80
- &max_depth, &X_obj, &X_missing_obj, &y_obj, &R_obj, &R_missing_obj, &tree_limit, &base_offset_obj,
81
- &out_contribs_obj, &feature_dependence, &model_output, &interactions
82
- )) return NULL;
83
-
84
- /* Interpret the input objects as numpy arrays. */
85
- PyArrayObject *children_left_array = (PyArrayObject*)PyArray_FROM_OTF(children_left_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
86
- PyArrayObject *children_right_array = (PyArrayObject*)PyArray_FROM_OTF(children_right_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
87
- PyArrayObject *children_default_array = (PyArrayObject*)PyArray_FROM_OTF(children_default_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
88
- PyArrayObject *features_array = (PyArrayObject*)PyArray_FROM_OTF(features_obj, NPY_INT, NPY_ARRAY_IN_ARRAY);
89
- PyArrayObject *thresholds_array = (PyArrayObject*)PyArray_FROM_OTF(thresholds_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
90
- PyArrayObject *values_array = (PyArrayObject*)PyArray_FROM_OTF(values_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
91
- PyArrayObject *node_sample_weights_array = (PyArrayObject*)PyArray_FROM_OTF(node_sample_weights_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
92
- PyArrayObject *X_array = (PyArrayObject*)PyArray_FROM_OTF(X_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
93
- PyArrayObject *X_missing_array = (PyArrayObject*)PyArray_FROM_OTF(X_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
94
- PyArrayObject *y_array = NULL;
95
- if (y_obj != Py_None) y_array = (PyArrayObject*)PyArray_FROM_OTF(y_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
96
- PyArrayObject *R_array = NULL;
97
- if (R_obj != Py_None) R_array = (PyArrayObject*)PyArray_FROM_OTF(R_obj, NPY_DOUBLE, NPY_ARRAY_IN_ARRAY);
98
- PyArrayObject *R_missing_array = NULL;
99
- if (R_missing_obj != Py_None) R_missing_array = (PyArrayObject*)PyArray_FROM_OTF(R_missing_obj, NPY_BOOL, NPY_ARRAY_IN_ARRAY);
100
- PyArrayObject *out_contribs_array = (PyArrayObject*)PyArray_FROM_OTF(out_contribs_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
101
- PyArrayObject *base_offset_array = (PyArrayObject*)PyArray_FROM_OTF(base_offset_obj, NPY_DOUBLE, NPY_ARRAY_INOUT_ARRAY);
102
-
103
- /* If that didn't work, throw an exception. Note that R and y are optional. */
104
- if (children_left_array == NULL || children_right_array == NULL ||
105
- children_default_array == NULL || features_array == NULL || thresholds_array == NULL ||
106
- values_array == NULL || node_sample_weights_array == NULL || X_array == NULL ||
107
- X_missing_array == NULL || out_contribs_array == NULL) {
108
- Py_XDECREF(children_left_array);
109
- Py_XDECREF(children_right_array);
110
- Py_XDECREF(children_default_array);
111
- Py_XDECREF(features_array);
112
- Py_XDECREF(thresholds_array);
113
- Py_XDECREF(values_array);
114
- Py_XDECREF(node_sample_weights_array);
115
- Py_XDECREF(X_array);
116
- Py_XDECREF(X_missing_array);
117
- if (y_array != NULL) Py_XDECREF(y_array);
118
- if (R_array != NULL) Py_XDECREF(R_array);
119
- if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
120
- //PyArray_ResolveWritebackIfCopy(out_contribs_array);
121
- Py_XDECREF(out_contribs_array);
122
- Py_XDECREF(base_offset_array);
123
- return NULL;
124
- }
125
-
126
- const unsigned num_X = PyArray_DIM(X_array, 0);
127
- const unsigned M = PyArray_DIM(X_array, 1);
128
- const unsigned max_nodes = PyArray_DIM(values_array, 1);
129
- const unsigned num_outputs = PyArray_DIM(values_array, 2);
130
- unsigned num_R = 0;
131
- if (R_array != NULL) num_R = PyArray_DIM(R_array, 0);
132
-
133
- // Get pointers to the data as C-types
134
- int *children_left = (int*)PyArray_DATA(children_left_array);
135
- int *children_right = (int*)PyArray_DATA(children_right_array);
136
- int *children_default = (int*)PyArray_DATA(children_default_array);
137
- int *features = (int*)PyArray_DATA(features_array);
138
- tfloat *thresholds = (tfloat*)PyArray_DATA(thresholds_array);
139
- tfloat *values = (tfloat*)PyArray_DATA(values_array);
140
- tfloat *node_sample_weights = (tfloat*)PyArray_DATA(node_sample_weights_array);
141
- tfloat *X = (tfloat*)PyArray_DATA(X_array);
142
- bool *X_missing = (bool*)PyArray_DATA(X_missing_array);
143
- tfloat *y = NULL;
144
- if (y_array != NULL) y = (tfloat*)PyArray_DATA(y_array);
145
- tfloat *R = NULL;
146
- if (R_array != NULL) R = (tfloat*)PyArray_DATA(R_array);
147
- bool *R_missing = NULL;
148
- if (R_missing_array != NULL) R_missing = (bool*)PyArray_DATA(R_missing_array);
149
- tfloat *out_contribs = (tfloat*)PyArray_DATA(out_contribs_array);
150
- tfloat *base_offset = (tfloat*)PyArray_DATA(base_offset_array);
151
-
152
- // these are just a wrapper objects for all the pointers and numbers associated with
153
- // the ensemble tree model and the dataset we are explaining
154
- TreeEnsemble trees = TreeEnsemble(
155
- children_left, children_right, children_default, features, thresholds, values,
156
- node_sample_weights, max_depth, tree_limit, base_offset,
157
- max_nodes, num_outputs
158
- );
159
- ExplanationDataset data = ExplanationDataset(X, X_missing, y, R, R_missing, num_X, M, num_R);
160
-
161
- dense_tree_shap_gpu(trees, data, out_contribs, feature_dependence, model_output, interactions);
162
-
163
-
164
- // retrieve return value before python cleanup of objects
165
- tfloat ret_value = (double)values[0];
166
-
167
- // clean up the created python objects
168
- Py_XDECREF(children_left_array);
169
- Py_XDECREF(children_right_array);
170
- Py_XDECREF(children_default_array);
171
- Py_XDECREF(features_array);
172
- Py_XDECREF(thresholds_array);
173
- Py_XDECREF(values_array);
174
- Py_XDECREF(node_sample_weights_array);
175
- Py_XDECREF(X_array);
176
- Py_XDECREF(X_missing_array);
177
- if (y_array != NULL) Py_XDECREF(y_array);
178
- if (R_array != NULL) Py_XDECREF(R_array);
179
- if (R_missing_array != NULL) Py_XDECREF(R_missing_array);
180
- //PyArray_ResolveWritebackIfCopy(out_contribs_array);
181
- Py_XDECREF(out_contribs_array);
182
- Py_XDECREF(base_offset_array);
183
-
184
- /* Build the output tuple */
185
- PyObject *ret = Py_BuildValue("d", ret_value);
186
- return ret;
187
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/cext/_cext_gpu.cu DELETED
@@ -1,353 +0,0 @@
1
- #include <Python.h>
2
-
3
- #include "gpu_treeshap.h"
4
- #include "tree_shap.h"
5
-
6
- const float inf = std::numeric_limits<tfloat>::infinity();
7
-
8
- struct ShapSplitCondition {
9
- ShapSplitCondition() = default;
10
- ShapSplitCondition(tfloat feature_lower_bound, tfloat feature_upper_bound,
11
- bool is_missing_branch)
12
- : feature_lower_bound(feature_lower_bound),
13
- feature_upper_bound(feature_upper_bound),
14
- is_missing_branch(is_missing_branch) {
15
- assert(feature_lower_bound <= feature_upper_bound);
16
- }
17
-
18
- /*! Feature values >= lower and < upper flow down this path. */
19
- tfloat feature_lower_bound;
20
- tfloat feature_upper_bound;
21
- /*! Do missing values flow down this path? */
22
- bool is_missing_branch;
23
-
24
- // Does this instance flow down this path?
25
- __host__ __device__ bool EvaluateSplit(float x) const {
26
- // is nan
27
- if (isnan(x)) {
28
- return is_missing_branch;
29
- }
30
- return x > feature_lower_bound && x <= feature_upper_bound;
31
- }
32
-
33
- // Combine two split conditions on the same feature
34
- __host__ __device__ void
35
- Merge(const ShapSplitCondition &other) { // Combine duplicate features
36
- feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
37
- feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
38
- is_missing_branch = is_missing_branch && other.is_missing_branch;
39
- }
40
- };
41
-
42
-
43
- // Inspired by: https://en.cppreference.com/w/cpp/iterator/size
44
- // Limited implementation of std::size fo arrays
45
- template <class T, size_t N>
46
- constexpr size_t array_size(const T (&array)[N]) noexcept
47
- {
48
- return N;
49
- }
50
-
51
- void RecurseTree(
52
- unsigned pos, const TreeEnsemble &tree,
53
- std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> *tmp_path,
54
- std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> *paths,
55
- size_t *path_idx, int num_outputs) {
56
- if (tree.is_leaf(pos)) {
57
- for (auto j = 0ull; j < num_outputs; j++) {
58
- auto v = tree.values[pos * num_outputs + j];
59
- if (v == 0.0) {
60
- // The tree has no output for this class, don't bother adding the path
61
- continue;
62
- }
63
- // Go back over path, setting v, path_idx
64
- for (auto &e : *tmp_path) {
65
- e.v = v;
66
- e.group = j;
67
- e.path_idx = *path_idx;
68
- }
69
-
70
- paths->insert(paths->end(), tmp_path->begin(), tmp_path->end());
71
- // Increment path index
72
- (*path_idx)++;
73
- }
74
- return;
75
- }
76
-
77
- // Add left split to the path
78
- unsigned left_child = tree.children_left[pos];
79
- double left_zero_fraction =
80
- tree.node_sample_weights[left_child] / tree.node_sample_weights[pos];
81
- // Encode the range of feature values that flow down this path
82
- tmp_path->emplace_back(0, tree.features[pos], 0,
83
- ShapSplitCondition{-inf, tree.thresholds[pos], false},
84
- left_zero_fraction, 0.0f);
85
-
86
- RecurseTree(left_child, tree, tmp_path, paths, path_idx, num_outputs);
87
-
88
- // Add left split to the path
89
- tmp_path->back() = gpu_treeshap::PathElement<ShapSplitCondition>(
90
- 0, tree.features[pos], 0,
91
- ShapSplitCondition{tree.thresholds[pos], inf, false},
92
- 1.0 - left_zero_fraction, 0.0f);
93
-
94
- RecurseTree(tree.children_right[pos], tree, tmp_path, paths, path_idx,
95
- num_outputs);
96
-
97
- tmp_path->pop_back();
98
- }
99
-
100
- std::vector<gpu_treeshap::PathElement<ShapSplitCondition>>
101
- ExtractPaths(const TreeEnsemble &trees) {
102
- std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> paths;
103
- size_t path_idx = 0;
104
- for (auto i = 0; i < trees.tree_limit; i++) {
105
- TreeEnsemble tree;
106
- trees.get_tree(tree, i);
107
- std::vector<gpu_treeshap::PathElement<ShapSplitCondition>> tmp_path;
108
- tmp_path.reserve(tree.max_depth);
109
- tmp_path.emplace_back(0, -1, 0, ShapSplitCondition{-inf, inf, false}, 1.0,
110
- 0.0f);
111
- RecurseTree(0, tree, &tmp_path, &paths, &path_idx, tree.num_outputs);
112
- }
113
- return paths;
114
- }
115
-
116
- class DeviceExplanationDataset {
117
- thrust::device_vector<tfloat> data;
118
- thrust::device_vector<bool> missing;
119
- size_t num_features;
120
- size_t num_rows;
121
-
122
- public:
123
- DeviceExplanationDataset(const ExplanationDataset &host_data,
124
- bool background_dataset = false) {
125
- num_features = host_data.M;
126
- if (background_dataset) {
127
- num_rows = host_data.num_R;
128
- data = thrust::device_vector<tfloat>(
129
- host_data.R, host_data.R + host_data.num_R * host_data.M);
130
- missing = thrust::device_vector<bool>(host_data.R_missing,
131
- host_data.R_missing +
132
- host_data.num_R * host_data.M);
133
-
134
- } else {
135
- num_rows = host_data.num_X;
136
- data = thrust::device_vector<tfloat>(
137
- host_data.X, host_data.X + host_data.num_X * host_data.M);
138
- missing = thrust::device_vector<bool>(host_data.X_missing,
139
- host_data.X_missing +
140
- host_data.num_X * host_data.M);
141
- }
142
- }
143
-
144
- class DenseDatasetWrapper {
145
- const tfloat *data;
146
- const bool *missing;
147
- int num_rows;
148
- int num_cols;
149
-
150
- public:
151
- DenseDatasetWrapper() = default;
152
- DenseDatasetWrapper(const tfloat *data, const bool *missing, int num_rows,
153
- int num_cols)
154
- : data(data), missing(missing), num_rows(num_rows), num_cols(num_cols) {
155
- }
156
- __device__ tfloat GetElement(size_t row_idx, size_t col_idx) const {
157
- auto idx = row_idx * num_cols + col_idx;
158
- if (missing[idx]) {
159
- return std::numeric_limits<tfloat>::quiet_NaN();
160
- }
161
- return data[idx];
162
- }
163
- __host__ __device__ size_t NumRows() const { return num_rows; }
164
- __host__ __device__ size_t NumCols() const { return num_cols; }
165
- };
166
-
167
- DenseDatasetWrapper GetDeviceAccessor() {
168
- return DenseDatasetWrapper(data.data().get(), missing.data().get(),
169
- num_rows, num_features);
170
- }
171
- };
172
-
173
- inline void dense_tree_path_dependent_gpu(
174
- const TreeEnsemble &trees, const ExplanationDataset &data,
175
- tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
176
- auto paths = ExtractPaths(trees);
177
- DeviceExplanationDataset device_data(data);
178
- DeviceExplanationDataset::DenseDatasetWrapper X =
179
- device_data.GetDeviceAccessor();
180
-
181
- thrust::device_vector<float> phis((X.NumCols() + 1) * X.NumRows() *
182
- trees.num_outputs);
183
- gpu_treeshap::GPUTreeShap(X, paths.begin(), paths.end(), trees.num_outputs,
184
- phis.begin(), phis.end());
185
- // Add the base offset term to bias
186
- thrust::device_vector<double> base_offset(
187
- trees.base_offset, trees.base_offset + trees.num_outputs);
188
- auto counting = thrust::make_counting_iterator(size_t(0));
189
- auto d_phis = phis.data().get();
190
- auto d_base_offset = base_offset.data().get();
191
- size_t num_groups = trees.num_outputs;
192
- thrust::for_each(counting, counting + X.NumRows() * trees.num_outputs,
193
- [=] __device__(size_t idx) {
194
- size_t row_idx = idx / num_groups;
195
- size_t group = idx % num_groups;
196
- auto phi_idx = gpu_treeshap::IndexPhi(
197
- row_idx, num_groups, group, X.NumCols(), X.NumCols());
198
- d_phis[phi_idx] += d_base_offset[group];
199
- });
200
-
201
- // Shap uses a slightly different layout for multiclass
202
- thrust::device_vector<float> transposed_phis(phis.size());
203
- auto d_transposed_phis = transposed_phis.data();
204
- thrust::for_each(
205
- counting, counting + phis.size(), [=] __device__(size_t idx) {
206
- size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1)};
207
- size_t old_idx[array_size(old_shape)];
208
- gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx);
209
- // Define new tensor format, switch num_groups axis to end
210
- size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), num_groups};
211
- size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[1]};
212
- size_t transposed_idx =
213
- gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx);
214
- d_transposed_phis[transposed_idx] = d_phis[idx];
215
- });
216
- thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs);
217
- }
218
-
219
- inline void
220
- dense_tree_independent_gpu(const TreeEnsemble &trees,
221
- const ExplanationDataset &data, tfloat *out_contribs,
222
- tfloat transform(const tfloat, const tfloat)) {
223
- auto paths = ExtractPaths(trees);
224
- DeviceExplanationDataset device_data(data);
225
- DeviceExplanationDataset::DenseDatasetWrapper X =
226
- device_data.GetDeviceAccessor();
227
- DeviceExplanationDataset background_device_data(data, true);
228
- DeviceExplanationDataset::DenseDatasetWrapper R =
229
- background_device_data.GetDeviceAccessor();
230
-
231
- thrust::device_vector<float> phis((X.NumCols() + 1) * X.NumRows() *
232
- trees.num_outputs);
233
- gpu_treeshap::GPUTreeShapInterventional(X, R, paths.begin(), paths.end(),
234
- trees.num_outputs, phis.begin(),
235
- phis.end());
236
- // Add the base offset term to bias
237
- thrust::device_vector<double> base_offset(
238
- trees.base_offset, trees.base_offset + trees.num_outputs);
239
- auto counting = thrust::make_counting_iterator(size_t(0));
240
- auto d_phis = phis.data().get();
241
- auto d_base_offset = base_offset.data().get();
242
- size_t num_groups = trees.num_outputs;
243
- thrust::for_each(counting, counting + X.NumRows() * trees.num_outputs,
244
- [=] __device__(size_t idx) {
245
- size_t row_idx = idx / num_groups;
246
- size_t group = idx % num_groups;
247
- auto phi_idx = gpu_treeshap::IndexPhi(
248
- row_idx, num_groups, group, X.NumCols(), X.NumCols());
249
- d_phis[phi_idx] += d_base_offset[group];
250
- });
251
-
252
- // Shap uses a slightly different layout for multiclass
253
- thrust::device_vector<float> transposed_phis(phis.size());
254
- auto d_transposed_phis = transposed_phis.data();
255
- thrust::for_each(
256
- counting, counting + phis.size(), [=] __device__(size_t idx) {
257
- size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1)};
258
- size_t old_idx[array_size(old_shape)];
259
- gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx);
260
- // Define new tensor format, switch num_groups axis to end
261
- size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), num_groups};
262
- size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[1]};
263
- size_t transposed_idx =
264
- gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx);
265
- d_transposed_phis[transposed_idx] = d_phis[idx];
266
- });
267
- thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs);
268
- }
269
-
270
- inline void dense_tree_path_dependent_interactions_gpu(
271
- const TreeEnsemble &trees, const ExplanationDataset &data,
272
- tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
273
- auto paths = ExtractPaths(trees);
274
- DeviceExplanationDataset device_data(data);
275
- DeviceExplanationDataset::DenseDatasetWrapper X =
276
- device_data.GetDeviceAccessor();
277
-
278
- thrust::device_vector<float> phis((X.NumCols() + 1) * (X.NumCols() + 1) *
279
- X.NumRows() * trees.num_outputs);
280
- gpu_treeshap::GPUTreeShapInteractions(X, paths.begin(), paths.end(),
281
- trees.num_outputs, phis.begin(),
282
- phis.end());
283
- // Add the base offset term to bias
284
- thrust::device_vector<double> base_offset(
285
- trees.base_offset, trees.base_offset + trees.num_outputs);
286
- auto counting = thrust::make_counting_iterator(size_t(0));
287
- auto d_phis = phis.data().get();
288
- auto d_base_offset = base_offset.data().get();
289
- size_t num_groups = trees.num_outputs;
290
- thrust::for_each(counting, counting + X.NumRows() * num_groups,
291
- [=] __device__(size_t idx) {
292
- size_t row_idx = idx / num_groups;
293
- size_t group = idx % num_groups;
294
- auto phi_idx = gpu_treeshap::IndexPhiInteractions(
295
- row_idx, num_groups, group, X.NumCols(), X.NumCols(),
296
- X.NumCols());
297
- d_phis[phi_idx] += d_base_offset[group];
298
- });
299
- // Shap uses a slightly different layout for multiclass
300
- thrust::device_vector<float> transposed_phis(phis.size());
301
- auto d_transposed_phis = transposed_phis.data();
302
- thrust::for_each(
303
- counting, counting + phis.size(), [=] __device__(size_t idx) {
304
- size_t old_shape[] = {X.NumRows(), num_groups, (X.NumCols() + 1),
305
- (X.NumCols() + 1)};
306
- size_t old_idx[array_size(old_shape)];
307
- gpu_treeshap::FlatIdxToTensorIdx(idx, old_shape, old_idx);
308
- // Define new tensor format, switch num_groups axis to end
309
- size_t new_shape[] = {X.NumRows(), (X.NumCols() + 1), (X.NumCols() + 1),
310
- num_groups};
311
- size_t new_idx[] = {old_idx[0], old_idx[2], old_idx[3], old_idx[1]};
312
- size_t transposed_idx =
313
- gpu_treeshap::TensorIdxToFlatIdx(new_shape, new_idx);
314
- d_transposed_phis[transposed_idx] = d_phis[idx];
315
- });
316
- thrust::copy(transposed_phis.begin(), transposed_phis.end(), out_contribs);
317
- }
318
-
319
- void dense_tree_shap_gpu(const TreeEnsemble &trees,
320
- const ExplanationDataset &data, tfloat *out_contribs,
321
- const int feature_dependence, unsigned model_transform,
322
- bool interactions) {
323
- // see what transform (if any) we have
324
- transform_f transform = get_transform(model_transform);
325
-
326
- // dispatch to the correct algorithm handler
327
- switch (feature_dependence) {
328
- case FEATURE_DEPENDENCE::independent:
329
- if (interactions) {
330
- std::cerr << "FEATURE_DEPENDENCE::independent with interactions not yet "
331
- "supported\n";
332
- } else {
333
- dense_tree_independent_gpu(trees, data, out_contribs, transform);
334
- }
335
- return;
336
-
337
- case FEATURE_DEPENDENCE::tree_path_dependent:
338
- if (interactions) {
339
- dense_tree_path_dependent_interactions_gpu(trees, data, out_contribs,
340
- transform);
341
- } else {
342
- dense_tree_path_dependent_gpu(trees, data, out_contribs, transform);
343
- }
344
- return;
345
-
346
- case FEATURE_DEPENDENCE::global_path_dependent:
347
- std::cerr << "FEATURE_DEPENDENCE::global_path_dependent not supported\n";
348
- return;
349
- default:
350
- std::cerr << "Unknown feature dependence option\n";
351
- return;
352
- }
353
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/cext/gpu_treeshap.h DELETED
@@ -1,1535 +0,0 @@
1
- /*
2
- * Copyright (c) 2020, NVIDIA CORPORATION.
3
- *
4
- * Licensed under the Apache License, Version 2.0 (the "License");
5
- * you may not use this file except in compliance with the License.
6
- * You may obtain a copy of the License at
7
- *
8
- * http://www.apache.org/licenses/LICENSE-2.0
9
- *
10
- * Unless required by applicable law or agreed to in writing, software
11
- * distributed under the License is distributed on an "AS IS" BASIS,
12
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- * See the License for the specific language governing permissions and
14
- * limitations under the License.
15
- */
16
-
17
- #pragma once
18
- #include <thrust/device_allocator.h>
19
- #include <thrust/device_vector.h>
20
- #include <thrust/iterator/discard_iterator.h>
21
- #include <thrust/logical.h>
22
- #include <thrust/reduce.h>
23
- #include <thrust/host_vector.h>
24
- #if (CUDART_VERSION >= 11000)
25
- #include <cub/cub.cuh>
26
- #else
27
- // Hack to get cub device reduce on older toolkits
28
- #include <thrust/system/cuda/detail/cub/device/device_reduce.cuh>
29
- using namespace thrust::cuda_cub;
30
- #endif
31
- #include <algorithm>
32
- #include <functional>
33
- #include <set>
34
- #include <stdexcept>
35
- #include <utility>
36
- #include <vector>
37
-
38
- namespace gpu_treeshap {
39
-
40
- struct XgboostSplitCondition {
41
- XgboostSplitCondition() = default;
42
- XgboostSplitCondition(float feature_lower_bound, float feature_upper_bound,
43
- bool is_missing_branch)
44
- : feature_lower_bound(feature_lower_bound),
45
- feature_upper_bound(feature_upper_bound),
46
- is_missing_branch(is_missing_branch) {
47
- assert(feature_lower_bound <= feature_upper_bound);
48
- }
49
-
50
- /*! Feature values >= lower and < upper flow down this path. */
51
- float feature_lower_bound;
52
- float feature_upper_bound;
53
- /*! Do missing values flow down this path? */
54
- bool is_missing_branch;
55
-
56
- // Does this instance flow down this path?
57
- __host__ __device__ bool EvaluateSplit(float x) const {
58
- // is nan
59
- if (isnan(x)) {
60
- return is_missing_branch;
61
- }
62
- return x >= feature_lower_bound && x < feature_upper_bound;
63
- }
64
-
65
- // Combine two split conditions on the same feature
66
- __host__ __device__ void Merge(
67
- const XgboostSplitCondition& other) { // Combine duplicate features
68
- feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
69
- feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
70
- is_missing_branch = is_missing_branch && other.is_missing_branch;
71
- }
72
- };
73
-
74
- /*!
75
- * An element of a unique path through a decision tree. Can implement various
76
- * types of splits via the templated SplitConditionT. Some decision tree
77
- * implementations may wish to use double precision or single precision, some
78
- * may use < or <= as the threshold, missing values can be handled differently,
79
- * categoricals may be supported.
80
- *
81
- * \tparam SplitConditionT A split condition implementing the methods
82
- * EvaluateSplit and Merge.
83
- */
84
- template <typename SplitConditionT>
85
- struct PathElement {
86
- using split_type = SplitConditionT;
87
- __host__ __device__ PathElement(size_t path_idx, int64_t feature_idx,
88
- int group, SplitConditionT split_condition,
89
- double zero_fraction, float v)
90
- : path_idx(path_idx),
91
- feature_idx(feature_idx),
92
- group(group),
93
- split_condition(split_condition),
94
- zero_fraction(zero_fraction),
95
- v(v) {}
96
-
97
- PathElement() = default;
98
- __host__ __device__ bool IsRoot() const { return feature_idx == -1; }
99
-
100
- template <typename DatasetT>
101
- __host__ __device__ bool EvaluateSplit(DatasetT X, size_t row_idx) const {
102
- if (this->IsRoot()) {
103
- return 1.0;
104
- }
105
- return split_condition.EvaluateSplit(X.GetElement(row_idx, feature_idx));
106
- }
107
-
108
- /*! Unique path index. */
109
- size_t path_idx;
110
- /*! Feature of this split, -1 indicates bias term. */
111
- int64_t feature_idx;
112
- /*! Indicates class for multiclass problems. */
113
- int group;
114
- SplitConditionT split_condition;
115
- /*! Probability of following this path when feature_idx is not in the active
116
- * set. */
117
- double zero_fraction;
118
- float v; // Leaf weight at the end of the path
119
- };
120
-
121
- // Helper function that accepts an index into a flat contiguous array and the
122
- // dimensions of a tensor and returns the indices with respect to the tensor
123
- template <typename T, size_t N>
124
- __device__ void FlatIdxToTensorIdx(T flat_idx, const T (&shape)[N],
125
- T (&out_idx)[N]) {
126
- T current_size = shape[0];
127
- for (auto i = 1ull; i < N; i++) {
128
- current_size *= shape[i];
129
- }
130
- for (auto i = 0ull; i < N; i++) {
131
- current_size /= shape[i];
132
- out_idx[i] = flat_idx / current_size;
133
- flat_idx -= current_size * out_idx[i];
134
- }
135
- }
136
-
137
- // Given a shape and coordinates into a tensor, return the index into the
138
- // backing storage one-dimensional array
139
- template <typename T, size_t N>
140
- __device__ T TensorIdxToFlatIdx(const T (&shape)[N], const T (&tensor_idx)[N]) {
141
- T current_size = shape[0];
142
- for (auto i = 1ull; i < N; i++) {
143
- current_size *= shape[i];
144
- }
145
- T idx = 0;
146
- for (auto i = 0ull; i < N; i++) {
147
- current_size /= shape[i];
148
- idx += tensor_idx[i] * current_size;
149
- }
150
- return idx;
151
- }
152
-
153
- // Maps values to the phi array according to row, group and column
154
- __host__ __device__ inline size_t IndexPhi(size_t row_idx, size_t num_groups,
155
- size_t group, size_t num_columns,
156
- size_t column_idx) {
157
- return (row_idx * num_groups + group) * (num_columns + 1) + column_idx;
158
- }
159
-
160
- __host__ __device__ inline size_t IndexPhiInteractions(size_t row_idx,
161
- size_t num_groups,
162
- size_t group,
163
- size_t num_columns,
164
- size_t i, size_t j) {
165
- size_t matrix_size = (num_columns + 1) * (num_columns + 1);
166
- size_t matrix_offset = (row_idx * num_groups + group) * matrix_size;
167
- return matrix_offset + i * (num_columns + 1) + j;
168
- }
169
-
170
- namespace detail {
171
-
172
- // Shorthand for creating a device vector with an appropriate allocator type
173
- template <class T, class DeviceAllocatorT>
174
- using RebindVector =
175
- thrust::device_vector<T,
176
- typename DeviceAllocatorT::template rebind<T>::other>;
177
-
178
- #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 || defined(__clang__)
179
- __device__ __forceinline__ double atomicAddDouble(double* address, double val) {
180
- return atomicAdd(address, val);
181
- }
182
- #else // In device code and CUDA < 600
183
- __device__ __forceinline__ double atomicAddDouble(double* address,
184
- double val) { // NOLINT
185
- unsigned long long int* address_as_ull = // NOLINT
186
- (unsigned long long int*)address; // NOLINT
187
- unsigned long long int old = *address_as_ull, assumed; // NOLINT
188
-
189
- do {
190
- assumed = old;
191
- old = atomicCAS(address_as_ull, assumed,
192
- __double_as_longlong(val + __longlong_as_double(assumed)));
193
-
194
- // Note: uses integer comparison to avoid hang in case of NaN (since NaN !=
195
- // NaN)
196
- } while (assumed != old);
197
-
198
- return __longlong_as_double(old);
199
- }
200
- #endif
201
-
202
- __forceinline__ __device__ unsigned int lanemask32_lt() {
203
- unsigned int lanemask32_lt;
204
- asm volatile("mov.u32 %0, %%lanemask_lt;" : "=r"(lanemask32_lt));
205
- return (lanemask32_lt);
206
- }
207
-
208
- // Like a coalesced group, except we can make the assumption that all threads in
209
- // a group are next to each other. This makes shuffle operations much cheaper.
210
- class ContiguousGroup {
211
- public:
212
- __device__ ContiguousGroup(uint32_t mask) : mask_(mask) {}
213
-
214
- __device__ uint32_t size() const { return __popc(mask_); }
215
- __device__ uint32_t thread_rank() const {
216
- return __popc(mask_ & lanemask32_lt());
217
- }
218
- template <typename T>
219
- __device__ T shfl(T val, uint32_t src) const {
220
- return __shfl_sync(mask_, val, src + __ffs(mask_) - 1);
221
- }
222
- template <typename T>
223
- __device__ T shfl_up(T val, uint32_t delta) const {
224
- return __shfl_up_sync(mask_, val, delta);
225
- }
226
- __device__ uint32_t ballot(int predicate) const {
227
- return __ballot_sync(mask_, predicate) >> (__ffs(mask_) - 1);
228
- }
229
-
230
- template <typename T, typename OpT>
231
- __device__ T reduce(T val, OpT op) {
232
- for (int i = 1; i < this->size(); i *= 2) {
233
- T shfl = shfl_up(val, i);
234
- if (static_cast<int>(thread_rank()) - i >= 0) {
235
- val = op(val, shfl);
236
- }
237
- }
238
- return shfl(val, size() - 1);
239
- }
240
- uint32_t mask_;
241
- };
242
-
243
- // Separate the active threads by labels
244
- // This functionality is available in cuda 11.0 on cc >=7.0
245
- // We reimplement for backwards compatibility
246
- // Assumes partitions are contiguous
247
- inline __device__ ContiguousGroup active_labeled_partition(uint32_t mask,
248
- int label) {
249
- #if __CUDA_ARCH__ >= 700
250
- uint32_t subgroup_mask = __match_any_sync(mask, label);
251
- #else
252
- uint32_t subgroup_mask = 0;
253
- for (int i = 0; i < 32;) {
254
- int current_label = __shfl_sync(mask, label, i);
255
- uint32_t ballot = __ballot_sync(mask, label == current_label);
256
- if (label == current_label) {
257
- subgroup_mask = ballot;
258
- }
259
- uint32_t completed_mask =
260
- (1 << (32 - __clz(ballot))) - 1; // Threads that have finished
261
- // Find the start of the next group, mask off completed threads from active
262
- // threads Then use ffs - 1 to find the position of the next group
263
- int next_i = __ffs(mask & ~completed_mask) - 1;
264
- if (next_i == -1) break; // -1 indicates all finished
265
- assert(next_i > i); // Prevent infinite loops when the constraints not met
266
- i = next_i;
267
- }
268
- #endif
269
- return ContiguousGroup(subgroup_mask);
270
- }
271
-
272
- // Group of threads where each thread holds a path element
273
- class GroupPath {
274
- protected:
275
- const ContiguousGroup& g_;
276
- // These are combined so we can communicate them in a single 64 bit shuffle
277
- // instruction
278
- float zero_one_fraction_[2];
279
- float pweight_;
280
- int unique_depth_;
281
-
282
- public:
283
- __device__ GroupPath(const ContiguousGroup& g, float zero_fraction,
284
- float one_fraction)
285
- : g_(g),
286
- zero_one_fraction_{zero_fraction, one_fraction},
287
- pweight_(g.thread_rank() == 0 ? 1.0f : 0.0f),
288
- unique_depth_(0) {}
289
-
290
- // Cooperatively extend the path with a group of threads
291
- // Each thread maintains pweight for its path element in register
292
- __device__ void Extend() {
293
- unique_depth_++;
294
-
295
- // Broadcast the zero and one fraction from the newly added path element
296
- // Combine 2 shuffle operations into 64 bit word
297
- const size_t rank = g_.thread_rank();
298
- const float inv_unique_depth =
299
- __fdividef(1.0f, static_cast<float>(unique_depth_ + 1));
300
- uint64_t res = g_.shfl(*reinterpret_cast<uint64_t*>(&zero_one_fraction_),
301
- unique_depth_);
302
- const float new_zero_fraction = reinterpret_cast<float*>(&res)[0];
303
- const float new_one_fraction = reinterpret_cast<float*>(&res)[1];
304
- float left_pweight = g_.shfl_up(pweight_, 1);
305
-
306
- // pweight of threads with rank < unique_depth_ is 0
307
- // We use max(x,0) to avoid using a branch
308
- // pweight_ *=
309
- // new_zero_fraction * max(unique_depth_ - rank, 0llu) * inv_unique_depth;
310
- pweight_ = __fmul_rn(
311
- __fmul_rn(pweight_, new_zero_fraction),
312
- __fmul_rn(max(unique_depth_ - rank, size_t(0)), inv_unique_depth));
313
-
314
- // pweight_ += new_one_fraction * left_pweight * rank * inv_unique_depth;
315
- pweight_ = __fmaf_rn(__fmul_rn(new_one_fraction, left_pweight),
316
- __fmul_rn(rank, inv_unique_depth), pweight_);
317
- }
318
-
319
- // Each thread unwinds the path for its feature and returns the sum
320
- __device__ float UnwoundPathSum() {
321
- float next_one_portion = g_.shfl(pweight_, unique_depth_);
322
- float total = 0.0f;
323
- const float zero_frac_div_unique_depth = __fdividef(
324
- zero_one_fraction_[0], static_cast<float>(unique_depth_ + 1));
325
- for (int i = unique_depth_ - 1; i >= 0; i--) {
326
- float ith_pweight = g_.shfl(pweight_, i);
327
- float precomputed =
328
- __fmul_rn((unique_depth_ - i), zero_frac_div_unique_depth);
329
- const float tmp =
330
- __fdividef(__fmul_rn(next_one_portion, unique_depth_ + 1), i + 1);
331
- total = __fmaf_rn(tmp, zero_one_fraction_[1], total);
332
- next_one_portion = __fmaf_rn(-tmp, precomputed, ith_pweight);
333
- float numerator =
334
- __fmul_rn(__fsub_rn(1.0f, zero_one_fraction_[1]), ith_pweight);
335
- if (precomputed > 0.0f) {
336
- total += __fdividef(numerator, precomputed);
337
- }
338
- }
339
-
340
- return total;
341
- }
342
- };
343
-
344
- // Has different permutation weightings to the above
345
- // Used in Taylor Shapley interaction index
346
- class TaylorGroupPath : GroupPath {
347
- public:
348
- __device__ TaylorGroupPath(const ContiguousGroup& g, float zero_fraction,
349
- float one_fraction)
350
- : GroupPath(g, zero_fraction, one_fraction) {}
351
-
352
- // Extend the path is normal, all reweighting can happen in UnwoundPathSum
353
- __device__ void Extend() { GroupPath::Extend(); }
354
-
355
- // Each thread unwinds the path for its feature and returns the sum
356
- // We use a different permutation weighting for Taylor interactions
357
- // As if the total number of features was one larger
358
- __device__ float UnwoundPathSum() {
359
- float one_fraction = zero_one_fraction_[1];
360
- float zero_fraction = zero_one_fraction_[0];
361
- float next_one_portion = g_.shfl(pweight_, unique_depth_) /
362
- static_cast<float>(unique_depth_ + 2);
363
-
364
- float total = 0.0f;
365
- for (int i = unique_depth_ - 1; i >= 0; i--) {
366
- float ith_pweight =
367
- g_.shfl(pweight_, i) * (static_cast<float>(unique_depth_ - i + 1) /
368
- static_cast<float>(unique_depth_ + 2));
369
- if (one_fraction > 0.0f) {
370
- const float tmp =
371
- next_one_portion * (unique_depth_ + 2) / ((i + 1) * one_fraction);
372
-
373
- total += tmp;
374
- next_one_portion =
375
- ith_pweight - tmp * zero_fraction *
376
- ((unique_depth_ - i + 1) /
377
- static_cast<float>(unique_depth_ + 2));
378
- } else if (zero_fraction > 0.0f) {
379
- total +=
380
- (ith_pweight / zero_fraction) /
381
- ((unique_depth_ - i + 1) / static_cast<float>(unique_depth_ + 2));
382
- }
383
- }
384
-
385
- return 2 * total;
386
- }
387
- };
388
-
389
- template <typename DatasetT, typename SplitConditionT>
390
- __device__ float ComputePhi(const PathElement<SplitConditionT>& e,
391
- size_t row_idx, const DatasetT& X,
392
- const ContiguousGroup& group, float zero_fraction) {
393
- float one_fraction =
394
- e.EvaluateSplit(X, row_idx);
395
- GroupPath path(group, zero_fraction, one_fraction);
396
- size_t unique_path_length = group.size();
397
-
398
- // Extend the path
399
- for (auto unique_depth = 1ull; unique_depth < unique_path_length;
400
- unique_depth++) {
401
- path.Extend();
402
- }
403
-
404
- float sum = path.UnwoundPathSum();
405
- return sum * (one_fraction - zero_fraction) * e.v;
406
- }
407
-
408
- inline __host__ __device__ size_t DivRoundUp(size_t a, size_t b) {
409
- return (a + b - 1) / b;
410
- }
411
-
412
- template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
413
- typename SplitConditionT>
414
- void __device__
415
- ConfigureThread(const DatasetT& X, const size_t bins_per_row,
416
- const PathElement<SplitConditionT>* path_elements,
417
- const size_t* bin_segments, size_t* start_row, size_t* end_row,
418
- PathElement<SplitConditionT>* e, bool* thread_active) {
419
- // Partition work
420
- // Each warp processes a set of training instances applied to a path
421
- size_t tid = kBlockSize * blockIdx.x + threadIdx.x;
422
- const size_t warp_size = 32;
423
- size_t warp_rank = tid / warp_size;
424
- if (warp_rank >= bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp)) {
425
- *thread_active = false;
426
- return;
427
- }
428
- size_t bin_idx = warp_rank % bins_per_row;
429
- size_t bank = warp_rank / bins_per_row;
430
- size_t path_start = bin_segments[bin_idx];
431
- size_t path_end = bin_segments[bin_idx + 1];
432
- uint32_t thread_rank = threadIdx.x % warp_size;
433
- if (thread_rank >= path_end - path_start) {
434
- *thread_active = false;
435
- } else {
436
- *e = path_elements[path_start + thread_rank];
437
- *start_row = bank * kRowsPerWarp;
438
- *end_row = min((bank + 1) * kRowsPerWarp, X.NumRows());
439
- *thread_active = true;
440
- }
441
- }
442
-
443
- #define GPUTREESHAP_MAX_THREADS_PER_BLOCK 256
444
- #define FULL_MASK 0xffffffff
445
-
446
- template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
447
- typename SplitConditionT>
448
- __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
449
- ShapKernel(DatasetT X, size_t bins_per_row,
450
- const PathElement<SplitConditionT>* path_elements,
451
- const size_t* bin_segments, size_t num_groups, double* phis) {
452
- // Use shared memory for structs, otherwise nvcc puts in local memory
453
- __shared__ DatasetT s_X;
454
- s_X = X;
455
- __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
456
- PathElement<SplitConditionT>& e = s_elements[threadIdx.x];
457
-
458
- size_t start_row, end_row;
459
- bool thread_active;
460
- ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
461
- s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e,
462
- &thread_active);
463
- uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
464
- if (!thread_active) return;
465
-
466
- float zero_fraction = e.zero_fraction;
467
- auto labelled_group = active_labeled_partition(mask, e.path_idx);
468
-
469
- for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
470
- float phi = ComputePhi(e, row_idx, X, labelled_group, zero_fraction);
471
-
472
- if (!e.IsRoot()) {
473
- atomicAddDouble(&phis[IndexPhi(row_idx, num_groups, e.group, X.NumCols(),
474
- e.feature_idx)],
475
- phi);
476
- }
477
- }
478
- }
479
-
480
- template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
481
- typename SplitConditionT>
482
- void ComputeShap(
483
- DatasetT X,
484
- const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
485
- const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
486
- path_elements,
487
- size_t num_groups, double* phis) {
488
- size_t bins_per_row = bin_segments.size() - 1;
489
- const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
490
- const int warps_per_block = kBlockThreads / 32;
491
- const int kRowsPerWarp = 1024;
492
- size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
493
-
494
- const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
495
-
496
- ShapKernel<DatasetT, kBlockThreads, kRowsPerWarp>
497
- <<<grid_size, kBlockThreads>>>(
498
- X, bins_per_row, path_elements.data().get(),
499
- bin_segments.data().get(), num_groups, phis);
500
- }
501
-
502
- template <typename PathT, typename DatasetT, typename SplitConditionT>
503
- __device__ float ComputePhiCondition(const PathElement<SplitConditionT>& e,
504
- size_t row_idx, const DatasetT& X,
505
- const ContiguousGroup& group,
506
- int64_t condition_feature) {
507
- float one_fraction = e.EvaluateSplit(X, row_idx);
508
- PathT path(group, e.zero_fraction, one_fraction);
509
- size_t unique_path_length = group.size();
510
- float condition_on_fraction = 1.0f;
511
- float condition_off_fraction = 1.0f;
512
-
513
- // Extend the path
514
- for (auto i = 1ull; i < unique_path_length; i++) {
515
- bool is_condition_feature =
516
- group.shfl(e.feature_idx, i) == condition_feature;
517
- float o_i = group.shfl(one_fraction, i);
518
- float z_i = group.shfl(e.zero_fraction, i);
519
-
520
- if (is_condition_feature) {
521
- condition_on_fraction = o_i;
522
- condition_off_fraction = z_i;
523
- } else {
524
- path.Extend();
525
- }
526
- }
527
- float sum = path.UnwoundPathSum();
528
- if (e.feature_idx == condition_feature) {
529
- return 0.0f;
530
- }
531
- float phi = sum * (one_fraction - e.zero_fraction) * e.v;
532
- return phi * (condition_on_fraction - condition_off_fraction) * 0.5f;
533
- }
534
-
535
- // If there is a feature in the path we are conditioning on, swap it to the end
536
- // of the path
537
- template <typename SplitConditionT>
538
- inline __device__ void SwapConditionedElement(
539
- PathElement<SplitConditionT>** e, PathElement<SplitConditionT>* s_elements,
540
- uint32_t condition_rank, const ContiguousGroup& group) {
541
- auto last_rank = group.size() - 1;
542
- auto this_rank = group.thread_rank();
543
- if (this_rank == last_rank) {
544
- *e = &s_elements[(threadIdx.x - this_rank) + condition_rank];
545
- } else if (this_rank == condition_rank) {
546
- *e = &s_elements[(threadIdx.x - this_rank) + last_rank];
547
- }
548
- }
549
-
550
- template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
551
- typename SplitConditionT>
552
- __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
553
- ShapInteractionsKernel(DatasetT X, size_t bins_per_row,
554
- const PathElement<SplitConditionT>* path_elements,
555
- const size_t* bin_segments, size_t num_groups,
556
- double* phis_interactions) {
557
- // Use shared memory for structs, otherwise nvcc puts in local memory
558
- __shared__ DatasetT s_X;
559
- s_X = X;
560
- __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
561
- PathElement<SplitConditionT>* e = &s_elements[threadIdx.x];
562
-
563
- size_t start_row, end_row;
564
- bool thread_active;
565
- ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
566
- s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e,
567
- &thread_active);
568
- uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
569
- if (!thread_active) return;
570
-
571
- auto labelled_group = active_labeled_partition(mask, e->path_idx);
572
-
573
- for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
574
- float phi = ComputePhi(*e, row_idx, X, labelled_group, e->zero_fraction);
575
- if (!e->IsRoot()) {
576
- auto phi_offset =
577
- IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
578
- e->feature_idx, e->feature_idx);
579
- atomicAddDouble(phis_interactions + phi_offset, phi);
580
- }
581
-
582
- for (auto condition_rank = 1ull; condition_rank < labelled_group.size();
583
- condition_rank++) {
584
- e = &s_elements[threadIdx.x];
585
- int64_t condition_feature =
586
- labelled_group.shfl(e->feature_idx, condition_rank);
587
- SwapConditionedElement(&e, s_elements, condition_rank, labelled_group);
588
- float x = ComputePhiCondition<GroupPath>(*e, row_idx, X, labelled_group,
589
- condition_feature);
590
- if (!e->IsRoot()) {
591
- auto phi_offset =
592
- IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
593
- e->feature_idx, condition_feature);
594
- atomicAddDouble(phis_interactions + phi_offset, x);
595
- // Subtract effect from diagonal
596
- auto phi_diag =
597
- IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
598
- e->feature_idx, e->feature_idx);
599
- atomicAddDouble(phis_interactions + phi_diag, -x);
600
- }
601
- }
602
- }
603
- }
604
-
605
- template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
606
- typename SplitConditionT>
607
- void ComputeShapInteractions(
608
- DatasetT X,
609
- const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
610
- const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
611
- path_elements,
612
- size_t num_groups, double* phis) {
613
- size_t bins_per_row = bin_segments.size() - 1;
614
- const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
615
- const int warps_per_block = kBlockThreads / 32;
616
- const int kRowsPerWarp = 100;
617
- size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
618
-
619
- const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
620
-
621
- ShapInteractionsKernel<DatasetT, kBlockThreads, kRowsPerWarp>
622
- <<<grid_size, kBlockThreads>>>(
623
- X, bins_per_row, path_elements.data().get(),
624
- bin_segments.data().get(), num_groups, phis);
625
- }
626
-
627
- template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
628
- typename SplitConditionT>
629
- __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
630
- ShapTaylorInteractionsKernel(
631
- DatasetT X, size_t bins_per_row,
632
- const PathElement<SplitConditionT>* path_elements,
633
- const size_t* bin_segments, size_t num_groups,
634
- double* phis_interactions) {
635
- // Use shared memory for structs, otherwise nvcc puts in local memory
636
- __shared__ DatasetT s_X;
637
- if (threadIdx.x == 0) {
638
- s_X = X;
639
- }
640
- __syncthreads();
641
- __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
642
- PathElement<SplitConditionT>* e = &s_elements[threadIdx.x];
643
-
644
- size_t start_row, end_row;
645
- bool thread_active;
646
- ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
647
- s_X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, e,
648
- &thread_active);
649
- uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
650
- if (!thread_active) return;
651
-
652
- auto labelled_group = active_labeled_partition(mask, e->path_idx);
653
-
654
- for (int64_t row_idx = start_row; row_idx < end_row; row_idx++) {
655
- for (auto condition_rank = 1ull; condition_rank < labelled_group.size();
656
- condition_rank++) {
657
- e = &s_elements[threadIdx.x];
658
- // Compute the diagonal terms
659
- // TODO(Rory): this can be more efficient
660
- float reduce_input =
661
- e->IsRoot() || labelled_group.thread_rank() == condition_rank
662
- ? 1.0f
663
- : e->zero_fraction;
664
- float reduce =
665
- labelled_group.reduce(reduce_input, thrust::multiplies<float>());
666
- if (labelled_group.thread_rank() == condition_rank) {
667
- float one_fraction = e->split_condition.EvaluateSplit(
668
- X.GetElement(row_idx, e->feature_idx));
669
- auto phi_offset =
670
- IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
671
- e->feature_idx, e->feature_idx);
672
- atomicAddDouble(phis_interactions + phi_offset,
673
- reduce * (one_fraction - e->zero_fraction) * e->v);
674
- }
675
-
676
- int64_t condition_feature =
677
- labelled_group.shfl(e->feature_idx, condition_rank);
678
-
679
- SwapConditionedElement(&e, s_elements, condition_rank, labelled_group);
680
-
681
- float x = ComputePhiCondition<TaylorGroupPath>(
682
- *e, row_idx, X, labelled_group, condition_feature);
683
- if (!e->IsRoot()) {
684
- auto phi_offset =
685
- IndexPhiInteractions(row_idx, num_groups, e->group, X.NumCols(),
686
- e->feature_idx, condition_feature);
687
- atomicAddDouble(phis_interactions + phi_offset, x);
688
- }
689
- }
690
- }
691
- }
692
-
693
- template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
694
- typename SplitConditionT>
695
- void ComputeShapTaylorInteractions(
696
- DatasetT X,
697
- const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
698
- const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
699
- path_elements,
700
- size_t num_groups, double* phis) {
701
- size_t bins_per_row = bin_segments.size() - 1;
702
- const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
703
- const int warps_per_block = kBlockThreads / 32;
704
- const int kRowsPerWarp = 100;
705
- size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
706
-
707
- const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
708
-
709
- ShapTaylorInteractionsKernel<DatasetT, kBlockThreads, kRowsPerWarp>
710
- <<<grid_size, kBlockThreads>>>(
711
- X, bins_per_row, path_elements.data().get(),
712
- bin_segments.data().get(), num_groups, phis);
713
- }
714
-
715
-
716
- inline __host__ __device__ int64_t Factorial(int64_t x) {
717
- int64_t y = 1;
718
- for (auto i = 2; i <= x; i++) {
719
- y *= i;
720
- }
721
- return y;
722
- }
723
-
724
- // Compute factorials in log space using lgamma to avoid overflow
725
- inline __host__ __device__ double W(double s, double n) {
726
- assert(n - s - 1 >= 0);
727
- return exp(lgamma(s + 1) - lgamma(n + 1) + lgamma(n - s));
728
- }
729
-
730
- template <typename DatasetT, size_t kBlockSize, size_t kRowsPerWarp,
731
- typename SplitConditionT>
732
- __global__ void __launch_bounds__(GPUTREESHAP_MAX_THREADS_PER_BLOCK)
733
- ShapInterventionalKernel(DatasetT X, DatasetT R, size_t bins_per_row,
734
- const PathElement<SplitConditionT>* path_elements,
735
- const size_t* bin_segments, size_t num_groups,
736
- double* phis) {
737
- // Cache W coefficients
738
- __shared__ float s_W[33][33];
739
- for (int i = threadIdx.x; i < 33 * 33; i += kBlockSize) {
740
- auto s = i % 33;
741
- auto n = i / 33;
742
- if (n - s - 1 >= 0) {
743
- s_W[s][n] = W(s, n);
744
- } else {
745
- s_W[s][n] = 0.0;
746
- }
747
- }
748
-
749
- __syncthreads();
750
-
751
- __shared__ PathElement<SplitConditionT> s_elements[kBlockSize];
752
- PathElement<SplitConditionT>& e = s_elements[threadIdx.x];
753
-
754
- size_t start_row, end_row;
755
- bool thread_active;
756
- ConfigureThread<DatasetT, kBlockSize, kRowsPerWarp>(
757
- X, bins_per_row, path_elements, bin_segments, &start_row, &end_row, &e,
758
- &thread_active);
759
-
760
- uint32_t mask = __ballot_sync(FULL_MASK, thread_active);
761
- if (!thread_active) return;
762
-
763
- auto labelled_group = active_labeled_partition(mask, e.path_idx);
764
-
765
- for (int64_t x_idx = start_row; x_idx < end_row; x_idx++) {
766
- float result = 0.0f;
767
- bool x_cond = e.EvaluateSplit(X, x_idx);
768
- uint32_t x_ballot = labelled_group.ballot(x_cond);
769
- for (int64_t r_idx = 0; r_idx < R.NumRows(); r_idx++) {
770
- bool r_cond = e.EvaluateSplit(R, r_idx);
771
- uint32_t r_ballot = labelled_group.ballot(r_cond);
772
- assert(!e.IsRoot() ||
773
- (x_cond == r_cond)); // These should be the same for the root
774
- uint32_t s = __popc(x_ballot & ~r_ballot);
775
- uint32_t n = __popc(x_ballot ^ r_ballot);
776
- float tmp = 0.0f;
777
- // Theorem 1
778
- if (x_cond && !r_cond) {
779
- tmp += s_W[s - 1][n];
780
- }
781
- tmp -= s_W[s][n] * (r_cond && !x_cond);
782
-
783
- // No foreground samples make it to this leaf, increment bias
784
- if (e.IsRoot() && s == 0) {
785
- tmp += 1.0f;
786
- }
787
- // If neither foreground or background go down this path, ignore this path
788
- bool reached_leaf = !labelled_group.ballot(!x_cond && !r_cond);
789
- tmp *= reached_leaf;
790
- result += tmp;
791
- }
792
-
793
- if (result != 0.0) {
794
- result /= R.NumRows();
795
-
796
- // Root writes bias
797
- auto feature = e.IsRoot() ? X.NumCols() : e.feature_idx;
798
- atomicAddDouble(
799
- &phis[IndexPhi(x_idx, num_groups, e.group, X.NumCols(), feature)],
800
- result * e.v);
801
- }
802
- }
803
- }
804
-
805
- template <typename DatasetT, typename SizeTAllocatorT, typename PathAllocatorT,
806
- typename SplitConditionT>
807
- void ComputeShapInterventional(
808
- DatasetT X, DatasetT R,
809
- const thrust::device_vector<size_t, SizeTAllocatorT>& bin_segments,
810
- const thrust::device_vector<PathElement<SplitConditionT>, PathAllocatorT>&
811
- path_elements,
812
- size_t num_groups, double* phis) {
813
- size_t bins_per_row = bin_segments.size() - 1;
814
- const int kBlockThreads = GPUTREESHAP_MAX_THREADS_PER_BLOCK;
815
- const int warps_per_block = kBlockThreads / 32;
816
- const int kRowsPerWarp = 100;
817
- size_t warps_needed = bins_per_row * DivRoundUp(X.NumRows(), kRowsPerWarp);
818
-
819
- const uint32_t grid_size = DivRoundUp(warps_needed, warps_per_block);
820
-
821
- ShapInterventionalKernel<DatasetT, kBlockThreads, kRowsPerWarp>
822
- <<<grid_size, kBlockThreads>>>(
823
- X, R, bins_per_row, path_elements.data().get(),
824
- bin_segments.data().get(), num_groups, phis);
825
- }
826
-
827
- template <typename PathVectorT, typename SizeVectorT, typename DeviceAllocatorT>
828
- void GetBinSegments(const PathVectorT& paths, const SizeVectorT& bin_map,
829
- SizeVectorT* bin_segments) {
830
- DeviceAllocatorT alloc;
831
- size_t num_bins =
832
- thrust::reduce(thrust::cuda::par(alloc), bin_map.begin(), bin_map.end(),
833
- size_t(0), thrust::maximum<size_t>()) +
834
- 1;
835
- bin_segments->resize(num_bins + 1, 0);
836
- auto counting = thrust::make_counting_iterator(0llu);
837
- auto d_paths = paths.data().get();
838
- auto d_bin_segments = bin_segments->data().get();
839
- auto d_bin_map = bin_map.data();
840
- thrust::for_each_n(counting, paths.size(), [=] __device__(size_t idx) {
841
- auto path_idx = d_paths[idx].path_idx;
842
- atomicAdd(reinterpret_cast<unsigned long long*>(d_bin_segments) + // NOLINT
843
- d_bin_map[path_idx],
844
- 1);
845
- });
846
- thrust::exclusive_scan(thrust::cuda::par(alloc), bin_segments->begin(),
847
- bin_segments->end(), bin_segments->begin());
848
- }
849
-
850
- struct DeduplicateKeyTransformOp {
851
- template <typename SplitConditionT>
852
- __device__ thrust::pair<size_t, int64_t> operator()(
853
- const PathElement<SplitConditionT>& e) {
854
- return {e.path_idx, e.feature_idx};
855
- }
856
- };
857
-
858
- inline void CheckCuda(cudaError_t err) {
859
- if (err != cudaSuccess) {
860
- throw thrust::system_error(err, thrust::cuda_category());
861
- }
862
- }
863
-
864
- template <typename Return>
865
- class DiscardOverload : public thrust::discard_iterator<Return> {
866
- public:
867
- using value_type = Return; // NOLINT
868
- };
869
-
870
- template <typename PathVectorT, typename DeviceAllocatorT,
871
- typename SplitConditionT>
872
- void DeduplicatePaths(PathVectorT* device_paths,
873
- PathVectorT* deduplicated_paths) {
874
- DeviceAllocatorT alloc;
875
- // Sort by feature
876
- thrust::sort(thrust::cuda::par(alloc), device_paths->begin(),
877
- device_paths->end(),
878
- [=] __device__(const PathElement<SplitConditionT>& a,
879
- const PathElement<SplitConditionT>& b) {
880
- if (a.path_idx < b.path_idx) return true;
881
- if (b.path_idx < a.path_idx) return false;
882
-
883
- if (a.feature_idx < b.feature_idx) return true;
884
- if (b.feature_idx < a.feature_idx) return false;
885
- return false;
886
- });
887
-
888
- deduplicated_paths->resize(device_paths->size());
889
-
890
- using Pair = thrust::pair<size_t, int64_t>;
891
- auto key_transform = thrust::make_transform_iterator(
892
- device_paths->begin(), DeduplicateKeyTransformOp());
893
-
894
- thrust::device_vector<size_t> d_num_runs_out(1);
895
- size_t* h_num_runs_out;
896
- CheckCuda(cudaMallocHost(&h_num_runs_out, sizeof(size_t)));
897
-
898
- auto combine = [] __device__(PathElement<SplitConditionT> a,
899
- PathElement<SplitConditionT> b) {
900
- // Combine duplicate features
901
- a.split_condition.Merge(b.split_condition);
902
- a.zero_fraction *= b.zero_fraction;
903
- return a;
904
- }; // NOLINT
905
- size_t temp_size = 0;
906
- CheckCuda(cub::DeviceReduce::ReduceByKey(
907
- nullptr, temp_size, key_transform, DiscardOverload<Pair>(),
908
- device_paths->begin(), deduplicated_paths->begin(),
909
- d_num_runs_out.begin(), combine, device_paths->size()));
910
- using TempAlloc = RebindVector<char, DeviceAllocatorT>;
911
- TempAlloc tmp(temp_size);
912
- CheckCuda(cub::DeviceReduce::ReduceByKey(
913
- tmp.data().get(), temp_size, key_transform, DiscardOverload<Pair>(),
914
- device_paths->begin(), deduplicated_paths->begin(),
915
- d_num_runs_out.begin(), combine, device_paths->size()));
916
-
917
- CheckCuda(cudaMemcpy(h_num_runs_out, d_num_runs_out.data().get(),
918
- sizeof(size_t), cudaMemcpyDeviceToHost));
919
- deduplicated_paths->resize(*h_num_runs_out);
920
- CheckCuda(cudaFreeHost(h_num_runs_out));
921
- }
922
-
923
- template <typename PathVectorT, typename SplitConditionT, typename SizeVectorT,
924
- typename DeviceAllocatorT>
925
- void SortPaths(PathVectorT* paths, const SizeVectorT& bin_map) {
926
- auto d_bin_map = bin_map.data();
927
- DeviceAllocatorT alloc;
928
- thrust::sort(thrust::cuda::par(alloc), paths->begin(), paths->end(),
929
- [=] __device__(const PathElement<SplitConditionT>& a,
930
- const PathElement<SplitConditionT>& b) {
931
- size_t a_bin = d_bin_map[a.path_idx];
932
- size_t b_bin = d_bin_map[b.path_idx];
933
- if (a_bin < b_bin) return true;
934
- if (b_bin < a_bin) return false;
935
-
936
- if (a.path_idx < b.path_idx) return true;
937
- if (b.path_idx < a.path_idx) return false;
938
-
939
- if (a.feature_idx < b.feature_idx) return true;
940
- if (b.feature_idx < a.feature_idx) return false;
941
- return false;
942
- });
943
- }
944
-
945
- using kv = std::pair<size_t, int>;
946
-
947
- struct BFDCompare {
948
- bool operator()(const kv& lhs, const kv& rhs) const {
949
- if (lhs.second == rhs.second) {
950
- return lhs.first < rhs.first;
951
- }
952
- return lhs.second < rhs.second;
953
- }
954
- };
955
-
956
- // Best Fit Decreasing bin packing
957
- // Efficient O(nlogn) implementation with balanced tree using std::set
958
- template <typename IntVectorT>
959
- std::vector<size_t> BFDBinPacking(const IntVectorT& counts,
960
- int bin_limit = 32) {
961
- thrust::host_vector<int> counts_host(counts);
962
- std::vector<kv> path_lengths(counts_host.size());
963
- for (auto i = 0ull; i < counts_host.size(); i++) {
964
- path_lengths[i] = {i, counts_host[i]};
965
- }
966
-
967
- std::sort(path_lengths.begin(), path_lengths.end(),
968
- [&](const kv& a, const kv& b) {
969
- std::greater<> op;
970
- return op(a.second, b.second);
971
- });
972
-
973
- // map unique_id -> bin
974
- std::vector<size_t> bin_map(counts_host.size());
975
- std::set<kv, BFDCompare> bin_capacities;
976
- bin_capacities.insert({bin_capacities.size(), bin_limit});
977
- for (auto pair : path_lengths) {
978
- int new_size = pair.second;
979
- auto itr = bin_capacities.lower_bound({0, new_size});
980
- // Does not fit in any bin
981
- if (itr == bin_capacities.end()) {
982
- size_t new_bin_idx = bin_capacities.size();
983
- bin_capacities.insert({new_bin_idx, bin_limit - new_size});
984
- bin_map[pair.first] = new_bin_idx;
985
- } else {
986
- kv entry = *itr;
987
- entry.second -= new_size;
988
- bin_map[pair.first] = entry.first;
989
- bin_capacities.erase(itr);
990
- bin_capacities.insert(entry);
991
- }
992
- }
993
-
994
- return bin_map;
995
- }
996
-
997
- // First Fit Decreasing bin packing
998
- // Inefficient O(n^2) implementation
999
- template <typename IntVectorT>
1000
- std::vector<size_t> FFDBinPacking(const IntVectorT& counts,
1001
- int bin_limit = 32) {
1002
- thrust::host_vector<int> counts_host(counts);
1003
- std::vector<kv> path_lengths(counts_host.size());
1004
- for (auto i = 0ull; i < counts_host.size(); i++) {
1005
- path_lengths[i] = {i, counts_host[i]};
1006
- }
1007
- std::sort(path_lengths.begin(), path_lengths.end(),
1008
- [&](const kv& a, const kv& b) {
1009
- std::greater<> op;
1010
- return op(a.second, b.second);
1011
- });
1012
-
1013
- // map unique_id -> bin
1014
- std::vector<size_t> bin_map(counts_host.size());
1015
- std::vector<int> bin_capacities(path_lengths.size(), bin_limit);
1016
- for (auto pair : path_lengths) {
1017
- int new_size = pair.second;
1018
- for (auto j = 0ull; j < bin_capacities.size(); j++) {
1019
- int& capacity = bin_capacities[j];
1020
-
1021
- if (capacity >= new_size) {
1022
- capacity -= new_size;
1023
- bin_map[pair.first] = j;
1024
- break;
1025
- }
1026
- }
1027
- }
1028
-
1029
- return bin_map;
1030
- }
1031
-
1032
- // Next Fit bin packing
1033
- // O(n) implementation
1034
- template <typename IntVectorT>
1035
- std::vector<size_t> NFBinPacking(const IntVectorT& counts, int bin_limit = 32) {
1036
- thrust::host_vector<int> counts_host(counts);
1037
- std::vector<size_t> bin_map(counts_host.size());
1038
- size_t current_bin = 0;
1039
- int current_capacity = bin_limit;
1040
- for (auto i = 0ull; i < counts_host.size(); i++) {
1041
- int new_size = counts_host[i];
1042
- size_t path_idx = i;
1043
- if (new_size <= current_capacity) {
1044
- current_capacity -= new_size;
1045
- bin_map[path_idx] = current_bin;
1046
- } else {
1047
- current_capacity = bin_limit - new_size;
1048
- bin_map[path_idx] = ++current_bin;
1049
- }
1050
- }
1051
- return bin_map;
1052
- }
1053
-
1054
- template <typename DeviceAllocatorT, typename SplitConditionT,
1055
- typename PathVectorT, typename LengthVectorT>
1056
- void GetPathLengths(const PathVectorT& device_paths,
1057
- LengthVectorT* path_lengths) {
1058
- path_lengths->resize(
1059
- static_cast<PathElement<SplitConditionT>>(device_paths.back()).path_idx +
1060
- 1,
1061
- 0);
1062
- auto counting = thrust::make_counting_iterator(0llu);
1063
- auto d_paths = device_paths.data().get();
1064
- auto d_lengths = path_lengths->data().get();
1065
- thrust::for_each_n(counting, device_paths.size(), [=] __device__(size_t idx) {
1066
- auto path_idx = d_paths[idx].path_idx;
1067
- atomicAdd(d_lengths + path_idx, 1ull);
1068
- });
1069
- }
1070
-
1071
- struct PathTooLongOp {
1072
- __device__ size_t operator()(size_t length) { return length > 32; }
1073
- };
1074
-
1075
- template <typename SplitConditionT>
1076
- struct IncorrectVOp {
1077
- const PathElement<SplitConditionT>* paths;
1078
- __device__ size_t operator()(size_t idx) {
1079
- auto a = paths[idx - 1];
1080
- auto b = paths[idx];
1081
- return a.path_idx == b.path_idx && a.v != b.v;
1082
- }
1083
- };
1084
-
1085
- template <typename DeviceAllocatorT, typename SplitConditionT,
1086
- typename PathVectorT, typename LengthVectorT>
1087
- void ValidatePaths(const PathVectorT& device_paths,
1088
- const LengthVectorT& path_lengths) {
1089
- DeviceAllocatorT alloc;
1090
- PathTooLongOp too_long_op;
1091
- auto invalid_length =
1092
- thrust::any_of(thrust::cuda::par(alloc), path_lengths.begin(),
1093
- path_lengths.end(), too_long_op);
1094
-
1095
- if (invalid_length) {
1096
- throw std::invalid_argument("Tree depth must be < 32");
1097
- }
1098
-
1099
- IncorrectVOp<SplitConditionT> incorrect_v_op{device_paths.data().get()};
1100
- auto counting = thrust::counting_iterator<size_t>(0);
1101
- auto incorrect_v =
1102
- thrust::any_of(thrust::cuda::par(alloc), counting + 1,
1103
- counting + device_paths.size(), incorrect_v_op);
1104
-
1105
- if (incorrect_v) {
1106
- throw std::invalid_argument(
1107
- "Leaf value v should be the same across a single path");
1108
- }
1109
- }
1110
-
1111
- template <typename DeviceAllocatorT, typename SplitConditionT,
1112
- typename PathVectorT, typename SizeVectorT>
1113
- void PreprocessPaths(PathVectorT* device_paths, PathVectorT* deduplicated_paths,
1114
- SizeVectorT* bin_segments) {
1115
- // Sort paths by length and feature
1116
- detail::DeduplicatePaths<PathVectorT, DeviceAllocatorT, SplitConditionT>(
1117
- device_paths, deduplicated_paths);
1118
- using int_vector = RebindVector<int, DeviceAllocatorT>;
1119
- int_vector path_lengths;
1120
- detail::GetPathLengths<DeviceAllocatorT, SplitConditionT>(*deduplicated_paths,
1121
- &path_lengths);
1122
- SizeVectorT device_bin_map = detail::BFDBinPacking(path_lengths);
1123
- ValidatePaths<DeviceAllocatorT, SplitConditionT>(*deduplicated_paths,
1124
- path_lengths);
1125
- detail::SortPaths<PathVectorT, SplitConditionT, SizeVectorT,
1126
- DeviceAllocatorT>(deduplicated_paths, device_bin_map);
1127
- detail::GetBinSegments<PathVectorT, SizeVectorT, DeviceAllocatorT>(
1128
- *deduplicated_paths, device_bin_map, bin_segments);
1129
- }
1130
-
1131
- struct PathIdxTransformOp {
1132
- template <typename SplitConditionT>
1133
- __device__ size_t operator()(const PathElement<SplitConditionT>& e) {
1134
- return e.path_idx;
1135
- }
1136
- };
1137
-
1138
- struct GroupIdxTransformOp {
1139
- template <typename SplitConditionT>
1140
- __device__ size_t operator()(const PathElement<SplitConditionT>& e) {
1141
- return e.group;
1142
- }
1143
- };
1144
-
1145
- struct BiasTransformOp {
1146
- template <typename SplitConditionT>
1147
- __device__ double operator()(const PathElement<SplitConditionT>& e) {
1148
- return e.zero_fraction * e.v;
1149
- }
1150
- };
1151
-
1152
- // While it is possible to compute bias in the primary kernel, we do it here
1153
- // using double precision to avoid numerical stability issues
1154
- template <typename PathVectorT, typename DoubleVectorT,
1155
- typename DeviceAllocatorT, typename SplitConditionT>
1156
- void ComputeBias(const PathVectorT& device_paths, DoubleVectorT* bias) {
1157
- using double_vector = thrust::device_vector<
1158
- double, typename DeviceAllocatorT::template rebind<double>::other>;
1159
- PathVectorT sorted_paths(device_paths);
1160
- DeviceAllocatorT alloc;
1161
- // Make sure groups are contiguous
1162
- thrust::sort(thrust::cuda::par(alloc), sorted_paths.begin(),
1163
- sorted_paths.end(),
1164
- [=] __device__(const PathElement<SplitConditionT>& a,
1165
- const PathElement<SplitConditionT>& b) {
1166
- if (a.group < b.group) return true;
1167
- if (b.group < a.group) return false;
1168
-
1169
- if (a.path_idx < b.path_idx) return true;
1170
- if (b.path_idx < a.path_idx) return false;
1171
-
1172
- return false;
1173
- });
1174
- // Combine zero fraction for all paths
1175
- auto path_key = thrust::make_transform_iterator(sorted_paths.begin(),
1176
- PathIdxTransformOp());
1177
- PathVectorT combined(sorted_paths.size());
1178
- auto combined_out = thrust::reduce_by_key(
1179
- thrust::cuda ::par(alloc), path_key, path_key + sorted_paths.size(),
1180
- sorted_paths.begin(), thrust::make_discard_iterator(), combined.begin(),
1181
- thrust::equal_to<size_t>(),
1182
- [=] __device__(PathElement<SplitConditionT> a,
1183
- const PathElement<SplitConditionT>& b) {
1184
- a.zero_fraction *= b.zero_fraction;
1185
- return a;
1186
- });
1187
- size_t num_paths = combined_out.second - combined.begin();
1188
- // Combine bias for each path, over each group
1189
- using size_vector = thrust::device_vector<
1190
- size_t, typename DeviceAllocatorT::template rebind<size_t>::other>;
1191
- size_vector keys_out(num_paths);
1192
- double_vector values_out(num_paths);
1193
- auto group_key =
1194
- thrust::make_transform_iterator(combined.begin(), GroupIdxTransformOp());
1195
- auto values =
1196
- thrust::make_transform_iterator(combined.begin(), BiasTransformOp());
1197
-
1198
- auto out_itr = thrust::reduce_by_key(thrust::cuda::par(alloc), group_key,
1199
- group_key + num_paths, values,
1200
- keys_out.begin(), values_out.begin());
1201
-
1202
- // Write result
1203
- size_t n = out_itr.first - keys_out.begin();
1204
- auto counting = thrust::make_counting_iterator(0llu);
1205
- auto d_keys_out = keys_out.data().get();
1206
- auto d_values_out = values_out.data().get();
1207
- auto d_bias = bias->data().get();
1208
- thrust::for_each_n(counting, n, [=] __device__(size_t idx) {
1209
- d_bias[d_keys_out[idx]] = d_values_out[idx];
1210
- });
1211
- }
1212
-
1213
- }; // namespace detail
1214
-
1215
- /*!
1216
- * Compute feature contributions on the GPU given a set of unique paths through
1217
- * a tree ensemble and a dataset. Uses device memory proportional to the tree
1218
- * ensemble size.
1219
- *
1220
- * \exception std::invalid_argument Thrown when an invalid argument error
1221
- * condition occurs. \tparam PathIteratorT Thrust type iterator, may be
1222
- * thrust::device_ptr for device memory, or stl iterator/raw pointer for host
1223
- * memory. \tparam PhiIteratorT Thrust type iterator, may be
1224
- * thrust::device_ptr for device memory, or stl iterator/raw pointer for host
1225
- * memory. Value type must be floating point. \tparam DatasetT User-specified
1226
- * dataset container. \tparam DeviceAllocatorT Optional thrust style
1227
- * allocator.
1228
- *
1229
- * \param X Thin wrapper over a dataset allocated in device memory. X
1230
- * should be trivially copyable as a kernel parameter (i.e. contain only
1231
- * pointers to actual data) and must implement the methods
1232
- * NumRows()/NumCols()/GetElement(size_t row_idx, size_t col_idx) as __device__
1233
- * functions. GetElement may return NaN where the feature value is missing.
1234
- * \param begin Iterator to paths, where separate paths are delineated by
1235
- * PathElement.path_idx. Each unique path should contain 1
1236
- * root with feature_idx = -1 and zero_fraction = 1.0. The ordering of path
1237
- * elements inside a unique path does not matter - the result will be the same.
1238
- * Paths may contain duplicate features. See the PathElement class for more
1239
- * information. \param end Path end iterator. \param num_groups Number
1240
- * of output groups. In multiclass classification the algorithm outputs feature
1241
- * contributions per output class. \param phis_begin Begin iterator for output
1242
- * phis. \param phis_end End iterator for output phis.
1243
- */
1244
- template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1245
- typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1246
- void GPUTreeShap(DatasetT X, PathIteratorT begin, PathIteratorT end,
1247
- size_t num_groups, PhiIteratorT phis_begin,
1248
- PhiIteratorT phis_end) {
1249
- if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1250
-
1251
- if (size_t(phis_end - phis_begin) <
1252
- X.NumRows() * (X.NumCols() + 1) * num_groups) {
1253
- throw std::invalid_argument(
1254
- "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1255
- "num_groups");
1256
- }
1257
-
1258
- using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1259
- using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1260
- using path_vector = detail::RebindVector<
1261
- typename std::iterator_traits<PathIteratorT>::value_type,
1262
- DeviceAllocatorT>;
1263
- using split_condition =
1264
- typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1265
-
1266
- // Compute the global bias
1267
- double_vector temp_phi(phis_end - phis_begin, 0.0);
1268
- path_vector device_paths(begin, end);
1269
- double_vector bias(num_groups, 0.0);
1270
- detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT,
1271
- split_condition>(device_paths, &bias);
1272
- auto d_bias = bias.data().get();
1273
- auto d_temp_phi = temp_phi.data().get();
1274
- thrust::for_each_n(thrust::make_counting_iterator(0llu),
1275
- X.NumRows() * num_groups, [=] __device__(size_t idx) {
1276
- size_t group = idx % num_groups;
1277
- size_t row_idx = idx / num_groups;
1278
- d_temp_phi[IndexPhi(row_idx, num_groups, group,
1279
- X.NumCols(), X.NumCols())] +=
1280
- d_bias[group];
1281
- });
1282
-
1283
- path_vector deduplicated_paths;
1284
- size_vector device_bin_segments;
1285
- detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1286
- &device_paths, &deduplicated_paths, &device_bin_segments);
1287
-
1288
- detail::ComputeShap(X, device_bin_segments, deduplicated_paths, num_groups,
1289
- temp_phi.data().get());
1290
- thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1291
- }
1292
-
1293
- /*!
1294
- * Compute feature interaction contributions on the GPU given a set of unique
1295
- * paths through a tree ensemble and a dataset. Uses device memory
1296
- * proportional to the tree ensemble size.
1297
- *
1298
- * \exception std::invalid_argument Thrown when an invalid argument error
1299
- * condition occurs.
1300
- * \tparam DeviceAllocatorT Optional thrust style allocator.
1301
- * \tparam DatasetT User-specified dataset container.
1302
- * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr
1303
- * for device memory, or stl iterator/raw pointer for
1304
- * host memory.
1305
- * \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr
1306
- * for device memory, or stl iterator/raw pointer for
1307
- * host memory. Value type must be floating point.
1308
- *
1309
- * \param X Thin wrapper over a dataset allocated in device memory. X
1310
- * should be trivially copyable as a kernel parameter (i.e.
1311
- * contain only pointers to actual data) and must implement
1312
- * the methods NumRows()/NumCols()/GetElement(size_t row_idx,
1313
- * size_t col_idx) as __device__ functions. GetElement may
1314
- * return NaN where the feature value is missing.
1315
- * \param begin Iterator to paths, where separate paths are delineated by
1316
- * PathElement.path_idx. Each unique path should contain 1
1317
- * root with feature_idx = -1 and zero_fraction = 1.0. The
1318
- * ordering of path elements inside a unique path does not
1319
- * matter - the result will be the same. Paths may contain
1320
- * duplicate features. See the PathElement class for more
1321
- * information.
1322
- * \param end Path end iterator.
1323
- * \param num_groups Number of output groups. In multiclass classification the
1324
- * algorithm outputs feature contributions per output class.
1325
- * \param phis_begin Begin iterator for output phis.
1326
- * \param phis_end End iterator for output phis.
1327
- */
1328
- template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1329
- typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1330
- void GPUTreeShapInteractions(DatasetT X, PathIteratorT begin, PathIteratorT end,
1331
- size_t num_groups, PhiIteratorT phis_begin,
1332
- PhiIteratorT phis_end) {
1333
- if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1334
- if (size_t(phis_end - phis_begin) <
1335
- X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) {
1336
- throw std::invalid_argument(
1337
- "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1338
- "(X.NumCols() + 1) * "
1339
- "num_groups");
1340
- }
1341
-
1342
- using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1343
- using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1344
- using path_vector = detail::RebindVector<
1345
- typename std::iterator_traits<PathIteratorT>::value_type,
1346
- DeviceAllocatorT>;
1347
- using split_condition =
1348
- typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1349
-
1350
- // Compute the global bias
1351
- double_vector temp_phi(phis_end - phis_begin, 0.0);
1352
- path_vector device_paths(begin, end);
1353
- double_vector bias(num_groups, 0.0);
1354
- detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT,
1355
- split_condition>(device_paths, &bias);
1356
- auto d_bias = bias.data().get();
1357
- auto d_temp_phi = temp_phi.data().get();
1358
- thrust::for_each_n(
1359
- thrust::make_counting_iterator(0llu), X.NumRows() * num_groups,
1360
- [=] __device__(size_t idx) {
1361
- size_t group = idx % num_groups;
1362
- size_t row_idx = idx / num_groups;
1363
- d_temp_phi[IndexPhiInteractions(row_idx, num_groups, group, X.NumCols(),
1364
- X.NumCols(), X.NumCols())] +=
1365
- d_bias[group];
1366
- });
1367
-
1368
- path_vector deduplicated_paths;
1369
- size_vector device_bin_segments;
1370
- detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1371
- &device_paths, &deduplicated_paths, &device_bin_segments);
1372
-
1373
- detail::ComputeShapInteractions(X, device_bin_segments, deduplicated_paths,
1374
- num_groups, temp_phi.data().get());
1375
- thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1376
- }
1377
-
1378
- /*!
1379
- * Compute feature interaction contributions using the Shapley Taylor index on
1380
- * the GPU, given a set of unique paths through a tree ensemble and a dataset.
1381
- * Uses device memory proportional to the tree ensemble size.
1382
- *
1383
- * \exception std::invalid_argument Thrown when an invalid argument error
1384
- * condition occurs.
1385
- * \tparam PhiIteratorT Thrust type iterator, may be thrust::device_ptr
1386
- * for device memory, or stl iterator/raw pointer for
1387
- * host memory. Value type must be floating point.
1388
- * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr
1389
- * for device memory, or stl iterator/raw pointer for
1390
- * host memory.
1391
- * \tparam DatasetT User-specified dataset container.
1392
- * \tparam DeviceAllocatorT Optional thrust style allocator.
1393
- *
1394
- * \param X Thin wrapper over a dataset allocated in device memory. X
1395
- * should be trivially copyable as a kernel parameter (i.e.
1396
- * contain only pointers to actual data) and must implement
1397
- * the methods NumRows()/NumCols()/GetElement(size_t row_idx,
1398
- * size_t col_idx) as __device__ functions. GetElement may
1399
- * return NaN where the feature value is missing.
1400
- * \param begin Iterator to paths, where separate paths are delineated by
1401
- * PathElement.path_idx. Each unique path should contain 1
1402
- * root with feature_idx = -1 and zero_fraction = 1.0. The
1403
- * ordering of path elements inside a unique path does not
1404
- * matter - the result will be the same. Paths may contain
1405
- * duplicate features. See the PathElement class for more
1406
- * information.
1407
- * \param end Path end iterator.
1408
- * \param num_groups Number of output groups. In multiclass classification the
1409
- * algorithm outputs feature contributions per output class.
1410
- * \param phis_begin Begin iterator for output phis.
1411
- * \param phis_end End iterator for output phis.
1412
- */
1413
- template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1414
- typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1415
- void GPUTreeShapTaylorInteractions(DatasetT X, PathIteratorT begin,
1416
- PathIteratorT end, size_t num_groups,
1417
- PhiIteratorT phis_begin,
1418
- PhiIteratorT phis_end) {
1419
- using phis_type = typename std::iterator_traits<PhiIteratorT>::value_type;
1420
- static_assert(std::is_floating_point<phis_type>::value,
1421
- "Phis type must be floating point");
1422
-
1423
- if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1424
-
1425
- if (size_t(phis_end - phis_begin) <
1426
- X.NumRows() * (X.NumCols() + 1) * (X.NumCols() + 1) * num_groups) {
1427
- throw std::invalid_argument(
1428
- "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1429
- "(X.NumCols() + 1) * "
1430
- "num_groups");
1431
- }
1432
-
1433
- using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1434
- using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1435
- using path_vector = detail::RebindVector<
1436
- typename std::iterator_traits<PathIteratorT>::value_type,
1437
- DeviceAllocatorT>;
1438
- using split_condition =
1439
- typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1440
-
1441
- // Compute the global bias
1442
- double_vector temp_phi(phis_end - phis_begin, 0.0);
1443
- path_vector device_paths(begin, end);
1444
- double_vector bias(num_groups, 0.0);
1445
- detail::ComputeBias<path_vector, double_vector, DeviceAllocatorT,
1446
- split_condition>(device_paths, &bias);
1447
- auto d_bias = bias.data().get();
1448
- auto d_temp_phi = temp_phi.data().get();
1449
- thrust::for_each_n(
1450
- thrust::make_counting_iterator(0llu), X.NumRows() * num_groups,
1451
- [=] __device__(size_t idx) {
1452
- size_t group = idx % num_groups;
1453
- size_t row_idx = idx / num_groups;
1454
- d_temp_phi[IndexPhiInteractions(row_idx, num_groups, group, X.NumCols(),
1455
- X.NumCols(), X.NumCols())] +=
1456
- d_bias[group];
1457
- });
1458
-
1459
- path_vector deduplicated_paths;
1460
- size_vector device_bin_segments;
1461
- detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1462
- &device_paths, &deduplicated_paths, &device_bin_segments);
1463
-
1464
- detail::ComputeShapTaylorInteractions(X, device_bin_segments,
1465
- deduplicated_paths, num_groups,
1466
- temp_phi.data().get());
1467
- thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1468
- }
1469
-
1470
- /*!
1471
- * Compute feature contributions on the GPU given a set of unique paths through a tree ensemble
1472
- * and a dataset. Uses device memory proportional to the tree ensemble size. This variant
1473
- * implements the interventional tree shap algorithm described here:
1474
- * https://drafts.distill.pub/HughChen/its_blog/
1475
- *
1476
- * It requires a background dataset R.
1477
- *
1478
- * \exception std::invalid_argument Thrown when an invalid argument error condition occurs.
1479
- * \tparam DeviceAllocatorT Optional thrust style allocator.
1480
- * \tparam DatasetT User-specified dataset container.
1481
- * \tparam PathIteratorT Thrust type iterator, may be thrust::device_ptr for device memory, or
1482
- * stl iterator/raw pointer for host memory.
1483
- *
1484
- * \param X Thin wrapper over a dataset allocated in device memory. X should be trivially
1485
- * copyable as a kernel parameter (i.e. contain only pointers to actual data) and
1486
- * must implement the methods NumRows()/NumCols()/GetElement(size_t row_idx,
1487
- * size_t col_idx) as __device__ functions. GetElement may return NaN where the
1488
- * feature value is missing.
1489
- * \param R Background dataset.
1490
- * \param begin Iterator to paths, where separate paths are delineated by
1491
- * PathElement.path_idx. Each unique path should contain 1 root with feature_idx =
1492
- * -1 and zero_fraction = 1.0. The ordering of path elements inside a unique path
1493
- * does not matter - the result will be the same. Paths may contain duplicate
1494
- * features. See the PathElement class for more information.
1495
- * \param end Path end iterator.
1496
- * \param num_groups Number of output groups. In multiclass classification the algorithm outputs
1497
- * feature contributions per output class.
1498
- * \param phis_begin Begin iterator for output phis.
1499
- * \param phis_end End iterator for output phis.
1500
- */
1501
- template <typename DeviceAllocatorT = thrust::device_allocator<int>,
1502
- typename DatasetT, typename PathIteratorT, typename PhiIteratorT>
1503
- void GPUTreeShapInterventional(DatasetT X, DatasetT R, PathIteratorT begin,
1504
- PathIteratorT end, size_t num_groups,
1505
- PhiIteratorT phis_begin, PhiIteratorT phis_end) {
1506
- if (X.NumRows() == 0 || X.NumCols() == 0 || end - begin <= 0) return;
1507
-
1508
- if (size_t(phis_end - phis_begin) <
1509
- X.NumRows() * (X.NumCols() + 1) * num_groups) {
1510
- throw std::invalid_argument(
1511
- "phis_out must be at least of size X.NumRows() * (X.NumCols() + 1) * "
1512
- "num_groups");
1513
- }
1514
-
1515
- using size_vector = detail::RebindVector<size_t, DeviceAllocatorT>;
1516
- using double_vector = detail::RebindVector<double, DeviceAllocatorT>;
1517
- using path_vector = detail::RebindVector<
1518
- typename std::iterator_traits<PathIteratorT>::value_type,
1519
- DeviceAllocatorT>;
1520
- using split_condition =
1521
- typename std::iterator_traits<PathIteratorT>::value_type::split_type;
1522
-
1523
- double_vector temp_phi(phis_end - phis_begin, 0.0);
1524
- path_vector device_paths(begin, end);
1525
-
1526
- path_vector deduplicated_paths;
1527
- size_vector device_bin_segments;
1528
- detail::PreprocessPaths<DeviceAllocatorT, split_condition>(
1529
- &device_paths, &deduplicated_paths, &device_bin_segments);
1530
- detail::ComputeShapInterventional(X, R, device_bin_segments,
1531
- deduplicated_paths, num_groups,
1532
- temp_phi.data().get());
1533
- thrust::copy(temp_phi.begin(), temp_phi.end(), phis_begin);
1534
- }
1535
- } // namespace gpu_treeshap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/cext/tree_shap.h DELETED
@@ -1,1460 +0,0 @@
1
- /**
2
- * Fast recursive computation of SHAP values in trees.
3
- * See https://arxiv.org/abs/1802.03888 for details.
4
- *
5
- * Scott Lundberg, 2018 (independent algorithm courtesy of Hugh Chen 2018)
6
- */
7
-
8
- #include <algorithm>
9
- #include <iostream>
10
- #include <fstream>
11
- #include <stdio.h>
12
- #include <cmath>
13
- #include <ctime>
14
- #if defined(_WIN32) || defined(WIN32)
15
- #include <malloc.h>
16
- #elif defined(__MVS__)
17
- #include <stdlib.h>
18
- #else
19
- #include <alloca.h>
20
- #endif
21
- using namespace std;
22
-
23
- typedef double tfloat;
24
- typedef tfloat (* transform_f)(const tfloat margin, const tfloat y);
25
-
26
- namespace FEATURE_DEPENDENCE {
27
- const unsigned independent = 0;
28
- const unsigned tree_path_dependent = 1;
29
- const unsigned global_path_dependent = 2;
30
- }
31
-
32
- struct TreeEnsemble {
33
- int *children_left;
34
- int *children_right;
35
- int *children_default;
36
- int *features;
37
- tfloat *thresholds;
38
- tfloat *values;
39
- tfloat *node_sample_weights;
40
- unsigned max_depth;
41
- unsigned tree_limit;
42
- tfloat *base_offset;
43
- unsigned max_nodes;
44
- unsigned num_outputs;
45
-
46
- TreeEnsemble() {}
47
- TreeEnsemble(int *children_left, int *children_right, int *children_default, int *features,
48
- tfloat *thresholds, tfloat *values, tfloat *node_sample_weights,
49
- unsigned max_depth, unsigned tree_limit, tfloat *base_offset,
50
- unsigned max_nodes, unsigned num_outputs) :
51
- children_left(children_left), children_right(children_right),
52
- children_default(children_default), features(features), thresholds(thresholds),
53
- values(values), node_sample_weights(node_sample_weights),
54
- max_depth(max_depth), tree_limit(tree_limit),
55
- base_offset(base_offset), max_nodes(max_nodes), num_outputs(num_outputs) {}
56
-
57
- void get_tree(TreeEnsemble &tree, const unsigned i) const {
58
- const unsigned d = i * max_nodes;
59
-
60
- tree.children_left = children_left + d;
61
- tree.children_right = children_right + d;
62
- tree.children_default = children_default + d;
63
- tree.features = features + d;
64
- tree.thresholds = thresholds + d;
65
- tree.values = values + d * num_outputs;
66
- tree.node_sample_weights = node_sample_weights + d;
67
- tree.max_depth = max_depth;
68
- tree.tree_limit = 1;
69
- tree.base_offset = base_offset;
70
- tree.max_nodes = max_nodes;
71
- tree.num_outputs = num_outputs;
72
- }
73
-
74
- bool is_leaf(unsigned pos)const {
75
- return children_left[pos] < 0;
76
- }
77
-
78
- void allocate(unsigned tree_limit_in, unsigned max_nodes_in, unsigned num_outputs_in) {
79
- tree_limit = tree_limit_in;
80
- max_nodes = max_nodes_in;
81
- num_outputs = num_outputs_in;
82
- children_left = new int[tree_limit * max_nodes];
83
- children_right = new int[tree_limit * max_nodes];
84
- children_default = new int[tree_limit * max_nodes];
85
- features = new int[tree_limit * max_nodes];
86
- thresholds = new tfloat[tree_limit * max_nodes];
87
- values = new tfloat[tree_limit * max_nodes * num_outputs];
88
- node_sample_weights = new tfloat[tree_limit * max_nodes];
89
- }
90
-
91
- void free() {
92
- delete[] children_left;
93
- delete[] children_right;
94
- delete[] children_default;
95
- delete[] features;
96
- delete[] thresholds;
97
- delete[] values;
98
- delete[] node_sample_weights;
99
- }
100
- };
101
-
102
- struct ExplanationDataset {
103
- tfloat *X;
104
- bool *X_missing;
105
- tfloat *y;
106
- tfloat *R;
107
- bool *R_missing;
108
- unsigned num_X;
109
- unsigned M;
110
- unsigned num_R;
111
-
112
- ExplanationDataset() {}
113
- ExplanationDataset(tfloat *X, bool *X_missing, tfloat *y, tfloat *R, bool *R_missing, unsigned num_X,
114
- unsigned M, unsigned num_R) :
115
- X(X), X_missing(X_missing), y(y), R(R), R_missing(R_missing), num_X(num_X), M(M), num_R(num_R) {}
116
-
117
- void get_x_instance(ExplanationDataset &instance, const unsigned i) const {
118
- instance.M = M;
119
- instance.X = X + i * M;
120
- instance.X_missing = X_missing + i * M;
121
- instance.num_X = 1;
122
- }
123
- };
124
-
125
-
126
- // data we keep about our decision path
127
- // note that pweight is included for convenience and is not tied with the other attributes
128
- // the pweight of the i'th path element is the permutation weight of paths with i-1 ones in them
129
- struct PathElement {
130
- int feature_index;
131
- tfloat zero_fraction;
132
- tfloat one_fraction;
133
- tfloat pweight;
134
- PathElement() {}
135
- PathElement(int i, tfloat z, tfloat o, tfloat w) :
136
- feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
137
- };
138
-
139
- inline tfloat logistic_transform(const tfloat margin, const tfloat y) {
140
- return 1 / (1 + exp(-margin));
141
- }
142
-
143
- inline tfloat logistic_nlogloss_transform(const tfloat margin, const tfloat y) {
144
- return log(1 + exp(margin)) - y * margin; // y is in {0, 1}
145
- }
146
-
147
- inline tfloat squared_loss_transform(const tfloat margin, const tfloat y) {
148
- return (margin - y) * (margin - y);
149
- }
150
-
151
- namespace MODEL_TRANSFORM {
152
- const unsigned identity = 0;
153
- const unsigned logistic = 1;
154
- const unsigned logistic_nlogloss = 2;
155
- const unsigned squared_loss = 3;
156
- }
157
-
158
- inline transform_f get_transform(unsigned model_transform) {
159
- transform_f transform = NULL;
160
- switch (model_transform) {
161
- case MODEL_TRANSFORM::logistic:
162
- transform = logistic_transform;
163
- break;
164
-
165
- case MODEL_TRANSFORM::logistic_nlogloss:
166
- transform = logistic_nlogloss_transform;
167
- break;
168
-
169
- case MODEL_TRANSFORM::squared_loss:
170
- transform = squared_loss_transform;
171
- break;
172
- }
173
-
174
- return transform;
175
- }
176
-
177
- inline tfloat *tree_predict(unsigned i, const TreeEnsemble &trees, const tfloat *x, const bool *x_missing) {
178
- const unsigned offset = i * trees.max_nodes;
179
- unsigned node = 0;
180
- while (true) {
181
- const unsigned pos = offset + node;
182
- const unsigned feature = trees.features[pos];
183
-
184
- // we hit a leaf so return a pointer to the values
185
- if (trees.is_leaf(pos)) {
186
- return trees.values + pos * trees.num_outputs;
187
- }
188
-
189
- // otherwise we are at an internal node and need to recurse
190
- if (x_missing[feature]) {
191
- node = trees.children_default[pos];
192
- } else if (x[feature] <= trees.thresholds[pos]) {
193
- node = trees.children_left[pos];
194
- } else {
195
- node = trees.children_right[pos];
196
- }
197
- }
198
- }
199
-
200
- inline void dense_tree_predict(tfloat *out, const TreeEnsemble &trees, const ExplanationDataset &data, unsigned model_transform) {
201
- tfloat *row_out = out;
202
- const tfloat *x = data.X;
203
- const bool *x_missing = data.X_missing;
204
-
205
- // see what transform (if any) we have
206
- transform_f transform = get_transform(model_transform);
207
-
208
- for (unsigned i = 0; i < data.num_X; ++i) {
209
-
210
- // add the base offset
211
- for (unsigned k = 0; k < trees.num_outputs; ++k) {
212
- row_out[k] += trees.base_offset[k];
213
- }
214
-
215
- // add the leaf values from each tree
216
- for (unsigned j = 0; j < trees.tree_limit; ++j) {
217
- const tfloat *leaf_value = tree_predict(j, trees, x, x_missing);
218
-
219
- for (unsigned k = 0; k < trees.num_outputs; ++k) {
220
- row_out[k] += leaf_value[k];
221
- }
222
- }
223
-
224
- // apply any needed transform
225
- if (transform != NULL) {
226
- const tfloat y_i = data.y == NULL ? 0 : data.y[i];
227
- for (unsigned k = 0; k < trees.num_outputs; ++k) {
228
- row_out[k] = transform(row_out[k], y_i);
229
- }
230
- }
231
-
232
- x += data.M;
233
- x_missing += data.M;
234
- row_out += trees.num_outputs;
235
- }
236
- }
237
-
238
- inline void tree_update_weights(unsigned i, TreeEnsemble &trees, const tfloat *x, const bool *x_missing) {
239
- const unsigned offset = i * trees.max_nodes;
240
- unsigned node = 0;
241
- while (true) {
242
- const unsigned pos = offset + node;
243
- const unsigned feature = trees.features[pos];
244
-
245
- // Record that a sample passed through this node
246
- trees.node_sample_weights[pos] += 1.0;
247
-
248
- // we hit a leaf so return a pointer to the values
249
- if (trees.children_left[pos] < 0) break;
250
-
251
- // otherwise we are at an internal node and need to recurse
252
- if (x_missing[feature]) {
253
- node = trees.children_default[pos];
254
- } else if (x[feature] <= trees.thresholds[pos]) {
255
- node = trees.children_left[pos];
256
- } else {
257
- node = trees.children_right[pos];
258
- }
259
- }
260
- }
261
-
262
- inline void dense_tree_update_weights(TreeEnsemble &trees, const ExplanationDataset &data) {
263
- const tfloat *x = data.X;
264
- const bool *x_missing = data.X_missing;
265
-
266
- for (unsigned i = 0; i < data.num_X; ++i) {
267
-
268
- // add the leaf values from each tree
269
- for (unsigned j = 0; j < trees.tree_limit; ++j) {
270
- tree_update_weights(j, trees, x, x_missing);
271
- }
272
-
273
- x += data.M;
274
- x_missing += data.M;
275
- }
276
- }
277
-
278
- inline void tree_saabas(tfloat *out, const TreeEnsemble &tree, const ExplanationDataset &data) {
279
- unsigned curr_node = 0;
280
- unsigned next_node = 0;
281
- while (true) {
282
-
283
- // we hit a leaf and are done
284
- if (tree.children_left[curr_node] < 0) return;
285
-
286
- // otherwise we are at an internal node and need to recurse
287
- const unsigned feature = tree.features[curr_node];
288
- if (data.X_missing[feature]) {
289
- next_node = tree.children_default[curr_node];
290
- } else if (data.X[feature] <= tree.thresholds[curr_node]) {
291
- next_node = tree.children_left[curr_node];
292
- } else {
293
- next_node = tree.children_right[curr_node];
294
- }
295
-
296
- // assign credit to this feature as the difference in values at the current node vs. the next node
297
- for (unsigned i = 0; i < tree.num_outputs; ++i) {
298
- out[feature * tree.num_outputs + i] += tree.values[next_node * tree.num_outputs + i] - tree.values[curr_node * tree.num_outputs + i];
299
- }
300
-
301
- curr_node = next_node;
302
- }
303
- }
304
-
305
- /**
306
- * This runs Tree SHAP with a per tree path conditional dependence assumption.
307
- */
308
- inline void dense_tree_saabas(tfloat *out_contribs, const TreeEnsemble& trees, const ExplanationDataset &data) {
309
- tfloat *instance_out_contribs;
310
- TreeEnsemble tree;
311
- ExplanationDataset instance;
312
-
313
- // build explanation for each sample
314
- for (unsigned i = 0; i < data.num_X; ++i) {
315
- instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
316
- data.get_x_instance(instance, i);
317
-
318
- // aggregate the effect of explaining each tree
319
- // (this works because of the linearity property of Shapley values)
320
- for (unsigned j = 0; j < trees.tree_limit; ++j) {
321
- trees.get_tree(tree, j);
322
- tree_saabas(instance_out_contribs, tree, instance);
323
- }
324
-
325
- // apply the base offset to the bias term
326
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
327
- instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j];
328
- }
329
- }
330
- }
331
-
332
-
333
- // extend our decision path with a fraction of one and zero extensions
334
- inline void extend_path(PathElement *unique_path, unsigned unique_depth,
335
- tfloat zero_fraction, tfloat one_fraction, int feature_index) {
336
- unique_path[unique_depth].feature_index = feature_index;
337
- unique_path[unique_depth].zero_fraction = zero_fraction;
338
- unique_path[unique_depth].one_fraction = one_fraction;
339
- unique_path[unique_depth].pweight = (unique_depth == 0 ? 1.0f : 0.0f);
340
- for (int i = unique_depth - 1; i >= 0; i--) {
341
- unique_path[i + 1].pweight += one_fraction * unique_path[i].pweight * (i + 1)
342
- / static_cast<tfloat>(unique_depth + 1);
343
- unique_path[i].pweight = zero_fraction * unique_path[i].pweight * (unique_depth - i)
344
- / static_cast<tfloat>(unique_depth + 1);
345
- }
346
- }
347
-
348
- // undo a previous extension of the decision path
349
- inline void unwind_path(PathElement *unique_path, unsigned unique_depth, unsigned path_index) {
350
- const tfloat one_fraction = unique_path[path_index].one_fraction;
351
- const tfloat zero_fraction = unique_path[path_index].zero_fraction;
352
- tfloat next_one_portion = unique_path[unique_depth].pweight;
353
-
354
- for (int i = unique_depth - 1; i >= 0; --i) {
355
- if (one_fraction != 0) {
356
- const tfloat tmp = unique_path[i].pweight;
357
- unique_path[i].pweight = next_one_portion * (unique_depth + 1)
358
- / static_cast<tfloat>((i + 1) * one_fraction);
359
- next_one_portion = tmp - unique_path[i].pweight * zero_fraction * (unique_depth - i)
360
- / static_cast<tfloat>(unique_depth + 1);
361
- } else {
362
- unique_path[i].pweight = (unique_path[i].pweight * (unique_depth + 1))
363
- / static_cast<tfloat>(zero_fraction * (unique_depth - i));
364
- }
365
- }
366
-
367
- for (unsigned i = path_index; i < unique_depth; ++i) {
368
- unique_path[i].feature_index = unique_path[i+1].feature_index;
369
- unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
370
- unique_path[i].one_fraction = unique_path[i+1].one_fraction;
371
- }
372
- }
373
-
374
- // determine what the total permutation weight would be if
375
- // we unwound a previous extension in the decision path
376
- inline tfloat unwound_path_sum(const PathElement *unique_path, unsigned unique_depth,
377
- unsigned path_index) {
378
- const tfloat one_fraction = unique_path[path_index].one_fraction;
379
- const tfloat zero_fraction = unique_path[path_index].zero_fraction;
380
- tfloat next_one_portion = unique_path[unique_depth].pweight;
381
- tfloat total = 0;
382
-
383
- if (one_fraction != 0) {
384
- for (int i = unique_depth - 1; i >= 0; --i) {
385
- const tfloat tmp = next_one_portion / static_cast<tfloat>((i + 1) * one_fraction);
386
- total += tmp;
387
- next_one_portion = unique_path[i].pweight - tmp * zero_fraction * (unique_depth - i);
388
- }
389
- } else {
390
- for (int i = unique_depth - 1; i >= 0; --i) {
391
- total += unique_path[i].pweight / (zero_fraction * (unique_depth - i));
392
- }
393
- }
394
- return total * (unique_depth + 1);
395
- }
396
-
397
- // recursive computation of SHAP values for a decision tree
398
- inline void tree_shap_recursive(const unsigned num_outputs, const int *children_left,
399
- const int *children_right,
400
- const int *children_default, const int *features,
401
- const tfloat *thresholds, const tfloat *values,
402
- const tfloat *node_sample_weight,
403
- const tfloat *x, const bool *x_missing, tfloat *phi,
404
- unsigned node_index, unsigned unique_depth,
405
- PathElement *parent_unique_path, tfloat parent_zero_fraction,
406
- tfloat parent_one_fraction, int parent_feature_index,
407
- int condition, unsigned condition_feature,
408
- tfloat condition_fraction) {
409
-
410
- // stop if we have no weight coming down to us
411
- if (condition_fraction == 0) return;
412
-
413
- // extend the unique path
414
- PathElement *unique_path = parent_unique_path + unique_depth + 1;
415
- std::copy(parent_unique_path, parent_unique_path + unique_depth + 1, unique_path);
416
-
417
- if (condition == 0 || condition_feature != static_cast<unsigned>(parent_feature_index)) {
418
- extend_path(unique_path, unique_depth, parent_zero_fraction,
419
- parent_one_fraction, parent_feature_index);
420
- }
421
- const unsigned split_index = features[node_index];
422
-
423
- // leaf node
424
- if (children_right[node_index] < 0) {
425
- for (unsigned i = 1; i <= unique_depth; ++i) {
426
- const tfloat w = unwound_path_sum(unique_path, unique_depth, i);
427
- const PathElement &el = unique_path[i];
428
- const unsigned phi_offset = el.feature_index * num_outputs;
429
- const unsigned values_offset = node_index * num_outputs;
430
- const tfloat scale = w * (el.one_fraction - el.zero_fraction) * condition_fraction;
431
- for (unsigned j = 0; j < num_outputs; ++j) {
432
- phi[phi_offset + j] += scale * values[values_offset + j];
433
- }
434
- }
435
-
436
- // internal node
437
- } else {
438
- // find which branch is "hot" (meaning x would follow it)
439
- unsigned hot_index = 0;
440
- if (x_missing[split_index]) {
441
- hot_index = children_default[node_index];
442
- } else if (x[split_index] <= thresholds[node_index]) {
443
- hot_index = children_left[node_index];
444
- } else {
445
- hot_index = children_right[node_index];
446
- }
447
- const unsigned cold_index = (static_cast<int>(hot_index) == children_left[node_index] ?
448
- children_right[node_index] : children_left[node_index]);
449
- const tfloat w = node_sample_weight[node_index];
450
- const tfloat hot_zero_fraction = node_sample_weight[hot_index] / w;
451
- const tfloat cold_zero_fraction = node_sample_weight[cold_index] / w;
452
- tfloat incoming_zero_fraction = 1;
453
- tfloat incoming_one_fraction = 1;
454
-
455
- // see if we have already split on this feature,
456
- // if so we undo that split so we can redo it for this node
457
- unsigned path_index = 0;
458
- for (; path_index <= unique_depth; ++path_index) {
459
- if (static_cast<unsigned>(unique_path[path_index].feature_index) == split_index) break;
460
- }
461
- if (path_index != unique_depth + 1) {
462
- incoming_zero_fraction = unique_path[path_index].zero_fraction;
463
- incoming_one_fraction = unique_path[path_index].one_fraction;
464
- unwind_path(unique_path, unique_depth, path_index);
465
- unique_depth -= 1;
466
- }
467
-
468
- // divide up the condition_fraction among the recursive calls
469
- tfloat hot_condition_fraction = condition_fraction;
470
- tfloat cold_condition_fraction = condition_fraction;
471
- if (condition > 0 && split_index == condition_feature) {
472
- cold_condition_fraction = 0;
473
- unique_depth -= 1;
474
- } else if (condition < 0 && split_index == condition_feature) {
475
- hot_condition_fraction *= hot_zero_fraction;
476
- cold_condition_fraction *= cold_zero_fraction;
477
- unique_depth -= 1;
478
- }
479
-
480
- tree_shap_recursive(
481
- num_outputs, children_left, children_right, children_default, features, thresholds, values,
482
- node_sample_weight, x, x_missing, phi, hot_index, unique_depth + 1, unique_path,
483
- hot_zero_fraction * incoming_zero_fraction, incoming_one_fraction,
484
- split_index, condition, condition_feature, hot_condition_fraction
485
- );
486
-
487
- tree_shap_recursive(
488
- num_outputs, children_left, children_right, children_default, features, thresholds, values,
489
- node_sample_weight, x, x_missing, phi, cold_index, unique_depth + 1, unique_path,
490
- cold_zero_fraction * incoming_zero_fraction, 0,
491
- split_index, condition, condition_feature, cold_condition_fraction
492
- );
493
- }
494
- }
495
-
496
- inline int compute_expectations(TreeEnsemble &tree, int i = 0, int depth = 0) {
497
- unsigned max_depth = 0;
498
-
499
- if (tree.children_right[i] >= 0) {
500
- const unsigned li = tree.children_left[i];
501
- const unsigned ri = tree.children_right[i];
502
- const unsigned depth_left = compute_expectations(tree, li, depth + 1);
503
- const unsigned depth_right = compute_expectations(tree, ri, depth + 1);
504
- const tfloat left_weight = tree.node_sample_weights[li];
505
- const tfloat right_weight = tree.node_sample_weights[ri];
506
- const unsigned li_offset = li * tree.num_outputs;
507
- const unsigned ri_offset = ri * tree.num_outputs;
508
- const unsigned i_offset = i * tree.num_outputs;
509
- for (unsigned j = 0; j < tree.num_outputs; ++j) {
510
- if ((left_weight == 0) && (right_weight == 0)) {
511
- tree.values[i_offset + j] = 0.0;
512
- } else {
513
- const tfloat v = (left_weight * tree.values[li_offset + j] + right_weight * tree.values[ri_offset + j]) / (left_weight + right_weight);
514
- tree.values[i_offset + j] = v;
515
- }
516
- }
517
- max_depth = std::max(depth_left, depth_right) + 1;
518
- }
519
-
520
- if (depth == 0) tree.max_depth = max_depth;
521
-
522
- return max_depth;
523
- }
524
-
525
- inline void tree_shap(const TreeEnsemble& tree, const ExplanationDataset &data,
526
- tfloat *out_contribs, int condition, unsigned condition_feature) {
527
-
528
- // update the reference value with the expected value of the tree's predictions
529
- if (condition == 0) {
530
- for (unsigned j = 0; j < tree.num_outputs; ++j) {
531
- out_contribs[data.M * tree.num_outputs + j] += tree.values[j];
532
- }
533
- }
534
-
535
- // Pre-allocate space for the unique path data
536
- const unsigned maxd = tree.max_depth + 2; // need a bit more space than the max depth
537
- PathElement *unique_path_data = new PathElement[(maxd * (maxd + 1)) / 2];
538
-
539
- tree_shap_recursive(
540
- tree.num_outputs, tree.children_left, tree.children_right, tree.children_default,
541
- tree.features, tree.thresholds, tree.values, tree.node_sample_weights, data.X,
542
- data.X_missing, out_contribs, 0, 0, unique_path_data, 1, 1, -1, condition,
543
- condition_feature, 1
544
- );
545
-
546
- delete[] unique_path_data;
547
- }
548
-
549
-
550
- inline unsigned build_merged_tree_recursive(TreeEnsemble &out_tree, const TreeEnsemble &trees,
551
- const tfloat *data, const bool *data_missing, int *data_inds,
552
- const unsigned num_background_data_inds, unsigned num_data_inds,
553
- unsigned M, unsigned row = 0, unsigned i = 0, unsigned pos = 0,
554
- tfloat *leaf_value = NULL) {
555
- //tfloat new_leaf_value[trees.num_outputs];
556
- tfloat *new_leaf_value = (tfloat *) alloca(sizeof(tfloat) * trees.num_outputs); // allocate on the stack
557
- unsigned row_offset = row * trees.max_nodes;
558
-
559
- // we have hit a terminal leaf!!!
560
- if (trees.children_left[row_offset + i] < 0 && row + 1 == trees.tree_limit) {
561
-
562
- // create the leaf node
563
- const tfloat *vals = trees.values + (row * trees.max_nodes + i) * trees.num_outputs;
564
- if (leaf_value == NULL) {
565
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
566
- out_tree.values[pos * trees.num_outputs + j] = vals[j];
567
- }
568
- } else {
569
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
570
- out_tree.values[pos * trees.num_outputs + j] = leaf_value[j] + vals[j];
571
- }
572
- }
573
- out_tree.children_left[pos] = -1;
574
- out_tree.children_right[pos] = -1;
575
- out_tree.children_default[pos] = -1;
576
- out_tree.features[pos] = -1;
577
- out_tree.thresholds[pos] = 0;
578
- out_tree.node_sample_weights[pos] = num_background_data_inds;
579
-
580
- return pos;
581
- }
582
-
583
- // we hit an intermediate leaf (so just add the value to our accumulator and move to the next tree)
584
- if (trees.children_left[row_offset + i] < 0) {
585
-
586
- // accumulate the value of this original leaf so it will land on all eventual terminal leaves
587
- const tfloat *vals = trees.values + (row * trees.max_nodes + i) * trees.num_outputs;
588
- if (leaf_value == NULL) {
589
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
590
- new_leaf_value[j] = vals[j];
591
- }
592
- } else {
593
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
594
- new_leaf_value[j] = leaf_value[j] + vals[j];
595
- }
596
- }
597
- leaf_value = new_leaf_value;
598
-
599
- // move forward to the next tree
600
- row += 1;
601
- row_offset += trees.max_nodes;
602
- i = 0;
603
- }
604
-
605
- // split the data inds by this node's threshold
606
- const tfloat t = trees.thresholds[row_offset + i];
607
- const int f = trees.features[row_offset + i];
608
- const bool right_default = trees.children_default[row_offset + i] == trees.children_right[row_offset + i];
609
- int low_ptr = 0;
610
- int high_ptr = num_data_inds - 1;
611
- unsigned num_left_background_data_inds = 0;
612
- int low_data_ind;
613
- while (low_ptr <= high_ptr) {
614
- low_data_ind = data_inds[low_ptr];
615
- const int data_ind = std::abs(low_data_ind) * M + f;
616
- const bool is_missing = data_missing[data_ind];
617
- if ((!is_missing && data[data_ind] > t) || (right_default && is_missing)) {
618
- data_inds[low_ptr] = data_inds[high_ptr];
619
- data_inds[high_ptr] = low_data_ind;
620
- high_ptr -= 1;
621
- } else {
622
- if (low_data_ind >= 0) ++num_left_background_data_inds; // negative data_inds are not background samples
623
- low_ptr += 1;
624
- }
625
- }
626
- int *left_data_inds = data_inds;
627
- const unsigned num_left_data_inds = low_ptr;
628
- int *right_data_inds = data_inds + low_ptr;
629
- const unsigned num_right_data_inds = num_data_inds - num_left_data_inds;
630
- const unsigned num_right_background_data_inds = num_background_data_inds - num_left_background_data_inds;
631
-
632
- // all the data went right, so we skip creating this node and just recurse right
633
- if (num_left_data_inds == 0) {
634
- return build_merged_tree_recursive(
635
- out_tree, trees, data, data_missing, data_inds,
636
- num_background_data_inds, num_data_inds, M, row,
637
- trees.children_right[row_offset + i], pos, leaf_value
638
- );
639
-
640
- // all the data went left, so we skip creating this node and just recurse left
641
- } else if (num_right_data_inds == 0) {
642
- return build_merged_tree_recursive(
643
- out_tree, trees, data, data_missing, data_inds,
644
- num_background_data_inds, num_data_inds, M, row,
645
- trees.children_left[row_offset + i], pos, leaf_value
646
- );
647
-
648
- // data went both ways so we create this node and recurse down both paths
649
- } else {
650
-
651
- // build the left subtree
652
- const unsigned new_pos = build_merged_tree_recursive(
653
- out_tree, trees, data, data_missing, left_data_inds,
654
- num_left_background_data_inds, num_left_data_inds, M, row,
655
- trees.children_left[row_offset + i], pos + 1, leaf_value
656
- );
657
-
658
- // fill in the data for this node
659
- out_tree.children_left[pos] = pos + 1;
660
- out_tree.children_right[pos] = new_pos + 1;
661
- if (trees.children_left[row_offset + i] == trees.children_default[row_offset + i]) {
662
- out_tree.children_default[pos] = pos + 1;
663
- } else {
664
- out_tree.children_default[pos] = new_pos + 1;
665
- }
666
-
667
- out_tree.features[pos] = trees.features[row_offset + i];
668
- out_tree.thresholds[pos] = trees.thresholds[row_offset + i];
669
- out_tree.node_sample_weights[pos] = num_background_data_inds;
670
-
671
- // build the right subtree
672
- return build_merged_tree_recursive(
673
- out_tree, trees, data, data_missing, right_data_inds,
674
- num_right_background_data_inds, num_right_data_inds, M, row,
675
- trees.children_right[row_offset + i], new_pos + 1, leaf_value
676
- );
677
- }
678
- }
679
-
680
-
681
- inline void build_merged_tree(TreeEnsemble &out_tree, const ExplanationDataset &data, const TreeEnsemble &trees) {
682
-
683
- // create a joint data matrix from both X and R matrices
684
- tfloat *joined_data = new tfloat[(data.num_X + data.num_R) * data.M];
685
- std::copy(data.X, data.X + data.num_X * data.M, joined_data);
686
- std::copy(data.R, data.R + data.num_R * data.M, joined_data + data.num_X * data.M);
687
- bool *joined_data_missing = new bool[(data.num_X + data.num_R) * data.M];
688
- std::copy(data.X_missing, data.X_missing + data.num_X * data.M, joined_data_missing);
689
- std::copy(data.R_missing, data.R_missing + data.num_R * data.M, joined_data_missing + data.num_X * data.M);
690
-
691
- // create an starting array of data indexes we will recursively sort
692
- int *data_inds = new int[data.num_X + data.num_R];
693
- for (unsigned i = 0; i < data.num_X; ++i) data_inds[i] = i;
694
- for (unsigned i = data.num_X; i < data.num_X + data.num_R; ++i) {
695
- data_inds[i] = -i; // a negative index means it won't be recorded as a background sample
696
- }
697
-
698
- build_merged_tree_recursive(
699
- out_tree, trees, joined_data, joined_data_missing, data_inds, data.num_R,
700
- data.num_X + data.num_R, data.M
701
- );
702
-
703
- delete[] joined_data;
704
- delete[] joined_data_missing;
705
- delete[] data_inds;
706
- }
707
-
708
-
709
- // Independent Tree SHAP functions below here
710
- // ------------------------------------------
711
- struct Node {
712
- short cl, cr, cd, pnode, feat, pfeat; // uint_16
713
- float thres, value;
714
- char from_flag;
715
- };
716
-
717
- #define FROM_NEITHER 0
718
- #define FROM_X_NOT_R 1
719
- #define FROM_R_NOT_X 2
720
-
721
- // https://www.geeksforgeeks.org/space-and-time-efficient-binomial-coefficient/
722
- inline int bin_coeff(int n, int k) {
723
- int res = 1;
724
- if (k > n - k)
725
- k = n - k;
726
- for (int i = 0; i < k; ++i) {
727
- res *= (n - i);
728
- res /= (i + 1);
729
- }
730
- return res;
731
- }
732
-
733
- // note this only handles single output models, so multi-output models get explained using multiple passes
734
- inline void tree_shap_indep(const unsigned max_depth, const unsigned num_feats,
735
- const unsigned num_nodes, const tfloat *x,
736
- const bool *x_missing, const tfloat *r,
737
- const bool *r_missing, tfloat *out_contribs,
738
- float *pos_lst, float *neg_lst, signed short *feat_hist,
739
- float *memoized_weights, int *node_stack, Node *mytree) {
740
-
741
- // const bool DEBUG = true;
742
- // ofstream myfile;
743
- // if (DEBUG) {
744
- // myfile.open ("/homes/gws/hughchen/shap/out.txt",fstream::app);
745
- // myfile << "Entering tree_shap_indep\n";
746
- // }
747
- int ns_ctr = 0;
748
- std::fill_n(feat_hist, num_feats, 0);
749
- short node = 0, feat, cl, cr, cd, pnode, pfeat = -1;
750
- short next_xnode = -1, next_rnode = -1;
751
- short next_node = -1, from_child = -1;
752
- float thres, pos_x = 0, neg_x = 0, pos_r = 0, neg_r = 0;
753
- char from_flag;
754
- unsigned M = 0, N = 0;
755
-
756
- Node curr_node = mytree[node];
757
- feat = curr_node.feat;
758
- thres = curr_node.thres;
759
- cl = curr_node.cl;
760
- cr = curr_node.cr;
761
- cd = curr_node.cd;
762
-
763
- // short circuit when this is a stump tree (with no splits)
764
- if (cl < 0) {
765
- out_contribs[num_feats] += curr_node.value;
766
- return;
767
- }
768
-
769
- // if (DEBUG) {
770
- // myfile << "\nNode: " << node << "\n";
771
- // myfile << "x[feat]: " << x[feat] << ", r[feat]: " << r[feat] << "\n";
772
- // myfile << "thres: " << thres << "\n";
773
- // }
774
-
775
- if (x_missing[feat]) {
776
- next_xnode = cd;
777
- } else if (x[feat] > thres) {
778
- next_xnode = cr;
779
- } else if (x[feat] <= thres) {
780
- next_xnode = cl;
781
- }
782
-
783
- if (r_missing[feat]) {
784
- next_rnode = cd;
785
- } else if (r[feat] > thres) {
786
- next_rnode = cr;
787
- } else if (r[feat] <= thres) {
788
- next_rnode = cl;
789
- }
790
-
791
- if (next_xnode != next_rnode) {
792
- mytree[next_xnode].from_flag = FROM_X_NOT_R;
793
- mytree[next_rnode].from_flag = FROM_R_NOT_X;
794
- } else {
795
- mytree[next_xnode].from_flag = FROM_NEITHER;
796
- }
797
-
798
- // Check if x and r go the same way
799
- if (next_xnode == next_rnode) {
800
- next_node = next_xnode;
801
- }
802
-
803
- // If not, go left
804
- if (next_node < 0) {
805
- next_node = cl;
806
- if (next_rnode == next_node) { // rpath
807
- N = N+1;
808
- feat_hist[feat] -= 1;
809
- } else if (next_xnode == next_node) { // xpath
810
- M = M+1;
811
- N = N+1;
812
- feat_hist[feat] += 1;
813
- }
814
- }
815
- node_stack[ns_ctr] = node;
816
- ns_ctr += 1;
817
- while (true) {
818
- node = next_node;
819
- curr_node = mytree[node];
820
- feat = curr_node.feat;
821
- thres = curr_node.thres;
822
- cl = curr_node.cl;
823
- cr = curr_node.cr;
824
- cd = curr_node.cd;
825
- pnode = curr_node.pnode;
826
- pfeat = curr_node.pfeat;
827
- from_flag = curr_node.from_flag;
828
-
829
-
830
-
831
- // if (DEBUG) {
832
- // myfile << "\nNode: " << node << "\n";
833
- // myfile << "N: " << N << ", M: " << M << "\n";
834
- // myfile << "from_flag==FROM_X_NOT_R: " << (from_flag==FROM_X_NOT_R) << "\n";
835
- // myfile << "from_flag==FROM_R_NOT_X: " << (from_flag==FROM_R_NOT_X) << "\n";
836
- // myfile << "from_flag==FROM_NEITHER: " << (from_flag==FROM_NEITHER) << "\n";
837
- // myfile << "feat_hist[feat]: " << feat_hist[feat] << "\n";
838
- // }
839
-
840
- // At a leaf
841
- if (cl < 0) {
842
- // if (DEBUG) {
843
- // myfile << "At a leaf\n";
844
- // }
845
-
846
- if (M == 0) {
847
- out_contribs[num_feats] += mytree[node].value;
848
- }
849
-
850
- // Currently assuming a single output
851
- if (N != 0) {
852
- if (M != 0) {
853
- pos_lst[node] = mytree[node].value * memoized_weights[N + max_depth * (M-1)];
854
- }
855
- if (M != N) {
856
- neg_lst[node] = -mytree[node].value * memoized_weights[N + max_depth * M];
857
- }
858
- }
859
- // if (DEBUG) {
860
- // myfile << "pos_lst[node]: " << pos_lst[node] << "\n";
861
- // myfile << "neg_lst[node]: " << neg_lst[node] << "\n";
862
- // }
863
- // Pop from node_stack
864
- ns_ctr -= 1;
865
- next_node = node_stack[ns_ctr];
866
- from_child = node;
867
- // Unwind
868
- if (feat_hist[pfeat] > 0) {
869
- feat_hist[pfeat] -= 1;
870
- } else if (feat_hist[pfeat] < 0) {
871
- feat_hist[pfeat] += 1;
872
- }
873
- if (feat_hist[pfeat] == 0) {
874
- if (from_flag == FROM_X_NOT_R) {
875
- N = N-1;
876
- M = M-1;
877
- } else if (from_flag == FROM_R_NOT_X) {
878
- N = N-1;
879
- }
880
- }
881
- continue;
882
- }
883
-
884
- const bool x_right = x[feat] > thres;
885
- const bool r_right = r[feat] > thres;
886
-
887
- if (x_missing[feat]) {
888
- next_xnode = cd;
889
- } else if (x_right) {
890
- next_xnode = cr;
891
- } else if (!x_right) {
892
- next_xnode = cl;
893
- }
894
-
895
- if (r_missing[feat]) {
896
- next_rnode = cd;
897
- } else if (r_right) {
898
- next_rnode = cr;
899
- } else if (!r_right) {
900
- next_rnode = cl;
901
- }
902
-
903
- if (next_xnode >= 0) {
904
- if (next_xnode != next_rnode) {
905
- mytree[next_xnode].from_flag = FROM_X_NOT_R;
906
- mytree[next_rnode].from_flag = FROM_R_NOT_X;
907
- } else {
908
- mytree[next_xnode].from_flag = FROM_NEITHER;
909
- }
910
- }
911
-
912
- // Arriving at node from parent
913
- if (from_child == -1) {
914
- // if (DEBUG) {
915
- // myfile << "Arriving at node from parent\n";
916
- // }
917
- node_stack[ns_ctr] = node;
918
- ns_ctr += 1;
919
- next_node = -1;
920
-
921
- // if (DEBUG) {
922
- // myfile << "feat_hist[feat]" << feat_hist[feat] << "\n";
923
- // }
924
- // Feature is set upstream
925
- if (feat_hist[feat] > 0) {
926
- next_node = next_xnode;
927
- feat_hist[feat] += 1;
928
- } else if (feat_hist[feat] < 0) {
929
- next_node = next_rnode;
930
- feat_hist[feat] -= 1;
931
- }
932
-
933
- // x and r go the same way
934
- if (next_node < 0) {
935
- if (next_xnode == next_rnode) {
936
- next_node = next_xnode;
937
- }
938
- }
939
-
940
- // Go down one path
941
- if (next_node >= 0) {
942
- continue;
943
- }
944
-
945
- // Go down both paths, but go left first
946
- next_node = cl;
947
- if (next_rnode == next_node) {
948
- N = N+1;
949
- feat_hist[feat] -= 1;
950
- } else if (next_xnode == next_node) {
951
- M = M+1;
952
- N = N+1;
953
- feat_hist[feat] += 1;
954
- }
955
- from_child = -1;
956
- continue;
957
- }
958
-
959
- // Arriving at node from child
960
- if (from_child != -1) {
961
- // if (DEBUG) {
962
- // myfile << "Arriving at node from child\n";
963
- // }
964
- next_node = -1;
965
- // Check if we should unroll immediately
966
- if ((next_rnode == next_xnode) || (feat_hist[feat] != 0)) {
967
- next_node = pnode;
968
- }
969
-
970
- // Came from a single path, so unroll
971
- if (next_node >= 0) {
972
- // if (DEBUG) {
973
- // myfile << "Came from a single path, so unroll\n";
974
- // }
975
- // At the root node
976
- if (node == 0) {
977
- break;
978
- }
979
- // Update and unroll
980
- pos_lst[node] = pos_lst[from_child];
981
- neg_lst[node] = neg_lst[from_child];
982
-
983
- // if (DEBUG) {
984
- // myfile << "pos_lst[node]: " << pos_lst[node] << "\n";
985
- // myfile << "neg_lst[node]: " << neg_lst[node] << "\n";
986
- // }
987
- from_child = node;
988
- ns_ctr -= 1;
989
-
990
- // Unwind
991
- if (feat_hist[pfeat] > 0) {
992
- feat_hist[pfeat] -= 1;
993
- } else if (feat_hist[pfeat] < 0) {
994
- feat_hist[pfeat] += 1;
995
- }
996
- if (feat_hist[pfeat] == 0) {
997
- if (from_flag == FROM_X_NOT_R) {
998
- N = N-1;
999
- M = M-1;
1000
- } else if (from_flag == FROM_R_NOT_X) {
1001
- N = N-1;
1002
- }
1003
- }
1004
- continue;
1005
- // Go right - Arriving from the left child
1006
- } else if (from_child == cl) {
1007
- // if (DEBUG) {
1008
- // myfile << "Go right - Arriving from the left child\n";
1009
- // }
1010
- node_stack[ns_ctr] = node;
1011
- ns_ctr += 1;
1012
- next_node = cr;
1013
- if (next_xnode == next_node) {
1014
- M = M+1;
1015
- N = N+1;
1016
- feat_hist[feat] += 1;
1017
- } else if (next_rnode == next_node) {
1018
- N = N+1;
1019
- feat_hist[feat] -= 1;
1020
- }
1021
- from_child = -1;
1022
- continue;
1023
- // Compute stuff and unroll - Arriving from the right child
1024
- } else if (from_child == cr) {
1025
- // if (DEBUG) {
1026
- // myfile << "Compute stuff and unroll - Arriving from the right child\n";
1027
- // }
1028
- pos_x = 0;
1029
- neg_x = 0;
1030
- pos_r = 0;
1031
- neg_r = 0;
1032
- if ((next_xnode == cr) && (next_rnode == cl)) {
1033
- pos_x = pos_lst[cr];
1034
- neg_x = neg_lst[cr];
1035
- pos_r = pos_lst[cl];
1036
- neg_r = neg_lst[cl];
1037
- } else if ((next_xnode == cl) && (next_rnode == cr)) {
1038
- pos_x = pos_lst[cl];
1039
- neg_x = neg_lst[cl];
1040
- pos_r = pos_lst[cr];
1041
- neg_r = neg_lst[cr];
1042
- }
1043
- // out_contribs needs to have been initialized as all zeros
1044
- // if (pos_x + neg_r != 0) {
1045
- // std::cout << "val " << pos_x + neg_r << "\n";
1046
- // }
1047
- out_contribs[feat] += pos_x + neg_r;
1048
- pos_lst[node] = pos_x + pos_r;
1049
- neg_lst[node] = neg_x + neg_r;
1050
-
1051
- // if (DEBUG) {
1052
- // myfile << "out_contribs[feat]: " << out_contribs[feat] << "\n";
1053
- // myfile << "pos_lst[node]: " << pos_lst[node] << "\n";
1054
- // myfile << "neg_lst[node]: " << neg_lst[node] << "\n";
1055
- // }
1056
-
1057
- // Check if at root
1058
- if (node == 0) {
1059
- break;
1060
- }
1061
-
1062
- // Pop
1063
- ns_ctr -= 1;
1064
- next_node = node_stack[ns_ctr];
1065
- from_child = node;
1066
-
1067
- // Unwind
1068
- if (feat_hist[pfeat] > 0) {
1069
- feat_hist[pfeat] -= 1;
1070
- } else if (feat_hist[pfeat] < 0) {
1071
- feat_hist[pfeat] += 1;
1072
- }
1073
- if (feat_hist[pfeat] == 0) {
1074
- if (from_flag == FROM_X_NOT_R) {
1075
- N = N-1;
1076
- M = M-1;
1077
- } else if (from_flag == FROM_R_NOT_X) {
1078
- N = N-1;
1079
- }
1080
- }
1081
- continue;
1082
- }
1083
- }
1084
- }
1085
- // if (DEBUG) {
1086
- // myfile.close();
1087
- // }
1088
- }
1089
-
1090
-
1091
- inline void print_progress_bar(tfloat &last_print, tfloat start_time, unsigned i, unsigned total_count) {
1092
- const tfloat elapsed_seconds = difftime(time(NULL), start_time);
1093
-
1094
- if (elapsed_seconds > 10 && elapsed_seconds - last_print > 0.5) {
1095
- const tfloat fraction = static_cast<tfloat>(i) / total_count;
1096
- const double total_seconds = elapsed_seconds / fraction;
1097
- last_print = elapsed_seconds;
1098
-
1099
- PySys_WriteStderr(
1100
- "\r%3.0f%%|%.*s%.*s| %d/%d [%02d:%02d<%02d:%02d] ",
1101
- fraction * 100, int(0.5 + fraction*20), "===================",
1102
- 20-int(0.5 + fraction*20), " ",
1103
- i, total_count,
1104
- int(elapsed_seconds/60), int(elapsed_seconds) % 60,
1105
- int((total_seconds - elapsed_seconds)/60), int(total_seconds - elapsed_seconds) % 60
1106
- );
1107
-
1108
- // Get handle to python stderr file and flush it (https://mail.python.org/pipermail/python-list/2004-November/294912.html)
1109
- PyObject *pyStderr = PySys_GetObject("stderr");
1110
- if (pyStderr) {
1111
- PyObject *result = PyObject_CallMethod(pyStderr, "flush", NULL);
1112
- Py_XDECREF(result);
1113
- }
1114
- }
1115
- }
1116
-
1117
- /**
1118
- * Runs Tree SHAP with feature independence assumptions on dense data.
1119
- */
1120
- inline void dense_independent(const TreeEnsemble& trees, const ExplanationDataset &data,
1121
- tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
1122
-
1123
- // reformat the trees for faster access
1124
- Node *node_trees = new Node[trees.tree_limit * trees.max_nodes];
1125
- for (unsigned i = 0; i < trees.tree_limit; ++i) {
1126
- Node *node_tree = node_trees + i * trees.max_nodes;
1127
- for (unsigned j = 0; j < trees.max_nodes; ++j) {
1128
- const unsigned en_ind = i * trees.max_nodes + j;
1129
- node_tree[j].cl = trees.children_left[en_ind];
1130
- node_tree[j].cr = trees.children_right[en_ind];
1131
- node_tree[j].cd = trees.children_default[en_ind];
1132
- if (j == 0) {
1133
- node_tree[j].pnode = 0;
1134
- }
1135
- if (trees.children_left[en_ind] >= 0) { // relies on all unused entries having negative values in them
1136
- node_tree[trees.children_left[en_ind]].pnode = j;
1137
- node_tree[trees.children_left[en_ind]].pfeat = trees.features[en_ind];
1138
- }
1139
- if (trees.children_right[en_ind] >= 0) { // relies on all unused entries having negative values in them
1140
- node_tree[trees.children_right[en_ind]].pnode = j;
1141
- node_tree[trees.children_right[en_ind]].pfeat = trees.features[en_ind];
1142
- }
1143
-
1144
- node_tree[j].thres = trees.thresholds[en_ind];
1145
- node_tree[j].feat = trees.features[en_ind];
1146
- }
1147
- }
1148
-
1149
- // preallocate arrays needed by the algorithm
1150
- float *pos_lst = new float[trees.max_nodes];
1151
- float *neg_lst = new float[trees.max_nodes];
1152
- int *node_stack = new int[(unsigned) trees.max_depth];
1153
- signed short *feat_hist = new signed short[data.M];
1154
- tfloat *tmp_out_contribs = new tfloat[(data.M + 1)];
1155
-
1156
- // precompute all the weight coefficients
1157
- float *memoized_weights = new float[(trees.max_depth+1) * (trees.max_depth+1)];
1158
- for (unsigned n = 0; n <= trees.max_depth; ++n) {
1159
- for (unsigned m = 0; m <= trees.max_depth; ++m) {
1160
- memoized_weights[n + trees.max_depth * m] = 1.0 / (n * bin_coeff(n-1, m));
1161
- }
1162
- }
1163
-
1164
- // compute the explanations for each sample
1165
- tfloat *instance_out_contribs;
1166
- tfloat rescale_factor = 1.0;
1167
- tfloat margin_x = 0;
1168
- tfloat margin_r = 0;
1169
- time_t start_time = time(NULL);
1170
- tfloat last_print = 0;
1171
- for (unsigned oind = 0; oind < trees.num_outputs; ++oind) {
1172
- // set the values in the reformatted tree to the current output index
1173
- for (unsigned i = 0; i < trees.tree_limit; ++i) {
1174
- Node *node_tree = node_trees + i * trees.max_nodes;
1175
- for (unsigned j = 0; j < trees.max_nodes; ++j) {
1176
- const unsigned en_ind = i * trees.max_nodes + j;
1177
- node_tree[j].value = trees.values[en_ind * trees.num_outputs + oind];
1178
- }
1179
- }
1180
-
1181
- // loop over all the samples
1182
- for (unsigned i = 0; i < data.num_X; ++i) {
1183
- const tfloat *x = data.X + i * data.M;
1184
- const bool *x_missing = data.X_missing + i * data.M;
1185
- instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
1186
- const tfloat y_i = data.y == NULL ? 0 : data.y[i];
1187
-
1188
- print_progress_bar(last_print, start_time, oind * data.num_X + i, data.num_X * trees.num_outputs);
1189
-
1190
- // compute the model's margin output for x
1191
- if (transform != NULL) {
1192
- margin_x = trees.base_offset[oind];
1193
- for (unsigned k = 0; k < trees.tree_limit; ++k) {
1194
- margin_x += tree_predict(k, trees, x, x_missing)[oind];
1195
- }
1196
- }
1197
-
1198
- for (unsigned j = 0; j < data.num_R; ++j) {
1199
- const tfloat *r = data.R + j * data.M;
1200
- const bool *r_missing = data.R_missing + j * data.M;
1201
- std::fill_n(tmp_out_contribs, (data.M + 1), 0);
1202
-
1203
- // compute the model's margin output for r
1204
- if (transform != NULL) {
1205
- margin_r = trees.base_offset[oind];
1206
- for (unsigned k = 0; k < trees.tree_limit; ++k) {
1207
- margin_r += tree_predict(k, trees, r, r_missing)[oind];
1208
- }
1209
- }
1210
-
1211
- for (unsigned k = 0; k < trees.tree_limit; ++k) {
1212
- tree_shap_indep(
1213
- trees.max_depth, data.M, trees.max_nodes, x, x_missing, r, r_missing,
1214
- tmp_out_contribs, pos_lst, neg_lst, feat_hist, memoized_weights,
1215
- node_stack, node_trees + k * trees.max_nodes
1216
- );
1217
- }
1218
-
1219
- // compute the rescale factor
1220
- if (transform != NULL) {
1221
- if (margin_x == margin_r) {
1222
- rescale_factor = 1.0;
1223
- } else {
1224
- rescale_factor = (*transform)(margin_x, y_i) - (*transform)(margin_r, y_i);
1225
- rescale_factor /= margin_x - margin_r;
1226
- }
1227
- }
1228
-
1229
- // add the effect of the current reference to our running total
1230
- // this is where we can do per reference scaling for non-linear transformations
1231
- for (unsigned k = 0; k < data.M; ++k) {
1232
- instance_out_contribs[k * trees.num_outputs + oind] += tmp_out_contribs[k] * rescale_factor;
1233
- }
1234
-
1235
- // Add the base offset
1236
- if (transform != NULL) {
1237
- instance_out_contribs[data.M * trees.num_outputs + oind] += (*transform)(trees.base_offset[oind] + tmp_out_contribs[data.M], 0);
1238
- } else {
1239
- instance_out_contribs[data.M * trees.num_outputs + oind] += trees.base_offset[oind] + tmp_out_contribs[data.M];
1240
- }
1241
- }
1242
-
1243
- // average the results over all the references.
1244
- for (unsigned j = 0; j < (data.M + 1); ++j) {
1245
- instance_out_contribs[j * trees.num_outputs + oind] /= data.num_R;
1246
- }
1247
-
1248
- // apply the base offset to the bias term
1249
- // for (unsigned j = 0; j < trees.num_outputs; ++j) {
1250
- // instance_out_contribs[data.M * trees.num_outputs + j] += (*transform)(trees.base_offset[j], 0);
1251
- // }
1252
- }
1253
- }
1254
-
1255
- delete[] tmp_out_contribs;
1256
- delete[] node_trees;
1257
- delete[] pos_lst;
1258
- delete[] neg_lst;
1259
- delete[] node_stack;
1260
- delete[] feat_hist;
1261
- delete[] memoized_weights;
1262
- }
1263
-
1264
-
1265
- /**
1266
- * This runs Tree SHAP with a per tree path conditional dependence assumption.
1267
- */
1268
- inline void dense_tree_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
1269
- tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
1270
- tfloat *instance_out_contribs;
1271
- TreeEnsemble tree;
1272
- ExplanationDataset instance;
1273
-
1274
- // build explanation for each sample
1275
- for (unsigned i = 0; i < data.num_X; ++i) {
1276
- instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
1277
- data.get_x_instance(instance, i);
1278
-
1279
- // aggregate the effect of explaining each tree
1280
- // (this works because of the linearity property of Shapley values)
1281
- for (unsigned j = 0; j < trees.tree_limit; ++j) {
1282
- trees.get_tree(tree, j);
1283
- tree_shap(tree, instance, instance_out_contribs, 0, 0);
1284
- }
1285
-
1286
- // apply the base offset to the bias term
1287
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
1288
- instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j];
1289
- }
1290
- }
1291
- }
1292
-
1293
- // phi = np.zeros((self._current_X.shape[1] + 1, self._current_X.shape[1] + 1, self.n_outputs))
1294
- // phi_diag = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
1295
- // for t in range(self.tree_limit):
1296
- // self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_diag)
1297
- // for j in self.trees[t].unique_features:
1298
- // phi_on = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
1299
- // phi_off = np.zeros((self._current_X.shape[1] + 1, self.n_outputs))
1300
- // self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_on, 1, j)
1301
- // self.tree_shap(self.trees[t], self._current_X[i,:], self._current_x_missing, phi_off, -1, j)
1302
- // phi[j] += np.true_divide(np.subtract(phi_on,phi_off),2.0)
1303
- // phi_diag[j] -= np.sum(np.true_divide(np.subtract(phi_on,phi_off),2.0))
1304
- // for j in range(self._current_X.shape[1]+1):
1305
- // phi[j][j] = phi_diag[j]
1306
- // phi /= self.tree_limit
1307
- // return phi
1308
-
1309
- inline void dense_tree_interactions_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
1310
- tfloat *out_contribs,
1311
- tfloat transform(const tfloat, const tfloat)) {
1312
-
1313
- // build a list of all the unique features in each tree
1314
- int amount_of_unique_features = min(data.M, trees.max_nodes);
1315
- int *unique_features = new int[trees.tree_limit * amount_of_unique_features];
1316
- std::fill(unique_features, unique_features + trees.tree_limit * amount_of_unique_features, -1);
1317
- for (unsigned j = 0; j < trees.tree_limit; ++j) {
1318
- const int *features_row = trees.features + j * trees.max_nodes;
1319
- int *unique_features_row = unique_features + j * amount_of_unique_features;
1320
- for (unsigned k = 0; k < trees.max_nodes; ++k) {
1321
- for (unsigned l = 0; l < amount_of_unique_features; ++l) {
1322
- if (features_row[k] == unique_features_row[l]) break;
1323
- if (unique_features_row[l] < 0) {
1324
- unique_features_row[l] = features_row[k];
1325
- break;
1326
- }
1327
- }
1328
- }
1329
- }
1330
-
1331
- // build an interaction explanation for each sample
1332
- tfloat *instance_out_contribs;
1333
- TreeEnsemble tree;
1334
- ExplanationDataset instance;
1335
- const unsigned contrib_row_size = (data.M + 1) * trees.num_outputs;
1336
- tfloat *diag_contribs = new tfloat[contrib_row_size];
1337
- tfloat *on_contribs = new tfloat[contrib_row_size];
1338
- tfloat *off_contribs = new tfloat[contrib_row_size];
1339
- for (unsigned i = 0; i < data.num_X; ++i) {
1340
- instance_out_contribs = out_contribs + i * (data.M + 1) * contrib_row_size;
1341
- data.get_x_instance(instance, i);
1342
-
1343
- // aggregate the effect of explaining each tree
1344
- // (this works because of the linearity property of Shapley values)
1345
- std::fill(diag_contribs, diag_contribs + contrib_row_size, 0);
1346
- for (unsigned j = 0; j < trees.tree_limit; ++j) {
1347
- trees.get_tree(tree, j);
1348
- tree_shap(tree, instance, diag_contribs, 0, 0);
1349
-
1350
- const int *unique_features_row = unique_features + j * amount_of_unique_features;
1351
- for (unsigned k = 0; k < amount_of_unique_features; ++k) {
1352
- const int ind = unique_features_row[k];
1353
- if (ind < 0) break; // < 0 means we have seen all the features for this tree
1354
-
1355
- // compute the shap value with this feature held on and off
1356
- std::fill(on_contribs, on_contribs + contrib_row_size, 0);
1357
- std::fill(off_contribs, off_contribs + contrib_row_size, 0);
1358
- tree_shap(tree, instance, on_contribs, 1, ind);
1359
- tree_shap(tree, instance, off_contribs, -1, ind);
1360
-
1361
- // save the difference between on and off as the interaction value
1362
- for (unsigned l = 0; l < contrib_row_size; ++l) {
1363
- const tfloat val = (on_contribs[l] - off_contribs[l]) / 2;
1364
- instance_out_contribs[ind * contrib_row_size + l] += val;
1365
- diag_contribs[l] -= val;
1366
- }
1367
- }
1368
- }
1369
-
1370
- // set the diagonal
1371
- for (unsigned j = 0; j < data.M + 1; ++j) {
1372
- const unsigned offset = j * contrib_row_size + j * trees.num_outputs;
1373
- for (unsigned k = 0; k < trees.num_outputs; ++k) {
1374
- instance_out_contribs[offset + k] = diag_contribs[j * trees.num_outputs + k];
1375
- }
1376
- }
1377
-
1378
- // apply the base offset to the bias term
1379
- const unsigned last_ind = (data.M * (data.M + 1) + data.M) * trees.num_outputs;
1380
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
1381
- instance_out_contribs[last_ind + j] += trees.base_offset[j];
1382
- }
1383
- }
1384
-
1385
- delete[] diag_contribs;
1386
- delete[] on_contribs;
1387
- delete[] off_contribs;
1388
- delete[] unique_features;
1389
- }
1390
-
1391
- /**
1392
- * This runs Tree SHAP with a global path conditional dependence assumption.
1393
- *
1394
- * By first merging all the trees in a tree ensemble into an equivalent single tree
1395
- * this method allows arbitrary marginal transformations and also ensures that all the
1396
- * evaluations of the model are consistent with some training data point.
1397
- */
1398
- inline void dense_global_path_dependent(const TreeEnsemble& trees, const ExplanationDataset &data,
1399
- tfloat *out_contribs, tfloat transform(const tfloat, const tfloat)) {
1400
-
1401
- // allocate space for our new merged tree (we save enough room to totally split all samples if need be)
1402
- TreeEnsemble merged_tree;
1403
- merged_tree.allocate(1, (data.num_X + data.num_R) * 2, trees.num_outputs);
1404
-
1405
- // collapse the ensemble of trees into a single tree that has the same behavior
1406
- // for all the X and R samples in the dataset
1407
- build_merged_tree(merged_tree, data, trees);
1408
-
1409
- // compute the expected value and depth of the new merged tree
1410
- compute_expectations(merged_tree);
1411
-
1412
- // explain each sample using our new merged tree
1413
- ExplanationDataset instance;
1414
- tfloat *instance_out_contribs;
1415
- for (unsigned i = 0; i < data.num_X; ++i) {
1416
- instance_out_contribs = out_contribs + i * (data.M + 1) * trees.num_outputs;
1417
- data.get_x_instance(instance, i);
1418
-
1419
- // since we now just have a single merged tree we can just use the tree_path_dependent algorithm
1420
- tree_shap(merged_tree, instance, instance_out_contribs, 0, 0);
1421
-
1422
- // apply the base offset to the bias term
1423
- for (unsigned j = 0; j < trees.num_outputs; ++j) {
1424
- instance_out_contribs[data.M * trees.num_outputs + j] += trees.base_offset[j];
1425
- }
1426
- }
1427
-
1428
- merged_tree.free();
1429
- }
1430
-
1431
-
1432
- /**
1433
- * The main method for computing Tree SHAP on models using dense data.
1434
- */
1435
- inline void dense_tree_shap(const TreeEnsemble& trees, const ExplanationDataset &data, tfloat *out_contribs,
1436
- const int feature_dependence, unsigned model_transform, bool interactions) {
1437
-
1438
- // see what transform (if any) we have
1439
- transform_f transform = get_transform(model_transform);
1440
-
1441
- // dispatch to the correct algorithm handler
1442
- switch (feature_dependence) {
1443
- case FEATURE_DEPENDENCE::independent:
1444
- if (interactions) {
1445
- std::cerr << "FEATURE_DEPENDENCE::independent does not support interactions!\n";
1446
- } else dense_independent(trees, data, out_contribs, transform);
1447
- return;
1448
-
1449
- case FEATURE_DEPENDENCE::tree_path_dependent:
1450
- if (interactions) dense_tree_interactions_path_dependent(trees, data, out_contribs, transform);
1451
- else dense_tree_path_dependent(trees, data, out_contribs, transform);
1452
- return;
1453
-
1454
- case FEATURE_DEPENDENCE::global_path_dependent:
1455
- if (interactions) {
1456
- std::cerr << "FEATURE_DEPENDENCE::global_path_dependent does not support interactions!\n";
1457
- } else dense_global_path_dependent(trees, data, out_contribs, transform);
1458
- return;
1459
- }
1460
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/datasets.py DELETED
@@ -1,309 +0,0 @@
1
- import os
2
- from urllib.request import urlretrieve
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import sklearn.datasets
7
-
8
- import shap
9
-
10
- github_data_url = "https://github.com/shap/shap/raw/master/data/"
11
-
12
-
13
- def imagenet50(display=False, resolution=224, n_points=None):
14
- """ This is a set of 50 images representative of ImageNet images.
15
-
16
- This dataset was collected by randomly finding a working ImageNet link and then pasting the
17
- original ImageNet image into Google image search restricted to images licensed for reuse. A
18
- similar image (now with rights to reuse) was downloaded as a rough replacement for the original
19
- ImageNet image. The point is to have a random sample of ImageNet for use as a background
20
- distribution for explaining models trained on ImageNet data.
21
-
22
- Note that because the images are only rough replacements the labels might no longer be correct.
23
- """
24
-
25
- prefix = github_data_url + "imagenet50_"
26
- X = np.load(cache(f"{prefix}{resolution}x{resolution}.npy")).astype(np.float32)
27
- y = np.loadtxt(cache(f"{prefix}labels.csv"))
28
-
29
- if n_points is not None:
30
- X = shap.utils.sample(X, n_points, random_state=0)
31
- y = shap.utils.sample(y, n_points, random_state=0)
32
-
33
- return X, y
34
-
35
-
36
- def california(display=False, n_points=None):
37
- """ Return the california housing data in a nice package. """
38
-
39
- d = sklearn.datasets.fetch_california_housing()
40
- df = pd.DataFrame(data=d.data, columns=d.feature_names)
41
- target = d.target
42
-
43
- if n_points is not None:
44
- df = shap.utils.sample(df, n_points, random_state=0)
45
- target = shap.utils.sample(target, n_points, random_state=0)
46
-
47
- return df, target
48
-
49
-
50
- def linnerud(display=False, n_points=None):
51
- """ Return the linnerud data in a nice package (multi-target regression). """
52
-
53
- d = sklearn.datasets.load_linnerud()
54
- X = pd.DataFrame(d.data, columns=d.feature_names)
55
- y = pd.DataFrame(d.target, columns=d.target_names)
56
-
57
- if n_points is not None:
58
- X = shap.utils.sample(X, n_points, random_state=0)
59
- y = shap.utils.sample(y, n_points, random_state=0)
60
-
61
- return X, y
62
-
63
-
64
- def imdb(display=False, n_points=None):
65
- """ Return the classic IMDB sentiment analysis training data in a nice package.
66
-
67
- Full data is at: http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
68
- Paper to cite when using the data is: http://www.aclweb.org/anthology/P11-1015
69
- """
70
-
71
- with open(cache(github_data_url + "imdb_train.txt"), encoding="utf-8") as f:
72
- data = f.readlines()
73
- y = np.ones(25000, dtype=bool)
74
- y[:12500] = 0
75
-
76
- if n_points is not None:
77
- data = shap.utils.sample(data, n_points, random_state=0)
78
- y = shap.utils.sample(y, n_points, random_state=0)
79
-
80
- return data, y
81
-
82
-
83
- def communitiesandcrime(display=False, n_points=None):
84
- """ Predict total number of non-violent crimes per 100K popuation.
85
-
86
- This dataset is from the classic UCI Machine Learning repository:
87
- https://archive.ics.uci.edu/ml/datasets/Communities+and+Crime+Unnormalized
88
- """
89
-
90
- raw_data = pd.read_csv(
91
- cache(github_data_url + "CommViolPredUnnormalizedData.txt"),
92
- na_values="?"
93
- )
94
-
95
- # find the indices where the total violent crimes are known
96
- valid_inds = np.where(np.invert(np.isnan(raw_data.iloc[:,-2])))[0]
97
-
98
- if n_points is not None:
99
- valid_inds = shap.utils.sample(valid_inds, n_points, random_state=0)
100
-
101
- y = np.array(raw_data.iloc[valid_inds,-2], dtype=float)
102
-
103
- # extract the predictive features and remove columns with missing values
104
- X = raw_data.iloc[valid_inds,5:-18]
105
- valid_cols = np.where(np.isnan(X.values).sum(0) == 0)[0]
106
- X = X.iloc[:,valid_cols]
107
-
108
- return X, y
109
-
110
-
111
- def diabetes(display=False, n_points=None):
112
- """ Return the diabetes data in a nice package. """
113
-
114
- d = sklearn.datasets.load_diabetes()
115
- df = pd.DataFrame(data=d.data, columns=d.feature_names)
116
- target = d.target
117
-
118
- if n_points is not None:
119
- df = shap.utils.sample(df, n_points, random_state=0)
120
- target = shap.utils.sample(target, n_points, random_state=0)
121
-
122
- return df, target
123
-
124
-
125
- def iris(display=False, n_points=None):
126
- """ Return the classic iris data in a nice package. """
127
-
128
- d = sklearn.datasets.load_iris()
129
- df = pd.DataFrame(data=d.data, columns=d.feature_names)
130
- target = d.target
131
-
132
- if n_points is not None:
133
- df = shap.utils.sample(df, n_points, random_state=0)
134
- target = shap.utils.sample(target, n_points, random_state=0)
135
-
136
- if display:
137
- return df, [d.target_names[v] for v in target]
138
- return df, target
139
-
140
-
141
- def adult(display=False, n_points=None):
142
- """ Return the Adult census data in a nice package. """
143
- dtypes = [
144
- ("Age", "float32"), ("Workclass", "category"), ("fnlwgt", "float32"),
145
- ("Education", "category"), ("Education-Num", "float32"), ("Marital Status", "category"),
146
- ("Occupation", "category"), ("Relationship", "category"), ("Race", "category"),
147
- ("Sex", "category"), ("Capital Gain", "float32"), ("Capital Loss", "float32"),
148
- ("Hours per week", "float32"), ("Country", "category"), ("Target", "category")
149
- ]
150
- raw_data = pd.read_csv(
151
- cache(github_data_url + "adult.data"),
152
- names=[d[0] for d in dtypes],
153
- na_values="?",
154
- dtype=dict(dtypes)
155
- )
156
-
157
- if n_points is not None:
158
- raw_data = shap.utils.sample(raw_data, n_points, random_state=0)
159
-
160
- data = raw_data.drop(["Education"], axis=1) # redundant with Education-Num
161
- filt_dtypes = list(filter(lambda x: x[0] not in ["Target", "Education"], dtypes))
162
- data["Target"] = data["Target"] == " >50K"
163
- rcode = {
164
- "Not-in-family": 0,
165
- "Unmarried": 1,
166
- "Other-relative": 2,
167
- "Own-child": 3,
168
- "Husband": 4,
169
- "Wife": 5
170
- }
171
- for k, dtype in filt_dtypes:
172
- if dtype == "category":
173
- if k == "Relationship":
174
- data[k] = np.array([rcode[v.strip()] for v in data[k]])
175
- else:
176
- data[k] = data[k].cat.codes
177
-
178
- if display:
179
- return raw_data.drop(["Education", "Target", "fnlwgt"], axis=1), data["Target"].values
180
- return data.drop(["Target", "fnlwgt"], axis=1), data["Target"].values
181
-
182
-
183
- def nhanesi(display=False, n_points=None):
184
- """ A nicely packaged version of NHANES I data with surivival times as labels.
185
- """
186
- X = pd.read_csv(cache(github_data_url + "NHANESI_X.csv"), index_col=0)
187
- y = pd.read_csv(cache(github_data_url + "NHANESI_y.csv"), index_col=0)["y"]
188
-
189
- if n_points is not None:
190
- X = shap.utils.sample(X, n_points, random_state=0)
191
- y = shap.utils.sample(y, n_points, random_state=0)
192
-
193
- if display:
194
- X_display = X.copy()
195
- # X_display["sex_isFemale"] = ["Female" if v else "Male" for v in X["sex_isFemale"]]
196
- return X_display, np.array(y)
197
- return X, np.array(y)
198
-
199
-
200
- def corrgroups60(display=False, n_points=1_000):
201
- """ Correlated Groups 60
202
-
203
- A simulated dataset with tight correlations among distinct groups of features.
204
- """
205
-
206
- # set a constant seed
207
- old_seed = np.random.seed()
208
- np.random.seed(0)
209
-
210
- # generate dataset with known correlation
211
- N, M = n_points, 60
212
-
213
- # set one coefficient from each group of 3 to 1
214
- beta = np.zeros(M)
215
- beta[0:30:3] = 1
216
-
217
- # build a correlation matrix with groups of 3 tightly correlated features
218
- C = np.eye(M)
219
- for i in range(0,30,3):
220
- C[i,i+1] = C[i+1,i] = 0.99
221
- C[i,i+2] = C[i+2,i] = 0.99
222
- C[i+1,i+2] = C[i+2,i+1] = 0.99
223
- def f(X):
224
- return np.matmul(X, beta)
225
-
226
- # Make sure the sample correlation is a perfect match
227
- X_start = np.random.randn(N, M)
228
- X_centered = X_start - X_start.mean(0)
229
- Sigma = np.matmul(X_centered.T, X_centered) / X_centered.shape[0]
230
- W = np.linalg.cholesky(np.linalg.inv(Sigma)).T
231
- X_white = np.matmul(X_centered, W.T)
232
- assert np.linalg.norm(np.corrcoef(np.matmul(X_centered, W.T).T) - np.eye(M)) < 1e-6 # ensure this decorrelates the data
233
-
234
- # create the final data
235
- X_final = np.matmul(X_white, np.linalg.cholesky(C).T)
236
- X = X_final
237
- y = f(X) + np.random.randn(N) * 1e-2
238
-
239
- # restore the previous numpy random seed
240
- np.random.seed(old_seed)
241
-
242
- return pd.DataFrame(X), y
243
-
244
-
245
- def independentlinear60(display=False, n_points=1_000):
246
- """ A simulated dataset with tight correlations among distinct groups of features.
247
- """
248
-
249
- # set a constant seed
250
- old_seed = np.random.seed()
251
- np.random.seed(0)
252
-
253
- # generate dataset with known correlation
254
- N, M = n_points, 60
255
-
256
- # set one coefficient from each group of 3 to 1
257
- beta = np.zeros(M)
258
- beta[0:30:3] = 1
259
- def f(X):
260
- return np.matmul(X, beta)
261
-
262
- # Make sure the sample correlation is a perfect match
263
- X_start = np.random.randn(N, M)
264
- X = X_start - X_start.mean(0)
265
- y = f(X) + np.random.randn(N) * 1e-2
266
-
267
- # restore the previous numpy random seed
268
- np.random.seed(old_seed)
269
-
270
- return pd.DataFrame(X), y
271
-
272
-
273
- def a1a(n_points=None):
274
- """ A sparse dataset in scipy csr matrix format.
275
- """
276
- data, target = sklearn.datasets.load_svmlight_file(cache(github_data_url + 'a1a.svmlight'))
277
-
278
- if n_points is not None:
279
- data = shap.utils.sample(data, n_points, random_state=0)
280
- target = shap.utils.sample(target, n_points, random_state=0)
281
-
282
- return data, target
283
-
284
-
285
- def rank():
286
- """ Ranking datasets from lightgbm repository.
287
- """
288
- rank_data_url = 'https://raw.githubusercontent.com/Microsoft/LightGBM/master/examples/lambdarank/'
289
- x_train, y_train = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.train'))
290
- x_test, y_test = sklearn.datasets.load_svmlight_file(cache(rank_data_url + 'rank.test'))
291
- q_train = np.loadtxt(cache(rank_data_url + 'rank.train.query'))
292
- q_test = np.loadtxt(cache(rank_data_url + 'rank.test.query'))
293
-
294
- return x_train, y_train, x_test, y_test, q_train, q_test
295
-
296
-
297
- def cache(url, file_name=None):
298
- """ Loads a file from the URL and caches it locally.
299
- """
300
- if file_name is None:
301
- file_name = os.path.basename(url)
302
- data_dir = os.path.join(os.path.dirname(__file__), "cached_data")
303
- os.makedirs(data_dir, exist_ok=True)
304
-
305
- file_path = os.path.join(data_dir, file_name)
306
- if not os.path.isfile(file_path):
307
- urlretrieve(url, file_path)
308
-
309
- return file_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/__init__.py DELETED
@@ -1,38 +0,0 @@
1
- from ._additive import AdditiveExplainer
2
- from ._deep import DeepExplainer
3
- from ._exact import ExactExplainer
4
- from ._gpu_tree import GPUTreeExplainer
5
- from ._gradient import GradientExplainer
6
- from ._kernel import KernelExplainer
7
- from ._linear import LinearExplainer
8
- from ._partition import PartitionExplainer
9
- from ._permutation import PermutationExplainer
10
- from ._sampling import SamplingExplainer
11
- from ._tree import TreeExplainer
12
-
13
- # Alternative legacy "short-form" aliases, which are kept here for backwards-compatibility
14
- Additive = AdditiveExplainer
15
- Deep = DeepExplainer
16
- Exact = ExactExplainer
17
- GPUTree = GPUTreeExplainer
18
- Gradient = GradientExplainer
19
- Kernel = KernelExplainer
20
- Linear = LinearExplainer
21
- Partition = PartitionExplainer
22
- Permutation = PermutationExplainer
23
- Sampling = SamplingExplainer
24
- Tree = TreeExplainer
25
-
26
- __all__ = [
27
- "AdditiveExplainer",
28
- "DeepExplainer",
29
- "ExactExplainer",
30
- "GPUTreeExplainer",
31
- "GradientExplainer",
32
- "KernelExplainer",
33
- "LinearExplainer",
34
- "PartitionExplainer",
35
- "PermutationExplainer",
36
- "SamplingExplainer",
37
- "TreeExplainer",
38
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_additive.py DELETED
@@ -1,187 +0,0 @@
1
- import numpy as np
2
-
3
- from ..utils import MaskedModel, safe_isinstance
4
- from ._explainer import Explainer
5
-
6
-
7
- class AdditiveExplainer(Explainer):
8
- """ Computes SHAP values for generalized additive models.
9
-
10
- This assumes that the model only has first-order effects. Extending this to
11
- second- and third-order effects is future work (if you apply this to those models right now
12
- you will get incorrect answers that fail additivity).
13
- """
14
-
15
- def __init__(self, model, masker, link=None, feature_names=None, linearize_link=True):
16
- """ Build an Additive explainer for the given model using the given masker object.
17
-
18
- Parameters
19
- ----------
20
- model : function
21
- A callable python object that executes the model given a set of input data samples.
22
-
23
- masker : function or numpy.array or pandas.DataFrame
24
- A callable python object used to "mask" out hidden features of the form `masker(mask, *fargs)`.
25
- It takes a single a binary mask and an input sample and returns a matrix of masked samples. These
26
- masked samples are evaluated using the model function and the outputs are then averaged.
27
- As a shortcut for the standard masking used by SHAP you can pass a background data matrix
28
- instead of a function and that matrix will be used for masking. To use a clustering
29
- game structure you can pass a shap.maskers.Tabular(data, hclustering=\"correlation\") object, but
30
- note that this structure information has no effect on the explanations of additive models.
31
- """
32
- super().__init__(model, masker, feature_names=feature_names, linearize_link=linearize_link)
33
-
34
- if safe_isinstance(model, "interpret.glassbox.ExplainableBoostingClassifier"):
35
- self.model = model.decision_function
36
-
37
- if self.masker is None:
38
- self._expected_value = model.intercept_
39
- # num_features = len(model.additive_terms_)
40
-
41
- # fm = MaskedModel(self.model, self.masker, self.link, np.zeros(num_features))
42
- # masks = np.ones((1, num_features), dtype=bool)
43
- # outputs = fm(masks)
44
- # self.model(np.zeros(num_features))
45
- # self._zero_offset = self.model(np.zeros(num_features))#model.intercept_#outputs[0]
46
- # self._input_offsets = np.zeros(num_features) #* self._zero_offset
47
- raise NotImplementedError("Masker not given and we don't yet support pulling the distribution centering directly from the EBM model!")
48
- return
49
-
50
- # here we need to compute the offsets ourselves because we can't pull them directly from a model we know about
51
- assert safe_isinstance(self.masker, "shap.maskers.Independent"), "The Additive explainer only supports the Tabular masker at the moment!"
52
-
53
- # pre-compute per-feature offsets
54
- fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, np.zeros(self.masker.shape[1]))
55
- masks = np.ones((self.masker.shape[1]+1, self.masker.shape[1]), dtype=bool)
56
- for i in range(1, self.masker.shape[1]+1):
57
- masks[i,i-1] = False
58
- outputs = fm(masks)
59
- self._zero_offset = outputs[0]
60
- self._input_offsets = np.zeros(masker.shape[1])
61
- for i in range(1, self.masker.shape[1]+1):
62
- self._input_offsets[i-1] = outputs[i] - self._zero_offset
63
-
64
- self._expected_value = self._input_offsets.sum() + self._zero_offset
65
-
66
- def __call__(self, *args, max_evals=None, silent=False):
67
- """ Explains the output of model(*args), where args represents one or more parallel iterable args.
68
- """
69
-
70
- # we entirely rely on the general call implementation, we override just to remove **kwargs
71
- # from the function signature
72
- return super().__call__(*args, max_evals=max_evals, silent=silent)
73
-
74
- @staticmethod
75
- def supports_model_with_masker(model, masker):
76
- """ Determines if this explainer can handle the given model.
77
-
78
- This is an abstract static method meant to be implemented by each subclass.
79
- """
80
- if safe_isinstance(model, "interpret.glassbox.ExplainableBoostingClassifier"):
81
- if model.interactions != 0:
82
- raise NotImplementedError("Need to add support for interaction effects!")
83
- return True
84
-
85
- return False
86
-
87
- def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent):
88
- """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
89
- """
90
-
91
- x = row_args[0]
92
- inputs = np.zeros((len(x), len(x)))
93
- for i in range(len(x)):
94
- inputs[i,i] = x[i]
95
-
96
- phi = self.model(inputs) - self._zero_offset - self._input_offsets
97
-
98
- return {
99
- "values": phi,
100
- "expected_values": self._expected_value,
101
- "mask_shapes": [a.shape for a in row_args],
102
- "main_effects": phi,
103
- "clustering": getattr(self.masker, "clustering", None)
104
- }
105
-
106
- # class AdditiveExplainer(Explainer):
107
- # """ Computes SHAP values for generalized additive models.
108
-
109
- # This assumes that the model only has first order effects. Extending this to
110
- # 2nd and third order effects is future work (if you apply this to those models right now
111
- # you will get incorrect answers that fail additivity).
112
-
113
- # Parameters
114
- # ----------
115
- # model : function or ExplainableBoostingRegressor
116
- # User supplied additive model either as either a function or a model object.
117
-
118
- # data : numpy.array, pandas.DataFrame
119
- # The background dataset to use for computing conditional expectations.
120
- # feature_perturbation : "interventional"
121
- # Only the standard interventional SHAP values are supported by AdditiveExplainer right now.
122
- # """
123
-
124
- # def __init__(self, model, data, feature_perturbation="interventional"):
125
- # if feature_perturbation != "interventional":
126
- # raise Exception("Unsupported type of feature_perturbation provided: " + feature_perturbation)
127
-
128
- # if safe_isinstance(model, "interpret.glassbox.ebm.ebm.ExplainableBoostingRegressor"):
129
- # self.f = model.predict
130
- # elif callable(model):
131
- # self.f = model
132
- # else:
133
- # raise ValueError("The passed model must be a recognized object or a function!")
134
-
135
- # # convert dataframes
136
- # if isinstance(data, (pd.Series, pd.DataFrame)):
137
- # data = data.values
138
- # self.data = data
139
-
140
- # # compute the expected value of the model output
141
- # self.expected_value = self.f(data).mean()
142
-
143
- # # pre-compute per-feature offsets
144
- # tmp = np.zeros(data.shape)
145
- # self._zero_offset = self.f(tmp).mean()
146
- # self._feature_offset = np.zeros(data.shape[1])
147
- # for i in range(data.shape[1]):
148
- # tmp[:,i] = data[:,i]
149
- # self._feature_offset[i] = self.f(tmp).mean() - self._zero_offset
150
- # tmp[:,i] = 0
151
-
152
-
153
- # def shap_values(self, X):
154
- # """ Estimate the SHAP values for a set of samples.
155
-
156
- # Parameters
157
- # ----------
158
- # X : numpy.array, pandas.DataFrame or scipy.csr_matrix
159
- # A matrix of samples (# samples x # features) on which to explain the model's output.
160
-
161
- # Returns
162
- # -------
163
- # For models with a single output this returns a matrix of SHAP values
164
- # (# samples x # features). Each row sums to the difference between the model output for that
165
- # sample and the expected value of the model output (which is stored as expected_value
166
- # attribute of the explainer).
167
- # """
168
-
169
- # # convert dataframes
170
- # if isinstance(X, (pd.Series, pd.DataFrame)):
171
- # X = X.values
172
-
173
- # # assert isinstance(X, np.ndarray), "Unknown instance type: " + str(type(X))
174
- # assert len(X.shape) == 1 or len(X.shape) == 2, "Instance must have 1 or 2 dimensions!"
175
-
176
- # # convert dataframes
177
- # if isinstance(X, (pd.Series, pd.DataFrame)):
178
- # X = X.values
179
-
180
- # phi = np.zeros(X.shape)
181
- # tmp = np.zeros(X.shape)
182
- # for i in range(X.shape[1]):
183
- # tmp[:,i] = X[:,i]
184
- # phi[:,i] = self.f(tmp) - self._zero_offset - self._feature_offset[i]
185
- # tmp[:,i] = 0
186
-
187
- # return phi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_deep/__init__.py DELETED
@@ -1,125 +0,0 @@
1
- from .._explainer import Explainer
2
- from .deep_pytorch import PyTorchDeep
3
- from .deep_tf import TFDeep
4
-
5
-
6
- class DeepExplainer(Explainer):
7
- """ Meant to approximate SHAP values for deep learning models.
8
-
9
- This is an enhanced version of the DeepLIFT algorithm (Deep SHAP) where, similar to Kernel SHAP, we
10
- approximate the conditional expectations of SHAP values using a selection of background samples.
11
- Lundberg and Lee, NIPS 2017 showed that the per node attribution rules in DeepLIFT (Shrikumar,
12
- Greenside, and Kundaje, arXiv 2017) can be chosen to approximate Shapley values. By integrating
13
- over many background samples Deep estimates approximate SHAP values such that they sum
14
- up to the difference between the expected model output on the passed background samples and the
15
- current model output (f(x) - E[f(x)]).
16
-
17
- Examples
18
- --------
19
- See :ref:`Deep Explainer Examples <deep_explainer_examples>`
20
- """
21
-
22
- def __init__(self, model, data, session=None, learning_phase_flags=None):
23
- """ An explainer object for a differentiable model using a given background dataset.
24
-
25
- Note that the complexity of the method scales linearly with the number of background data
26
- samples. Passing the entire training dataset as `data` will give very accurate expected
27
- values, but be unreasonably expensive. The variance of the expectation estimates scale by
28
- roughly 1/sqrt(N) for N background data samples. So 100 samples will give a good estimate,
29
- and 1000 samples a very good estimate of the expected values.
30
-
31
- Parameters
32
- ----------
33
- model : if framework == 'tensorflow', (input : [tf.Tensor], output : tf.Tensor)
34
- A pair of TensorFlow tensors (or a list and a tensor) that specifies the input and
35
- output of the model to be explained. Note that SHAP values are specific to a single
36
- output value, so the output tf.Tensor should be a single dimensional output (,1).
37
-
38
- if framework == 'pytorch', an nn.Module object (model), or a tuple (model, layer),
39
- where both are nn.Module objects
40
- The model is an nn.Module object which takes as input a tensor (or list of tensors) of
41
- shape data, and returns a single dimensional output.
42
- If the input is a tuple, the returned shap values will be for the input of the
43
- layer argument. layer must be a layer in the model, i.e. model.conv2
44
-
45
- data :
46
- if framework == 'tensorflow': [numpy.array] or [pandas.DataFrame]
47
- if framework == 'pytorch': [torch.tensor]
48
- The background dataset to use for integrating out features. Deep integrates
49
- over these samples. The data passed here must match the input tensors given in the
50
- first argument. Note that since these samples are integrated over for each sample you
51
- should only something like 100 or 1000 random background samples, not the whole training
52
- dataset.
53
-
54
- if framework == 'tensorflow':
55
-
56
- session : None or tensorflow.Session
57
- The TensorFlow session that has the model we are explaining. If None is passed then
58
- we do our best to find the right session, first looking for a keras session, then
59
- falling back to the default TensorFlow session.
60
-
61
- learning_phase_flags : None or list of tensors
62
- If you have your own custom learning phase flags pass them here. When explaining a prediction
63
- we need to ensure we are not in training mode, since this changes the behavior of ops like
64
- batch norm or dropout. If None is passed then we look for tensors in the graph that look like
65
- learning phase flags (this works for Keras models). Note that we assume all the flags should
66
- have a value of False during predictions (and hence explanations).
67
- """
68
- # first, we need to find the framework
69
- if type(model) is tuple:
70
- a, b = model
71
- try:
72
- a.named_parameters()
73
- framework = 'pytorch'
74
- except Exception:
75
- framework = 'tensorflow'
76
- else:
77
- try:
78
- model.named_parameters()
79
- framework = 'pytorch'
80
- except Exception:
81
- framework = 'tensorflow'
82
-
83
- if framework == 'tensorflow':
84
- self.explainer = TFDeep(model, data, session, learning_phase_flags)
85
- elif framework == 'pytorch':
86
- self.explainer = PyTorchDeep(model, data)
87
-
88
- self.expected_value = self.explainer.expected_value
89
- self.explainer.framework = framework
90
-
91
- def shap_values(self, X, ranked_outputs=None, output_rank_order='max', check_additivity=True):
92
- """ Return approximate SHAP values for the model applied to the data given by X.
93
-
94
- Parameters
95
- ----------
96
- X : list,
97
- if framework == 'tensorflow': numpy.array, or pandas.DataFrame
98
- if framework == 'pytorch': torch.tensor
99
- A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to
100
- explain the model's output.
101
-
102
- ranked_outputs : None or int
103
- If ranked_outputs is None then we explain all the outputs in a multi-output model. If
104
- ranked_outputs is a positive integer then we only explain that many of the top model
105
- outputs (where "top" is determined by output_rank_order). Note that this causes a pair
106
- of values to be returned (shap_values, indexes), where shap_values is a list of numpy
107
- arrays for each of the output ranks, and indexes is a matrix that indicates for each sample
108
- which output indexes were choses as "top".
109
-
110
- output_rank_order : "max", "min", or "max_abs"
111
- How to order the model outputs when using ranked_outputs, either by maximum, minimum, or
112
- maximum absolute value.
113
-
114
- Returns
115
- -------
116
- array or list
117
- For a models with a single output this returns a tensor of SHAP values with the same shape
118
- as X. For a model with multiple outputs this returns a list of SHAP value tensors, each of
119
- which are the same shape as X. If ranked_outputs is None then this list of tensors matches
120
- the number of model outputs. If ranked_outputs is a positive integer a pair is returned
121
- (shap_values, indexes), where shap_values is a list of tensors with a length of
122
- ranked_outputs, and indexes is a matrix that indicates for each sample which output indexes
123
- were chosen as "top".
124
- """
125
- return self.explainer.shap_values(X, ranked_outputs, output_rank_order, check_additivity=check_additivity)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_deep/deep_pytorch.py DELETED
@@ -1,386 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- from packaging import version
5
-
6
- from .._explainer import Explainer
7
- from .deep_utils import _check_additivity
8
-
9
- torch = None
10
-
11
-
12
- class PyTorchDeep(Explainer):
13
-
14
- def __init__(self, model, data):
15
- # try and import pytorch
16
- global torch
17
- if torch is None:
18
- import torch
19
- if version.parse(torch.__version__) < version.parse("0.4"):
20
- warnings.warn("Your PyTorch version is older than 0.4 and not supported.")
21
-
22
- # check if we have multiple inputs
23
- self.multi_input = False
24
- if isinstance(data, list):
25
- self.multi_input = True
26
- if not isinstance(data, list):
27
- data = [data]
28
- self.data = data
29
- self.layer = None
30
- self.input_handle = None
31
- self.interim = False
32
- self.interim_inputs_shape = None
33
- self.expected_value = None # to keep the DeepExplainer base happy
34
- if type(model) == tuple:
35
- self.interim = True
36
- model, layer = model
37
- model = model.eval()
38
- self.layer = layer
39
- self.add_target_handle(self.layer)
40
-
41
- # if we are taking an interim layer, the 'data' is going to be the input
42
- # of the interim layer; we will capture this using a forward hook
43
- with torch.no_grad():
44
- _ = model(*data)
45
- interim_inputs = self.layer.target_input
46
- if type(interim_inputs) is tuple:
47
- # this should always be true, but just to be safe
48
- self.interim_inputs_shape = [i.shape for i in interim_inputs]
49
- else:
50
- self.interim_inputs_shape = [interim_inputs.shape]
51
- self.target_handle.remove()
52
- del self.layer.target_input
53
- self.model = model.eval()
54
-
55
- self.multi_output = False
56
- self.num_outputs = 1
57
- with torch.no_grad():
58
- outputs = model(*data)
59
-
60
- # also get the device everything is running on
61
- self.device = outputs.device
62
- if outputs.shape[1] > 1:
63
- self.multi_output = True
64
- self.num_outputs = outputs.shape[1]
65
- self.expected_value = outputs.mean(0).cpu().numpy()
66
-
67
- def add_target_handle(self, layer):
68
- input_handle = layer.register_forward_hook(get_target_input)
69
- self.target_handle = input_handle
70
-
71
- def add_handles(self, model, forward_handle, backward_handle):
72
- """
73
- Add handles to all non-container layers in the model.
74
- Recursively for non-container layers
75
- """
76
- handles_list = []
77
- model_children = list(model.children())
78
- if model_children:
79
- for child in model_children:
80
- handles_list.extend(self.add_handles(child, forward_handle, backward_handle))
81
- else: # leaves
82
- handles_list.append(model.register_forward_hook(forward_handle))
83
- handles_list.append(model.register_full_backward_hook(backward_handle))
84
- return handles_list
85
-
86
- def remove_attributes(self, model):
87
- """
88
- Removes the x and y attributes which were added by the forward handles
89
- Recursively searches for non-container layers
90
- """
91
- for child in model.children():
92
- if 'nn.modules.container' in str(type(child)):
93
- self.remove_attributes(child)
94
- else:
95
- try:
96
- del child.x
97
- except AttributeError:
98
- pass
99
- try:
100
- del child.y
101
- except AttributeError:
102
- pass
103
-
104
- def gradient(self, idx, inputs):
105
- self.model.zero_grad()
106
- X = [x.requires_grad_() for x in inputs]
107
- outputs = self.model(*X)
108
- selected = [val for val in outputs[:, idx]]
109
- grads = []
110
- if self.interim:
111
- interim_inputs = self.layer.target_input
112
- for idx, input in enumerate(interim_inputs):
113
- grad = torch.autograd.grad(selected, input,
114
- retain_graph=True if idx + 1 < len(interim_inputs) else None,
115
- allow_unused=True)[0]
116
- if grad is not None:
117
- grad = grad.cpu().numpy()
118
- else:
119
- grad = torch.zeros_like(X[idx]).cpu().numpy()
120
- grads.append(grad)
121
- del self.layer.target_input
122
- return grads, [i.detach().cpu().numpy() for i in interim_inputs]
123
- else:
124
- for idx, x in enumerate(X):
125
- grad = torch.autograd.grad(selected, x,
126
- retain_graph=True if idx + 1 < len(X) else None,
127
- allow_unused=True)[0]
128
- if grad is not None:
129
- grad = grad.cpu().numpy()
130
- else:
131
- grad = torch.zeros_like(X[idx]).cpu().numpy()
132
- grads.append(grad)
133
- return grads
134
-
135
- def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_additivity=True):
136
- # X ~ self.model_input
137
- # X_data ~ self.data
138
-
139
- # check if we have multiple inputs
140
- if not self.multi_input:
141
- assert not isinstance(X, list), "Expected a single tensor model input!"
142
- X = [X]
143
- else:
144
- assert isinstance(X, list), "Expected a list of model inputs!"
145
-
146
- X = [x.detach().to(self.device) for x in X]
147
-
148
- model_output_values = None
149
-
150
- if ranked_outputs is not None and self.multi_output:
151
- with torch.no_grad():
152
- model_output_values = self.model(*X)
153
- # rank and determine the model outputs that we will explain
154
- if output_rank_order == "max":
155
- _, model_output_ranks = torch.sort(model_output_values, descending=True)
156
- elif output_rank_order == "min":
157
- _, model_output_ranks = torch.sort(model_output_values, descending=False)
158
- elif output_rank_order == "max_abs":
159
- _, model_output_ranks = torch.sort(torch.abs(model_output_values), descending=True)
160
- else:
161
- emsg = "output_rank_order must be max, min, or max_abs!"
162
- raise ValueError(emsg)
163
- model_output_ranks = model_output_ranks[:, :ranked_outputs]
164
- else:
165
- model_output_ranks = (torch.ones((X[0].shape[0], self.num_outputs)).int() *
166
- torch.arange(0, self.num_outputs).int())
167
-
168
- # add the gradient handles
169
- handles = self.add_handles(self.model, add_interim_values, deeplift_grad)
170
- if self.interim:
171
- self.add_target_handle(self.layer)
172
-
173
- # compute the attributions
174
- output_phis = []
175
- for i in range(model_output_ranks.shape[1]):
176
- phis = []
177
- if self.interim:
178
- for k in range(len(self.interim_inputs_shape)):
179
- phis.append(np.zeros((X[0].shape[0], ) + self.interim_inputs_shape[k][1: ]))
180
- else:
181
- for k in range(len(X)):
182
- phis.append(np.zeros(X[k].shape))
183
- for j in range(X[0].shape[0]):
184
- # tile the inputs to line up with the background data samples
185
- tiled_X = [X[t][j:j + 1].repeat(
186
- (self.data[t].shape[0],) + tuple([1 for k in range(len(X[t].shape) - 1)])) for t
187
- in range(len(X))]
188
- joint_x = [torch.cat((tiled_X[t], self.data[t]), dim=0) for t in range(len(X))]
189
- # run attribution computation graph
190
- feature_ind = model_output_ranks[j, i]
191
- sample_phis = self.gradient(feature_ind, joint_x)
192
- # assign the attributions to the right part of the output arrays
193
- if self.interim:
194
- sample_phis, output = sample_phis
195
- x, data = [], []
196
- for k in range(len(output)):
197
- x_temp, data_temp = np.split(output[k], 2)
198
- x.append(x_temp)
199
- data.append(data_temp)
200
- for t in range(len(self.interim_inputs_shape)):
201
- phis[t][j] = (sample_phis[t][self.data[t].shape[0]:] * (x[t] - data[t])).mean(0)
202
- else:
203
- for t in range(len(X)):
204
- phis[t][j] = (torch.from_numpy(sample_phis[t][self.data[t].shape[0]:]).to(self.device) * (X[t][j: j + 1] - self.data[t])).cpu().detach().numpy().mean(0)
205
- output_phis.append(phis[0] if not self.multi_input else phis)
206
- # cleanup; remove all gradient handles
207
- for handle in handles:
208
- handle.remove()
209
- self.remove_attributes(self.model)
210
- if self.interim:
211
- self.target_handle.remove()
212
-
213
- # check that the SHAP values sum up to the model output
214
- if check_additivity:
215
- if model_output_values is None:
216
- with torch.no_grad():
217
- model_output_values = self.model(*X)
218
-
219
- _check_additivity(self, model_output_values.cpu(), output_phis)
220
-
221
- if not self.multi_output:
222
- return output_phis[0]
223
- elif ranked_outputs is not None:
224
- return output_phis, model_output_ranks
225
- else:
226
- return output_phis
227
-
228
- # Module hooks
229
-
230
-
231
- def deeplift_grad(module, grad_input, grad_output):
232
- """The backward hook which computes the deeplift
233
- gradient for an nn.Module
234
- """
235
- # first, get the module type
236
- module_type = module.__class__.__name__
237
- # first, check the module is supported
238
- if module_type in op_handler:
239
- if op_handler[module_type].__name__ not in ['passthrough', 'linear_1d']:
240
- return op_handler[module_type](module, grad_input, grad_output)
241
- else:
242
- warnings.warn(f'unrecognized nn.Module: {module_type}')
243
- return grad_input
244
-
245
-
246
- def add_interim_values(module, input, output):
247
- """The forward hook used to save interim tensors, detached
248
- from the graph. Used to calculate the multipliers
249
- """
250
- try:
251
- del module.x
252
- except AttributeError:
253
- pass
254
- try:
255
- del module.y
256
- except AttributeError:
257
- pass
258
- module_type = module.__class__.__name__
259
- if module_type in op_handler:
260
- func_name = op_handler[module_type].__name__
261
- # First, check for cases where we don't need to save the x and y tensors
262
- if func_name == 'passthrough':
263
- pass
264
- else:
265
- # check only the 0th input varies
266
- for i in range(len(input)):
267
- if i != 0 and type(output) is tuple:
268
- assert input[i] == output[i], "Only the 0th input may vary!"
269
- # if a new method is added, it must be added here too. This ensures tensors
270
- # are only saved if necessary
271
- if func_name in ['maxpool', 'nonlinear_1d']:
272
- # only save tensors if necessary
273
- if type(input) is tuple:
274
- setattr(module, 'x', torch.nn.Parameter(input[0].detach()))
275
- else:
276
- setattr(module, 'x', torch.nn.Parameter(input.detach()))
277
- if type(output) is tuple:
278
- setattr(module, 'y', torch.nn.Parameter(output[0].detach()))
279
- else:
280
- setattr(module, 'y', torch.nn.Parameter(output.detach()))
281
-
282
-
283
- def get_target_input(module, input, output):
284
- """A forward hook which saves the tensor - attached to its graph.
285
- Used if we want to explain the interim outputs of a model
286
- """
287
- try:
288
- del module.target_input
289
- except AttributeError:
290
- pass
291
- setattr(module, 'target_input', input)
292
-
293
-
294
- def passthrough(module, grad_input, grad_output):
295
- """No change made to gradients"""
296
- return None
297
-
298
-
299
- def maxpool(module, grad_input, grad_output):
300
- pool_to_unpool = {
301
- 'MaxPool1d': torch.nn.functional.max_unpool1d,
302
- 'MaxPool2d': torch.nn.functional.max_unpool2d,
303
- 'MaxPool3d': torch.nn.functional.max_unpool3d
304
- }
305
- pool_to_function = {
306
- 'MaxPool1d': torch.nn.functional.max_pool1d,
307
- 'MaxPool2d': torch.nn.functional.max_pool2d,
308
- 'MaxPool3d': torch.nn.functional.max_pool3d
309
- }
310
- delta_in = module.x[: int(module.x.shape[0] / 2)] - module.x[int(module.x.shape[0] / 2):]
311
- dup0 = [2] + [1 for i in delta_in.shape[1:]]
312
- # we also need to check if the output is a tuple
313
- y, ref_output = torch.chunk(module.y, 2)
314
- cross_max = torch.max(y, ref_output)
315
- diffs = torch.cat([cross_max - ref_output, y - cross_max], 0)
316
-
317
- # all of this just to unpool the outputs
318
- with torch.no_grad():
319
- _, indices = pool_to_function[module.__class__.__name__](
320
- module.x, module.kernel_size, module.stride, module.padding,
321
- module.dilation, module.ceil_mode, True)
322
- xmax_pos, rmax_pos = torch.chunk(pool_to_unpool[module.__class__.__name__](
323
- grad_output[0] * diffs, indices, module.kernel_size, module.stride,
324
- module.padding, list(module.x.shape)), 2)
325
-
326
- grad_input = [None for _ in grad_input]
327
- grad_input[0] = torch.where(torch.abs(delta_in) < 1e-7, torch.zeros_like(delta_in),
328
- (xmax_pos + rmax_pos) / delta_in).repeat(dup0)
329
-
330
- return tuple(grad_input)
331
-
332
-
333
- def linear_1d(module, grad_input, grad_output):
334
- """No change made to gradients."""
335
- return None
336
-
337
-
338
- def nonlinear_1d(module, grad_input, grad_output):
339
- delta_out = module.y[: int(module.y.shape[0] / 2)] - module.y[int(module.y.shape[0] / 2):]
340
-
341
- delta_in = module.x[: int(module.x.shape[0] / 2)] - module.x[int(module.x.shape[0] / 2):]
342
- dup0 = [2] + [1 for i in delta_in.shape[1:]]
343
- # handles numerical instabilities where delta_in is very small by
344
- # just taking the gradient in those cases
345
- grads = [None for _ in grad_input]
346
- grads[0] = torch.where(torch.abs(delta_in.repeat(dup0)) < 1e-6, grad_input[0],
347
- grad_output[0] * (delta_out / delta_in).repeat(dup0))
348
- return tuple(grads)
349
-
350
-
351
- op_handler = {}
352
-
353
- # passthrough ops, where we make no change to the gradient
354
- op_handler['Dropout3d'] = passthrough
355
- op_handler['Dropout2d'] = passthrough
356
- op_handler['Dropout'] = passthrough
357
- op_handler['AlphaDropout'] = passthrough
358
-
359
- op_handler['Conv1d'] = linear_1d
360
- op_handler['Conv2d'] = linear_1d
361
- op_handler['Conv3d'] = linear_1d
362
- op_handler['ConvTranspose1d'] = linear_1d
363
- op_handler['ConvTranspose2d'] = linear_1d
364
- op_handler['ConvTranspose3d'] = linear_1d
365
- op_handler['Linear'] = linear_1d
366
- op_handler['AvgPool1d'] = linear_1d
367
- op_handler['AvgPool2d'] = linear_1d
368
- op_handler['AvgPool3d'] = linear_1d
369
- op_handler['AdaptiveAvgPool1d'] = linear_1d
370
- op_handler['AdaptiveAvgPool2d'] = linear_1d
371
- op_handler['AdaptiveAvgPool3d'] = linear_1d
372
- op_handler['BatchNorm1d'] = linear_1d
373
- op_handler['BatchNorm2d'] = linear_1d
374
- op_handler['BatchNorm3d'] = linear_1d
375
-
376
- op_handler['LeakyReLU'] = nonlinear_1d
377
- op_handler['ReLU'] = nonlinear_1d
378
- op_handler['ELU'] = nonlinear_1d
379
- op_handler['Sigmoid'] = nonlinear_1d
380
- op_handler["Tanh"] = nonlinear_1d
381
- op_handler["Softplus"] = nonlinear_1d
382
- op_handler['Softmax'] = nonlinear_1d
383
-
384
- op_handler['MaxPool1d'] = maxpool
385
- op_handler['MaxPool2d'] = maxpool
386
- op_handler['MaxPool3d'] = maxpool
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_deep/deep_tf.py DELETED
@@ -1,763 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- from packaging import version
5
-
6
- from ...utils._exceptions import DimensionError
7
- from .._explainer import Explainer
8
- from ..tf_utils import _get_graph, _get_model_inputs, _get_model_output, _get_session
9
- from .deep_utils import _check_additivity
10
-
11
- tf = None
12
- tf_ops = None
13
- tf_backprop = None
14
- tf_execute = None
15
- tf_gradients_impl = None
16
-
17
- def custom_record_gradient(op_name, inputs, attrs, results):
18
- """ This overrides tensorflow.python.eager.backprop._record_gradient.
19
-
20
- We need to override _record_gradient in order to get gradient backprop to
21
- get called for ResourceGather operations. In order to make this work we
22
- temporarily "lie" about the input type to prevent the node from getting
23
- pruned from the gradient backprop process. We then reset the type directly
24
- afterwards back to what it was (an integer type).
25
- """
26
- reset_input = False
27
- if op_name == "ResourceGather" and inputs[1].dtype == tf.int32:
28
- inputs[1].__dict__["_dtype"] = tf.float32
29
- reset_input = True
30
- try:
31
- out = tf_backprop._record_gradient("shap_"+op_name, inputs, attrs, results)
32
- except AttributeError:
33
- out = tf_backprop.record_gradient("shap_"+op_name, inputs, attrs, results)
34
-
35
- if reset_input:
36
- inputs[1].__dict__["_dtype"] = tf.int32
37
-
38
- return out
39
-
40
- class TFDeep(Explainer):
41
- """
42
- Using tf.gradients to implement the backpropagation was
43
- inspired by the gradient-based implementation approach proposed by Ancona et al, ICLR 2018. Note
44
- that this package does not currently use the reveal-cancel rule for ReLu units proposed in DeepLIFT.
45
- """
46
-
47
- def __init__(self, model, data, session=None, learning_phase_flags=None):
48
- """ An explainer object for a deep model using a given background dataset.
49
-
50
- Note that the complexity of the method scales linearly with the number of background data
51
- samples. Passing the entire training dataset as `data` will give very accurate expected
52
- values, but will be computationally expensive. The variance of the expectation estimates scales by
53
- roughly 1/sqrt(N) for N background data samples. So 100 samples will give a good estimate,
54
- and 1000 samples a very good estimate of the expected values.
55
-
56
- Parameters
57
- ----------
58
- model : tf.keras.Model or (input : [tf.Operation], output : tf.Operation)
59
- A keras model object or a pair of TensorFlow operations (or a list and an op) that
60
- specifies the input and output of the model to be explained. Note that SHAP values
61
- are specific to a single output value, so you get an explanation for each element of
62
- the output tensor (which must be a flat rank one vector).
63
-
64
- data : [numpy.array] or [pandas.DataFrame] or function
65
- The background dataset to use for integrating out features. DeepExplainer integrates
66
- over all these samples for each explanation. The data passed here must match the input
67
- operations given to the model. If a function is supplied, it must be a function that
68
- takes a particular input example and generates the background dataset for that example
69
- session : None or tensorflow.Session
70
- The TensorFlow session that has the model we are explaining. If None is passed then
71
- we do our best to find the right session, first looking for a keras session, then
72
- falling back to the default TensorFlow session.
73
-
74
- learning_phase_flags : None or list of tensors
75
- If you have your own custom learning phase flags pass them here. When explaining a prediction
76
- we need to ensure we are not in training mode, since this changes the behavior of ops like
77
- batch norm or dropout. If None is passed then we look for tensors in the graph that look like
78
- learning phase flags (this works for Keras models). Note that we assume all the flags should
79
- have a value of False during predictions (and hence explanations).
80
-
81
- """
82
- # try to import tensorflow
83
- global tf, tf_ops, tf_backprop, tf_execute, tf_gradients_impl
84
- if tf is None:
85
- from tensorflow.python.eager import backprop as tf_backprop
86
- from tensorflow.python.eager import execute as tf_execute
87
- from tensorflow.python.framework import (
88
- ops as tf_ops,
89
- )
90
- from tensorflow.python.ops import (
91
- gradients_impl as tf_gradients_impl,
92
- )
93
- if not hasattr(tf_gradients_impl, "_IsBackpropagatable"):
94
- from tensorflow.python.ops import gradients_util as tf_gradients_impl
95
- import tensorflow as tf
96
- if version.parse(tf.__version__) < version.parse("1.4.0"):
97
- warnings.warn("Your TensorFlow version is older than 1.4.0 and not supported.")
98
-
99
- if version.parse(tf.__version__) >= version.parse("2.4.0"):
100
- warnings.warn("Your TensorFlow version is newer than 2.4.0 and so graph support has been removed in eager mode and some static graphs may not be supported. See PR #1483 for discussion.")
101
-
102
- # determine the model inputs and outputs
103
- self.model_inputs = _get_model_inputs(model)
104
- self.model_output = _get_model_output(model)
105
- assert not isinstance(self.model_output, list), "The model output to be explained must be a single tensor!"
106
- assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!"
107
- self.multi_output = True
108
- if len(self.model_output.shape) == 1:
109
- self.multi_output = False
110
-
111
- if tf.executing_eagerly():
112
- if isinstance(model, tuple) or isinstance(model, list):
113
- assert len(model) == 2, "When a tuple is passed it must be of the form (inputs, outputs)"
114
- from tensorflow.keras import Model
115
- self.model = Model(model[0], model[1])
116
- else:
117
- self.model = model
118
-
119
- # check if we have multiple inputs
120
- self.multi_input = True
121
- if not isinstance(self.model_inputs, list) or len(self.model_inputs) == 1:
122
- self.multi_input = False
123
- if not isinstance(self.model_inputs, list):
124
- self.model_inputs = [self.model_inputs]
125
- if not isinstance(data, list) and (hasattr(data, "__call__") is False):
126
- data = [data]
127
- self.data = data
128
-
129
- self._vinputs = {} # used to track what op inputs depends on the model inputs
130
- self.orig_grads = {}
131
-
132
- if not tf.executing_eagerly():
133
- self.session = _get_session(session)
134
-
135
- self.graph = _get_graph(self)
136
-
137
- # if no learning phase flags were given we go looking for them
138
- # ...this will catch the one that keras uses
139
- # we need to find them since we want to make sure learning phase flags are set to False
140
- if learning_phase_flags is None:
141
- self.learning_phase_ops = []
142
- for op in self.graph.get_operations():
143
- if 'learning_phase' in op.name and op.type == "Const" and len(op.outputs[0].shape) == 0:
144
- if op.outputs[0].dtype == tf.bool:
145
- self.learning_phase_ops.append(op)
146
- self.learning_phase_flags = [op.outputs[0] for op in self.learning_phase_ops]
147
- else:
148
- self.learning_phase_ops = [t.op for t in learning_phase_flags]
149
-
150
- # save the expected output of the model
151
- # if self.data is a function, set self.expected_value to None
152
- if (hasattr(self.data, '__call__')):
153
- self.expected_value = None
154
- else:
155
- if self.data[0].shape[0] > 5000:
156
- warnings.warn("You have provided over 5k background samples! For better performance consider using smaller random sample.")
157
- if not tf.executing_eagerly():
158
- self.expected_value = self.run(self.model_output, self.model_inputs, self.data).mean(0)
159
- else:
160
- #if type(self.model)is tuple:
161
- # self.fModel(cnn.inputs, cnn.get_layer(theNameYouWant).outputs)
162
- self.expected_value = tf.reduce_mean(self.model(self.data), 0)
163
-
164
- if not tf.executing_eagerly():
165
- self._init_between_tensors(self.model_output.op, self.model_inputs)
166
-
167
- # make a blank array that will get lazily filled in with the SHAP value computation
168
- # graphs for each output. Lazy is important since if there are 1000 outputs and we
169
- # only explain the top 5 it would be a waste to build graphs for the other 995
170
- if not self.multi_output:
171
- self.phi_symbolics = [None]
172
- else:
173
- noutputs = self.model_output.shape.as_list()[1]
174
- if noutputs is not None:
175
- self.phi_symbolics = [None for i in range(noutputs)]
176
- else:
177
- raise DimensionError("The model output tensor to be explained cannot have a static shape in dim 1 of None!")
178
-
179
- def _get_model_output(self, model):
180
- if len(model.layers[-1]._inbound_nodes) == 0:
181
- if len(model.outputs) > 1:
182
- warnings.warn("Only one model output supported.")
183
- return model.outputs[0]
184
- else:
185
- return model.layers[-1].output
186
-
187
- def _init_between_tensors(self, out_op, model_inputs):
188
- # find all the operations in the graph between our inputs and outputs
189
- tensor_blacklist = tensors_blocked_by_false(self.learning_phase_ops) # don't follow learning phase branches
190
- dependence_breakers = [k for k in op_handlers if op_handlers[k] == break_dependence]
191
- back_ops = backward_walk_ops(
192
- [out_op], tensor_blacklist,
193
- dependence_breakers
194
- )
195
- start_ops = []
196
- for minput in model_inputs:
197
- for op in minput.consumers():
198
- start_ops.append(op)
199
- self.between_ops = forward_walk_ops(
200
- start_ops,
201
- tensor_blacklist, dependence_breakers,
202
- within_ops=back_ops
203
- )
204
-
205
- # note all the tensors that are on the path between the inputs and the output
206
- self.between_tensors = {}
207
- for op in self.between_ops:
208
- for t in op.outputs:
209
- self.between_tensors[t.name] = True
210
- for t in model_inputs:
211
- self.between_tensors[t.name] = True
212
-
213
- # save what types are being used
214
- self.used_types = {}
215
- for op in self.between_ops:
216
- self.used_types[op.type] = True
217
-
218
- def _variable_inputs(self, op):
219
- """ Return which inputs of this operation are variable (i.e. depend on the model inputs).
220
- """
221
- if op not in self._vinputs:
222
- out = np.zeros(len(op.inputs), dtype=bool)
223
- for i,t in enumerate(op.inputs):
224
- out[i] = t.name in self.between_tensors
225
- self._vinputs[op] = out
226
- return self._vinputs[op]
227
-
228
- def phi_symbolic(self, i):
229
- """ Get the SHAP value computation graph for a given model output.
230
- """
231
- if self.phi_symbolics[i] is None:
232
-
233
- if not tf.executing_eagerly():
234
- def anon():
235
- out = self.model_output[:,i] if self.multi_output else self.model_output
236
- return tf.gradients(out, self.model_inputs)
237
-
238
- self.phi_symbolics[i] = self.execute_with_overridden_gradients(anon)
239
- else:
240
- @tf.function
241
- def grad_graph(shap_rAnD):
242
- phase = tf.keras.backend.learning_phase()
243
- tf.keras.backend.set_learning_phase(0)
244
-
245
- with tf.GradientTape(watch_accessed_variables=False) as tape:
246
- tape.watch(shap_rAnD)
247
- out = self.model(shap_rAnD)
248
- if self.multi_output:
249
- out = out[:,i]
250
-
251
- self._init_between_tensors(out.op, shap_rAnD)
252
- x_grad = tape.gradient(out, shap_rAnD)
253
- tf.keras.backend.set_learning_phase(phase)
254
- return x_grad
255
-
256
- self.phi_symbolics[i] = grad_graph
257
-
258
- return self.phi_symbolics[i]
259
-
260
- def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_additivity=True):
261
- # check if we have multiple inputs
262
- if not self.multi_input:
263
- if isinstance(X, list) and len(X) != 1:
264
- raise ValueError("Expected a single tensor as model input!")
265
- elif not isinstance(X, list):
266
- X = [X]
267
- else:
268
- assert isinstance(X, list), "Expected a list of model inputs!"
269
- assert len(self.model_inputs) == len(X), "Number of model inputs (%d) does not match the number given (%d)!" % (len(self.model_inputs), len(X))
270
-
271
- # rank and determine the model outputs that we will explain
272
- if ranked_outputs is not None and self.multi_output:
273
- if not tf.executing_eagerly():
274
- model_output_values = self.run(self.model_output, self.model_inputs, X)
275
- else:
276
- model_output_values = self.model(X)
277
-
278
- if output_rank_order == "max":
279
- model_output_ranks = np.argsort(-model_output_values)
280
- elif output_rank_order == "min":
281
- model_output_ranks = np.argsort(model_output_values)
282
- elif output_rank_order == "max_abs":
283
- model_output_ranks = np.argsort(np.abs(model_output_values))
284
- else:
285
- emsg = "output_rank_order must be max, min, or max_abs!"
286
- raise ValueError(emsg)
287
- model_output_ranks = model_output_ranks[:,:ranked_outputs]
288
- else:
289
- model_output_ranks = np.tile(np.arange(len(self.phi_symbolics)), (X[0].shape[0], 1))
290
-
291
- # compute the attributions
292
- output_phis = []
293
- for i in range(model_output_ranks.shape[1]):
294
- phis = []
295
- for k in range(len(X)):
296
- phis.append(np.zeros(X[k].shape))
297
- for j in range(X[0].shape[0]):
298
- if (hasattr(self.data, '__call__')):
299
- bg_data = self.data([X[t][j] for t in range(len(X))])
300
- if not isinstance(bg_data, list):
301
- bg_data = [bg_data]
302
- else:
303
- bg_data = self.data
304
-
305
- # tile the inputs to line up with the background data samples
306
- tiled_X = [np.tile(X[t][j:j+1], (bg_data[t].shape[0],) + tuple([1 for k in range(len(X[t].shape)-1)])) for t in range(len(X))]
307
-
308
- # we use the first sample for the current sample and the rest for the references
309
- joint_input = [np.concatenate([tiled_X[t], bg_data[t]], 0) for t in range(len(X))]
310
-
311
- # run attribution computation graph
312
- feature_ind = model_output_ranks[j,i]
313
- sample_phis = self.run(self.phi_symbolic(feature_ind), self.model_inputs, joint_input)
314
-
315
- # assign the attributions to the right part of the output arrays
316
- for t in range(len(X)):
317
- phis[t][j] = (sample_phis[t][bg_data[t].shape[0]:] * (X[t][j] - bg_data[t])).mean(0)
318
-
319
- output_phis.append(phis[0] if not self.multi_input else phis)
320
-
321
- # check that the SHAP values sum up to the model output
322
- if check_additivity:
323
- if not tf.executing_eagerly():
324
- model_output = self.run(self.model_output, self.model_inputs, X)
325
- else:
326
- model_output = self.model(X)
327
-
328
- _check_additivity(self, model_output, output_phis)
329
-
330
- if not self.multi_output:
331
- return output_phis[0]
332
- elif ranked_outputs is not None:
333
- return output_phis, model_output_ranks
334
- else:
335
- return output_phis
336
-
337
- def run(self, out, model_inputs, X):
338
- """ Runs the model while also setting the learning phase flags to False.
339
- """
340
- if not tf.executing_eagerly():
341
- feed_dict = dict(zip(model_inputs, X))
342
- for t in self.learning_phase_flags:
343
- feed_dict[t] = False
344
- return self.session.run(out, feed_dict)
345
- else:
346
- def anon():
347
- tf_execute.record_gradient = custom_record_gradient
348
-
349
- # build inputs that are correctly shaped, typed, and tf-wrapped
350
- inputs = []
351
- for i in range(len(X)):
352
- shape = list(self.model_inputs[i].shape)
353
- shape[0] = -1
354
- data = X[i].reshape(shape)
355
- v = tf.constant(data, dtype=self.model_inputs[i].dtype)
356
- inputs.append(v)
357
- final_out = out(inputs)
358
- try:
359
- tf_execute.record_gradient = tf_backprop._record_gradient
360
- except AttributeError:
361
- tf_execute.record_gradient = tf_backprop.record_gradient
362
-
363
- return final_out
364
- return self.execute_with_overridden_gradients(anon)
365
-
366
- def custom_grad(self, op, *grads):
367
- """ Passes a gradient op creation request to the correct handler.
368
- """
369
- type_name = op.type[5:] if op.type.startswith("shap_") else op.type
370
- out = op_handlers[type_name](self, op, *grads) # we cut off the shap_ prefix before the lookup
371
- return out
372
-
373
- def execute_with_overridden_gradients(self, f):
374
- # replace the gradients for all the non-linear activations
375
- # we do this by hacking our way into the registry (TODO: find a public API for this if it exists)
376
- reg = tf_ops._gradient_registry._registry
377
- ops_not_in_registry = ['TensorListReserve']
378
- # NOTE: location_tag taken from tensorflow source for None type ops
379
- location_tag = ("UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN", "UNKNOWN")
380
- # TODO: unclear why some ops are not in the registry with TF 2.0 like TensorListReserve
381
- for non_reg_ops in ops_not_in_registry:
382
- reg[non_reg_ops] = {'type': None, 'location': location_tag}
383
- for n in op_handlers:
384
- if n in reg:
385
- self.orig_grads[n] = reg[n]["type"]
386
- reg["shap_"+n] = {
387
- "type": self.custom_grad,
388
- "location": reg[n]["location"]
389
- }
390
- reg[n]["type"] = self.custom_grad
391
-
392
- # In TensorFlow 1.10 they started pruning out nodes that they think can't be backpropped
393
- # unfortunately that includes the index of embedding layers so we disable that check here
394
- if hasattr(tf_gradients_impl, "_IsBackpropagatable"):
395
- orig_IsBackpropagatable = tf_gradients_impl._IsBackpropagatable
396
- tf_gradients_impl._IsBackpropagatable = lambda tensor: True
397
-
398
- # define the computation graph for the attribution values using a custom gradient-like computation
399
- try:
400
- out = f()
401
- finally:
402
- # reinstate the backpropagatable check
403
- if hasattr(tf_gradients_impl, "_IsBackpropagatable"):
404
- tf_gradients_impl._IsBackpropagatable = orig_IsBackpropagatable
405
-
406
- # restore the original gradient definitions
407
- for n in op_handlers:
408
- if n in reg:
409
- del reg["shap_"+n]
410
- reg[n]["type"] = self.orig_grads[n]
411
- for non_reg_ops in ops_not_in_registry:
412
- del reg[non_reg_ops]
413
- if not tf.executing_eagerly():
414
- return out
415
- else:
416
- return [v.numpy() for v in out]
417
-
418
- def tensors_blocked_by_false(ops):
419
- """ Follows a set of ops assuming their value is False and find blocked Switch paths.
420
-
421
- This is used to prune away parts of the model graph that are only used during the training
422
- phase (like dropout, batch norm, etc.).
423
- """
424
- blocked = []
425
- def recurse(op):
426
- if op.type == "Switch":
427
- blocked.append(op.outputs[1]) # the true path is blocked since we assume the ops we trace are False
428
- else:
429
- for out in op.outputs:
430
- for c in out.consumers():
431
- recurse(c)
432
- for op in ops:
433
- recurse(op)
434
-
435
- return blocked
436
-
437
- def backward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist):
438
- found_ops = []
439
- op_stack = [op for op in start_ops]
440
- while len(op_stack) > 0:
441
- op = op_stack.pop()
442
- if op.type not in op_type_blacklist and op not in found_ops:
443
- found_ops.append(op)
444
- for input in op.inputs:
445
- if input not in tensor_blacklist:
446
- op_stack.append(input.op)
447
- return found_ops
448
-
449
- def forward_walk_ops(start_ops, tensor_blacklist, op_type_blacklist, within_ops):
450
- found_ops = []
451
- op_stack = [op for op in start_ops]
452
- while len(op_stack) > 0:
453
- op = op_stack.pop()
454
- if op.type not in op_type_blacklist and op in within_ops and op not in found_ops:
455
- found_ops.append(op)
456
- for out in op.outputs:
457
- if out not in tensor_blacklist:
458
- for c in out.consumers():
459
- op_stack.append(c)
460
- return found_ops
461
-
462
-
463
- def softmax(explainer, op, *grads):
464
- """ Just decompose softmax into its components and recurse, we can handle all of them :)
465
-
466
- We assume the 'axis' is the last dimension because the TF codebase swaps the 'axis' to
467
- the last dimension before the softmax op if 'axis' is not already the last dimension.
468
- We also don't subtract the max before tf.exp for numerical stability since that might
469
- mess up the attributions and it seems like TensorFlow doesn't define softmax that way
470
- (according to the docs)
471
- """
472
- in0 = op.inputs[0]
473
- in0_max = tf.reduce_max(in0, axis=-1, keepdims=True, name="in0_max")
474
- in0_centered = in0 - in0_max
475
- evals = tf.exp(in0_centered, name="custom_exp")
476
- rsum = tf.reduce_sum(evals, axis=-1, keepdims=True)
477
- div = evals / rsum
478
-
479
- # mark these as in-between the inputs and outputs
480
- for op in [evals.op, rsum.op, div.op, in0_centered.op]:
481
- for t in op.outputs:
482
- if t.name not in explainer.between_tensors:
483
- explainer.between_tensors[t.name] = False
484
-
485
- out = tf.gradients(div, in0_centered, grad_ys=grads[0])[0]
486
-
487
- # remove the names we just added
488
- for op in [evals.op, rsum.op, div.op, in0_centered.op]:
489
- for t in op.outputs:
490
- if explainer.between_tensors[t.name] is False:
491
- del explainer.between_tensors[t.name]
492
-
493
- # rescale to account for our shift by in0_max (which we did for numerical stability)
494
- xin0,rin0 = tf.split(in0, 2)
495
- xin0_centered,rin0_centered = tf.split(in0_centered, 2)
496
- delta_in0 = xin0 - rin0
497
- dup0 = [2] + [1 for i in delta_in0.shape[1:]]
498
- return tf.where(
499
- tf.tile(tf.abs(delta_in0), dup0) < 1e-6,
500
- out,
501
- out * tf.tile((xin0_centered - rin0_centered) / delta_in0, dup0)
502
- )
503
-
504
- def maxpool(explainer, op, *grads):
505
- xin0,rin0 = tf.split(op.inputs[0], 2)
506
- xout,rout = tf.split(op.outputs[0], 2)
507
- delta_in0 = xin0 - rin0
508
- dup0 = [2] + [1 for i in delta_in0.shape[1:]]
509
- cross_max = tf.maximum(xout, rout)
510
- diffs = tf.concat([cross_max - rout, xout - cross_max], 0)
511
- if op.type.startswith("shap_"):
512
- op.type = op.type[5:]
513
- xmax_pos,rmax_pos = tf.split(explainer.orig_grads[op.type](op, grads[0] * diffs), 2)
514
- return tf.tile(tf.where(
515
- tf.abs(delta_in0) < 1e-7,
516
- tf.zeros_like(delta_in0),
517
- (xmax_pos + rmax_pos) / delta_in0
518
- ), dup0)
519
-
520
- def gather(explainer, op, *grads):
521
- #params = op.inputs[0]
522
- indices = op.inputs[1]
523
- #axis = op.inputs[2]
524
- var = explainer._variable_inputs(op)
525
- if var[1] and not var[0]:
526
- assert len(indices.shape) == 2, "Only scalar indices supported right now in GatherV2!"
527
-
528
- xin1,rin1 = tf.split(tf.cast(op.inputs[1], tf.float32), 2)
529
- xout,rout = tf.split(op.outputs[0], 2)
530
- dup_in1 = [2] + [1 for i in xin1.shape[1:]]
531
- dup_out = [2] + [1 for i in xout.shape[1:]]
532
- delta_in1_t = tf.tile(xin1 - rin1, dup_in1)
533
- out_sum = tf.reduce_sum(grads[0] * tf.tile(xout - rout, dup_out), list(range(len(indices.shape), len(grads[0].shape))))
534
- if op.type == "ResourceGather":
535
- return [None, tf.where(
536
- tf.abs(delta_in1_t) < 1e-6,
537
- tf.zeros_like(delta_in1_t),
538
- out_sum / delta_in1_t
539
- )]
540
- return [None, tf.where(
541
- tf.abs(delta_in1_t) < 1e-6,
542
- tf.zeros_like(delta_in1_t),
543
- out_sum / delta_in1_t
544
- ), None]
545
- elif var[0] and not var[1]:
546
- if op.type.startswith("shap_"):
547
- op.type = op.type[5:]
548
- return [explainer.orig_grads[op.type](op, grads[0]), None] # linear in this case
549
- else:
550
- raise ValueError("Axis not yet supported to be varying for gather op!")
551
-
552
-
553
- def linearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func):
554
- def handler(explainer, op, *grads):
555
- var = explainer._variable_inputs(op)
556
- if var[input_ind0] and not var[input_ind1]:
557
- return linearity_1d_handler(input_ind0, explainer, op, *grads)
558
- elif var[input_ind1] and not var[input_ind0]:
559
- return linearity_1d_handler(input_ind1, explainer, op, *grads)
560
- elif var[input_ind0] and var[input_ind1]:
561
- return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads)
562
- else:
563
- return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function
564
- return handler
565
-
566
- def nonlinearity_1d_nonlinearity_2d(input_ind0, input_ind1, op_func):
567
- def handler(explainer, op, *grads):
568
- var = explainer._variable_inputs(op)
569
- if var[input_ind0] and not var[input_ind1]:
570
- return nonlinearity_1d_handler(input_ind0, explainer, op, *grads)
571
- elif var[input_ind1] and not var[input_ind0]:
572
- return nonlinearity_1d_handler(input_ind1, explainer, op, *grads)
573
- elif var[input_ind0] and var[input_ind1]:
574
- return nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads)
575
- else:
576
- return [None for _ in op.inputs] # no inputs vary, we must be hidden by a switch function
577
- return handler
578
-
579
- def nonlinearity_1d(input_ind):
580
- def handler(explainer, op, *grads):
581
- return nonlinearity_1d_handler(input_ind, explainer, op, *grads)
582
- return handler
583
-
584
- def nonlinearity_1d_handler(input_ind, explainer, op, *grads):
585
- # make sure only the given input varies
586
- op_inputs = op.inputs
587
- if op_inputs is None:
588
- op_inputs = op.outputs[0].op.inputs
589
-
590
- for i in range(len(op_inputs)):
591
- if i != input_ind:
592
- assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
593
-
594
- xin0, rin0 = tf.split(op_inputs[input_ind], 2)
595
- xout, rout = tf.split(op.outputs[input_ind], 2)
596
- delta_in0 = xin0 - rin0
597
- if delta_in0.shape is None:
598
- dup0 = [2, 1]
599
- else:
600
- dup0 = [2] + [1 for i in delta_in0.shape[1:]]
601
- out = [None for _ in op_inputs]
602
- if op.type.startswith("shap_"):
603
- op.type = op.type[5:]
604
- orig_grad = explainer.orig_grads[op.type](op, grads[0])
605
- out[input_ind] = tf.where(
606
- tf.tile(tf.abs(delta_in0), dup0) < 1e-6,
607
- orig_grad[input_ind] if len(op_inputs) > 1 else orig_grad,
608
- grads[0] * tf.tile((xout - rout) / delta_in0, dup0)
609
- )
610
- return out
611
-
612
- def nonlinearity_2d_handler(input_ind0, input_ind1, op_func, explainer, op, *grads):
613
- if not (input_ind0 == 0 and input_ind1 == 1):
614
- emsg = "TODO: Can't yet handle double inputs that are not first!"
615
- raise Exception(emsg)
616
- xout,rout = tf.split(op.outputs[0], 2)
617
- in0 = op.inputs[input_ind0]
618
- in1 = op.inputs[input_ind1]
619
- xin0,rin0 = tf.split(in0, 2)
620
- xin1,rin1 = tf.split(in1, 2)
621
- delta_in0 = xin0 - rin0
622
- delta_in1 = xin1 - rin1
623
- dup0 = [2] + [1 for i in delta_in0.shape[1:]]
624
- out10 = op_func(xin0, rin1)
625
- out01 = op_func(rin0, xin1)
626
- out11,out00 = xout,rout
627
- out0 = 0.5 * (out11 - out01 + out10 - out00)
628
- out0 = grads[0] * tf.tile(out0 / delta_in0, dup0)
629
- out1 = 0.5 * (out11 - out10 + out01 - out00)
630
- out1 = grads[0] * tf.tile(out1 / delta_in1, dup0)
631
-
632
- # Avoid divide by zero nans
633
- out0 = tf.where(tf.abs(tf.tile(delta_in0, dup0)) < 1e-7, tf.zeros_like(out0), out0)
634
- out1 = tf.where(tf.abs(tf.tile(delta_in1, dup0)) < 1e-7, tf.zeros_like(out1), out1)
635
-
636
- # see if due to broadcasting our gradient shapes don't match our input shapes
637
- if (np.any(np.array(out1.shape) != np.array(in1.shape))):
638
- broadcast_index = np.where(np.array(out1.shape) != np.array(in1.shape))[0][0]
639
- out1 = tf.reduce_sum(out1, axis=broadcast_index, keepdims=True)
640
- elif (np.any(np.array(out0.shape) != np.array(in0.shape))):
641
- broadcast_index = np.where(np.array(out0.shape) != np.array(in0.shape))[0][0]
642
- out0 = tf.reduce_sum(out0, axis=broadcast_index, keepdims=True)
643
-
644
- return [out0, out1]
645
-
646
- def linearity_1d(input_ind):
647
- def handler(explainer, op, *grads):
648
- return linearity_1d_handler(input_ind, explainer, op, *grads)
649
- return handler
650
-
651
- def linearity_1d_handler(input_ind, explainer, op, *grads):
652
- # make sure only the given input varies (negative means only that input cannot vary, and is measured from the end of the list)
653
- for i in range(len(op.inputs)):
654
- if i != input_ind:
655
- assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
656
- if op.type.startswith("shap_"):
657
- op.type = op.type[5:]
658
- return explainer.orig_grads[op.type](op, *grads)
659
-
660
- def linearity_with_excluded(input_inds):
661
- def handler(explainer, op, *grads):
662
- return linearity_with_excluded_handler(input_inds, explainer, op, *grads)
663
- return handler
664
-
665
- def linearity_with_excluded_handler(input_inds, explainer, op, *grads):
666
- # make sure the given inputs don't vary (negative is measured from the end of the list)
667
- for i in range(len(op.inputs)):
668
- if i in input_inds or i - len(op.inputs) in input_inds:
669
- assert not explainer._variable_inputs(op)[i], str(i) + "th input to " + op.name + " cannot vary!"
670
- if op.type.startswith("shap_"):
671
- op.type = op.type[5:]
672
- return explainer.orig_grads[op.type](op, *grads)
673
-
674
- def passthrough(explainer, op, *grads):
675
- if op.type.startswith("shap_"):
676
- op.type = op.type[5:]
677
- return explainer.orig_grads[op.type](op, *grads)
678
-
679
- def break_dependence(explainer, op, *grads):
680
- """ This function name is used to break attribution dependence in the graph traversal.
681
-
682
- These operation types may be connected above input data values in the graph but their outputs
683
- don't depend on the input values (for example they just depend on the shape).
684
- """
685
- return [None for _ in op.inputs]
686
-
687
-
688
- op_handlers = {}
689
-
690
- # ops that are always linear
691
- op_handlers["Identity"] = passthrough
692
- op_handlers["StridedSlice"] = passthrough
693
- op_handlers["Squeeze"] = passthrough
694
- op_handlers["ExpandDims"] = passthrough
695
- op_handlers["Pack"] = passthrough
696
- op_handlers["BiasAdd"] = passthrough
697
- op_handlers["Unpack"] = passthrough
698
- op_handlers["Add"] = passthrough
699
- op_handlers["Sub"] = passthrough
700
- op_handlers["Merge"] = passthrough
701
- op_handlers["Sum"] = passthrough
702
- op_handlers["Mean"] = passthrough
703
- op_handlers["Cast"] = passthrough
704
- op_handlers["Transpose"] = passthrough
705
- op_handlers["Enter"] = passthrough
706
- op_handlers["Exit"] = passthrough
707
- op_handlers["NextIteration"] = passthrough
708
- op_handlers["Tile"] = passthrough
709
- op_handlers["TensorArrayScatterV3"] = passthrough
710
- op_handlers["TensorArrayReadV3"] = passthrough
711
- op_handlers["TensorArrayWriteV3"] = passthrough
712
-
713
-
714
- # ops that don't pass any attributions to their inputs
715
- op_handlers["Shape"] = break_dependence
716
- op_handlers["RandomUniform"] = break_dependence
717
- op_handlers["ZerosLike"] = break_dependence
718
- #op_handlers["StopGradient"] = break_dependence # this allows us to stop attributions when we want to (like softmax re-centering)
719
-
720
- # ops that are linear and only allow a single input to vary
721
- op_handlers["Reshape"] = linearity_1d(0)
722
- op_handlers["Pad"] = linearity_1d(0)
723
- op_handlers["ReverseV2"] = linearity_1d(0)
724
- op_handlers["ConcatV2"] = linearity_with_excluded([-1])
725
- op_handlers["Conv2D"] = linearity_1d(0)
726
- op_handlers["Switch"] = linearity_1d(0)
727
- op_handlers["AvgPool"] = linearity_1d(0)
728
- op_handlers["FusedBatchNorm"] = linearity_1d(0)
729
-
730
- # ops that are nonlinear and only allow a single input to vary
731
- op_handlers["Relu"] = nonlinearity_1d(0)
732
- op_handlers["Elu"] = nonlinearity_1d(0)
733
- op_handlers["Sigmoid"] = nonlinearity_1d(0)
734
- op_handlers["Tanh"] = nonlinearity_1d(0)
735
- op_handlers["Softplus"] = nonlinearity_1d(0)
736
- op_handlers["Exp"] = nonlinearity_1d(0)
737
- op_handlers["ClipByValue"] = nonlinearity_1d(0)
738
- op_handlers["Rsqrt"] = nonlinearity_1d(0)
739
- op_handlers["Square"] = nonlinearity_1d(0)
740
- op_handlers["Max"] = nonlinearity_1d(0)
741
-
742
- # ops that are nonlinear and allow two inputs to vary
743
- op_handlers["SquaredDifference"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: (x - y) * (x - y))
744
- op_handlers["Minimum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.minimum(x, y))
745
- op_handlers["Maximum"] = nonlinearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.maximum(x, y))
746
-
747
- # ops that allow up to two inputs to vary are are linear when only one input varies
748
- op_handlers["Mul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x * y)
749
- op_handlers["RealDiv"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: x / y)
750
- op_handlers["MatMul"] = linearity_1d_nonlinearity_2d(0, 1, lambda x, y: tf.matmul(x, y))
751
-
752
- # ops that need their own custom attribution functions
753
- op_handlers["GatherV2"] = gather
754
- op_handlers["ResourceGather"] = gather
755
- op_handlers["MaxPool"] = maxpool
756
- op_handlers["Softmax"] = softmax
757
-
758
-
759
- # TODO items
760
- # TensorArrayGatherV3
761
- # Max
762
- # TensorArraySizeV3
763
- # Range
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_deep/deep_utils.py DELETED
@@ -1,23 +0,0 @@
1
- import numpy as np
2
-
3
-
4
- def _check_additivity(explainer, model_output_values, output_phis):
5
- TOLERANCE = 1e-2
6
-
7
- assert len(explainer.expected_value) == model_output_values.shape[1], "Length of expected values and model outputs does not match."
8
-
9
- for t in range(len(explainer.expected_value)):
10
- if not explainer.multi_input:
11
- diffs = model_output_values[:, t] - explainer.expected_value[t] - output_phis[t].sum(axis=tuple(range(1, output_phis[t].ndim)))
12
- else:
13
- diffs = model_output_values[:, t] - explainer.expected_value[t]
14
-
15
- for i in range(len(output_phis[t])):
16
- diffs -= output_phis[t][i].sum(axis=tuple(range(1, output_phis[t][i].ndim)))
17
-
18
- maxdiff = np.abs(diffs).max()
19
-
20
- assert maxdiff < TOLERANCE, "The SHAP explanations do not sum up to the model's output! This is either because of a " \
21
- "rounding error or because an operator in your computation graph was not fully supported. If " \
22
- "the sum difference of %f is significant compared to the scale of your model outputs, please post " \
23
- f"as a github issue, with a reproducible example so we can debug it. Used framework: {explainer.framework} - Max. diff: {maxdiff} - Tolerance: {TOLERANCE}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_exact.py DELETED
@@ -1,366 +0,0 @@
1
- import logging
2
-
3
- import numpy as np
4
- from numba import njit
5
-
6
- from .. import links
7
- from ..models import Model
8
- from ..utils import (
9
- MaskedModel,
10
- delta_minimization_order,
11
- make_masks,
12
- shapley_coefficients,
13
- )
14
- from ._explainer import Explainer
15
-
16
- log = logging.getLogger('shap')
17
-
18
-
19
- class ExactExplainer(Explainer):
20
- """ Computes SHAP values via an optimized exact enumeration.
21
-
22
- This works well for standard Shapley value maskers for models with less than ~15 features that vary
23
- from the background per sample. It also works well for Owen values from hclustering structured
24
- maskers when there are less than ~100 features that vary from the background per sample. This
25
- explainer minimizes the number of function evaluations needed by ordering the masking sets to
26
- minimize sequential differences. This is done using gray codes for standard Shapley values
27
- and a greedy sorting method for hclustering structured maskers.
28
- """
29
-
30
- def __init__(self, model, masker, link=links.identity, linearize_link=True, feature_names=None):
31
- """ Build an explainers.Exact object for the given model using the given masker object.
32
-
33
- Parameters
34
- ----------
35
- model : function
36
- A callable python object that executes the model given a set of input data samples.
37
-
38
- masker : function or numpy.array or pandas.DataFrame
39
- A callable python object used to "mask" out hidden features of the form `masker(mask, *fargs)`.
40
- It takes a single a binary mask and an input sample and returns a matrix of masked samples. These
41
- masked samples are evaluated using the model function and the outputs are then averaged.
42
- As a shortcut for the standard masking used by SHAP you can pass a background data matrix
43
- instead of a function and that matrix will be used for masking. To use a clustering
44
- game structure you can pass a shap.maskers.TabularPartitions(data) object.
45
-
46
- link : function
47
- The link function used to map between the output units of the model and the SHAP value units. By
48
- default it is shap.links.identity, but shap.links.logit can be useful so that expectations are
49
- computed in probability units while explanations remain in the (more naturally additive) log-odds
50
- units. For more details on how link functions work see any overview of link functions for generalized
51
- linear models.
52
-
53
- linearize_link : bool
54
- If we use a non-linear link function to take expectations then models that are additive with respect to that
55
- link function for a single background sample will no longer be additive when using a background masker with
56
- many samples. This for example means that a linear logistic regression model would have interaction effects
57
- that arise from the non-linear changes in expectation averaging. To retain the additively of the model with
58
- still respecting the link function we linearize the link function by default.
59
- """ # TODO link to the link linearization paper when done
60
- super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names)
61
-
62
- self.model = Model(model)
63
-
64
- if getattr(masker, "clustering", None) is not None:
65
- self._partition_masks, self._partition_masks_inds = partition_masks(masker.clustering)
66
- self._partition_delta_indexes = partition_delta_indexes(masker.clustering, self._partition_masks)
67
-
68
- self._gray_code_cache = {} # used to avoid regenerating the same gray code patterns
69
-
70
- def __call__(self, *args, max_evals=100000, main_effects=False, error_bounds=False, batch_size="auto", interactions=1, silent=False):
71
- """ Explains the output of model(*args), where args represents one or more parallel iterators.
72
- """
73
-
74
- # we entirely rely on the general call implementation, we override just to remove **kwargs
75
- # from the function signature
76
- return super().__call__(
77
- *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
78
- batch_size=batch_size, interactions=interactions, silent=silent
79
- )
80
-
81
- def _cached_gray_codes(self, n):
82
- if n not in self._gray_code_cache:
83
- self._gray_code_cache[n] = gray_code_indexes(n)
84
- return self._gray_code_cache[n]
85
-
86
- def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, interactions, silent):
87
- """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
88
- """
89
-
90
- # build a masked version of the model for the current input sample
91
- fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)
92
-
93
- # do the standard Shapley values
94
- inds = None
95
- if getattr(self.masker, "clustering", None) is None:
96
-
97
- # see which elements we actually need to perturb
98
- inds = fm.varying_inputs()
99
-
100
- # make sure we have enough evals
101
- if max_evals is not None and max_evals != "auto" and max_evals < 2**len(inds):
102
- raise ValueError(
103
- f"It takes {2**len(inds)} masked evaluations to run the Exact explainer on this instance, but max_evals={max_evals}!"
104
- )
105
-
106
- # generate the masks in gray code order (so that we change the inputs as little
107
- # as possible while we iterate to minimize the need to re-eval when the inputs
108
- # don't vary from the background)
109
- delta_indexes = self._cached_gray_codes(len(inds))
110
-
111
- # map to a larger mask that includes the invariant entries
112
- extended_delta_indexes = np.zeros(2**len(inds), dtype=int)
113
- for i in range(2**len(inds)):
114
- if delta_indexes[i] == MaskedModel.delta_mask_noop_value:
115
- extended_delta_indexes[i] = delta_indexes[i]
116
- else:
117
- extended_delta_indexes[i] = inds[delta_indexes[i]]
118
-
119
- # run the model
120
- outputs = fm(extended_delta_indexes, zero_index=0, batch_size=batch_size)
121
-
122
- # Shapley values
123
- # Care: Need to distinguish between `True` and `1`
124
- if interactions is False or (interactions == 1 and interactions is not True):
125
-
126
- # loop over all the outputs to update the rows
127
- coeff = shapley_coefficients(len(inds))
128
- row_values = np.zeros((len(fm),) + outputs.shape[1:])
129
- mask = np.zeros(len(fm), dtype=bool)
130
- _compute_grey_code_row_values(row_values, mask, inds, outputs, coeff, extended_delta_indexes, MaskedModel.delta_mask_noop_value)
131
-
132
- # Shapley-Taylor interaction values
133
- elif interactions is True or interactions == 2:
134
-
135
- # loop over all the outputs to update the rows
136
- coeff = shapley_coefficients(len(inds))
137
- row_values = np.zeros((len(fm), len(fm)) + outputs.shape[1:])
138
- mask = np.zeros(len(fm), dtype=bool)
139
- _compute_grey_code_row_values_st(row_values, mask, inds, outputs, coeff, extended_delta_indexes, MaskedModel.delta_mask_noop_value)
140
-
141
- elif interactions > 2:
142
- raise NotImplementedError("Currently the Exact explainer does not support interactions higher than order 2!")
143
-
144
- # do a partition tree constrained version of Shapley values
145
- else:
146
-
147
- # make sure we have enough evals
148
- if max_evals is not None and max_evals != "auto" and max_evals < len(fm)**2:
149
- raise ValueError(
150
- f"It takes {len(fm)**2} masked evaluations to run the Exact explainer on this instance, but max_evals={max_evals}!"
151
- )
152
-
153
- # generate the masks in a hclust order (so that we change the inputs as little
154
- # as possible while we iterate to minimize the need to re-eval when the inputs
155
- # don't vary from the background)
156
- delta_indexes = self._partition_delta_indexes
157
-
158
- # run the model
159
- outputs = fm(delta_indexes, batch_size=batch_size)
160
-
161
- # loop over each output feature
162
- row_values = np.zeros((len(fm),) + outputs.shape[1:])
163
- for i in range(len(fm)):
164
- on_outputs = outputs[self._partition_masks_inds[i][1]]
165
- off_outputs = outputs[self._partition_masks_inds[i][0]]
166
- row_values[i] = (on_outputs - off_outputs).mean(0)
167
-
168
- # compute the main effects if we need to
169
- main_effect_values = None
170
- if main_effects or interactions is True or interactions == 2:
171
- if inds is None:
172
- inds = np.arange(len(fm))
173
- main_effect_values = fm.main_effects(inds)
174
- if interactions is True or interactions == 2:
175
- for i in range(len(fm)):
176
- row_values[i, i] = main_effect_values[i]
177
-
178
- return {
179
- "values": row_values,
180
- "expected_values": outputs[0],
181
- "mask_shapes": fm.mask_shapes,
182
- "main_effects": main_effect_values if main_effects else None,
183
- "clustering": getattr(self.masker, "clustering", None)
184
- }
185
-
186
- @njit
187
- def _compute_grey_code_row_values(row_values, mask, inds, outputs, shapley_coeff, extended_delta_indexes, noop_code):
188
- set_size = 0
189
- M = len(inds)
190
- for i in range(2**M):
191
-
192
- # update the mask
193
- delta_ind = extended_delta_indexes[i]
194
- if delta_ind != noop_code:
195
- mask[delta_ind] = ~mask[delta_ind]
196
- if mask[delta_ind]:
197
- set_size += 1
198
- else:
199
- set_size -= 1
200
-
201
- # update the output row values
202
- on_coeff = shapley_coeff[set_size-1]
203
- if set_size < M:
204
- off_coeff = shapley_coeff[set_size]
205
- out = outputs[i]
206
- for j in inds:
207
- if mask[j]:
208
- row_values[j] += out * on_coeff
209
- else:
210
- row_values[j] -= out * off_coeff
211
-
212
- @njit
213
- def _compute_grey_code_row_values_st(row_values, mask, inds, outputs, shapley_coeff, extended_delta_indexes, noop_code):
214
- set_size = 0
215
- M = len(inds)
216
- for i in range(2**M):
217
-
218
- # update the mask
219
- delta_ind = extended_delta_indexes[i]
220
- if delta_ind != noop_code:
221
- mask[delta_ind] = ~mask[delta_ind]
222
- if mask[delta_ind]:
223
- set_size += 1
224
- else:
225
- set_size -= 1
226
-
227
- # distribute the effect of this mask set over all the terms it impacts
228
- out = outputs[i]
229
- for j in range(M):
230
- for k in range(j+1, M):
231
- if not mask[j] and not mask[k]:
232
- delta = out * shapley_coeff[set_size] # * 2
233
- elif (not mask[j] and mask[k]) or (mask[j] and not mask[k]):
234
- delta = -out * shapley_coeff[set_size - 1] # * 2
235
- else: # both true
236
- delta = out * shapley_coeff[set_size - 2] # * 2
237
- row_values[j,k] += delta
238
- row_values[k,j] += delta
239
-
240
- def partition_delta_indexes(partition_tree, all_masks):
241
- """ Return an delta index encoded array of all the masks possible while following the given partition tree.
242
- """
243
-
244
- # convert the masks to delta index format
245
- mask = np.zeros(all_masks.shape[1], dtype=bool)
246
- delta_inds = []
247
- for i in range(len(all_masks)):
248
- inds = np.where(mask ^ all_masks[i,:])[0]
249
-
250
- for j in inds[:-1]:
251
- delta_inds.append(-j - 1) # negative + (-1) means we have more inds still to change...
252
- if len(inds) == 0:
253
- delta_inds.append(MaskedModel.delta_mask_noop_value)
254
- else:
255
- delta_inds.extend(inds[-1:])
256
- mask = all_masks[i,:]
257
-
258
- return np.array(delta_inds)
259
-
260
- def partition_masks(partition_tree):
261
- """ Return an array of all the masks possible while following the given partition tree.
262
- """
263
-
264
- M = partition_tree.shape[0] + 1
265
- mask_matrix = make_masks(partition_tree)
266
- all_masks = []
267
- m00 = np.zeros(M, dtype=bool)
268
- all_masks.append(m00)
269
- all_masks.append(~m00)
270
- #inds_stack = [0,1]
271
- inds_lists = [[[], []] for i in range(M)]
272
- _partition_masks_recurse(len(partition_tree)-1, m00, 0, 1, inds_lists, mask_matrix, partition_tree, M, all_masks)
273
-
274
- all_masks = np.array(all_masks)
275
-
276
- # we resort the clustering matrix to minimize the sequential difference between the masks
277
- # this minimizes the number of model evaluations we need to run when the background sometimes
278
- # matches the foreground. We seem to average about 1.5 feature changes per mask with this
279
- # approach. This is not as clean as the grey code ordering, but a perfect 1 feature change
280
- # ordering is not possible with a clustering tree
281
- order = delta_minimization_order(all_masks)
282
- inverse_order = np.arange(len(order))[np.argsort(order)]
283
-
284
- for inds_list0,inds_list1 in inds_lists:
285
- for i in range(len(inds_list0)):
286
- inds_list0[i] = inverse_order[inds_list0[i]]
287
- inds_list1[i] = inverse_order[inds_list1[i]]
288
-
289
- # Care: inds_lists have different lengths, so partition_masks_inds is a "ragged" array. See GH #3063
290
- partition_masks = all_masks[order]
291
- partition_masks_inds = [[np.array(on), np.array(off)] for on, off in inds_lists]
292
- return partition_masks, partition_masks_inds
293
-
294
- # TODO: this should be a jit function... which would require preallocating the inds_lists (sizes are 2**depth of that ind)
295
- # TODO: we could also probable avoid making the masks at all and just record the deltas if we want...
296
- def _partition_masks_recurse(index, m00, ind00, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks):
297
- if index < 0:
298
- inds_lists[index + M][0].append(ind00)
299
- inds_lists[index + M][1].append(ind11)
300
- return
301
-
302
- # get our children indexes
303
- left_index = int(partition_tree[index,0] - M)
304
- right_index = int(partition_tree[index,1] - M)
305
-
306
- # build more refined masks
307
- m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
308
- m10[:] += mask_matrix[left_index+M, :]
309
- m01 = m00.copy()
310
- m01[:] += mask_matrix[right_index+M, :]
311
-
312
- # record the new masks we made
313
- ind01 = len(all_masks)
314
- all_masks.append(m01)
315
- ind10 = len(all_masks)
316
- all_masks.append(m10)
317
-
318
- # inds_stack.append(len(all_masks) - 2)
319
- # inds_stack.append(len(all_masks) - 1)
320
-
321
- # recurse left and right with both 1 (True) and 0 (False) contexts
322
- _partition_masks_recurse(left_index, m00, ind00, ind10, inds_lists, mask_matrix, partition_tree, M, all_masks)
323
- _partition_masks_recurse(right_index, m10, ind10, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks)
324
- _partition_masks_recurse(left_index, m01, ind01, ind11, inds_lists, mask_matrix, partition_tree, M, all_masks)
325
- _partition_masks_recurse(right_index, m00, ind00, ind01, inds_lists, mask_matrix, partition_tree, M, all_masks)
326
-
327
-
328
- def gray_code_masks(nbits):
329
- """ Produces an array of all binary patterns of size nbits in gray code order.
330
-
331
- This is based on code from: http://code.activestate.com/recipes/576592-gray-code-generatoriterator/
332
- """
333
- out = np.zeros((2**nbits, nbits), dtype=bool)
334
- li = np.zeros(nbits, dtype=bool)
335
-
336
- for term in range(2, (1<<nbits)+1):
337
- if term % 2 == 1: # odd
338
- for i in range(-1,-nbits,-1):
339
- if li[i] == 1:
340
- li[i-1] = li[i-1]^1
341
- break
342
- else: # even
343
- li[-1] = li[-1]^1
344
-
345
- out[term-1,:] = li
346
- return out
347
-
348
- def gray_code_indexes(nbits):
349
- """ Produces an array of which bits flip at which position.
350
-
351
- We assume the masks start at all zero and -1 means don't do a flip.
352
- This is a more efficient representation of the gray_code_masks version.
353
- """
354
- out = np.ones(2**nbits, dtype=int) * MaskedModel.delta_mask_noop_value
355
- li = np.zeros(nbits, dtype=bool)
356
- for term in range((1<<nbits)-1):
357
- if term % 2 == 1: # odd
358
- for i in range(-1,-nbits,-1):
359
- if li[i] == 1:
360
- li[i-1] = li[i-1]^1
361
- out[term+1] = nbits + (i-1)
362
- break
363
- else: # even
364
- li[-1] = li[-1]^1
365
- out[term+1] = nbits-1
366
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_explainer.py DELETED
@@ -1,457 +0,0 @@
1
- import copy
2
- import time
3
-
4
- import numpy as np
5
- import pandas as pd
6
- import scipy.sparse
7
-
8
- from .. import explainers, links, maskers, models
9
- from .._explanation import Explanation
10
- from .._serializable import Deserializer, Serializable, Serializer
11
- from ..maskers import Masker
12
- from ..models import Model
13
- from ..utils import safe_isinstance, show_progress
14
- from ..utils._exceptions import InvalidAlgorithmError
15
- from ..utils.transformers import is_transformers_lm
16
-
17
-
18
- class Explainer(Serializable):
19
- """ Uses Shapley values to explain any machine learning model or python function.
20
-
21
- This is the primary explainer interface for the SHAP library. It takes any combination
22
- of a model and masker and returns a callable subclass object that implements
23
- the particular estimation algorithm that was chosen.
24
- """
25
-
26
- def __init__(self, model, masker=None, link=links.identity, algorithm="auto", output_names=None, feature_names=None, linearize_link=True,
27
- seed=None, **kwargs):
28
- """ Build a new explainer for the passed model.
29
-
30
- Parameters
31
- ----------
32
- model : object or function
33
- User supplied function or model object that takes a dataset of samples and
34
- computes the output of the model for those samples.
35
-
36
- masker : function, numpy.array, pandas.DataFrame, tokenizer, None, or a list of these for each model input
37
- The function used to "mask" out hidden features of the form `masked_args = masker(*model_args, mask=mask)`.
38
- It takes input in the same form as the model, but for just a single sample with a binary
39
- mask, then returns an iterable of masked samples. These
40
- masked samples will then be evaluated using the model function and the outputs averaged.
41
- As a shortcut for the standard masking using by SHAP you can pass a background data matrix
42
- instead of a function and that matrix will be used for masking. Domain specific masking
43
- functions are available in shap such as shap.ImageMasker for images and shap.TokenMasker
44
- for text. In addition to determining how to replace hidden features, the masker can also
45
- constrain the rules of the cooperative game used to explain the model. For example
46
- shap.TabularMasker(data, hclustering="correlation") will enforce a hierarchical clustering
47
- of coalitions for the game (in this special case the attributions are known as the Owen values).
48
-
49
- link : function
50
- The link function used to map between the output units of the model and the SHAP value units. By
51
- default it is shap.links.identity, but shap.links.logit can be useful so that expectations are
52
- computed in probability units while explanations remain in the (more naturally additive) log-odds
53
- units. For more details on how link functions work see any overview of link functions for generalized
54
- linear models.
55
-
56
- algorithm : "auto", "permutation", "partition", "tree", or "linear"
57
- The algorithm used to estimate the Shapley values. There are many different algorithms that
58
- can be used to estimate the Shapley values (and the related value for constrained games), each
59
- of these algorithms have various tradeoffs and are preferable in different situations. By
60
- default the "auto" options attempts to make the best choice given the passed model and masker,
61
- but this choice can always be overridden by passing the name of a specific algorithm. The type of
62
- algorithm used will determine what type of subclass object is returned by this constructor, and
63
- you can also build those subclasses directly if you prefer or need more fine grained control over
64
- their options.
65
-
66
- output_names : None or list of strings
67
- The names of the model outputs. For example if the model is an image classifier, then output_names would
68
- be the names of all the output classes. This parameter is optional. When output_names is None then
69
- the Explanation objects produced by this explainer will not have any output_names, which could effect
70
- downstream plots.
71
-
72
- seed: None or int
73
- seed for reproducibility
74
-
75
- """
76
-
77
- self.model = model
78
- self.output_names = output_names
79
- self.feature_names = feature_names
80
-
81
- # wrap the incoming masker object as a shap.Masker object
82
- if (
83
- isinstance(masker, pd.DataFrame)
84
- or ((isinstance(masker, np.ndarray) or scipy.sparse.issparse(masker)) and len(masker.shape) == 2)
85
- ):
86
- if algorithm == "partition":
87
- self.masker = maskers.Partition(masker)
88
- else:
89
- self.masker = maskers.Independent(masker)
90
- elif safe_isinstance(masker, ["transformers.PreTrainedTokenizer", "transformers.tokenization_utils_base.PreTrainedTokenizerBase"]):
91
- if is_transformers_lm(self.model):
92
- # auto assign text infilling if model is a transformer model with lm head
93
- self.masker = maskers.Text(masker, mask_token="...", collapse_mask_token=True)
94
- else:
95
- self.masker = maskers.Text(masker)
96
- elif (masker is list or masker is tuple) and masker[0] is not str:
97
- self.masker = maskers.Composite(*masker)
98
- elif (masker is dict) and ("mean" in masker):
99
- self.masker = maskers.Independent(masker)
100
- elif masker is None and isinstance(self.model, models.TransformersPipeline):
101
- return self.__init__(
102
- self.model, self.model.inner_model.tokenizer,
103
- link=link, algorithm=algorithm, output_names=output_names, feature_names=feature_names, linearize_link=linearize_link, **kwargs
104
- )
105
- else:
106
- self.masker = masker
107
-
108
- # Check for transformer pipeline objects and wrap them
109
- if safe_isinstance(self.model, "transformers.pipelines.Pipeline"):
110
- if is_transformers_lm(self.model.model):
111
- return self.__init__(
112
- self.model.model, self.model.tokenizer if self.masker is None else self.masker,
113
- link=link, algorithm=algorithm, output_names=output_names, feature_names=feature_names, linearize_link=linearize_link, **kwargs
114
- )
115
- else:
116
- return self.__init__(
117
- models.TransformersPipeline(self.model), self.masker,
118
- link=link, algorithm=algorithm, output_names=output_names, feature_names=feature_names, linearize_link=linearize_link, **kwargs
119
- )
120
-
121
- # wrap self.masker and self.model for output text explanation algorithm
122
- if is_transformers_lm(self.model):
123
- self.model = models.TeacherForcing(self.model, self.masker.tokenizer)
124
- self.masker = maskers.OutputComposite(self.masker, self.model.text_generate)
125
- elif safe_isinstance(self.model, "shap.models.TeacherForcing") and safe_isinstance(self.masker, ["shap.maskers.Text", "shap.maskers.Image"]):
126
- self.masker = maskers.OutputComposite(self.masker, self.model.text_generate)
127
- elif safe_isinstance(self.model, "shap.models.TopKLM") and safe_isinstance(self.masker, "shap.maskers.Text"):
128
- self.masker = maskers.FixedComposite(self.masker)
129
-
130
- #self._brute_force_fallback = explainers.BruteForce(self.model, self.masker)
131
-
132
- # validate and save the link function
133
- if callable(link):
134
- self.link = link
135
- else:
136
- raise TypeError("The passed link function needs to be callable!")
137
- self.linearize_link = linearize_link
138
-
139
- # if we are called directly (as opposed to through super()) then we convert ourselves to the subclass
140
- # that implements the specific algorithm that was chosen
141
- if self.__class__ is Explainer:
142
-
143
- # do automatic algorithm selection
144
- #from .. import explainers
145
- if algorithm == "auto":
146
-
147
- # use implementation-aware methods if possible
148
- if explainers.LinearExplainer.supports_model_with_masker(model, self.masker):
149
- algorithm = "linear"
150
- elif explainers.TreeExplainer.supports_model_with_masker(model, self.masker): # TODO: check for Partition?
151
- algorithm = "tree"
152
- elif explainers.AdditiveExplainer.supports_model_with_masker(model, self.masker):
153
- algorithm = "additive"
154
-
155
- # otherwise use a model agnostic method
156
- elif callable(self.model):
157
- if issubclass(type(self.masker), maskers.Independent):
158
- if self.masker.shape[1] <= 10:
159
- algorithm = "exact"
160
- else:
161
- algorithm = "permutation"
162
- elif issubclass(type(self.masker), maskers.Partition):
163
- if self.masker.shape[1] <= 32:
164
- algorithm = "exact"
165
- else:
166
- algorithm = "permutation"
167
- elif (getattr(self.masker, "text_data", False) or getattr(self.masker, "image_data", False)) and hasattr(self.masker, "clustering"):
168
- algorithm = "partition"
169
- else:
170
- algorithm = "permutation"
171
-
172
- # if we get here then we don't know how to handle what was given to us
173
- else:
174
- raise TypeError("The passed model is not callable and cannot be analyzed directly with the given masker! Model: " + str(model))
175
-
176
- # build the right subclass
177
- if algorithm == "exact":
178
- self.__class__ = explainers.ExactExplainer
179
- explainers.ExactExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
180
- elif algorithm == "permutation":
181
- self.__class__ = explainers.PermutationExplainer
182
- explainers.PermutationExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, seed=seed, **kwargs)
183
- elif algorithm == "partition":
184
- self.__class__ = explainers.PartitionExplainer
185
- explainers.PartitionExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, output_names=self.output_names, **kwargs)
186
- elif algorithm == "tree":
187
- self.__class__ = explainers.TreeExplainer
188
- explainers.TreeExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
189
- elif algorithm == "additive":
190
- self.__class__ = explainers.AdditiveExplainer
191
- explainers.AdditiveExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
192
- elif algorithm == "linear":
193
- self.__class__ = explainers.LinearExplainer
194
- explainers.LinearExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
195
- elif algorithm == "deep":
196
- self.__class__ = explainers.DeepExplainer
197
- explainers.DeepExplainer.__init__(self, self.model, self.masker, link=self.link, feature_names=self.feature_names, linearize_link=linearize_link, **kwargs)
198
- else:
199
- raise InvalidAlgorithmError("Unknown algorithm type passed: %s!" % algorithm)
200
-
201
-
202
- def __call__(self, *args, max_evals="auto", main_effects=False, error_bounds=False, batch_size="auto",
203
- outputs=None, silent=False, **kwargs):
204
- """ Explains the output of model(*args), where args is a list of parallel iterable datasets.
205
-
206
- Note this default version could be an abstract method that is implemented by each algorithm-specific
207
- subclass of Explainer. Descriptions of each subclasses' __call__ arguments
208
- are available in their respective doc-strings.
209
- """
210
-
211
- # if max_evals == "auto":
212
- # self._brute_force_fallback
213
-
214
- start_time = time.time()
215
-
216
- if issubclass(type(self.masker), maskers.OutputComposite) and len(args)==2:
217
- self.masker.model = models.TextGeneration(target_sentences=args[1])
218
- args = args[:1]
219
- # parse our incoming arguments
220
- num_rows = None
221
- args = list(args)
222
- if self.feature_names is None:
223
- feature_names = [None for _ in range(len(args))]
224
- elif issubclass(type(self.feature_names[0]), (list, tuple)):
225
- feature_names = copy.deepcopy(self.feature_names)
226
- else:
227
- feature_names = [copy.deepcopy(self.feature_names)]
228
- for i in range(len(args)):
229
-
230
- # try and see if we can get a length from any of the for our progress bar
231
- if num_rows is None:
232
- try:
233
- num_rows = len(args[i])
234
- except Exception:
235
- pass
236
-
237
- # convert DataFrames to numpy arrays
238
- if isinstance(args[i], pd.DataFrame):
239
- feature_names[i] = list(args[i].columns)
240
- args[i] = args[i].to_numpy()
241
-
242
- # convert nlp Dataset objects to lists
243
- if safe_isinstance(args[i], "nlp.arrow_dataset.Dataset"):
244
- args[i] = args[i]["text"]
245
- elif issubclass(type(args[i]), dict) and "text" in args[i]:
246
- args[i] = args[i]["text"]
247
-
248
- if batch_size == "auto":
249
- if hasattr(self.masker, "default_batch_size"):
250
- batch_size = self.masker.default_batch_size
251
- else:
252
- batch_size = 10
253
-
254
- # loop over each sample, filling in the values array
255
- values = []
256
- output_indices = []
257
- expected_values = []
258
- mask_shapes = []
259
- main_effects = []
260
- hierarchical_values = []
261
- clustering = []
262
- output_names = []
263
- error_std = []
264
- if callable(getattr(self.masker, "feature_names", None)):
265
- feature_names = [[] for _ in range(len(args))]
266
- for row_args in show_progress(zip(*args), num_rows, self.__class__.__name__+" explainer", silent):
267
- row_result = self.explain_row(
268
- *row_args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
269
- batch_size=batch_size, outputs=outputs, silent=silent, **kwargs
270
- )
271
- values.append(row_result.get("values", None))
272
- output_indices.append(row_result.get("output_indices", None))
273
- expected_values.append(row_result.get("expected_values", None))
274
- mask_shapes.append(row_result["mask_shapes"])
275
- main_effects.append(row_result.get("main_effects", None))
276
- clustering.append(row_result.get("clustering", None))
277
- hierarchical_values.append(row_result.get("hierarchical_values", None))
278
- tmp = row_result.get("output_names", None)
279
- output_names.append(tmp(*row_args) if callable(tmp) else tmp)
280
- error_std.append(row_result.get("error_std", None))
281
- if callable(getattr(self.masker, "feature_names", None)):
282
- row_feature_names = self.masker.feature_names(*row_args)
283
- for i in range(len(row_args)):
284
- feature_names[i].append(row_feature_names[i])
285
-
286
- # split the values up according to each input
287
- arg_values = [[] for a in args]
288
- for i, v in enumerate(values):
289
- pos = 0
290
- for j in range(len(args)):
291
- mask_length = np.prod(mask_shapes[i][j])
292
- arg_values[j].append(values[i][pos:pos+mask_length])
293
- pos += mask_length
294
-
295
- # collapse the arrays as possible
296
- expected_values = pack_values(expected_values)
297
- main_effects = pack_values(main_effects)
298
- output_indices = pack_values(output_indices)
299
- main_effects = pack_values(main_effects)
300
- hierarchical_values = pack_values(hierarchical_values)
301
- error_std = pack_values(error_std)
302
- clustering = pack_values(clustering)
303
-
304
- # getting output labels
305
- ragged_outputs = False
306
- if output_indices is not None:
307
- ragged_outputs = not all(len(x) == len(output_indices[0]) for x in output_indices)
308
- if self.output_names is None:
309
- if None not in output_names:
310
- if not ragged_outputs:
311
- sliced_labels = np.array(output_names)
312
- else:
313
- sliced_labels = [np.array(output_names[i])[index_list] for i,index_list in enumerate(output_indices)]
314
- else:
315
- sliced_labels = None
316
- else:
317
- assert output_indices is not None, "You have passed a list for output_names but the model seems to not have multiple outputs!"
318
- labels = np.array(self.output_names)
319
- sliced_labels = [labels[index_list] for index_list in output_indices]
320
- if not ragged_outputs:
321
- sliced_labels = np.array(sliced_labels)
322
-
323
- if isinstance(sliced_labels, np.ndarray) and len(sliced_labels.shape) == 2:
324
- if np.all(sliced_labels[0,:] == sliced_labels):
325
- sliced_labels = sliced_labels[0]
326
-
327
- # allow the masker to transform the input data to better match the masking pattern
328
- # (such as breaking text into token segments)
329
- if hasattr(self.masker, "data_transform"):
330
- new_args = []
331
- for row_args in zip(*args):
332
- new_args.append([pack_values(v) for v in self.masker.data_transform(*row_args)])
333
- args = list(zip(*new_args))
334
-
335
- # build the explanation objects
336
- out = []
337
- for j, data in enumerate(args):
338
-
339
- # reshape the attribution values using the mask_shapes
340
- tmp = []
341
- for i, v in enumerate(arg_values[j]):
342
- if np.prod(mask_shapes[i][j]) != np.prod(v.shape): # see if we have multiple outputs
343
- tmp.append(v.reshape(*mask_shapes[i][j], -1))
344
- else:
345
- tmp.append(v.reshape(*mask_shapes[i][j]))
346
- arg_values[j] = pack_values(tmp)
347
-
348
- if feature_names[j] is None:
349
- feature_names[j] = ["Feature " + str(i) for i in range(data.shape[1])]
350
-
351
-
352
- # build an explanation object for this input argument
353
- out.append(Explanation(
354
- arg_values[j], expected_values, data,
355
- feature_names=feature_names[j], main_effects=main_effects,
356
- clustering=clustering,
357
- hierarchical_values=hierarchical_values,
358
- output_names=sliced_labels, # self.output_names
359
- error_std=error_std,
360
- compute_time=time.time() - start_time
361
- # output_shape=output_shape,
362
- #lower_bounds=v_min, upper_bounds=v_max
363
- ))
364
- return out[0] if len(out) == 1 else out
365
-
366
- def explain_row(self, *row_args, max_evals, main_effects, error_bounds, outputs, silent, **kwargs):
367
- """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes, main_effects).
368
-
369
- This is an abstract method meant to be implemented by each subclass.
370
-
371
- Returns
372
- -------
373
- tuple
374
- A tuple of (row_values, row_expected_values, row_mask_shapes), where row_values is an array of the
375
- attribution values for each sample, row_expected_values is an array (or single value) representing
376
- the expected value of the model for each sample (which is the same for all samples unless there
377
- are fixed inputs present, like labels when explaining the loss), and row_mask_shapes is a list
378
- of all the input shapes (since the row_values is always flattened),
379
- """
380
-
381
- return {}
382
-
383
- @staticmethod
384
- def supports_model_with_masker(model, masker):
385
- """ Determines if this explainer can handle the given model.
386
-
387
- This is an abstract static method meant to be implemented by each subclass.
388
- """
389
- return False
390
-
391
- @staticmethod
392
- def _compute_main_effects(fm, expected_value, inds):
393
- """ A utility method to compute the main effects from a MaskedModel.
394
- """
395
-
396
- # mask each input on in isolation
397
- masks = np.zeros(2*len(inds)-1, dtype=int)
398
- last_ind = -1
399
- for i in range(len(inds)):
400
- if i > 0:
401
- masks[2*i - 1] = -last_ind - 1 # turn off the last input
402
- masks[2*i] = inds[i] # turn on this input
403
- last_ind = inds[i]
404
-
405
- # compute the main effects for the given indexes
406
- main_effects = fm(masks) - expected_value
407
-
408
- # expand the vector to the full input size
409
- expanded_main_effects = np.zeros(len(fm))
410
- for i, ind in enumerate(inds):
411
- expanded_main_effects[ind] = main_effects[i]
412
-
413
- return expanded_main_effects
414
-
415
- def save(self, out_file, model_saver=".save", masker_saver=".save"):
416
- """ Write the explainer to the given file stream.
417
- """
418
- super().save(out_file)
419
- with Serializer(out_file, "shap.Explainer", version=0) as s:
420
- s.save("model", self.model, model_saver)
421
- s.save("masker", self.masker, masker_saver)
422
- s.save("link", self.link)
423
-
424
- @classmethod
425
- def load(cls, in_file, model_loader=Model.load, masker_loader=Masker.load, instantiate=True):
426
- """ Load an Explainer from the given file stream.
427
-
428
- Parameters
429
- ----------
430
- in_file : The file stream to load objects from.
431
- """
432
- if instantiate:
433
- return cls._instantiated_load(in_file, model_loader=model_loader, masker_loader=masker_loader)
434
-
435
- kwargs = super().load(in_file, instantiate=False)
436
- with Deserializer(in_file, "shap.Explainer", min_version=0, max_version=0) as s:
437
- kwargs["model"] = s.load("model", model_loader)
438
- kwargs["masker"] = s.load("masker", masker_loader)
439
- kwargs["link"] = s.load("link")
440
- return kwargs
441
-
442
- def pack_values(values):
443
- """ Used the clean up arrays before putting them into an Explanation object.
444
- """
445
-
446
- if not hasattr(values, "__len__"):
447
- return values
448
-
449
- # collapse the values if we didn't compute them
450
- if values is None or values[0] is None:
451
- return None
452
-
453
- # convert to a single numpy matrix when the array is not ragged
454
- elif np.issubdtype(type(values[0]), np.number) or len(np.unique([len(v) for v in values])) == 1:
455
- return np.array(values)
456
- else:
457
- return np.array(values, dtype=object)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_gpu_tree.py DELETED
@@ -1,179 +0,0 @@
1
- """GPU accelerated tree explanations"""
2
- import numpy as np
3
-
4
- from ..utils import assert_import, record_import_error
5
- from ._tree import TreeExplainer, feature_perturbation_codes, output_transform_codes
6
-
7
- try:
8
- from .. import _cext_gpu
9
- except ImportError as e:
10
- record_import_error("cext_gpu", "cuda extension was not built during install!", e)
11
-
12
-
13
- class GPUTreeExplainer(TreeExplainer):
14
- """
15
- Experimental GPU accelerated version of TreeExplainer. Currently requires source build with
16
- cuda available and 'CUDA_PATH' environment variable defined.
17
-
18
- Parameters
19
- ----------
20
- model : model object
21
- The tree based machine learning model that we want to explain. XGBoost, LightGBM,
22
- CatBoost, Pyspark and most tree-based scikit-learn models are supported.
23
-
24
- data : numpy.array or pandas.DataFrame
25
- The background dataset to use for integrating out features. This argument is optional when
26
- feature_perturbation="tree_path_dependent", since in that case we can use the number of
27
- training samples that went down each tree path as our background dataset (this is recorded
28
- in the model object).
29
-
30
- feature_perturbation : "interventional" (default) or "tree_path_dependent" (default when data=None)
31
- Since SHAP values rely on conditional expectations we need to decide how to handle correlated
32
- (or otherwise dependent) input features. The "interventional" approach breaks the dependencies
33
- between features according to the rules dictated by casual inference (Janzing et al. 2019). Note
34
- that the "interventional" option requires a background dataset and its runtime scales linearly
35
- with the size of the background dataset you use. Anywhere from 100 to 1000 random background samples
36
- are good sizes to use. The "tree_path_dependent" approach is to just follow the trees and use the
37
- number of training examples that went down each leaf to represent the background distribution.
38
- This approach does not require a background dataset and so is used by default when no background
39
- dataset is provided.
40
-
41
- model_output : "raw", "probability", "log_loss", or model method name
42
- What output of the model should be explained. If "raw" then we explain the raw output of the
43
- trees, which varies by model. For regression models "raw" is the standard output, for binary
44
- classification in XGBoost this is the log odds ratio. If model_output is the name of a
45
- supported prediction method on the model object then we explain the output of that model
46
- method name. For example model_output="predict_proba" explains the result of calling
47
- model.predict_proba. If "probability" then we explain the output of the model transformed into
48
- probability space (note that this means the SHAP values now sum to the probability output of the
49
- model). If "logloss" then we explain the log base e of the model loss function, so that the SHAP
50
- values sum up to the log loss of the model for each sample. This is helpful for breaking
51
- down model performance by feature. Currently the probability and logloss options are only
52
- supported when
53
- feature_dependence="independent".
54
-
55
- Examples
56
- --------
57
- See `GPUTree explainer examples <https://shap.readthedocs.io/en/latest/api_examples/explainers/GPUTreeExplainer.html>`_
58
- """
59
-
60
- def shap_values(self, X, y=None, tree_limit=None, approximate=False, check_additivity=True,
61
- from_call=False):
62
- """ Estimate the SHAP values for a set of samples.
63
-
64
- Parameters
65
- ----------
66
- X : numpy.array, pandas.DataFrame or catboost.Pool (for catboost)
67
- A matrix of samples (# samples x # features) on which to explain the model's output.
68
-
69
- y : numpy.array
70
- An array of label values for each sample. Used when explaining loss functions.
71
-
72
- tree_limit : None (default) or int
73
- Limit the number of trees used by the model. By default None means no use the limit
74
- of the
75
- original model, and -1 means no limit.
76
-
77
- approximate : bool
78
- Not supported.
79
-
80
- check_additivity : bool
81
- Run a validation check that the sum of the SHAP values equals the output of the
82
- model. This
83
- check takes only a small amount of time, and will catch potential unforeseen errors.
84
- Note that this check only runs right now when explaining the margin of the model.
85
-
86
- Returns
87
- -------
88
- array or list
89
- For models with a single output this returns a matrix of SHAP values
90
- (# samples x # features). Each row sums to the difference between the model output
91
- for that
92
- sample and the expected value of the model output (which is stored in the expected_value
93
- attribute of the explainer when it is constant). For models with vector outputs this
94
- returns
95
- a list of such matrices, one for each output.
96
- """
97
- assert not approximate, "approximate not supported"
98
-
99
- X, y, X_missing, flat_output, tree_limit, check_additivity = \
100
- self._validate_inputs(X, y,
101
- tree_limit,
102
- check_additivity)
103
- transform = self.model.get_transform()
104
-
105
- # run the core algorithm using the C extension
106
- assert_import("cext_gpu")
107
- phi = np.zeros((X.shape[0], X.shape[1] + 1, self.model.num_outputs))
108
- _cext_gpu.dense_tree_shap(
109
- self.model.children_left, self.model.children_right, self.model.children_default,
110
- self.model.features, self.model.thresholds, self.model.values,
111
- self.model.node_sample_weight,
112
- self.model.max_depth, X, X_missing, y, self.data, self.data_missing, tree_limit,
113
- self.model.base_offset, phi, feature_perturbation_codes[self.feature_perturbation],
114
- output_transform_codes[transform], False
115
- )
116
-
117
- out = self._get_shap_output(phi, flat_output)
118
- if check_additivity and self.model.model_output == "raw":
119
- self.assert_additivity(out, self.model.predict(X))
120
-
121
- return out
122
-
123
- def shap_interaction_values(self, X, y=None, tree_limit=None):
124
- """ Estimate the SHAP interaction values for a set of samples.
125
-
126
- Parameters
127
- ----------
128
- X : numpy.array, pandas.DataFrame or catboost.Pool (for catboost)
129
- A matrix of samples (# samples x # features) on which to explain the model's output.
130
-
131
- y : numpy.array
132
- An array of label values for each sample. Used when explaining loss functions (not
133
- yet supported).
134
-
135
- tree_limit : None (default) or int
136
- Limit the number of trees used by the model. By default None means no use the limit
137
- of the
138
- original model, and -1 means no limit.
139
-
140
- Returns
141
- -------
142
- array or list
143
- For models with a single output this returns a tensor of SHAP values
144
- (# samples x # features x # features). The matrix (# features x # features) for each
145
- sample sums
146
- to the difference between the model output for that sample and the expected value of
147
- the model output
148
- (which is stored in the expected_value attribute of the explainer). Each row of this
149
- matrix sums to the
150
- SHAP value for that feature for that sample. The diagonal entries of the matrix
151
- represent the
152
- "main effect" of that feature on the prediction and the symmetric off-diagonal
153
- entries represent the
154
- interaction effects between all pairs of features for that sample. For models with
155
- vector outputs
156
- this returns a list of tensors, one for each output.
157
- """
158
-
159
- assert self.model.model_output == "raw", "Only model_output = \"raw\" is supported for " \
160
- "SHAP interaction values right now!"
161
- assert self.feature_perturbation != "interventional", 'feature_perturbation="interventional" is not yet supported for ' + \
162
- 'interaction values. Use feature_perturbation="tree_path_dependent" instead.'
163
- transform = "identity"
164
-
165
- X, y, X_missing, flat_output, tree_limit, _ = self._validate_inputs(X, y, tree_limit,
166
- False)
167
- # run the core algorithm using the C extension
168
- assert_import("cext_gpu")
169
- phi = np.zeros((X.shape[0], X.shape[1] + 1, X.shape[1] + 1, self.model.num_outputs))
170
- _cext_gpu.dense_tree_shap(
171
- self.model.children_left, self.model.children_right, self.model.children_default,
172
- self.model.features, self.model.thresholds, self.model.values,
173
- self.model.node_sample_weight,
174
- self.model.max_depth, X, X_missing, y, self.data, self.data_missing, tree_limit,
175
- self.model.base_offset, phi, feature_perturbation_codes[self.feature_perturbation],
176
- output_transform_codes[transform], True
177
- )
178
-
179
- return self._get_shap_interactions_output(phi, flat_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_gradient.py DELETED
@@ -1,592 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from packaging import version
6
-
7
- from .._explanation import Explanation
8
- from ..explainers._explainer import Explainer
9
- from ..explainers.tf_utils import (
10
- _get_graph,
11
- _get_model_inputs,
12
- _get_model_output,
13
- _get_session,
14
- )
15
-
16
- keras = None
17
- tf = None
18
- torch = None
19
-
20
-
21
- class GradientExplainer(Explainer):
22
- """ Explains a model using expected gradients (an extension of integrated gradients).
23
-
24
- Expected gradients an extension of the integrated gradients method (Sundararajan et al. 2017), a
25
- feature attribution method designed for differentiable models based on an extension of Shapley
26
- values to infinite player games (Aumann-Shapley values). Integrated gradients values are a bit
27
- different from SHAP values, and require a single reference value to integrate from. As an adaptation
28
- to make them approximate SHAP values, expected gradients reformulates the integral as an expectation
29
- and combines that expectation with sampling reference values from the background dataset. This leads
30
- to a single combined expectation of gradients that converges to attributions that sum to the
31
- difference between the expected model output and the current output.
32
-
33
- Examples
34
- --------
35
- See :ref:`Gradient Explainer Examples <gradient_explainer_examples>`
36
- """
37
-
38
- def __init__(self, model, data, session=None, batch_size=50, local_smoothing=0):
39
- """ An explainer object for a differentiable model using a given background dataset.
40
-
41
- Parameters
42
- ----------
43
- model : tf.keras.Model, (input : [tf.Tensor], output : tf.Tensor), torch.nn.Module, or a tuple
44
- (model, layer), where both are torch.nn.Module objects
45
-
46
- For TensorFlow this can be a model object, or a pair of TensorFlow tensors (or a list and
47
- a tensor) that specifies the input and output of the model to be explained. Note that for
48
- TensowFlow 2 you must pass a tensorflow function, not a tuple of input/output tensors).
49
-
50
- For PyTorch this can be a nn.Module object (model), or a tuple (model, layer), where both
51
- are nn.Module objects. The model is an nn.Module object which takes as input a tensor
52
- (or list of tensors) of shape data, and returns a single dimensional output. If the input
53
- is a tuple, the returned shap values will be for the input of the layer argument. layer must
54
- be a layer in the model, i.e. model.conv2.
55
-
56
- data : [numpy.array] or [pandas.DataFrame] or [torch.tensor]
57
- The background dataset to use for integrating out features. Gradient explainer integrates
58
- over these samples. The data passed here must match the input tensors given in the
59
- first argument. Single element lists can be passed unwrapped.
60
- """
61
-
62
- # first, we need to find the framework
63
- if type(model) is tuple:
64
- a, b = model
65
- try:
66
- a.named_parameters()
67
- framework = 'pytorch'
68
- except Exception:
69
- framework = 'tensorflow'
70
- else:
71
- try:
72
- model.named_parameters()
73
- framework = 'pytorch'
74
- except Exception:
75
- framework = 'tensorflow'
76
-
77
- if isinstance(data, pd.DataFrame):
78
- self.features = data.columns.values
79
- else:
80
- self.features = None
81
-
82
- if framework == 'tensorflow':
83
- self.explainer = _TFGradient(model, data, session, batch_size, local_smoothing)
84
- elif framework == 'pytorch':
85
- self.explainer = _PyTorchGradient(model, data, batch_size, local_smoothing)
86
-
87
- def __call__(self, X, nsamples=200):
88
- """ Return an explanation object for the model applied to X.
89
-
90
- Parameters
91
- ----------
92
- X : list,
93
- if framework == 'tensorflow': numpy.array, or pandas.DataFrame
94
- if framework == 'pytorch': torch.tensor
95
- A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to
96
- explain the model's output.
97
- nsamples : int
98
- number of background samples
99
- Returns
100
- -------
101
- shap.Explanation:
102
- """
103
- shap_values = self.shap_values(X, nsamples)
104
- return Explanation(values=shap_values, data=X, feature_names=self.features)
105
-
106
- def shap_values(self, X, nsamples=200, ranked_outputs=None, output_rank_order="max", rseed=None, return_variances=False):
107
- """ Return the values for the model applied to X.
108
-
109
- Parameters
110
- ----------
111
- X : list,
112
- if framework == 'tensorflow': numpy.array, or pandas.DataFrame
113
- if framework == 'pytorch': torch.tensor
114
- A tensor (or list of tensors) of samples (where X.shape[0] == # samples) on which to
115
- explain the model's output.
116
-
117
- ranked_outputs : None or int
118
- If ranked_outputs is None then we explain all the outputs in a multi-output model. If
119
- ranked_outputs is a positive integer then we only explain that many of the top model
120
- outputs (where "top" is determined by output_rank_order). Note that this causes a pair
121
- of values to be returned (shap_values, indexes), where shap_values is a list of numpy arrays
122
- for each of the output ranks, and indexes is a matrix that tells for each sample which output
123
- indexes were chosen as "top".
124
-
125
- output_rank_order : "max", "min", "max_abs", or "custom"
126
- How to order the model outputs when using ranked_outputs, either by maximum, minimum, or
127
- maximum absolute value. If "custom" Then "ranked_outputs" contains a list of output nodes.
128
-
129
- rseed : None or int
130
- Seeding the randomness in shap value computation (background example choice,
131
- interpolation between current and background example, smoothing).
132
-
133
- Returns
134
- -------
135
- array or list
136
- For a models with a single output this returns a tensor of SHAP values with the same shape
137
- as X. For a model with multiple outputs this returns a list of SHAP value tensors, each of
138
- which are the same shape as X. If ranked_outputs is None then this list of tensors matches
139
- the number of model outputs. If ranked_outputs is a positive integer a pair is returned
140
- (shap_values, indexes), where shap_values is a list of tensors with a length of
141
- ranked_outputs, and indexes is a matrix that tells for each sample which output indexes
142
- were chosen as "top".
143
- """
144
- return self.explainer.shap_values(X, nsamples, ranked_outputs, output_rank_order, rseed, return_variances)
145
-
146
-
147
- class _TFGradient(Explainer):
148
-
149
- def __init__(self, model, data, session=None, batch_size=50, local_smoothing=0):
150
-
151
- # try and import keras and tensorflow
152
- global tf, keras
153
- if tf is None:
154
- import tensorflow as tf
155
- if version.parse(tf.__version__) < version.parse("1.4.0"):
156
- warnings.warn("Your TensorFlow version is older than 1.4.0 and not supported.")
157
- if keras is None:
158
- try:
159
- from tensorflow import keras
160
- if version.parse(keras.__version__) < version.parse("2.1.0"):
161
- warnings.warn("Your Keras version is older than 2.1.0 and not supported.")
162
- except Exception:
163
- pass
164
-
165
- # determine the model inputs and outputs
166
- self.model = model
167
- self.model_inputs = _get_model_inputs(model)
168
- self.model_output = _get_model_output(model)
169
- assert not isinstance(self.model_output, list), "The model output to be explained must be a single tensor!"
170
- assert len(self.model_output.shape) < 3, "The model output must be a vector or a single value!"
171
- self.multi_output = True
172
- if len(self.model_output.shape) == 1:
173
- self.multi_output = False
174
-
175
- # check if we have multiple inputs
176
- self.multi_input = True
177
- if not isinstance(self.model_inputs, list):
178
- self.model_inputs = [self.model_inputs]
179
- self.multi_input = len(self.model_inputs) > 1
180
- if isinstance(data, pd.DataFrame):
181
- data = [data.values]
182
- if not isinstance(data, list):
183
- data = [data]
184
-
185
- self.data = data
186
- self._num_vinputs = {}
187
- self.batch_size = batch_size
188
- self.local_smoothing = local_smoothing
189
-
190
- if not tf.executing_eagerly():
191
- self.session = _get_session(session)
192
- self.graph = _get_graph(self)
193
- # see if there is a keras operation we need to save
194
- self.keras_phase_placeholder = None
195
- for op in self.graph.get_operations():
196
- if 'keras_learning_phase' in op.name:
197
- self.keras_phase_placeholder = op.outputs[0]
198
-
199
- # save the expected output of the model (commented out because self.data could be huge for GradientExpliner)
200
- #self.expected_value = self.run(self.model_output, self.model_inputs, self.data).mean(0)
201
-
202
- if not self.multi_output:
203
- self.gradients = [None]
204
- else:
205
- self.gradients = [None for i in range(self.model_output.shape[1])]
206
-
207
- def gradient(self, i):
208
- global tf, keras
209
-
210
- if self.gradients[i] is None:
211
- if not tf.executing_eagerly():
212
- out = self.model_output[:,i] if self.multi_output else self.model_output
213
- self.gradients[i] = tf.gradients(out, self.model_inputs)
214
- else:
215
- @tf.function
216
- def grad_graph(x):
217
- phase = tf.keras.backend.learning_phase()
218
- tf.keras.backend.set_learning_phase(0)
219
-
220
- with tf.GradientTape(watch_accessed_variables=False) as tape:
221
- tape.watch(x)
222
- out = self.model(x)
223
- if self.multi_output:
224
- out = out[:,i]
225
-
226
- x_grad = tape.gradient(out, x)
227
-
228
- tf.keras.backend.set_learning_phase(phase)
229
-
230
- return x_grad
231
-
232
- self.gradients[i] = grad_graph
233
-
234
- return self.gradients[i]
235
-
236
- def shap_values(self, X, nsamples=200, ranked_outputs=None, output_rank_order="max", rseed=None, return_variances=False):
237
- global tf, keras
238
-
239
- import tensorflow as tf
240
- import tensorflow.keras as keras
241
-
242
- # check if we have multiple inputs
243
- if not self.multi_input:
244
- assert not isinstance(X, list), "Expected a single tensor model input!"
245
- X = [X]
246
- else:
247
- assert isinstance(X, list), "Expected a list of model inputs!"
248
- assert len(self.model_inputs) == len(X), "Number of model inputs does not match the number given!"
249
-
250
- # rank and determine the model outputs that we will explain
251
- if not tf.executing_eagerly():
252
- model_output_values = self.run(self.model_output, self.model_inputs, X)
253
- else:
254
- model_output_values = self.run(self.model, self.model_inputs, X)
255
- if ranked_outputs is not None and self.multi_output:
256
- if output_rank_order == "max":
257
- model_output_ranks = np.argsort(-model_output_values)
258
- elif output_rank_order == "min":
259
- model_output_ranks = np.argsort(model_output_values)
260
- elif output_rank_order == "max_abs":
261
- model_output_ranks = np.argsort(np.abs(model_output_values))
262
- elif output_rank_order == "custom":
263
- model_output_ranks = ranked_outputs
264
- else:
265
- emsg = "output_rank_order must be max, min, max_abs or custom!"
266
- raise ValueError(emsg)
267
-
268
- if output_rank_order in ["max", "min", "max_abs"]:
269
- model_output_ranks = model_output_ranks[:,:ranked_outputs]
270
- else:
271
- model_output_ranks = np.tile(np.arange(len(self.gradients)), (X[0].shape[0], 1))
272
-
273
- # compute the attributions
274
- output_phis = []
275
- output_phi_vars = []
276
- samples_input = [np.zeros((nsamples,) + X[t].shape[1:], dtype=np.float32) for t in range(len(X))]
277
- samples_delta = [np.zeros((nsamples,) + X[t].shape[1:], dtype=np.float32) for t in range(len(X))]
278
- # use random seed if no argument given
279
- if rseed is None:
280
- rseed = np.random.randint(0, 1e6)
281
-
282
- for i in range(model_output_ranks.shape[1]):
283
- np.random.seed(rseed) # so we get the same noise patterns for each output class
284
- phis = []
285
- phi_vars = []
286
- for k in range(len(X)):
287
- phis.append(np.zeros(X[k].shape))
288
- phi_vars.append(np.zeros(X[k].shape))
289
- for j in range(X[0].shape[0]):
290
-
291
- # fill in the samples arrays
292
- for k in range(nsamples):
293
- rind = np.random.choice(self.data[0].shape[0])
294
- t = np.random.uniform()
295
- for u in range(len(X)):
296
- if self.local_smoothing > 0:
297
- x = X[u][j] + np.random.randn(*X[u][j].shape) * self.local_smoothing
298
- else:
299
- x = X[u][j]
300
- samples_input[u][k] = t * x + (1 - t) * self.data[u][rind]
301
- samples_delta[u][k] = x - self.data[u][rind]
302
-
303
- # compute the gradients at all the sample points
304
- find = model_output_ranks[j,i]
305
- grads = []
306
- for b in range(0, nsamples, self.batch_size):
307
- batch = [samples_input[a][b:min(b+self.batch_size,nsamples)] for a in range(len(X))]
308
- grads.append(self.run(self.gradient(find), self.model_inputs, batch))
309
- grad = [np.concatenate([g[a] for g in grads], 0) for a in range(len(X))]
310
-
311
- # assign the attributions to the right part of the output arrays
312
- for a in range(len(X)):
313
- samples = grad[a] * samples_delta[a]
314
- phis[a][j] = samples.mean(0)
315
- phi_vars[a][j] = samples.var(0) / np.sqrt(samples.shape[0]) # estimate variance of means
316
-
317
- # TODO: this could be avoided by integrating between endpoints if no local smoothing is used
318
- # correct the sum of the values to equal the output of the model using a linear
319
- # regression model with priors of the coefficients equal to the estimated variances for each
320
- # value (note that 1e-6 is designed to increase the weight of the sample and so closely
321
- # match the correct sum)
322
- # if False and self.local_smoothing == 0: # disabled right now to make sure it doesn't mask problems
323
- # phis_sum = np.sum([phis[l][j].sum() for l in range(len(X))])
324
- # phi_vars_s = np.stack([phi_vars[l][j] for l in range(len(X))], 0).flatten()
325
- # if self.multi_output:
326
- # sum_error = model_output_values[j,find] - phis_sum - self.expected_value[find]
327
- # else:
328
- # sum_error = model_output_values[j] - phis_sum - self.expected_value
329
-
330
- # # this is a ridge regression with one sample of all ones with sum_error as the label
331
- # # and 1/v as the ridge penalties. This simplified (and stable) form comes from the
332
- # # Sherman-Morrison formula
333
- # v = (phi_vars_s / phi_vars_s.max()) * 1e6
334
- # adj = sum_error * (v - (v * v.sum()) / (1 + v.sum()))
335
-
336
- # # add the adjustment to the output so the sum matches
337
- # offset = 0
338
- # for l in range(len(X)):
339
- # s = np.prod(phis[l][j].shape)
340
- # phis[l][j] += adj[offset:offset+s].reshape(phis[l][j].shape)
341
- # offset += s
342
-
343
- output_phis.append(phis[0] if not self.multi_input else phis)
344
- output_phi_vars.append(phi_vars[0] if not self.multi_input else phi_vars)
345
- if not self.multi_output:
346
- if return_variances:
347
- return output_phis[0], output_phi_vars[0]
348
- else:
349
- return output_phis[0]
350
- elif ranked_outputs is not None:
351
- if return_variances:
352
- return output_phis, output_phi_vars, model_output_ranks
353
- else:
354
- return output_phis, model_output_ranks
355
- else:
356
- if return_variances:
357
- return output_phis, output_phi_vars
358
- else:
359
- return output_phis
360
-
361
- def run(self, out, model_inputs, X):
362
- global tf, keras
363
-
364
- if not tf.executing_eagerly():
365
- feed_dict = dict(zip(model_inputs, X))
366
- if self.keras_phase_placeholder is not None:
367
- feed_dict[self.keras_phase_placeholder] = 0
368
- return self.session.run(out, feed_dict)
369
- else:
370
- # build inputs that are correctly shaped, typed, and tf-wrapped
371
- inputs = []
372
- for i in range(len(X)):
373
- shape = list(self.model_inputs[i].shape)
374
- shape[0] = -1
375
- v = tf.constant(X[i].reshape(shape), dtype=self.model_inputs[i].dtype)
376
- inputs.append(v)
377
- return out(inputs)
378
-
379
-
380
- class _PyTorchGradient(Explainer):
381
-
382
- def __init__(self, model, data, batch_size=50, local_smoothing=0):
383
-
384
- # try and import pytorch
385
- global torch
386
- if torch is None:
387
- import torch
388
- if version.parse(torch.__version__) < version.parse("0.4"):
389
- warnings.warn("Your PyTorch version is older than 0.4 and not supported.")
390
-
391
- # check if we have multiple inputs
392
- self.multi_input = False
393
- if isinstance(data, list):
394
- self.multi_input = True
395
- if not isinstance(data, list):
396
- data = [data]
397
-
398
- # for consistency, the method signature calls for data as the model input.
399
- # However, within this class, self.model_inputs is the input (i.e. the data passed by the user)
400
- # and self.data is the background data for the layer we want to assign importances to. If this layer is
401
- # the input, then self.data = self.model_inputs
402
- self.model_inputs = data
403
- self.batch_size = batch_size
404
- self.local_smoothing = local_smoothing
405
-
406
- self.layer = None
407
- self.input_handle = None
408
- self.interim = False
409
- if type(model) == tuple:
410
- self.interim = True
411
- model, layer = model
412
- model = model.eval()
413
- self.add_handles(layer)
414
- self.layer = layer
415
-
416
- # now, if we are taking an interim layer, the 'data' is going to be the input
417
- # of the interim layer; we will capture this using a forward hook
418
- with torch.no_grad():
419
- _ = model(*data)
420
- interim_inputs = self.layer.target_input
421
- if type(interim_inputs) is tuple:
422
- # this should always be true, but just to be safe
423
- self.data = [i.clone().detach() for i in interim_inputs]
424
- else:
425
- self.data = [interim_inputs.clone().detach()]
426
- else:
427
- self.data = data
428
- self.model = model.eval()
429
-
430
- multi_output = False
431
- outputs = self.model(*self.model_inputs)
432
- if len(outputs.shape) > 1 and outputs.shape[1] > 1:
433
- multi_output = True
434
- self.multi_output = multi_output
435
-
436
- if not self.multi_output:
437
- self.gradients = [None]
438
- else:
439
- self.gradients = [None for i in range(outputs.shape[1])]
440
-
441
- def gradient(self, idx, inputs):
442
- self.model.zero_grad()
443
- X = [x.requires_grad_() for x in inputs]
444
- outputs = self.model(*X)
445
- selected = [val for val in outputs[:, idx]]
446
- if self.input_handle is not None:
447
- interim_inputs = self.layer.target_input
448
- grads = [torch.autograd.grad(selected, input,
449
- retain_graph=True if idx + 1 < len(interim_inputs) else None)[0].cpu().numpy()
450
- for idx, input in enumerate(interim_inputs)]
451
- del self.layer.target_input
452
- else:
453
- grads = [torch.autograd.grad(selected, x,
454
- retain_graph=True if idx + 1 < len(X) else None)[0].cpu().numpy()
455
- for idx, x in enumerate(X)]
456
- return grads
457
-
458
- @staticmethod
459
- def get_interim_input(self, input, output):
460
- try:
461
- del self.target_input
462
- except AttributeError:
463
- pass
464
- setattr(self, 'target_input', input)
465
-
466
- def add_handles(self, layer):
467
- input_handle = layer.register_forward_hook(self.get_interim_input)
468
- self.input_handle = input_handle
469
-
470
- def shap_values(self, X, nsamples=200, ranked_outputs=None, output_rank_order="max", rseed=None, return_variances=False):
471
-
472
- # X ~ self.model_input
473
- # X_data ~ self.data
474
-
475
- # check if we have multiple inputs
476
- if not self.multi_input:
477
- assert not isinstance(X, list), "Expected a single tensor model input!"
478
- X = [X]
479
- else:
480
- assert isinstance(X, list), "Expected a list of model inputs!"
481
-
482
- if ranked_outputs is not None and self.multi_output:
483
- with torch.no_grad():
484
- model_output_values = self.model(*X)
485
- # rank and determine the model outputs that we will explain
486
- if output_rank_order == "max":
487
- _, model_output_ranks = torch.sort(model_output_values, descending=True)
488
- elif output_rank_order == "min":
489
- _, model_output_ranks = torch.sort(model_output_values, descending=False)
490
- elif output_rank_order == "max_abs":
491
- _, model_output_ranks = torch.sort(torch.abs(model_output_values), descending=True)
492
- else:
493
- emsg = "output_rank_order must be max, min, or max_abs!"
494
- raise ValueError(emsg)
495
- model_output_ranks = model_output_ranks[:, :ranked_outputs]
496
- else:
497
- model_output_ranks = (torch.ones((X[0].shape[0], len(self.gradients))).int() *
498
- torch.arange(0, len(self.gradients)).int())
499
-
500
- # if a cleanup happened, we need to add the handles back
501
- # this allows shap_values to be called multiple times, but the model to be
502
- # 'clean' at the end of each run for other uses
503
- if self.input_handle is None and self.interim is True:
504
- self.add_handles(self.layer)
505
-
506
- # compute the attributions
507
- X_batches = X[0].shape[0]
508
- output_phis = []
509
- output_phi_vars = []
510
- # samples_input = input to the model
511
- # samples_delta = (x - x') for the input being explained - may be an interim input
512
- samples_input = [torch.zeros((nsamples,) + X[t].shape[1:], device=X[t].device) for t in range(len(X))]
513
- samples_delta = [np.zeros((nsamples, ) + self.data[t].shape[1:]) for t in range(len(self.data))]
514
-
515
- # use random seed if no argument given
516
- if rseed is None:
517
- rseed = np.random.randint(0, 1e6)
518
-
519
- for i in range(model_output_ranks.shape[1]):
520
- np.random.seed(rseed) # so we get the same noise patterns for each output class
521
- phis = []
522
- phi_vars = []
523
- for k in range(len(self.data)):
524
- # for each of the inputs being explained - may be an interim input
525
- phis.append(np.zeros((X_batches,) + self.data[k].shape[1:]))
526
- phi_vars.append(np.zeros((X_batches, ) + self.data[k].shape[1:]))
527
- for j in range(X[0].shape[0]):
528
- # fill in the samples arrays
529
- for k in range(nsamples):
530
- rind = np.random.choice(self.data[0].shape[0])
531
- t = np.random.uniform()
532
- for a in range(len(X)):
533
- if self.local_smoothing > 0:
534
- # local smoothing is added to the base input, unlike in the TF gradient explainer
535
- x = X[a][j].clone().detach() + torch.empty(X[a][j].shape, device=X[a].device).normal_() \
536
- * self.local_smoothing
537
- else:
538
- x = X[a][j].clone().detach()
539
- samples_input[a][k] = (t * x + (1 - t) * (self.model_inputs[a][rind]).clone().detach()).\
540
- clone().detach()
541
- if self.input_handle is None:
542
- samples_delta[a][k] = (x - (self.data[a][rind]).clone().detach()).cpu().numpy()
543
-
544
- if self.interim is True:
545
- with torch.no_grad():
546
- _ = self.model(*[samples_input[a][k].unsqueeze(0) for a in range(len(X))])
547
- interim_inputs = self.layer.target_input
548
- del self.layer.target_input
549
- if type(interim_inputs) is tuple:
550
- if type(interim_inputs) is tuple:
551
- # this should always be true, but just to be safe
552
- for a in range(len(interim_inputs)):
553
- samples_delta[a][k] = interim_inputs[a].cpu().numpy()
554
- else:
555
- samples_delta[0][k] = interim_inputs.cpu().numpy()
556
-
557
- # compute the gradients at all the sample points
558
- find = model_output_ranks[j, i]
559
- grads = []
560
- for b in range(0, nsamples, self.batch_size):
561
- batch = [samples_input[c][b:min(b+self.batch_size,nsamples)].clone().detach() for c in range(len(X))]
562
- grads.append(self.gradient(find, batch))
563
- grad = [np.concatenate([g[z] for g in grads], 0) for z in range(len(self.data))]
564
- # assign the attributions to the right part of the output arrays
565
- for t in range(len(self.data)):
566
- samples = grad[t] * samples_delta[t]
567
- phis[t][j] = samples.mean(0)
568
- phi_vars[t][j] = samples.var(0) / np.sqrt(samples.shape[0]) # estimate variance of means
569
-
570
- output_phis.append(phis[0] if len(self.data) == 1 else phis)
571
- output_phi_vars.append(phi_vars[0] if not self.multi_input else phi_vars)
572
- # cleanup: remove the handles, if they were added
573
- if self.input_handle is not None:
574
- self.input_handle.remove()
575
- self.input_handle = None
576
- # note: the target input attribute is deleted in the loop
577
-
578
- if not self.multi_output:
579
- if return_variances:
580
- return output_phis[0], output_phi_vars[0]
581
- else:
582
- return output_phis[0]
583
- elif ranked_outputs is not None:
584
- if return_variances:
585
- return output_phis, output_phi_vars, model_output_ranks
586
- else:
587
- return output_phis, model_output_ranks
588
- else:
589
- if return_variances:
590
- return output_phis, output_phi_vars
591
- else:
592
- return output_phis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_kernel.py DELETED
@@ -1,696 +0,0 @@
1
- import copy
2
- import gc
3
- import itertools
4
- import logging
5
- import time
6
- import warnings
7
-
8
- import numpy as np
9
- import pandas as pd
10
- import scipy.sparse
11
- import sklearn
12
- from packaging import version
13
- from scipy.special import binom
14
- from sklearn.linear_model import Lasso, LassoLarsIC, lars_path
15
- from sklearn.pipeline import make_pipeline
16
- from sklearn.preprocessing import StandardScaler
17
- from tqdm.auto import tqdm
18
-
19
- from .._explanation import Explanation
20
- from ..utils import safe_isinstance
21
- from ..utils._exceptions import DimensionError
22
- from ..utils._legacy import (
23
- DenseData,
24
- SparseData,
25
- convert_to_data,
26
- convert_to_instance,
27
- convert_to_instance_with_index,
28
- convert_to_link,
29
- convert_to_model,
30
- match_instance_to_data,
31
- match_model_to_data,
32
- )
33
- from ._explainer import Explainer
34
-
35
- log = logging.getLogger('shap')
36
-
37
-
38
- class KernelExplainer(Explainer):
39
- """Uses the Kernel SHAP method to explain the output of any function.
40
-
41
- Kernel SHAP is a method that uses a special weighted linear regression
42
- to compute the importance of each feature. The computed importance values
43
- are Shapley values from game theory and also coefficients from a local linear
44
- regression.
45
-
46
- Parameters
47
- ----------
48
- model : function or iml.Model
49
- User supplied function that takes a matrix of samples (# samples x # features) and
50
- computes the output of the model for those samples. The output can be a vector
51
- (# samples) or a matrix (# samples x # model outputs).
52
-
53
- data : numpy.array or pandas.DataFrame or shap.common.DenseData or any scipy.sparse matrix
54
- The background dataset to use for integrating out features. To determine the impact
55
- of a feature, that feature is set to "missing" and the change in the model output
56
- is observed. Since most models aren't designed to handle arbitrary missing data at test
57
- time, we simulate "missing" by replacing the feature with the values it takes in the
58
- background dataset. So if the background dataset is a simple sample of all zeros, then
59
- we would approximate a feature being missing by setting it to zero. For small problems,
60
- this background dataset can be the whole training set, but for larger problems consider
61
- using a single reference value or using the ``kmeans`` function to summarize the dataset.
62
- Note: for the sparse case, we accept any sparse matrix but convert to lil format for
63
- performance.
64
-
65
- feature_names : list
66
- The names of the features in the background dataset. If the background dataset is
67
- supplied as a pandas.DataFrame, then ``feature_names`` can be set to ``None`` (default),
68
- and the feature names will be taken as the column names of the dataframe.
69
-
70
- link : "identity" or "logit"
71
- A generalized linear model link to connect the feature importance values to the model
72
- output. Since the feature importance values, phi, sum up to the model output, it often makes
73
- sense to connect them to the output with a link function where link(output) = sum(phi).
74
- Default is "identity" (a no-op).
75
- If the model output is a probability, then "logit" can be used to transform the SHAP values
76
- into log-odds units.
77
-
78
- Examples
79
- --------
80
- See :ref:`Kernel Explainer Examples <kernel_explainer_examples>`.
81
- """
82
-
83
- def __init__(self, model, data, feature_names=None, link="identity", **kwargs):
84
-
85
- if feature_names is not None:
86
- self.data_feature_names=feature_names
87
- elif isinstance(data, pd.DataFrame):
88
- self.data_feature_names = list(data.columns)
89
-
90
- # convert incoming inputs to standardized iml objects
91
- self.link = convert_to_link(link)
92
- self.keep_index = kwargs.get("keep_index", False)
93
- self.keep_index_ordered = kwargs.get("keep_index_ordered", False)
94
- self.model = convert_to_model(model, keep_index=self.keep_index)
95
- self.data = convert_to_data(data, keep_index=self.keep_index)
96
- model_null = match_model_to_data(self.model, self.data)
97
-
98
- # enforce our current input type limitations
99
- if not isinstance(self.data, (DenseData, SparseData)):
100
- emsg = "Shap explainer only supports the DenseData and SparseData input currently."
101
- raise TypeError(emsg)
102
- if self.data.transposed:
103
- emsg = "Shap explainer does not support transposed DenseData or SparseData currently."
104
- raise DimensionError(emsg)
105
-
106
- # warn users about large background data sets
107
- if len(self.data.weights) > 100:
108
- log.warning("Using " + str(len(self.data.weights)) + " background data samples could cause " +
109
- "slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to " +
110
- "summarize the background as K samples.")
111
-
112
- # init our parameters
113
- self.N = self.data.data.shape[0]
114
- self.P = self.data.data.shape[1]
115
- self.linkfv = np.vectorize(self.link.f)
116
- self.nsamplesAdded = 0
117
- self.nsamplesRun = 0
118
-
119
- # find E_x[f(x)]
120
- if isinstance(model_null, (pd.DataFrame, pd.Series)):
121
- model_null = np.squeeze(model_null.values)
122
- if safe_isinstance(model_null, "tensorflow.python.framework.ops.EagerTensor"):
123
- model_null = model_null.numpy()
124
- self.fnull = np.sum((model_null.T * self.data.weights).T, 0)
125
- self.expected_value = self.linkfv(self.fnull)
126
-
127
- # see if we have a vector output
128
- self.vector_out = True
129
- if len(self.fnull.shape) == 0:
130
- self.vector_out = False
131
- self.fnull = np.array([self.fnull])
132
- self.D = 1
133
- self.expected_value = float(self.expected_value)
134
- else:
135
- self.D = self.fnull.shape[0]
136
-
137
- def __call__(self, X):
138
-
139
- start_time = time.time()
140
-
141
- if isinstance(X, pd.DataFrame):
142
- feature_names = list(X.columns)
143
- else:
144
- feature_names = getattr(self, "data_feature_names", None)
145
-
146
- v = self.shap_values(X)
147
- if isinstance(v, list):
148
- v = np.stack(v, axis=-1) # put outputs at the end
149
-
150
- # the explanation object expects an expected value for each row
151
- if hasattr(self.expected_value, "__len__"):
152
- ev_tiled = np.tile(self.expected_value, (v.shape[0],1))
153
- else:
154
- ev_tiled = np.tile(self.expected_value, v.shape[0])
155
-
156
- return Explanation(
157
- v,
158
- base_values=ev_tiled,
159
- data=X.to_numpy() if isinstance(X, pd.DataFrame) else X,
160
- feature_names=feature_names,
161
- compute_time=time.time() - start_time,
162
- )
163
-
164
- def shap_values(self, X, **kwargs):
165
- """ Estimate the SHAP values for a set of samples.
166
-
167
- Parameters
168
- ----------
169
- X : numpy.array or pandas.DataFrame or any scipy.sparse matrix
170
- A matrix of samples (# samples x # features) on which to explain the model's output.
171
-
172
- nsamples : "auto" or int
173
- Number of times to re-evaluate the model when explaining each prediction. More samples
174
- lead to lower variance estimates of the SHAP values. The "auto" setting uses
175
- `nsamples = 2 * X.shape[1] + 2048`.
176
-
177
- l1_reg : "num_features(int)", "auto" (default for now, but deprecated), "aic", "bic", or float
178
- The l1 regularization to use for feature selection (the estimation procedure is based on
179
- a debiased lasso). The auto option currently uses "aic" when less that 20% of the possible sample
180
- space is enumerated, otherwise it uses no regularization. THE BEHAVIOR OF "auto" WILL CHANGE
181
- in a future version to be based on num_features instead of AIC.
182
- The "aic" and "bic" options use the AIC and BIC rules for regularization.
183
- Using "num_features(int)" selects a fix number of top features. Passing a float directly sets the
184
- "alpha" parameter of the sklearn.linear_model.Lasso model used for feature selection.
185
-
186
- gc_collect : bool
187
- Run garbage collection after each explanation round. Sometime needed for memory intensive explanations (default False).
188
-
189
- Returns
190
- -------
191
- array or list
192
- For models with a single output this returns a matrix of SHAP values
193
- (# samples x # features). Each row sums to the difference between the model output for that
194
- sample and the expected value of the model output (which is stored as expected_value
195
- attribute of the explainer). For models with vector outputs this returns a list
196
- of such matrices, one for each output.
197
- """
198
-
199
- # convert dataframes
200
- if isinstance(X, pd.Series):
201
- X = X.values
202
- elif isinstance(X, pd.DataFrame):
203
- if self.keep_index:
204
- index_value = X.index.values
205
- index_name = X.index.name
206
- column_name = list(X.columns)
207
- X = X.values
208
-
209
- x_type = str(type(X))
210
- arr_type = "'numpy.ndarray'>"
211
- # if sparse, convert to lil for performance
212
- if scipy.sparse.issparse(X) and not scipy.sparse.isspmatrix_lil(X):
213
- X = X.tolil()
214
- assert x_type.endswith(arr_type) or scipy.sparse.isspmatrix_lil(X), "Unknown instance type: " + x_type
215
-
216
- # single instance
217
- if len(X.shape) == 1:
218
- data = X.reshape((1, X.shape[0]))
219
- if self.keep_index:
220
- data = convert_to_instance_with_index(data, column_name, index_name, index_value)
221
- explanation = self.explain(data, **kwargs)
222
-
223
- # vector-output
224
- s = explanation.shape
225
- if len(s) == 2:
226
- outs = [np.zeros(s[0]) for j in range(s[1])]
227
- for j in range(s[1]):
228
- outs[j] = explanation[:, j]
229
- return outs
230
-
231
- # single-output
232
- else:
233
- out = np.zeros(s[0])
234
- out[:] = explanation
235
- return out
236
-
237
- # explain the whole dataset
238
- elif len(X.shape) == 2:
239
- explanations = []
240
- for i in tqdm(range(X.shape[0]), disable=kwargs.get("silent", False)):
241
- data = X[i:i + 1, :]
242
- if self.keep_index:
243
- data = convert_to_instance_with_index(data, column_name, index_value[i:i + 1], index_name)
244
- explanations.append(self.explain(data, **kwargs))
245
- if kwargs.get("gc_collect", False):
246
- gc.collect()
247
-
248
- # vector-output
249
- s = explanations[0].shape
250
- if len(s) == 2:
251
- outs = [np.zeros((X.shape[0], s[0])) for j in range(s[1])]
252
- for i in range(X.shape[0]):
253
- for j in range(s[1]):
254
- outs[j][i] = explanations[i][:, j]
255
- return outs
256
-
257
- # single-output
258
- else:
259
- out = np.zeros((X.shape[0], s[0]))
260
- for i in range(X.shape[0]):
261
- out[i] = explanations[i]
262
- return out
263
-
264
- else:
265
- emsg = "Instance must have 1 or 2 dimensions!"
266
- raise DimensionError(emsg)
267
-
268
- def explain(self, incoming_instance, **kwargs):
269
- # convert incoming input to a standardized iml object
270
- instance = convert_to_instance(incoming_instance)
271
- match_instance_to_data(instance, self.data)
272
-
273
- # find the feature groups we will test. If a feature does not change from its
274
- # current value then we know it doesn't impact the model
275
- self.varyingInds = self.varying_groups(instance.x)
276
- if self.data.groups is None:
277
- self.varyingFeatureGroups = np.array([i for i in self.varyingInds])
278
- self.M = self.varyingFeatureGroups.shape[0]
279
- else:
280
- self.varyingFeatureGroups = [self.data.groups[i] for i in self.varyingInds]
281
- self.M = len(self.varyingFeatureGroups)
282
- groups = self.data.groups
283
- # convert to numpy array as it is much faster if not jagged array (all groups of same length)
284
- if self.varyingFeatureGroups and all(len(groups[i]) == len(groups[0]) for i in self.varyingInds):
285
- self.varyingFeatureGroups = np.array(self.varyingFeatureGroups)
286
- # further performance optimization in case each group has a single value
287
- if self.varyingFeatureGroups.shape[1] == 1:
288
- self.varyingFeatureGroups = self.varyingFeatureGroups.flatten()
289
-
290
- # find f(x)
291
- if self.keep_index:
292
- model_out = self.model.f(instance.convert_to_df())
293
- else:
294
- model_out = self.model.f(instance.x)
295
- if isinstance(model_out, (pd.DataFrame, pd.Series)):
296
- model_out = model_out.values
297
- self.fx = model_out[0]
298
-
299
- if not self.vector_out:
300
- self.fx = np.array([self.fx])
301
-
302
- # if no features vary then no feature has an effect
303
- if self.M == 0:
304
- phi = np.zeros((self.data.groups_size, self.D))
305
- phi_var = np.zeros((self.data.groups_size, self.D))
306
-
307
- # if only one feature varies then it has all the effect
308
- elif self.M == 1:
309
- phi = np.zeros((self.data.groups_size, self.D))
310
- phi_var = np.zeros((self.data.groups_size, self.D))
311
- diff = self.link.f(self.fx) - self.link.f(self.fnull)
312
- for d in range(self.D):
313
- phi[self.varyingInds[0],d] = diff[d]
314
-
315
- # if more than one feature varies then we have to do real work
316
- else:
317
- self.l1_reg = kwargs.get("l1_reg", "auto")
318
-
319
- # pick a reasonable number of samples if the user didn't specify how many they wanted
320
- self.nsamples = kwargs.get("nsamples", "auto")
321
- if self.nsamples == "auto":
322
- self.nsamples = 2 * self.M + 2**11
323
-
324
- # if we have enough samples to enumerate all subsets then ignore the unneeded samples
325
- self.max_samples = 2 ** 30
326
- if self.M <= 30:
327
- self.max_samples = 2 ** self.M - 2
328
- if self.nsamples > self.max_samples:
329
- self.nsamples = self.max_samples
330
-
331
- # reserve space for some of our computations
332
- self.allocate()
333
-
334
- # weight the different subset sizes
335
- num_subset_sizes = int(np.ceil((self.M - 1) / 2.0))
336
- num_paired_subset_sizes = int(np.floor((self.M - 1) / 2.0))
337
- weight_vector = np.array([(self.M - 1.0) / (i * (self.M - i)) for i in range(1, num_subset_sizes + 1)])
338
- weight_vector[:num_paired_subset_sizes] *= 2
339
- weight_vector /= np.sum(weight_vector)
340
- log.debug(f"{weight_vector = }")
341
- log.debug(f"{num_subset_sizes = }")
342
- log.debug(f"{num_paired_subset_sizes = }")
343
- log.debug(f"{self.M = }")
344
-
345
- # fill out all the subset sizes we can completely enumerate
346
- # given nsamples*remaining_weight_vector[subset_size]
347
- num_full_subsets = 0
348
- num_samples_left = self.nsamples
349
- group_inds = np.arange(self.M, dtype='int64')
350
- mask = np.zeros(self.M)
351
- remaining_weight_vector = copy.copy(weight_vector)
352
- for subset_size in range(1, num_subset_sizes + 1):
353
-
354
- # determine how many subsets (and their complements) are of the current size
355
- nsubsets = binom(self.M, subset_size)
356
- if subset_size <= num_paired_subset_sizes:
357
- nsubsets *= 2
358
- log.debug(f"{subset_size = }")
359
- log.debug(f"{nsubsets = }")
360
- log.debug("self.nsamples*weight_vector[subset_size-1] = {}".format(
361
- num_samples_left * remaining_weight_vector[subset_size - 1]))
362
- log.debug("self.nsamples*weight_vector[subset_size-1]/nsubsets = {}".format(
363
- num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets))
364
-
365
- # see if we have enough samples to enumerate all subsets of this size
366
- if num_samples_left * remaining_weight_vector[subset_size - 1] / nsubsets >= 1.0 - 1e-8:
367
- num_full_subsets += 1
368
- num_samples_left -= nsubsets
369
-
370
- # rescale what's left of the remaining weight vector to sum to 1
371
- if remaining_weight_vector[subset_size - 1] < 1.0:
372
- remaining_weight_vector /= (1 - remaining_weight_vector[subset_size - 1])
373
-
374
- # add all the samples of the current subset size
375
- w = weight_vector[subset_size - 1] / binom(self.M, subset_size)
376
- if subset_size <= num_paired_subset_sizes:
377
- w /= 2.0
378
- for inds in itertools.combinations(group_inds, subset_size):
379
- mask[:] = 0.0
380
- mask[np.array(inds, dtype='int64')] = 1.0
381
- self.addsample(instance.x, mask, w)
382
- if subset_size <= num_paired_subset_sizes:
383
- mask[:] = np.abs(mask - 1)
384
- self.addsample(instance.x, mask, w)
385
- else:
386
- break
387
- log.info(f"{num_full_subsets = }")
388
-
389
- # add random samples from what is left of the subset space
390
- nfixed_samples = self.nsamplesAdded
391
- samples_left = self.nsamples - self.nsamplesAdded
392
- log.debug(f"{samples_left = }")
393
- if num_full_subsets != num_subset_sizes:
394
- remaining_weight_vector = copy.copy(weight_vector)
395
- remaining_weight_vector[:num_paired_subset_sizes] /= 2 # because we draw two samples each below
396
- remaining_weight_vector = remaining_weight_vector[num_full_subsets:]
397
- remaining_weight_vector /= np.sum(remaining_weight_vector)
398
- log.info(f"{remaining_weight_vector = }")
399
- log.info(f"{num_paired_subset_sizes = }")
400
- ind_set = np.random.choice(len(remaining_weight_vector), 4 * samples_left, p=remaining_weight_vector)
401
- ind_set_pos = 0
402
- used_masks = {}
403
- while samples_left > 0 and ind_set_pos < len(ind_set):
404
- mask.fill(0.0)
405
- ind = ind_set[ind_set_pos] # we call np.random.choice once to save time and then just read it here
406
- ind_set_pos += 1
407
- subset_size = ind + num_full_subsets + 1
408
- mask[np.random.permutation(self.M)[:subset_size]] = 1.0
409
-
410
- # only add the sample if we have not seen it before, otherwise just
411
- # increment a previous sample's weight
412
- mask_tuple = tuple(mask)
413
- new_sample = False
414
- if mask_tuple not in used_masks:
415
- new_sample = True
416
- used_masks[mask_tuple] = self.nsamplesAdded
417
- samples_left -= 1
418
- self.addsample(instance.x, mask, 1.0)
419
- else:
420
- self.kernelWeights[used_masks[mask_tuple]] += 1.0
421
-
422
- # add the compliment sample
423
- if samples_left > 0 and subset_size <= num_paired_subset_sizes:
424
- mask[:] = np.abs(mask - 1)
425
-
426
- # only add the sample if we have not seen it before, otherwise just
427
- # increment a previous sample's weight
428
- if new_sample:
429
- samples_left -= 1
430
- self.addsample(instance.x, mask, 1.0)
431
- else:
432
- # we know the compliment sample is the next one after the original sample, so + 1
433
- self.kernelWeights[used_masks[mask_tuple] + 1] += 1.0
434
-
435
- # normalize the kernel weights for the random samples to equal the weight left after
436
- # the fixed enumerated samples have been already counted
437
- weight_left = np.sum(weight_vector[num_full_subsets:])
438
- log.info(f"{weight_left = }")
439
- self.kernelWeights[nfixed_samples:] *= weight_left / self.kernelWeights[nfixed_samples:].sum()
440
-
441
- # execute the model on the synthetic samples we have created
442
- self.run()
443
-
444
- # solve then expand the feature importance (Shapley value) vector to contain the non-varying features
445
- phi = np.zeros((self.data.groups_size, self.D))
446
- phi_var = np.zeros((self.data.groups_size, self.D))
447
- for d in range(self.D):
448
- vphi, vphi_var = self.solve(self.nsamples / self.max_samples, d)
449
- phi[self.varyingInds, d] = vphi
450
- phi_var[self.varyingInds, d] = vphi_var
451
-
452
- if not self.vector_out:
453
- phi = np.squeeze(phi, axis=1)
454
- phi_var = np.squeeze(phi_var, axis=1)
455
-
456
- return phi
457
-
458
- @staticmethod
459
- def not_equal(i, j):
460
- number_types = (int, float, np.number)
461
- if isinstance(i, number_types) and isinstance(j, number_types):
462
- return 0 if np.isclose(i, j, equal_nan=True) else 1
463
- else:
464
- return 0 if i == j else 1
465
-
466
- def varying_groups(self, x):
467
- if not scipy.sparse.issparse(x):
468
- varying = np.zeros(self.data.groups_size)
469
- for i in range(0, self.data.groups_size):
470
- inds = self.data.groups[i]
471
- x_group = x[0, inds]
472
- if scipy.sparse.issparse(x_group):
473
- if all(j not in x.nonzero()[1] for j in inds):
474
- varying[i] = False
475
- continue
476
- x_group = x_group.todense()
477
- num_mismatches = np.sum(np.frompyfunc(self.not_equal, 2, 1)(x_group, self.data.data[:, inds]))
478
- varying[i] = num_mismatches > 0
479
- varying_indices = np.nonzero(varying)[0]
480
- return varying_indices
481
- else:
482
- varying_indices = []
483
- # go over all nonzero columns in background and evaluation data
484
- # if both background and evaluation are zero, the column does not vary
485
- varying_indices = np.unique(np.union1d(self.data.data.nonzero()[1], x.nonzero()[1]))
486
- remove_unvarying_indices = []
487
- for i in range(0, len(varying_indices)):
488
- varying_index = varying_indices[i]
489
- # now verify the nonzero values do vary
490
- data_rows = self.data.data[:, [varying_index]]
491
- nonzero_rows = data_rows.nonzero()[0]
492
-
493
- if nonzero_rows.size > 0:
494
- background_data_rows = data_rows[nonzero_rows]
495
- if scipy.sparse.issparse(background_data_rows):
496
- background_data_rows = background_data_rows.toarray()
497
- num_mismatches = np.sum(np.abs(background_data_rows - x[0, varying_index]) > 1e-7)
498
- # Note: If feature column non-zero but some background zero, can't remove index
499
- if num_mismatches == 0 and not \
500
- (np.abs(x[0, [varying_index]][0, 0]) > 1e-7 and len(nonzero_rows) < data_rows.shape[0]):
501
- remove_unvarying_indices.append(i)
502
- mask = np.ones(len(varying_indices), dtype=bool)
503
- mask[remove_unvarying_indices] = False
504
- varying_indices = varying_indices[mask]
505
- return varying_indices
506
-
507
- def allocate(self):
508
- if scipy.sparse.issparse(self.data.data):
509
- # We tile the sparse matrix in csr format but convert it to lil
510
- # for performance when adding samples
511
- shape = self.data.data.shape
512
- nnz = self.data.data.nnz
513
- data_rows, data_cols = shape
514
- rows = data_rows * self.nsamples
515
- shape = rows, data_cols
516
- if nnz == 0:
517
- self.synth_data = scipy.sparse.csr_matrix(shape, dtype=self.data.data.dtype).tolil()
518
- else:
519
- data = self.data.data.data
520
- indices = self.data.data.indices
521
- indptr = self.data.data.indptr
522
- last_indptr_idx = indptr[len(indptr) - 1]
523
- indptr_wo_last = indptr[:-1]
524
- new_indptrs = []
525
- for i in range(0, self.nsamples - 1):
526
- new_indptrs.append(indptr_wo_last + (i * last_indptr_idx))
527
- new_indptrs.append(indptr + ((self.nsamples - 1) * last_indptr_idx))
528
- new_indptr = np.concatenate(new_indptrs)
529
- new_data = np.tile(data, self.nsamples)
530
- new_indices = np.tile(indices, self.nsamples)
531
- self.synth_data = scipy.sparse.csr_matrix((new_data, new_indices, new_indptr), shape=shape).tolil()
532
- else:
533
- self.synth_data = np.tile(self.data.data, (self.nsamples, 1))
534
-
535
- self.maskMatrix = np.zeros((self.nsamples, self.M))
536
- self.kernelWeights = np.zeros(self.nsamples)
537
- self.y = np.zeros((self.nsamples * self.N, self.D))
538
- self.ey = np.zeros((self.nsamples, self.D))
539
- self.lastMask = np.zeros(self.nsamples)
540
- self.nsamplesAdded = 0
541
- self.nsamplesRun = 0
542
- if self.keep_index:
543
- self.synth_data_index = np.tile(self.data.index_value, self.nsamples)
544
-
545
- def addsample(self, x, m, w):
546
- offset = self.nsamplesAdded * self.N
547
- if isinstance(self.varyingFeatureGroups, (list,)):
548
- for j in range(self.M):
549
- for k in self.varyingFeatureGroups[j]:
550
- if m[j] == 1.0:
551
- self.synth_data[offset:offset+self.N, k] = x[0, k]
552
- else:
553
- # for non-jagged numpy array we can significantly boost performance
554
- mask = m == 1.0
555
- groups = self.varyingFeatureGroups[mask]
556
- if len(groups.shape) == 2:
557
- for group in groups:
558
- self.synth_data[offset:offset+self.N, group] = x[0, group]
559
- else:
560
- # further performance optimization in case each group has a single feature
561
- evaluation_data = x[0, groups]
562
- # In edge case where background is all dense but evaluation data
563
- # is all sparse, make evaluation data dense
564
- if scipy.sparse.issparse(x) and not scipy.sparse.issparse(self.synth_data):
565
- evaluation_data = evaluation_data.toarray()
566
- self.synth_data[offset:offset+self.N, groups] = evaluation_data
567
- self.maskMatrix[self.nsamplesAdded, :] = m
568
- self.kernelWeights[self.nsamplesAdded] = w
569
- self.nsamplesAdded += 1
570
-
571
- def run(self):
572
- num_to_run = self.nsamplesAdded * self.N - self.nsamplesRun * self.N
573
- data = self.synth_data[self.nsamplesRun*self.N:self.nsamplesAdded*self.N,:]
574
- if self.keep_index:
575
- index = self.synth_data_index[self.nsamplesRun*self.N:self.nsamplesAdded*self.N]
576
- index = pd.DataFrame(index, columns=[self.data.index_name])
577
- data = pd.DataFrame(data, columns=self.data.group_names)
578
- data = pd.concat([index, data], axis=1).set_index(self.data.index_name)
579
- if self.keep_index_ordered:
580
- data = data.sort_index()
581
- modelOut = self.model.f(data)
582
- if isinstance(modelOut, (pd.DataFrame, pd.Series)):
583
- modelOut = modelOut.values
584
- self.y[self.nsamplesRun * self.N:self.nsamplesAdded * self.N, :] = np.reshape(modelOut, (num_to_run, self.D))
585
-
586
- # find the expected value of each output
587
- for i in range(self.nsamplesRun, self.nsamplesAdded):
588
- eyVal = np.zeros(self.D)
589
- for j in range(0, self.N):
590
- eyVal += self.y[i * self.N + j, :] * self.data.weights[j]
591
-
592
- self.ey[i, :] = eyVal
593
- self.nsamplesRun += 1
594
-
595
- def solve(self, fraction_evaluated, dim):
596
- eyAdj = self.linkfv(self.ey[:, dim]) - self.link.f(self.fnull[dim])
597
- s = np.sum(self.maskMatrix, 1)
598
-
599
- # do feature selection if we have not well enumerated the space
600
- nonzero_inds = np.arange(self.M)
601
- log.debug(f"{fraction_evaluated = }")
602
- # if self.l1_reg == "auto":
603
- # warnings.warn(
604
- # "l1_reg=\"auto\" is deprecated and in the next version (v0.29) the behavior will change from a " \
605
- # "conditional use of AIC to simply \"num_features(10)\"!"
606
- # )
607
- if (self.l1_reg not in ["auto", False, 0]) or (fraction_evaluated < 0.2 and self.l1_reg == "auto"):
608
- w_aug = np.hstack((self.kernelWeights * (self.M - s), self.kernelWeights * s))
609
- log.info(f"{np.sum(w_aug) = }")
610
- log.info(f"{np.sum(self.kernelWeights) = }")
611
- w_sqrt_aug = np.sqrt(w_aug)
612
- eyAdj_aug = np.hstack((eyAdj, eyAdj - (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))))
613
- eyAdj_aug *= w_sqrt_aug
614
- mask_aug = np.transpose(w_sqrt_aug * np.transpose(np.vstack((self.maskMatrix, self.maskMatrix - 1))))
615
- #var_norms = np.array([np.linalg.norm(mask_aug[:, i]) for i in range(mask_aug.shape[1])])
616
-
617
- # select a fixed number of top features
618
- if isinstance(self.l1_reg, str) and self.l1_reg.startswith("num_features("):
619
- r = int(self.l1_reg[len("num_features("):-1])
620
- nonzero_inds = lars_path(mask_aug, eyAdj_aug, max_iter=r)[1]
621
-
622
- # use an adaptive regularization method
623
- elif self.l1_reg == "auto" or self.l1_reg == "bic" or self.l1_reg == "aic":
624
- c = "aic" if self.l1_reg == "auto" else self.l1_reg
625
-
626
- # "Normalize" parameter of LassoLarsIC was deprecated in sklearn version 1.2
627
- if version.parse(sklearn.__version__) < version.parse("1.2.0"):
628
- kwg = dict(normalize=False)
629
- else:
630
- kwg = {}
631
- model = make_pipeline(StandardScaler(with_mean=False), LassoLarsIC(criterion=c, **kwg))
632
- nonzero_inds = np.nonzero(model.fit(mask_aug, eyAdj_aug)[1].coef_)[0]
633
-
634
- # use a fixed regularization coefficient
635
- else:
636
- nonzero_inds = np.nonzero(Lasso(alpha=self.l1_reg).fit(mask_aug, eyAdj_aug).coef_)[0]
637
-
638
- if len(nonzero_inds) == 0:
639
- return np.zeros(self.M), np.ones(self.M)
640
-
641
- # eliminate one variable with the constraint that all features sum to the output
642
- eyAdj2 = eyAdj - self.maskMatrix[:, nonzero_inds[-1]] * (
643
- self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim]))
644
- etmp = np.transpose(np.transpose(self.maskMatrix[:, nonzero_inds[:-1]]) - self.maskMatrix[:, nonzero_inds[-1]])
645
- log.debug(f"{etmp[:4, :] = }")
646
-
647
- # solve a weighted least squares equation to estimate phi
648
- # least squares:
649
- # phi = min_w ||W^(1/2) (y - X w)||^2
650
- # the corresponding normal equation:
651
- # (X' W X) phi = X' W y
652
- # with
653
- # X = etmp
654
- # W = np.diag(self.kernelWeights)
655
- # y = eyAdj2
656
- #
657
- # We could just rely on sciki-learn
658
- # from sklearn.linear_model import LinearRegression
659
- # lm = LinearRegression(fit_intercept=False).fit(etmp, eyAdj2, sample_weight=self.kernelWeights)
660
- # Under the hood, as of scikit-learn version 1.3, LinearRegression still uses np.linalg.lstsq and
661
- # there are more performant options. See https://github.com/scikit-learn/scikit-learn/issues/22855.
662
- y = eyAdj2
663
- X = etmp
664
- WX = self.kernelWeights[:, None] * X
665
- try:
666
- w = np.linalg.solve(X.T @ WX, WX.T @ y)
667
- except np.linalg.LinAlgError:
668
- warnings.warn(
669
- "Linear regression equation is singular, a least squares solutions is used instead.\n"
670
- "To avoid this situation and get a regular matrix do one of the following:\n"
671
- "1) turn up the number of samples,\n"
672
- "2) turn up the L1 regularization with num_features(N) where N is less than the number of samples,\n"
673
- "3) group features together to reduce the number of inputs that need to be explained."
674
- )
675
- # XWX = np.linalg.pinv(X.T @ WX)
676
- # w = np.dot(XWX, np.dot(np.transpose(WX), y))
677
- sqrt_W = np.sqrt(self.kernelWeights)
678
- w = np.linalg.lstsq(sqrt_W[:, None] * X, sqrt_W * y, rcond=None)[0]
679
- log.debug(f"{np.sum(w) = }")
680
- log.debug("self.link(self.fx) - self.link(self.fnull) = {}".format(
681
- self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim])))
682
- log.debug(f"self.fx = {self.fx[dim]}")
683
- log.debug(f"self.link(self.fx) = {self.link.f(self.fx[dim])}")
684
- log.debug(f"self.fnull = {self.fnull[dim]}")
685
- log.debug(f"self.link(self.fnull) = {self.link.f(self.fnull[dim])}")
686
- phi = np.zeros(self.M)
687
- phi[nonzero_inds[:-1]] = w
688
- phi[nonzero_inds[-1]] = (self.link.f(self.fx[dim]) - self.link.f(self.fnull[dim])) - sum(w)
689
- log.info(f"{phi = }")
690
-
691
- # clean up any rounding errors
692
- for i in range(self.M):
693
- if np.abs(phi[i]) < 1e-10:
694
- phi[i] = 0
695
-
696
- return phi, np.ones(len(phi))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_linear.py DELETED
@@ -1,406 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
- import pandas as pd
5
- from scipy.sparse import issparse
6
- from tqdm.auto import tqdm
7
-
8
- from .. import links, maskers
9
- from ..utils._exceptions import (
10
- DimensionError,
11
- InvalidFeaturePerturbationError,
12
- InvalidModelError,
13
- )
14
- from ._explainer import Explainer
15
-
16
-
17
- class LinearExplainer(Explainer):
18
- """ Computes SHAP values for a linear model, optionally accounting for inter-feature correlations.
19
-
20
- This computes the SHAP values for a linear model and can account for the correlations among
21
- the input features. Assuming features are independent leads to interventional SHAP values which
22
- for a linear model are coef[i] * (x[i] - X.mean(0)[i]) for the ith feature. If instead we account
23
- for correlations then we prevent any problems arising from collinearity and share credit among
24
- correlated features. Accounting for correlations can be computationally challenging, but
25
- LinearExplainer uses sampling to estimate a transform that can then be applied to explain
26
- any prediction of the model.
27
-
28
- Parameters
29
- ----------
30
- model : (coef, intercept) or sklearn.linear_model.*
31
- User supplied linear model either as either a parameter pair or sklearn object.
32
-
33
- data : (mean, cov), numpy.array, pandas.DataFrame, iml.DenseData or scipy.csr_matrix
34
- The background dataset to use for computing conditional expectations. Note that only the
35
- mean and covariance of the dataset are used. This means passing a raw data matrix is just
36
- a convenient alternative to passing the mean and covariance directly.
37
- nsamples : int
38
- Number of samples to use when estimating the transformation matrix used to account for
39
- feature correlations.
40
- feature_perturbation : "interventional" (default) or "correlation_dependent"
41
- There are two ways we might want to compute SHAP values, either the full conditional SHAP
42
- values or the interventional SHAP values. For interventional SHAP values we break any
43
- dependence structure between features in the model and so uncover how the model would behave if we
44
- intervened and changed some of the inputs. For the full conditional SHAP values we respect
45
- the correlations among the input features, so if the model depends on one input but that
46
- input is correlated with another input, then both get some credit for the model's behavior. The
47
- interventional option stays "true to the model" meaning it will only give credit to features that are
48
- actually used by the model, while the correlation option stays "true to the data" in the sense that
49
- it only considers how the model would behave when respecting the correlations in the input data.
50
- For sparse case only interventional option is supported.
51
-
52
- Examples
53
- --------
54
- See `Linear explainer examples <https://shap.readthedocs.io/en/latest/api_examples/explainers/LinearExplainer.html>`_
55
- """
56
-
57
- def __init__(self, model, masker, link=links.identity, nsamples=1000, feature_perturbation=None, **kwargs):
58
- if 'feature_dependence' in kwargs:
59
- warnings.warn('The option feature_dependence has been renamed to feature_perturbation!')
60
- feature_perturbation = kwargs["feature_dependence"]
61
- if feature_perturbation == "independent":
62
- warnings.warn('The option feature_perturbation="independent" is has been renamed to feature_perturbation="interventional"!')
63
- feature_perturbation = "interventional"
64
- elif feature_perturbation == "correlation":
65
- warnings.warn('The option feature_perturbation="correlation" is has been renamed to feature_perturbation="correlation_dependent"!')
66
- feature_perturbation = "correlation_dependent"
67
- if feature_perturbation is not None:
68
- warnings.warn("The feature_perturbation option is now deprecated in favor of using the appropriate masker (maskers.Independent, or maskers.Impute)")
69
- else:
70
- feature_perturbation = "interventional"
71
- self.feature_perturbation = feature_perturbation
72
-
73
- # wrap the incoming masker object as a shap.Masker object before calling
74
- # parent class constructor, which does the same but without respecting
75
- # the user-provided feature_perturbation choice
76
- if isinstance(masker, pd.DataFrame) or ((isinstance(masker, np.ndarray) or issparse(masker)) and len(masker.shape) == 2):
77
- if self.feature_perturbation == "correlation_dependent":
78
- masker = maskers.Impute(masker)
79
- else:
80
- masker = maskers.Independent(masker)
81
- elif issubclass(type(masker), tuple) and len(masker) == 2:
82
- if self.feature_perturbation == "correlation_dependent":
83
- masker = maskers.Impute({"mean": masker[0], "cov": masker[1]}, method="linear")
84
- else:
85
- masker = maskers.Independent({"mean": masker[0], "cov": masker[1]})
86
-
87
- super().__init__(model, masker, link=link, **kwargs)
88
-
89
- self.nsamples = nsamples
90
-
91
-
92
- # extract what we need from the given model object
93
- self.coef, self.intercept = LinearExplainer._parse_model(model)
94
-
95
- # extract the data
96
- if issubclass(type(self.masker), (maskers.Independent, maskers.Partition)):
97
- self.feature_perturbation = "interventional"
98
- elif issubclass(type(self.masker), maskers.Impute):
99
- self.feature_perturbation = "correlation_dependent"
100
- else:
101
- raise NotImplementedError("The Linear explainer only supports the Independent, Partition, and Impute maskers right now!")
102
- data = getattr(self.masker, "data", None)
103
-
104
- # convert DataFrame's to numpy arrays
105
- if isinstance(data, pd.DataFrame):
106
- data = data.values
107
-
108
- # get the mean and covariance of the model
109
- if getattr(self.masker, "mean", None) is not None:
110
- self.mean = self.masker.mean
111
- self.cov = self.masker.cov
112
- elif isinstance(data, dict) and len(data) == 2:
113
- self.mean = data["mean"]
114
- if isinstance(self.mean, pd.Series):
115
- self.mean = self.mean.values
116
-
117
- self.cov = data["cov"]
118
- if isinstance(self.cov, pd.DataFrame):
119
- self.cov = self.cov.values
120
- elif isinstance(data, tuple) and len(data) == 2:
121
- self.mean = data[0]
122
- if isinstance(self.mean, pd.Series):
123
- self.mean = self.mean.values
124
-
125
- self.cov = data[1]
126
- if isinstance(self.cov, pd.DataFrame):
127
- self.cov = self.cov.values
128
- elif data is None:
129
- raise ValueError("A background data distribution must be provided!")
130
- else:
131
- if issparse(data):
132
- self.mean = np.array(np.mean(data, 0))[0]
133
- if self.feature_perturbation != "interventional":
134
- raise NotImplementedError("Only feature_perturbation = 'interventional' is supported for sparse data")
135
- else:
136
- self.mean = np.array(np.mean(data, 0)).flatten() # assumes it is an array
137
- if self.feature_perturbation == "correlation_dependent":
138
- self.cov = np.cov(data, rowvar=False)
139
- #print(self.coef, self.mean.flatten(), self.intercept)
140
- # Note: mean can be numpy.matrixlib.defmatrix.matrix or numpy.matrix type depending on numpy version
141
- if issparse(self.mean) or str(type(self.mean)).endswith("matrix'>"):
142
- # accept both sparse and dense coef
143
- # if not issparse(self.coef):
144
- # self.coef = np.asmatrix(self.coef)
145
- self.expected_value = np.dot(self.coef, self.mean) + self.intercept
146
-
147
- # unwrap the matrix form
148
- if len(self.expected_value) == 1:
149
- self.expected_value = self.expected_value[0,0]
150
- else:
151
- self.expected_value = np.array(self.expected_value)[0]
152
- else:
153
- self.expected_value = np.dot(self.coef, self.mean) + self.intercept
154
-
155
- self.M = len(self.mean)
156
-
157
- # if needed, estimate the transform matrices
158
- if self.feature_perturbation == "correlation_dependent":
159
- self.valid_inds = np.where(np.diag(self.cov) > 1e-8)[0]
160
- self.mean = self.mean[self.valid_inds]
161
- self.cov = self.cov[:,self.valid_inds][self.valid_inds,:]
162
- self.coef = self.coef[self.valid_inds]
163
-
164
- # group perfectly redundant variables together
165
- self.avg_proj,sum_proj = duplicate_components(self.cov)
166
- self.cov = np.matmul(np.matmul(self.avg_proj, self.cov), self.avg_proj.T)
167
- self.mean = np.matmul(self.avg_proj, self.mean)
168
- self.coef = np.matmul(sum_proj, self.coef)
169
-
170
- # if we still have some multi-collinearity present then we just add regularization...
171
- e,_ = np.linalg.eig(self.cov)
172
- if e.min() < 1e-7:
173
- self.cov = self.cov + np.eye(self.cov.shape[0]) * 1e-6
174
-
175
- mean_transform, x_transform = self._estimate_transforms(nsamples)
176
- self.mean_transformed = np.matmul(mean_transform, self.mean)
177
- self.x_transform = x_transform
178
- elif self.feature_perturbation == "interventional":
179
- if nsamples != 1000:
180
- warnings.warn("Setting nsamples has no effect when feature_perturbation = 'interventional'!")
181
- else:
182
- raise InvalidFeaturePerturbationError("Unknown type of feature_perturbation provided: " + self.feature_perturbation)
183
-
184
- def _estimate_transforms(self, nsamples):
185
- """ Uses block matrix inversion identities to quickly estimate transforms.
186
-
187
- After a bit of matrix math we can isolate a transform matrix (# features x # features)
188
- that is independent of any sample we are explaining. It is the result of averaging over
189
- all feature permutations, but we just use a fixed number of samples to estimate the value.
190
-
191
- TODO: Do a brute force enumeration when # feature subsets is less than nsamples. This could
192
- happen through a recursive method that uses the same block matrix inversion as below.
193
- """
194
- M = len(self.coef)
195
-
196
- mean_transform = np.zeros((M,M))
197
- x_transform = np.zeros((M,M))
198
- inds = np.arange(M, dtype=int)
199
- for _ in tqdm(range(nsamples), "Estimating transforms"):
200
- np.random.shuffle(inds)
201
- cov_inv_SiSi = np.zeros((0,0))
202
- cov_Si = np.zeros((M,0))
203
- for j in range(M):
204
- i = inds[j]
205
-
206
- # use the last Si as the new S
207
- cov_S = cov_Si
208
- cov_inv_SS = cov_inv_SiSi
209
-
210
- # get the new cov_Si
211
- cov_Si = self.cov[:,inds[:j+1]]
212
-
213
- # compute the new cov_inv_SiSi from cov_inv_SS
214
- d = cov_Si[i,:-1].T
215
- t = np.matmul(cov_inv_SS, d)
216
- Z = self.cov[i, i]
217
- u = Z - np.matmul(t.T, d)
218
- cov_inv_SiSi = np.zeros((j+1, j+1))
219
- if j > 0:
220
- cov_inv_SiSi[:-1, :-1] = cov_inv_SS + np.outer(t, t) / u
221
- cov_inv_SiSi[:-1, -1] = cov_inv_SiSi[-1,:-1] = -t / u
222
- cov_inv_SiSi[-1, -1] = 1 / u
223
-
224
- # + coef @ (Q(bar(Sui)) - Q(bar(S)))
225
- mean_transform[i, i] += self.coef[i]
226
-
227
- # + coef @ R(Sui)
228
- coef_R_Si = np.matmul(self.coef[inds[j+1:]], np.matmul(cov_Si, cov_inv_SiSi)[inds[j+1:]])
229
- mean_transform[i, inds[:j+1]] += coef_R_Si
230
-
231
- # - coef @ R(S)
232
- coef_R_S = np.matmul(self.coef[inds[j:]], np.matmul(cov_S, cov_inv_SS)[inds[j:]])
233
- mean_transform[i, inds[:j]] -= coef_R_S
234
-
235
- # - coef @ (Q(Sui) - Q(S))
236
- x_transform[i, i] += self.coef[i]
237
-
238
- # + coef @ R(Sui)
239
- x_transform[i, inds[:j+1]] += coef_R_Si
240
-
241
- # - coef @ R(S)
242
- x_transform[i, inds[:j]] -= coef_R_S
243
-
244
- mean_transform /= nsamples
245
- x_transform /= nsamples
246
- return mean_transform, x_transform
247
-
248
- @staticmethod
249
- def _parse_model(model):
250
- """ Attempt to pull out the coefficients and intercept from the given model object.
251
- """
252
- # raw coefficients
253
- if type(model) == tuple and len(model) == 2:
254
- coef = model[0]
255
- intercept = model[1]
256
-
257
- # sklearn style model
258
- elif hasattr(model, "coef_") and hasattr(model, "intercept_"):
259
- # work around for multi-class with a single class
260
- if len(model.coef_.shape) > 1 and model.coef_.shape[0] == 1:
261
- coef = model.coef_[0]
262
- try:
263
- intercept = model.intercept_[0]
264
- except TypeError:
265
- intercept = model.intercept_
266
- else:
267
- coef = model.coef_
268
- intercept = model.intercept_
269
- else:
270
- raise InvalidModelError("An unknown model type was passed: " + str(type(model)))
271
-
272
- return coef,intercept
273
-
274
- @staticmethod
275
- def supports_model_with_masker(model, masker):
276
- """ Determines if we can parse the given model.
277
- """
278
-
279
- if not isinstance(masker, (maskers.Independent, maskers.Partition, maskers.Impute)):
280
- return False
281
-
282
- try:
283
- LinearExplainer._parse_model(model)
284
- except Exception:
285
- return False
286
- return True
287
-
288
- def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent):
289
- """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
290
- """
291
-
292
- assert len(row_args) == 1, "Only single-argument functions are supported by the Linear explainer!"
293
-
294
- X = row_args[0]
295
- if len(X.shape) == 1:
296
- X = X.reshape(1, -1)
297
-
298
- # convert dataframes
299
- if isinstance(X, (pd.Series, pd.DataFrame)):
300
- X = X.values
301
-
302
- if len(X.shape) not in (1, 2):
303
- raise DimensionError("Instance must have 1 or 2 dimensions! Not: %s" %len(X.shape))
304
-
305
- if self.feature_perturbation == "correlation_dependent":
306
- if issparse(X):
307
- raise InvalidFeaturePerturbationError("Only feature_perturbation = 'interventional' is supported for sparse data")
308
- phi = np.matmul(np.matmul(X[:,self.valid_inds], self.avg_proj.T), self.x_transform.T) - self.mean_transformed
309
- phi = np.matmul(phi, self.avg_proj)
310
-
311
- full_phi = np.zeros((phi.shape[0], self.M))
312
- full_phi[:,self.valid_inds] = phi
313
- phi = full_phi
314
-
315
- elif self.feature_perturbation == "interventional":
316
- if issparse(X):
317
- phi = np.array(np.multiply(X - self.mean, self.coef))
318
-
319
- # if len(self.coef.shape) == 1:
320
- # return np.array(np.multiply(X - self.mean, self.coef))
321
- # else:
322
- # return [np.array(np.multiply(X - self.mean, self.coef[i])) for i in range(self.coef.shape[0])]
323
- else:
324
- phi = np.array(X - self.mean) * self.coef
325
- # if len(self.coef.shape) == 1:
326
- # phi = np.array(X - self.mean) * self.coef
327
- # return np.array(X - self.mean) * self.coef
328
- # else:
329
- # return [np.array(X - self.mean) * self.coef[i] for i in range(self.coef.shape[0])]
330
-
331
- return {
332
- "values": phi.T,
333
- "expected_values": self.expected_value,
334
- "mask_shapes": (X.shape[1:],),
335
- "main_effects": phi.T,
336
- "clustering": None
337
- }
338
-
339
-
340
- def shap_values(self, X):
341
- """ Estimate the SHAP values for a set of samples.
342
-
343
- Parameters
344
- ----------
345
- X : numpy.array, pandas.DataFrame or scipy.csr_matrix
346
- A matrix of samples (# samples x # features) on which to explain the model's output.
347
-
348
- Returns
349
- -------
350
- array or list
351
- For models with a single output this returns a matrix of SHAP values
352
- (# samples x # features). Each row sums to the difference between the model output for that
353
- sample and the expected value of the model output (which is stored as expected_value
354
- attribute of the explainer).
355
- """
356
-
357
- # convert dataframes
358
- if isinstance(X, (pd.Series, pd.DataFrame)):
359
- X = X.values
360
-
361
- # assert isinstance(X, np.ndarray), "Unknown instance type: " + str(type(X))
362
- if len(X.shape) not in (1, 2):
363
- raise DimensionError("Instance must have 1 or 2 dimensions! Not: %s" % len(X.shape))
364
-
365
- if self.feature_perturbation == "correlation_dependent":
366
- if issparse(X):
367
- raise InvalidFeaturePerturbationError("Only feature_perturbation = 'interventional' is supported for sparse data")
368
- phi = np.matmul(np.matmul(X[:,self.valid_inds], self.avg_proj.T), self.x_transform.T) - self.mean_transformed
369
- phi = np.matmul(phi, self.avg_proj)
370
-
371
- full_phi = np.zeros((phi.shape[0], self.M))
372
- full_phi[:,self.valid_inds] = phi
373
-
374
- return full_phi
375
-
376
- elif self.feature_perturbation == "interventional":
377
- if issparse(X):
378
- if len(self.coef.shape) == 1:
379
- return np.array(np.multiply(X - self.mean, self.coef))
380
- else:
381
- return [np.array(np.multiply(X - self.mean, self.coef[i])) for i in range(self.coef.shape[0])]
382
- else:
383
- if len(self.coef.shape) == 1:
384
- return np.array(X - self.mean) * self.coef
385
- else:
386
- return [np.array(X - self.mean) * self.coef[i] for i in range(self.coef.shape[0])]
387
-
388
- def duplicate_components(C):
389
- D = np.diag(1/np.sqrt(np.diag(C)))
390
- C = np.matmul(np.matmul(D, C), D)
391
- components = -np.ones(C.shape[0], dtype=int)
392
- count = -1
393
- for i in range(C.shape[0]):
394
- found_group = False
395
- for j in range(C.shape[0]):
396
- if components[j] < 0 and np.abs(2*C[i,j] - C[i,i] - C[j,j]) < 1e-8:
397
- if not found_group:
398
- count += 1
399
- found_group = True
400
- components[j] = count
401
-
402
- proj = np.zeros((len(np.unique(components)), C.shape[0]))
403
- proj[0, 0] = 1
404
- for i in range(1,C.shape[0]):
405
- proj[components[i], i] = 1
406
- return (proj.T / proj.sum(1)).T, proj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_partition.py DELETED
@@ -1,681 +0,0 @@
1
- import queue
2
- import time
3
-
4
- import numpy as np
5
- from numba import njit
6
- from tqdm.auto import tqdm
7
-
8
- from .. import Explanation, links
9
- from ..models import Model
10
- from ..utils import MaskedModel, OpChain, make_masks, safe_isinstance
11
- from ._explainer import Explainer
12
-
13
-
14
- class PartitionExplainer(Explainer):
15
- """Uses the Partition SHAP method to explain the output of any function.
16
-
17
- Partition SHAP computes Shapley values recursively through a hierarchy of features, this
18
- hierarchy defines feature coalitions and results in the Owen values from game theory.
19
-
20
- The PartitionExplainer has two particularly nice properties:
21
-
22
- 1) PartitionExplainer is model-agnostic but when using a balanced partition tree only has
23
- quadratic exact runtime (in term of the number of input features). This is in contrast to the
24
- exponential exact runtime of KernelExplainer or SamplingExplainer.
25
- 2) PartitionExplainer always assigns to groups of correlated features the credit that set of features
26
- would have had if treated as a group. This means if the hierarchical clustering given to
27
- PartitionExplainer groups correlated features together, then feature correlations are
28
- "accounted for" in the sense that the total credit assigned to a group of tightly dependent features
29
- does not depend on how they behave if their correlation structure was broken during the explanation's
30
- perturbation process.
31
- Note that for linear models the Owen values that PartitionExplainer returns are the same as the standard
32
- non-hierarchical Shapley values.
33
- """
34
-
35
- def __init__(self, model, masker, *, output_names=None, link=links.identity, linearize_link=True,
36
- feature_names=None, **call_args):
37
- """Build a PartitionExplainer for the given model with the given masker.
38
-
39
- Parameters
40
- ----------
41
- model : function
42
- User supplied function that takes a matrix of samples (# samples x # features) and
43
- computes the output of the model for those samples.
44
-
45
- masker : function or numpy.array or pandas.DataFrame or tokenizer
46
- The function used to "mask" out hidden features of the form `masker(mask, x)`. It takes a
47
- single input sample and a binary mask and returns a matrix of masked samples. These
48
- masked samples will then be evaluated using the model function and the outputs averaged.
49
- As a shortcut for the standard masking using by SHAP you can pass a background data matrix
50
- instead of a function and that matrix will be used for masking. Domain specific masking
51
- functions are available in shap such as shap.maksers.Image for images and shap.maskers.Text
52
- for text.
53
-
54
- partition_tree : None or function or numpy.array
55
- A hierarchical clustering of the input features represented by a matrix that follows the format
56
- used by scipy.cluster.hierarchy (see the notebooks_html/partition_explainer directory an example).
57
- If this is a function then the function produces a clustering matrix when given a single input
58
- example. If you are using a standard SHAP masker object then you can pass masker.clustering
59
- to use that masker's built-in clustering of the features, or if partition_tree is None then
60
- masker.clustering will be used by default.
61
-
62
- Examples
63
- --------
64
- See `Partition explainer examples <https://shap.readthedocs.io/en/latest/api_examples/explainers/PartitionExplainer.html>`_
65
- """
66
-
67
- super().__init__(model, masker, link=link, linearize_link=linearize_link, algorithm="partition", \
68
- output_names = output_names, feature_names=feature_names)
69
-
70
- # convert dataframes
71
- # if isinstance(masker, pd.DataFrame):
72
- # masker = TabularMasker(masker)
73
- # elif isinstance(masker, np.ndarray) and len(masker.shape) == 2:
74
- # masker = TabularMasker(masker)
75
- # elif safe_isinstance(masker, "transformers.PreTrainedTokenizer"):
76
- # masker = TextMasker(masker)
77
- # self.masker = masker
78
-
79
- # TODO: maybe? if we have a tabular masker then we build a PermutationExplainer that we
80
- # will use for sampling
81
- self.input_shape = masker.shape[1:] if hasattr(masker, "shape") and not callable(masker.shape) else None
82
- # self.output_names = output_names
83
- if not safe_isinstance(self.model, "shap.models.Model"):
84
- self.model = Model(self.model)#lambda *args: np.array(model(*args))
85
- self.expected_value = None
86
- self._curr_base_value = None
87
- if getattr(self.masker, "clustering", None) is None:
88
- raise ValueError("The passed masker must have a .clustering attribute defined! Try shap.maskers.Partition(data) for example.")
89
- # if partition_tree is None:
90
- # if not hasattr(masker, "partition_tree"):
91
- # raise ValueError("The passed masker does not have masker.clustering, so the partition_tree must be passed!")
92
- # self.partition_tree = masker.clustering
93
- # else:
94
- # self.partition_tree = partition_tree
95
-
96
- # handle higher dimensional tensor inputs
97
- if self.input_shape is not None and len(self.input_shape) > 1:
98
- self._reshaped_model = lambda x: self.model(x.reshape(x.shape[0], *self.input_shape))
99
- else:
100
- self._reshaped_model = self.model
101
-
102
- # if we don't have a dynamic clustering algorithm then can precowe mpute
103
- # a lot of information
104
- if not callable(self.masker.clustering):
105
- self._clustering = self.masker.clustering
106
- self._mask_matrix = make_masks(self._clustering)
107
-
108
- # if we have gotten default arguments for the call function we need to wrap ourselves in a new class that
109
- # has a call function with those new default arguments
110
- if len(call_args) > 0:
111
- class PartitionExplainer(self.__class__):
112
- # this signature should match the __call__ signature of the class defined below
113
- def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, error_bounds=False, batch_size="auto",
114
- outputs=None, silent=False):
115
- return super().__call__(
116
- *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, error_bounds=error_bounds,
117
- batch_size=batch_size, outputs=outputs, silent=silent
118
- )
119
- PartitionExplainer.__call__.__doc__ = self.__class__.__call__.__doc__
120
- self.__class__ = PartitionExplainer
121
- for k, v in call_args.items():
122
- self.__call__.__kwdefaults__[k] = v
123
-
124
- # note that changes to this function signature should be copied to the default call argument wrapper above
125
- def __call__(self, *args, max_evals=500, fixed_context=None, main_effects=False, error_bounds=False, batch_size="auto",
126
- outputs=None, silent=False):
127
- """ Explain the output of the model on the given arguments.
128
- """
129
- return super().__call__(
130
- *args, max_evals=max_evals, fixed_context=fixed_context, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size,
131
- outputs=outputs, silent=silent
132
- )
133
-
134
- def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent, fixed_context = "auto"):
135
- """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
136
- """
137
-
138
- if fixed_context == "auto":
139
- # if isinstance(self.masker, maskers.Text):
140
- # fixed_context = 1 # we err on the side of speed for text models
141
- # else:
142
- fixed_context = None
143
- elif fixed_context not in [0, 1, None]:
144
- raise ValueError("Unknown fixed_context value passed (must be 0, 1 or None): %s" %fixed_context)
145
-
146
- # build a masked version of the model for the current input sample
147
- fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)
148
-
149
- # make sure we have the base value and current value outputs
150
- M = len(fm)
151
- m00 = np.zeros(M, dtype=bool)
152
- # if not fixed background or no base value assigned then compute base value for a row
153
- if self._curr_base_value is None or not getattr(self.masker, "fixed_background", False):
154
- self._curr_base_value = fm(m00.reshape(1, -1), zero_index=0)[0] # the zero index param tells the masked model what the baseline is
155
- f11 = fm(~m00.reshape(1, -1))[0]
156
-
157
- if callable(self.masker.clustering):
158
- self._clustering = self.masker.clustering(*row_args)
159
- self._mask_matrix = make_masks(self._clustering)
160
-
161
- if hasattr(self._curr_base_value, 'shape') and len(self._curr_base_value.shape) > 0:
162
- if outputs is None:
163
- outputs = np.arange(len(self._curr_base_value))
164
- elif isinstance(outputs, OpChain):
165
- outputs = outputs.apply(Explanation(f11)).values
166
-
167
- out_shape = (2*self._clustering.shape[0]+1, len(outputs))
168
- else:
169
- out_shape = (2*self._clustering.shape[0]+1,)
170
-
171
- if max_evals == "auto":
172
- max_evals = 500
173
-
174
- self.values = np.zeros(out_shape)
175
- self.dvalues = np.zeros(out_shape)
176
-
177
- self.owen(fm, self._curr_base_value, f11, max_evals - 2, outputs, fixed_context, batch_size, silent)
178
-
179
- # if False:
180
- # if self.multi_output:
181
- # return [self.dvalues[:,i] for i in range(self.dvalues.shape[1])], oinds
182
- # else:
183
- # return self.dvalues.copy(), oinds
184
- # else:
185
- # drop the interaction terms down onto self.values
186
- self.values[:] = self.dvalues
187
-
188
- lower_credit(len(self.dvalues) - 1, 0, M, self.values, self._clustering)
189
-
190
- return {
191
- "values": self.values[:M].copy(),
192
- "expected_values": self._curr_base_value if outputs is None else self._curr_base_value[outputs],
193
- "mask_shapes": [s + out_shape[1:] for s in fm.mask_shapes],
194
- "main_effects": None,
195
- "hierarchical_values": self.dvalues.copy(),
196
- "clustering": self._clustering,
197
- "output_indices": outputs,
198
- "output_names": getattr(self.model, "output_names", None)
199
- }
200
-
201
- def __str__(self):
202
- return "shap.explainers.PartitionExplainer()"
203
-
204
- def owen(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent):
205
- """ Compute a nested set of recursive Owen values based on an ordering recursion.
206
- """
207
-
208
- #f = self._reshaped_model
209
- #r = self.masker
210
- #masks = np.zeros(2*len(inds)+1, dtype=int)
211
- M = len(fm)
212
- m00 = np.zeros(M, dtype=bool)
213
- #f00 = fm(m00.reshape(1,-1))[0]
214
- base_value = f00
215
- #f11 = fm(~m00.reshape(1,-1))[0]
216
- #f11 = self._reshaped_model(r(~m00, x)).mean(0)
217
- ind = len(self.dvalues)-1
218
-
219
- # make sure output_indexes is a list of indexes
220
- if output_indexes is not None:
221
- # assert self.multi_output, "output_indexes is only valid for multi-output models!"
222
- # inds = output_indexes.apply(f11, 0)
223
- # out_len = output_indexes_len(output_indexes)
224
- # if output_indexes.startswith("max("):
225
- # output_indexes = np.argsort(-f11)[:out_len]
226
- # elif output_indexes.startswith("min("):
227
- # output_indexes = np.argsort(f11)[:out_len]
228
- # elif output_indexes.startswith("max(abs("):
229
- # output_indexes = np.argsort(np.abs(f11))[:out_len]
230
-
231
- f00 = f00[output_indexes]
232
- f11 = f11[output_indexes]
233
-
234
- q = queue.PriorityQueue()
235
- q.put((0, 0, (m00, f00, f11, ind, 1.0)))
236
- eval_count = 0
237
- total_evals = min(max_evals, (M-1)*M) # TODO: (M-1)*M is only right for balanced clusterings, but this is just for plotting progress...
238
- pbar = None
239
- start_time = time.time()
240
- while not q.empty():
241
-
242
- # if we passed our execution limit then leave everything else on the internal nodes
243
- if eval_count >= max_evals:
244
- while not q.empty():
245
- m00, f00, f11, ind, weight = q.get()[2]
246
- self.dvalues[ind] += (f11 - f00) * weight
247
- break
248
-
249
- # create a batch of work to do
250
- batch_args = []
251
- batch_masks = []
252
- while not q.empty() and len(batch_masks) < batch_size and eval_count + len(batch_masks) < max_evals:
253
-
254
- # get our next set of arguments
255
- m00, f00, f11, ind, weight = q.get()[2]
256
-
257
- # get the left and right children of this cluster
258
- lind = int(self._clustering[ind-M, 0]) if ind >= M else -1
259
- rind = int(self._clustering[ind-M, 1]) if ind >= M else -1
260
-
261
- # get the distance of this cluster's children
262
- if ind < M:
263
- distance = -1
264
- else:
265
- if self._clustering.shape[1] >= 3:
266
- distance = self._clustering[ind-M, 2]
267
- else:
268
- distance = 1
269
-
270
- # check if we are a leaf node (or other negative distance cluster) and so should terminate our decent
271
- if distance < 0:
272
- self.dvalues[ind] += (f11 - f00) * weight
273
- continue
274
-
275
- # build the masks
276
- m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
277
- m10[:] += self._mask_matrix[lind, :]
278
- m01 = m00.copy()
279
- m01[:] += self._mask_matrix[rind, :]
280
-
281
- batch_args.append((m00, m10, m01, f00, f11, ind, lind, rind, weight))
282
- batch_masks.append(m10)
283
- batch_masks.append(m01)
284
-
285
- batch_masks = np.array(batch_masks)
286
-
287
- # run the batch
288
- if len(batch_args) > 0:
289
- fout = fm(batch_masks)
290
- if output_indexes is not None:
291
- fout = fout[:,output_indexes]
292
-
293
- eval_count += len(batch_masks)
294
-
295
- if pbar is None and time.time() - start_time > 5:
296
- pbar = tqdm(total=total_evals, disable=silent, leave=False)
297
- pbar.update(eval_count)
298
- if pbar is not None:
299
- pbar.update(len(batch_masks))
300
-
301
- # use the results of the batch to add new nodes
302
- for i in range(len(batch_args)):
303
-
304
- m00, m10, m01, f00, f11, ind, lind, rind, weight = batch_args[i]
305
-
306
- # get the evaluated model output on the two new masked inputs
307
- f10 = fout[2*i]
308
- f01 = fout[2*i+1]
309
-
310
- new_weight = weight
311
- if fixed_context is None:
312
- new_weight /= 2
313
- elif fixed_context == 0:
314
- self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
315
- elif fixed_context == 1:
316
- self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
317
-
318
- if fixed_context is None or fixed_context == 0:
319
- # recurse on the left node with zero context
320
- args = (m00, f00, f10, lind, new_weight)
321
- q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
322
-
323
- # recurse on the right node with zero context
324
- args = (m00, f00, f01, rind, new_weight)
325
- q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
326
-
327
- if fixed_context is None or fixed_context == 1:
328
- # recurse on the left node with one context
329
- args = (m01, f01, f11, lind, new_weight)
330
- q.put((-np.max(np.abs(f11 - f01)) * new_weight, np.random.randn(), args))
331
-
332
- # recurse on the right node with one context
333
- args = (m10, f10, f11, rind, new_weight)
334
- q.put((-np.max(np.abs(f11 - f10)) * new_weight, np.random.randn(), args))
335
-
336
- if pbar is not None:
337
- pbar.close()
338
-
339
- self.last_eval_count = eval_count
340
-
341
- return output_indexes, base_value
342
-
343
- def owen3(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent):
344
- """ Compute a nested set of recursive Owen values based on an ordering recursion.
345
- """
346
-
347
- #f = self._reshaped_model
348
- #r = self.masker
349
- #masks = np.zeros(2*len(inds)+1, dtype=int)
350
- M = len(fm)
351
- m00 = np.zeros(M, dtype=bool)
352
- #f00 = fm(m00.reshape(1,-1))[0]
353
- base_value = f00
354
- #f11 = fm(~m00.reshape(1,-1))[0]
355
- #f11 = self._reshaped_model(r(~m00, x)).mean(0)
356
- ind = len(self.dvalues)-1
357
-
358
- # make sure output_indexes is a list of indexes
359
- if output_indexes is not None:
360
- # assert self.multi_output, "output_indexes is only valid for multi-output models!"
361
- # inds = output_indexes.apply(f11, 0)
362
- # out_len = output_indexes_len(output_indexes)
363
- # if output_indexes.startswith("max("):
364
- # output_indexes = np.argsort(-f11)[:out_len]
365
- # elif output_indexes.startswith("min("):
366
- # output_indexes = np.argsort(f11)[:out_len]
367
- # elif output_indexes.startswith("max(abs("):
368
- # output_indexes = np.argsort(np.abs(f11))[:out_len]
369
-
370
- f00 = f00[output_indexes]
371
- f11 = f11[output_indexes]
372
-
373
- # our starting plan is to evaluate all the nodes with a fixed_context
374
- evals_planned = M
375
-
376
- q = queue.PriorityQueue()
377
- q.put((0, 0, (m00, f00, f11, ind, 1.0, fixed_context))) # (m00, f00, f11, tree_index, weight)
378
- eval_count = 0
379
- total_evals = min(max_evals, (M-1)*M) # TODO: (M-1)*M is only right for balanced clusterings, but this is just for plotting progress...
380
- pbar = None
381
- start_time = time.time()
382
- while not q.empty():
383
-
384
- # if we passed our execution limit then leave everything else on the internal nodes
385
- if eval_count >= max_evals:
386
- while not q.empty():
387
- m00, f00, f11, ind, weight, _ = q.get()[2]
388
- self.dvalues[ind] += (f11 - f00) * weight
389
- break
390
-
391
- # create a batch of work to do
392
- batch_args = []
393
- batch_masks = []
394
- while not q.empty() and len(batch_masks) < batch_size and eval_count < max_evals:
395
-
396
- # get our next set of arguments
397
- m00, f00, f11, ind, weight, context = q.get()[2]
398
-
399
- # get the left and right children of this cluster
400
- lind = int(self._clustering[ind-M, 0]) if ind >= M else -1
401
- rind = int(self._clustering[ind-M, 1]) if ind >= M else -1
402
-
403
- # get the distance of this cluster's children
404
- if ind < M:
405
- distance = -1
406
- else:
407
- distance = self._clustering[ind-M, 2]
408
-
409
- # check if we are a leaf node (or other negative distance cluster) and so should terminate our decent
410
- if distance < 0:
411
- self.dvalues[ind] += (f11 - f00) * weight
412
- continue
413
-
414
- # build the masks
415
- m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
416
- m10[:] += self._mask_matrix[lind, :]
417
- m01 = m00.copy()
418
- m01[:] += self._mask_matrix[rind, :]
419
-
420
- batch_args.append((m00, m10, m01, f00, f11, ind, lind, rind, weight, context))
421
- batch_masks.append(m10)
422
- batch_masks.append(m01)
423
-
424
- batch_masks = np.array(batch_masks)
425
-
426
- # run the batch
427
- if len(batch_args) > 0:
428
- fout = fm(batch_masks)
429
- if output_indexes is not None:
430
- fout = fout[:,output_indexes]
431
-
432
- eval_count += len(batch_masks)
433
-
434
- if pbar is None and time.time() - start_time > 5:
435
- pbar = tqdm(total=total_evals, disable=silent, leave=False)
436
- pbar.update(eval_count)
437
- if pbar is not None:
438
- pbar.update(len(batch_masks))
439
-
440
- # use the results of the batch to add new nodes
441
- for i in range(len(batch_args)):
442
-
443
- m00, m10, m01, f00, f11, ind, lind, rind, weight, context = batch_args[i]
444
-
445
- # get the the number of leaves in this cluster
446
- if ind < M:
447
- num_leaves = 0
448
- else:
449
- num_leaves = self._clustering[ind-M, 3]
450
-
451
- # get the evaluated model output on the two new masked inputs
452
- f10 = fout[2*i]
453
- f01 = fout[2*i+1]
454
-
455
- # see if we have enough evaluations left to get both sides of a fixed context
456
- if max_evals - evals_planned > num_leaves:
457
- evals_planned += num_leaves
458
- ignore_context = True
459
- else:
460
- ignore_context = False
461
-
462
- new_weight = weight
463
- if context is None or ignore_context:
464
- new_weight /= 2
465
-
466
- if context is None or context == 0 or ignore_context:
467
- self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
468
-
469
- # recurse on the left node with zero context, flip the context for all descendents if we are ignoring it
470
- args = (m00, f00, f10, lind, new_weight, 0 if context == 1 else context)
471
- q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
472
-
473
- # recurse on the right node with zero context, flip the context for all descendents if we are ignoring it
474
- args = (m00, f00, f01, rind, new_weight, 0 if context == 1 else context)
475
- q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
476
-
477
- if context is None or context == 1 or ignore_context:
478
- self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
479
-
480
- # recurse on the left node with one context, flip the context for all descendents if we are ignoring it
481
- args = (m01, f01, f11, lind, new_weight, 1 if context == 0 else context)
482
- q.put((-np.max(np.abs(f11 - f01)) * new_weight, np.random.randn(), args))
483
-
484
- # recurse on the right node with one context, flip the context for all descendents if we are ignoring it
485
- args = (m10, f10, f11, rind, new_weight, 1 if context == 0 else context)
486
- q.put((-np.max(np.abs(f11 - f10)) * new_weight, np.random.randn(), args))
487
-
488
- if pbar is not None:
489
- pbar.close()
490
-
491
- self.last_eval_count = eval_count
492
-
493
- return output_indexes, base_value
494
-
495
-
496
-
497
- # def owen2(self, fm, f00, f11, max_evals, output_indexes, fixed_context, batch_size, silent):
498
- # """ Compute a nested set of recursive Owen values based on an ordering recursion.
499
- # """
500
-
501
- # #f = self._reshaped_model
502
- # #r = self.masker
503
- # #masks = np.zeros(2*len(inds)+1, dtype=int)
504
- # M = len(fm)
505
- # m00 = np.zeros(M, dtype=bool)
506
- # #f00 = fm(m00.reshape(1,-1))[0]
507
- # base_value = f00
508
- # #f11 = fm(~m00.reshape(1,-1))[0]
509
- # #f11 = self._reshaped_model(r(~m00, x)).mean(0)
510
- # ind = len(self.dvalues)-1
511
-
512
- # # make sure output_indexes is a list of indexes
513
- # if output_indexes is not None:
514
- # # assert self.multi_output, "output_indexes is only valid for multi-output models!"
515
- # # inds = output_indexes.apply(f11, 0)
516
- # # out_len = output_indexes_len(output_indexes)
517
- # # if output_indexes.startswith("max("):
518
- # # output_indexes = np.argsort(-f11)[:out_len]
519
- # # elif output_indexes.startswith("min("):
520
- # # output_indexes = np.argsort(f11)[:out_len]
521
- # # elif output_indexes.startswith("max(abs("):
522
- # # output_indexes = np.argsort(np.abs(f11))[:out_len]
523
-
524
- # f00 = f00[output_indexes]
525
- # f11 = f11[output_indexes]
526
-
527
- # fc_owen(m00, m11, 1)
528
- # fc_owen(m00, m11, 0)
529
-
530
- # def fc_owen(m00, m11, context):
531
-
532
- # # recurse on the left node with zero context
533
- # args = (m00, f00, f10, lind, new_weight)
534
- # q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
535
-
536
- # # recurse on the right node with zero context
537
- # args = (m00, f00, f01, rind, new_weight)
538
- # q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
539
- # fc_owen(m00, m11, 1)
540
- # m00 m11
541
- # owen(fc=1)
542
- # owen(fc=0)
543
-
544
- # q = queue.PriorityQueue()
545
- # q.put((0, 0, (m00, f00, f11, ind, 1.0, 1)))
546
- # eval_count = 0
547
- # total_evals = min(max_evals, (M-1)*M) # TODO: (M-1)*M is only right for balanced clusterings, but this is just for plotting progress...
548
- # pbar = None
549
- # start_time = time.time()
550
- # while not q.empty():
551
-
552
- # # if we passed our execution limit then leave everything else on the internal nodes
553
- # if eval_count >= max_evals:
554
- # while not q.empty():
555
- # m00, f00, f11, ind, weight, _ = q.get()[2]
556
- # self.dvalues[ind] += (f11 - f00) * weight
557
- # break
558
-
559
- # # create a batch of work to do
560
- # batch_args = []
561
- # batch_masks = []
562
- # while not q.empty() and len(batch_masks) < batch_size and eval_count < max_evals:
563
-
564
- # # get our next set of arguments
565
- # m00, f00, f11, ind, weight, context = q.get()[2]
566
-
567
- # # get the left and right children of this cluster
568
- # lind = int(self._clustering[ind-M, 0]) if ind >= M else -1
569
- # rind = int(self._clustering[ind-M, 1]) if ind >= M else -1
570
-
571
- # # get the distance of this cluster's children
572
- # if ind < M:
573
- # distance = -1
574
- # else:
575
- # if self._clustering.shape[1] >= 3:
576
- # distance = self._clustering[ind-M, 2]
577
- # else:
578
- # distance = 1
579
-
580
- # # check if we are a leaf node (or other negative distance cluster) and so should terminate our decent
581
- # if distance < 0:
582
- # self.dvalues[ind] += (f11 - f00) * weight
583
- # continue
584
-
585
- # # build the masks
586
- # m10 = m00.copy() # we separate the copy from the add so as to not get converted to a matrix
587
- # m10[:] += self._mask_matrix[lind, :]
588
- # m01 = m00.copy()
589
- # m01[:] += self._mask_matrix[rind, :]
590
-
591
- # batch_args.append((m00, m10, m01, f00, f11, ind, lind, rind, weight, context))
592
- # batch_masks.append(m10)
593
- # batch_masks.append(m01)
594
-
595
- # batch_masks = np.array(batch_masks)
596
-
597
- # # run the batch
598
- # if len(batch_args) > 0:
599
- # fout = fm(batch_masks)
600
- # if output_indexes is not None:
601
- # fout = fout[:,output_indexes]
602
-
603
- # eval_count += len(batch_masks)
604
-
605
- # if pbar is None and time.time() - start_time > 5:
606
- # pbar = tqdm(total=total_evals, disable=silent, leave=False)
607
- # pbar.update(eval_count)
608
- # if pbar is not None:
609
- # pbar.update(len(batch_masks))
610
-
611
- # # use the results of the batch to add new nodes
612
- # for i in range(len(batch_args)):
613
-
614
- # m00, m10, m01, f00, f11, ind, lind, rind, weight, context = batch_args[i]
615
-
616
- # # get the evaluated model output on the two new masked inputs
617
- # f10 = fout[2*i]
618
- # f01 = fout[2*i+1]
619
-
620
- # new_weight = weight
621
- # if fixed_context is None:
622
- # new_weight /= 2
623
- # elif fixed_context == 0:
624
- # self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
625
- # elif fixed_context == 1:
626
- # self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
627
-
628
- # if fixed_context is None or fixed_context == 0:
629
- # self.dvalues[ind] += (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
630
-
631
-
632
- # # recurse on the left node with zero context
633
- # args = (m00, f00, f10, lind, new_weight)
634
- # q.put((-np.max(np.abs(f10 - f00)) * new_weight, np.random.randn(), args))
635
-
636
- # # recurse on the right node with zero context
637
- # args = (m00, f00, f01, rind, new_weight)
638
- # q.put((-np.max(np.abs(f01 - f00)) * new_weight, np.random.randn(), args))
639
-
640
- # if fixed_context is None or fixed_context == 1:
641
- # self.dvalues[ind] -= (f11 - f10 - f01 + f00) * weight # leave the interaction effect on the internal node
642
-
643
-
644
- # # recurse on the left node with one context
645
- # args = (m01, f01, f11, lind, new_weight)
646
- # q.put((-np.max(np.abs(f11 - f01)) * new_weight, np.random.randn(), args))
647
-
648
- # # recurse on the right node with one context
649
- # args = (m10, f10, f11, rind, new_weight)
650
- # q.put((-np.max(np.abs(f11 - f10)) * new_weight, np.random.randn(), args))
651
-
652
- # if pbar is not None:
653
- # pbar.close()
654
-
655
- # return output_indexes, base_value
656
-
657
-
658
- def output_indexes_len(output_indexes):
659
- if output_indexes.startswith("max("):
660
- return int(output_indexes[4:-1])
661
- elif output_indexes.startswith("min("):
662
- return int(output_indexes[4:-1])
663
- elif output_indexes.startswith("max(abs("):
664
- return int(output_indexes[8:-2])
665
- elif not isinstance(output_indexes, str):
666
- return len(output_indexes)
667
-
668
- @njit
669
- def lower_credit(i, value, M, values, clustering):
670
- if i < M:
671
- values[i] += value
672
- return
673
- li = int(clustering[i-M,0])
674
- ri = int(clustering[i-M,1])
675
- group_size = int(clustering[i-M,3])
676
- lsize = int(clustering[li-M,3]) if li >= M else 1
677
- rsize = int(clustering[ri-M,3]) if ri >= M else 1
678
- assert lsize+rsize == group_size
679
- values[i] += value
680
- lower_credit(li, values[i] * lsize / group_size, M, values, clustering)
681
- lower_credit(ri, values[i] * rsize / group_size, M, values, clustering)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_permutation.py DELETED
@@ -1,217 +0,0 @@
1
- import warnings
2
-
3
- import numpy as np
4
-
5
- from .. import links
6
- from ..models import Model
7
- from ..utils import MaskedModel, partition_tree_shuffle
8
- from ._explainer import Explainer
9
-
10
-
11
- class PermutationExplainer(Explainer):
12
- """ This method approximates the Shapley values by iterating through permutations of the inputs.
13
-
14
- This is a model agnostic explainer that guarantees local accuracy (additivity) by iterating completely
15
- through an entire permutation of the features in both forward and reverse directions (antithetic sampling).
16
- If we do this once, then we get the exact SHAP values for models with up to second order interaction effects.
17
- We can iterate this many times over many random permutations to get better SHAP value estimates for models
18
- with higher order interactions. This sequential ordering formulation also allows for easy reuse of
19
- model evaluations and the ability to efficiently avoid evaluating the model when the background values
20
- for a feature are the same as the current input value. We can also account for hierarchical data
21
- structures with partition trees, something not currently implemented for KernalExplainer or SamplingExplainer.
22
- """
23
-
24
- def __init__(self, model, masker, link=links.identity, feature_names=None, linearize_link=True, seed=None, **call_args):
25
- """ Build an explainers.Permutation object for the given model using the given masker object.
26
-
27
- Parameters
28
- ----------
29
- model : function
30
- A callable python object that executes the model given a set of input data samples.
31
-
32
- masker : function or numpy.array or pandas.DataFrame
33
- A callable python object used to "mask" out hidden features of the form `masker(binary_mask, x)`.
34
- It takes a single input sample and a binary mask and returns a matrix of masked samples. These
35
- masked samples are evaluated using the model function and the outputs are then averaged.
36
- As a shortcut for the standard masking using by SHAP you can pass a background data matrix
37
- instead of a function and that matrix will be used for masking. To use a clustering
38
- game structure you can pass a shap.maskers.Tabular(data, clustering=\"correlation\") object.
39
-
40
- seed: None or int
41
- Seed for reproducibility
42
-
43
- **call_args : valid argument to the __call__ method
44
- These arguments are saved and passed to the __call__ method as the new default values for these arguments.
45
- """
46
-
47
- # setting seed for random generation: if seed is not None, then shap values computation should be reproducible
48
- np.random.seed(seed)
49
-
50
- if masker is None:
51
- raise ValueError("masker cannot be None.")
52
-
53
- super().__init__(model, masker, link=link, linearize_link=linearize_link, feature_names=feature_names)
54
-
55
- if not isinstance(self.model, Model):
56
- self.model = Model(self.model)
57
-
58
- # if we have gotten default arguments for the call function we need to wrap ourselves in a new class that
59
- # has a call function with those new default arguments
60
- if len(call_args) > 0:
61
- # this signature should match the __call__ signature of the class defined below
62
- class PermutationExplainer(self.__class__):
63
- def __call__(self, *args, max_evals=500, main_effects=False, error_bounds=False, batch_size="auto",
64
- outputs=None, silent=False):
65
- return super().__call__(
66
- *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds,
67
- batch_size=batch_size, outputs=outputs, silent=silent
68
- )
69
- PermutationExplainer.__call__.__doc__ = self.__class__.__call__.__doc__
70
- self.__class__ = PermutationExplainer
71
- for k, v in call_args.items():
72
- self.__call__.__kwdefaults__[k] = v
73
-
74
- # note that changes to this function signature should be copied to the default call argument wrapper above
75
- def __call__(self, *args, max_evals=500, main_effects=False, error_bounds=False, batch_size="auto",
76
- outputs=None, silent=False):
77
- """ Explain the output of the model on the given arguments.
78
- """
79
- return super().__call__(
80
- *args, max_evals=max_evals, main_effects=main_effects, error_bounds=error_bounds, batch_size=batch_size,
81
- outputs=outputs, silent=silent
82
- )
83
-
84
- def explain_row(self, *row_args, max_evals, main_effects, error_bounds, batch_size, outputs, silent):
85
- """ Explains a single row and returns the tuple (row_values, row_expected_values, row_mask_shapes).
86
- """
87
-
88
- # build a masked version of the model for the current input sample
89
- fm = MaskedModel(self.model, self.masker, self.link, self.linearize_link, *row_args)
90
-
91
- # by default we run 10 permutations forward and backward
92
- if max_evals == "auto":
93
- max_evals = 10 * 2 * len(fm)
94
-
95
- # compute any custom clustering for this row
96
- row_clustering = None
97
- if getattr(self.masker, "clustering", None) is not None:
98
- if isinstance(self.masker.clustering, np.ndarray):
99
- row_clustering = self.masker.clustering
100
- elif callable(self.masker.clustering):
101
- row_clustering = self.masker.clustering(*row_args)
102
- else:
103
- raise NotImplementedError("The masker passed has a .clustering attribute that is not yet supported by the Permutation explainer!")
104
-
105
- # loop over many permutations
106
- inds = fm.varying_inputs()
107
- inds_mask = np.zeros(len(fm), dtype=bool)
108
- inds_mask[inds] = True
109
- masks = np.zeros(2*len(inds)+1, dtype=int)
110
- masks[0] = MaskedModel.delta_mask_noop_value
111
- npermutations = max_evals // (2*len(inds)+1)
112
- row_values = None
113
- row_values_history = None
114
- history_pos = 0
115
- main_effect_values = None
116
- if len(inds) > 0:
117
- for _ in range(npermutations):
118
-
119
- # shuffle the indexes so we get a random permutation ordering
120
- if row_clustering is not None:
121
- # [TODO] This is shuffle does not work when inds is not a complete set of integers from 0 to M TODO: still true?
122
- #assert len(inds) == len(fm), "Need to support partition shuffle when not all the inds vary!!"
123
- partition_tree_shuffle(inds, inds_mask, row_clustering)
124
- else:
125
- np.random.shuffle(inds)
126
-
127
- # create a large batch of masks to evaluate
128
- i = 1
129
- for ind in inds:
130
- masks[i] = ind
131
- i += 1
132
- for ind in inds:
133
- masks[i] = ind
134
- i += 1
135
-
136
- # evaluate the masked model
137
- outputs = fm(masks, zero_index=0, batch_size=batch_size)
138
-
139
- if row_values is None:
140
- row_values = np.zeros((len(fm),) + outputs.shape[1:])
141
-
142
- if error_bounds:
143
- row_values_history = np.zeros((2 * npermutations, len(fm),) + outputs.shape[1:])
144
-
145
- # update our SHAP value estimates
146
- i = 0
147
- for ind in inds: # forward
148
- row_values[ind] += outputs[i + 1] - outputs[i]
149
- if error_bounds:
150
- row_values_history[history_pos][ind] = outputs[i + 1] - outputs[i]
151
- i += 1
152
- history_pos += 1
153
- for ind in inds: # backward
154
- row_values[ind] += outputs[i] - outputs[i + 1]
155
- if error_bounds:
156
- row_values_history[history_pos][ind] = outputs[i] - outputs[i + 1]
157
- i += 1
158
- history_pos += 1
159
-
160
- if npermutations == 0:
161
- raise ValueError(f"max_evals={max_evals} is too low for the Permutation explainer, it must be at least 2 * num_features + 1 = {2 * len(inds) + 1}!")
162
-
163
- expected_value = outputs[0]
164
-
165
- # compute the main effects if we need to
166
- if main_effects:
167
- main_effect_values = fm.main_effects(inds, batch_size=batch_size)
168
- else:
169
- masks = np.zeros(1, dtype=int)
170
- outputs = fm(masks, zero_index=0, batch_size=1)
171
- expected_value = outputs[0]
172
- row_values = np.zeros((len(fm),) + outputs.shape[1:])
173
- if error_bounds:
174
- row_values_history = np.zeros((2 * npermutations, len(fm),) + outputs.shape[1:])
175
-
176
- return {
177
- "values": row_values / (2 * npermutations),
178
- "expected_values": expected_value,
179
- "mask_shapes": fm.mask_shapes,
180
- "main_effects": main_effect_values,
181
- "clustering": row_clustering,
182
- "error_std": None if row_values_history is None else row_values_history.std(0),
183
- "output_names": self.model.output_names if hasattr(self.model, "output_names") else None
184
- }
185
-
186
-
187
- def shap_values(self, X, npermutations=10, main_effects=False, error_bounds=False, batch_evals=True, silent=False):
188
- """ Legacy interface to estimate the SHAP values for a set of samples.
189
-
190
- Parameters
191
- ----------
192
- X : numpy.array or pandas.DataFrame or any scipy.sparse matrix
193
- A matrix of samples (# samples x # features) on which to explain the model's output.
194
-
195
- npermutations : int
196
- Number of times to cycle through all the features, re-evaluating the model at each step.
197
- Each cycle evaluates the model function 2 * (# features + 1) times on a data matrix of
198
- (# background data samples) rows. An exception to this is when PermutationExplainer can
199
- avoid evaluating the model because a feature's value is the same in X and the background
200
- dataset (which is common for example with sparse features).
201
-
202
- Returns
203
- -------
204
- array or list
205
- For models with a single output this returns a matrix of SHAP values
206
- (# samples x # features). Each row sums to the difference between the model output for that
207
- sample and the expected value of the model output (which is stored as expected_value
208
- attribute of the explainer). For models with vector outputs this returns a list
209
- of such matrices, one for each output.
210
- """
211
- warnings.warn("shap_values() is deprecated; use __call__().", DeprecationWarning)
212
-
213
- explanation = self(X, max_evals=npermutations * X.shape[1], main_effects=main_effects)
214
- return explanation.values
215
-
216
- def __str__(self):
217
- return "shap.explainers.PermutationExplainer()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/shap/explainers/_sampling.py DELETED
@@ -1,199 +0,0 @@
1
- import logging
2
-
3
- import numpy as np
4
- import pandas as pd
5
-
6
- from .._explanation import Explanation
7
- from ..utils._exceptions import ExplainerError
8
- from ..utils._legacy import convert_to_instance, match_instance_to_data
9
- from ._kernel import KernelExplainer
10
-
11
- log = logging.getLogger('shap')
12
-
13
-
14
- class SamplingExplainer(KernelExplainer):
15
- """Computes SHAP values using an extension of the Shapley sampling values explanation method
16
- (also known as IME).
17
-
18
- SamplingExplainer computes SHAP values under the assumption of feature independence and is an
19
- extension of the algorithm proposed in "An Efficient Explanation of Individual Classifications
20
- using Game Theory", Erik Strumbelj, Igor Kononenko, JMLR 2010. It is a good alternative to
21
- KernelExplainer when you want to use a large background set (as opposed to a single reference
22
- value for example).
23
-
24
- Parameters
25
- ----------
26
- model : function
27
- User supplied function that takes a matrix of samples (# samples x # features) and
28
- computes the output of the model for those samples. The output can be a vector
29
- (# samples) or a matrix (# samples x # model outputs).
30
-
31
- data : numpy.array or pandas.DataFrame
32
- The background dataset to use for integrating out features. To determine the impact
33
- of a feature, that feature is set to "missing" and the change in the model output
34
- is observed. Since most models aren't designed to handle arbitrary missing data at test
35
- time, we simulate "missing" by replacing the feature with the values it takes in the
36
- background dataset. So if the background dataset is a simple sample of all zeros, then
37
- we would approximate a feature being missing by setting it to zero. Unlike the
38
- KernelExplainer, this data can be the whole training set, even if that is a large set. This
39
- is because SamplingExplainer only samples from this background dataset.
40
- """
41
-
42
- def __init__(self, model, data, **kwargs):
43
- # silence warning about large datasets
44
- level = log.level
45
- log.setLevel(logging.ERROR)
46
- super().__init__(model, data, **kwargs)
47
- log.setLevel(level)
48
-
49
- if str(self.link) != "identity":
50
- emsg = f"SamplingExplainer only supports the identity link, not {self.link}"
51
- raise ValueError(emsg)
52
-
53
- def __call__(self, X, y=None, nsamples=2000):
54
-
55
- if isinstance(X, pd.DataFrame):
56
- feature_names = list(X.columns)
57
- X = X.values
58
- else:
59
- feature_names = None # we can make self.feature_names from background data eventually if we have it
60
-
61
- v = self.shap_values(X, nsamples=nsamples)
62
- if isinstance(v, list):
63
- v = np.stack(v, axis=-1) # put outputs at the end
64
- e = Explanation(v, self.expected_value, X, feature_names=feature_names)
65
- return e
66
-
67
- def explain(self, incoming_instance, **kwargs):
68
- # convert incoming input to a standardized iml object
69
- instance = convert_to_instance(incoming_instance)
70
- match_instance_to_data(instance, self.data)
71
-
72
- if len(self.data.groups) != self.P:
73
- emsg = "SamplingExplainer does not support feature groups!"
74
- raise ExplainerError(emsg)
75
-
76
- # find the feature groups we will test. If a feature does not change from its
77
- # current value then we know it doesn't impact the model
78
- self.varyingInds = self.varying_groups(instance.x)
79
- #self.varyingFeatureGroups = [self.data.groups[i] for i in self.varyingInds]
80
- self.M = len(self.varyingInds)
81
-
82
- # find f(x)
83
- if self.keep_index:
84
- model_out = self.model.f(instance.convert_to_df())
85
- else:
86
- model_out = self.model.f(instance.x)
87
- if isinstance(model_out, (pd.DataFrame, pd.Series)):
88
- model_out = model_out.values[0]
89
- self.fx = model_out[0]
90
-
91
- if not self.vector_out:
92
- self.fx = np.array([self.fx])
93
-
94
- # if no features vary then there no feature has an effect
95
- if self.M == 0:
96
- phi = np.zeros((len(self.data.groups), self.D))
97
- phi_var = np.zeros((len(self.data.groups), self.D))
98
-
99
- # if only one feature varies then it has all the effect
100
- elif self.M == 1:
101
- phi = np.zeros((len(self.data.groups), self.D))
102
- phi_var = np.zeros((len(self.data.groups), self.D))
103
- diff = self.fx - self.fnull
104
- for d in range(self.D):
105
- phi[self.varyingInds[0],d] = diff[d]
106
-
107
- # if more than one feature varies then we have to do real work
108
- else:
109
-
110
- # pick a reasonable number of samples if the user didn't specify how many they wanted
111
- self.nsamples = kwargs.get("nsamples", "auto")
112
- if self.nsamples == "auto":
113
- self.nsamples = 1000 * self.M
114
-
115
- min_samples_per_feature = kwargs.get("min_samples_per_feature", 100)
116
- round1_samples = self.nsamples
117
- round2_samples = 0
118
- if round1_samples > self.M * min_samples_per_feature:
119
- round2_samples = round1_samples - self.M * min_samples_per_feature
120
- round1_samples -= round2_samples
121
-
122
- # divide up the samples among the features for round 1
123
- nsamples_each1 = np.ones(self.M, dtype=np.int64) * 2 * (round1_samples // (self.M * 2))
124
- for i in range((round1_samples % (self.M * 2)) // 2):
125
- nsamples_each1[i] += 2
126
-
127
- # explain every feature in round 1
128
- phi = np.zeros((self.P, self.D))
129
- phi_var = np.zeros((self.P, self.D))
130
- self.X_masked = np.zeros((nsamples_each1.max() * 2, self.data.data.shape[1]))
131
- for i,ind in enumerate(self.varyingInds):
132
- phi[ind,:],phi_var[ind,:] = self.sampling_estimate(ind, self.model.f, instance.x, self.data.data, nsamples=nsamples_each1[i])
133
-
134
- # optimally allocate samples according to the variance
135
- if phi_var.sum() == 0:
136
- phi_var += 1 # spread samples uniformally if we found no variability
137
- phi_var /= phi_var.sum(0)[np.newaxis, :]
138
- nsamples_each2 = (phi_var[self.varyingInds,:].mean(1) * round2_samples).astype(int)
139
- for i in range(len(nsamples_each2)):
140
- if nsamples_each2[i] % 2 == 1:
141
- nsamples_each2[i] += 1
142
- for i in range(len(nsamples_each2)):
143
- if nsamples_each2.sum() > round2_samples:
144
- nsamples_each2[i] -= 2
145
- elif nsamples_each2.sum() < round2_samples:
146
- nsamples_each2[i] += 2
147
- else:
148
- break
149
-
150
- self.X_masked = np.zeros((nsamples_each2.max() * 2, self.data.data.shape[1]))
151
- for i,ind in enumerate(self.varyingInds):
152
- if nsamples_each2[i] > 0:
153
- val,var = self.sampling_estimate(ind, self.model.f, instance.x, self.data.data, nsamples=nsamples_each2[i])
154
-
155
- total_samples = nsamples_each1[i] + nsamples_each2[i]
156
- phi[ind,:] = (phi[ind,:] * nsamples_each1[i] + val * nsamples_each2[i]) / total_samples
157
- phi_var[ind,:] = (phi_var[ind,:] * nsamples_each1[i] + var * nsamples_each2[i]) / total_samples
158
-
159
- # convert from the variance of the differences to the variance of the mean (phi)
160
- for i,ind in enumerate(self.varyingInds):
161
- phi_var[ind,:] /= np.sqrt(nsamples_each1[i] + nsamples_each2[i])
162
-
163
- # correct the sum of the SHAP values to equal the output of the model using a linear
164
- # regression model with priors of the coefficients equal to the estimated variances for each
165
- # SHAP value (note that 1e6 is designed to increase the weight of the sample and so closely
166
- # match the correct sum)
167
- sum_error = self.fx - phi.sum(0) - self.fnull
168
- for i in range(self.D):
169
- # this is a ridge regression with one sample of all ones with sum_error[i] as the label
170
- # and 1/v as the ridge penalties. This simplified (and stable) form comes from the
171
- # Sherman-Morrison formula
172
- v = (phi_var[:,i] / phi_var[:,i].max()) * 1e6
173
- adj = sum_error[i] * (v - (v * v.sum()) / (1 + v.sum()))
174
- phi[:,i] += adj
175
-
176
- if phi.shape[1] == 1:
177
- phi = phi[:,0]
178
-
179
- return phi
180
-
181
- def sampling_estimate(self, j, f, x, X, nsamples=10):
182
- X_masked = self.X_masked[:nsamples * 2,:]
183
- inds = np.arange(X.shape[1])
184
-
185
- for i in range(0, nsamples):
186
- np.random.shuffle(inds)
187
- pos = np.where(inds == j)[0][0]
188
- rind = np.random.randint(X.shape[0])
189
- X_masked[i, :] = x
190
- X_masked[i, inds[pos+1:]] = X[rind, inds[pos+1:]]
191
- X_masked[-(i+1), :] = x
192
- X_masked[-(i+1), inds[pos:]] = X[rind, inds[pos:]]
193
-
194
- evals = f(X_masked)
195
- evals_on = evals[:nsamples]
196
- evals_off = evals[nsamples:][::-1]
197
- d = evals_on - evals_off
198
-
199
- return np.mean(d, 0), np.var(d, 0)