LLH commited on
Commit
928d3e1
·
1 Parent(s): fa3edb1

2024/03/09/16:00

Browse files
analysis/others/shap_model.py CHANGED
@@ -7,7 +7,7 @@ from classes.static_custom_class import StaticValue
7
 
8
  def draw_shap_beeswarm(model, x, feature_names, type, paint_object):
9
  plt.clf()
10
- x = shap.sample(x, min(20, len(x)), random_state=StaticValue.RANDOM_STATE)
11
  explainer = shap.KernelExplainer(model.predict, x)
12
  shap_values = explainer(x)
13
 
@@ -21,7 +21,7 @@ def draw_shap_beeswarm(model, x, feature_names, type, paint_object):
21
 
22
  def draw_waterfall(model, x, feature_names, number, paint_object):
23
  plt.clf()
24
- x = shap.sample(x, min(20, len(x)), random_state=StaticValue.RANDOM_STATE)
25
  explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
26
  shap_values = explainer(x)
27
 
@@ -35,7 +35,7 @@ def draw_waterfall(model, x, feature_names, number, paint_object):
35
 
36
  def draw_force(model, x, feature_names, number, paint_object):
37
  plt.clf()
38
- x = shap.sample(x, min(20, len(x)), random_state=StaticValue.RANDOM_STATE)
39
  explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
40
  shap_values = explainer(x[number])
41
 
@@ -49,7 +49,7 @@ def draw_force(model, x, feature_names, number, paint_object):
49
 
50
  def draw_dependence(model, x, feature_names, col, paint_object):
51
  plt.clf()
52
- x = shap.sample(x, min(20, len(x)), random_state=StaticValue.RANDOM_STATE)
53
  explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
54
  shap_values = explainer(x)
55
 
 
7
 
8
  def draw_shap_beeswarm(model, x, feature_names, type, paint_object):
9
  plt.clf()
10
+ x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
11
  explainer = shap.KernelExplainer(model.predict, x)
12
  shap_values = explainer(x)
13
 
 
21
 
22
  def draw_waterfall(model, x, feature_names, number, paint_object):
23
  plt.clf()
24
+ x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
25
  explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
26
  shap_values = explainer(x)
27
 
 
35
 
36
  def draw_force(model, x, feature_names, number, paint_object):
37
  plt.clf()
38
+ x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
39
  explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
40
  shap_values = explainer(x[number])
41
 
 
49
 
50
  def draw_dependence(model, x, feature_names, col, paint_object):
51
  plt.clf()
52
+ x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
53
  explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
54
  shap_values = explainer(x)
55
 
app.py CHANGED
@@ -387,6 +387,9 @@ class Dataset:
387
 
388
  cls.container_dict = get_container_dict()
389
 
 
 
 
390
  @classmethod
391
  def check_descriptive_indicators_df(cls):
392
  return True if not cls.descriptive_indicators_df.empty else False
@@ -1205,7 +1208,7 @@ class Dataset:
1205
  EACH_ROW_NUM = 6 - 1
1206
  output_list = []
1207
 
1208
- if cls.cur_model and cls.choose_optimize:
1209
  output_dict = ChooseModelParams.choose(cls.cur_model)
1210
  row_unit_num_list = []
1211
  row_len = len(output_dict.keys())
@@ -1531,12 +1534,12 @@ def get_return(is_visible, extra_gr_dict: dict = None):
1531
  data_fit_button: gr.Button(LN.data_fit_button, visible=Dataset.check_before_train()),
1532
  waterfall_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
1533
  label=LN.waterfall_radio),
1534
- waterfall_number: gr.Slider(0, 20, value=0, step=1,
1535
  visible=Dataset.check_before_train(), label=LN.waterfall_number),
1536
  waterfall_button: gr.Button(LN.waterfall_button, visible=Dataset.check_before_train()),
1537
  force_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
1538
  label=LN.force_radio),
1539
- force_number: gr.Slider(0, 20, value=0, step=1,
1540
  visible=Dataset.check_before_train(), label=LN.force_number),
1541
  force_button: gr.Button(LN.force_button, visible=Dataset.check_before_train()),
1542
  dependence_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
 
387
 
388
  cls.container_dict = get_container_dict()
389
 
390
+ visualize = ""
391
+ choose_optimize = ""
392
+
393
  @classmethod
394
  def check_descriptive_indicators_df(cls):
395
  return True if not cls.descriptive_indicators_df.empty else False
 
1208
  EACH_ROW_NUM = 6 - 1
1209
  output_list = []
1210
 
1211
+ if cls.cur_model and cls.check_model_optimize_radio():
1212
  output_dict = ChooseModelParams.choose(cls.cur_model)
1213
  row_unit_num_list = []
1214
  row_len = len(output_dict.keys())
 
1534
  data_fit_button: gr.Button(LN.data_fit_button, visible=Dataset.check_before_train()),
1535
  waterfall_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
1536
  label=LN.waterfall_radio),
1537
+ waterfall_number: gr.Slider(0, StaticValue.SAMPLE_NUM, value=0, step=1,
1538
  visible=Dataset.check_before_train(), label=LN.waterfall_number),
1539
  waterfall_button: gr.Button(LN.waterfall_button, visible=Dataset.check_before_train()),
1540
  force_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
1541
  label=LN.force_radio),
1542
+ force_number: gr.Slider(0, StaticValue.SAMPLE_NUM, value=0, step=1,
1543
  visible=Dataset.check_before_train(), label=LN.force_number),
1544
  force_button: gr.Button(LN.force_button, visible=Dataset.check_before_train()),
1545
  dependence_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
classes/static_custom_class.py CHANGED
@@ -1,5 +1,7 @@
1
  # 全局静态变量值存储类
2
  class StaticValue:
 
 
3
  # 超参数文本框的最大组件数量
4
  MAX_PARAMS_NUM = 60
5
  # 颜色和标签显示的最大组件数量
 
1
  # 全局静态变量值存储类
2
  class StaticValue:
3
+ # SHAP抽样数量
4
+ SAMPLE_NUM = 20
5
  # 超参数文本框的最大组件数量
6
  MAX_PARAMS_NUM = 60
7
  # 颜色和标签显示的最大组件数量