Spaces:
Sleeping
Sleeping
LLH
commited on
Commit
·
928d3e1
1
Parent(s):
fa3edb1
2024/03/09/16:00
Browse files- analysis/others/shap_model.py +4 -4
- app.py +6 -3
- classes/static_custom_class.py +2 -0
analysis/others/shap_model.py
CHANGED
@@ -7,7 +7,7 @@ from classes.static_custom_class import StaticValue
|
|
7 |
|
8 |
def draw_shap_beeswarm(model, x, feature_names, type, paint_object):
|
9 |
plt.clf()
|
10 |
-
x = shap.sample(x, min(
|
11 |
explainer = shap.KernelExplainer(model.predict, x)
|
12 |
shap_values = explainer(x)
|
13 |
|
@@ -21,7 +21,7 @@ def draw_shap_beeswarm(model, x, feature_names, type, paint_object):
|
|
21 |
|
22 |
def draw_waterfall(model, x, feature_names, number, paint_object):
|
23 |
plt.clf()
|
24 |
-
x = shap.sample(x, min(
|
25 |
explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
|
26 |
shap_values = explainer(x)
|
27 |
|
@@ -35,7 +35,7 @@ def draw_waterfall(model, x, feature_names, number, paint_object):
|
|
35 |
|
36 |
def draw_force(model, x, feature_names, number, paint_object):
|
37 |
plt.clf()
|
38 |
-
x = shap.sample(x, min(
|
39 |
explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
|
40 |
shap_values = explainer(x[number])
|
41 |
|
@@ -49,7 +49,7 @@ def draw_force(model, x, feature_names, number, paint_object):
|
|
49 |
|
50 |
def draw_dependence(model, x, feature_names, col, paint_object):
|
51 |
plt.clf()
|
52 |
-
x = shap.sample(x, min(
|
53 |
explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
|
54 |
shap_values = explainer(x)
|
55 |
|
|
|
7 |
|
8 |
def draw_shap_beeswarm(model, x, feature_names, type, paint_object):
|
9 |
plt.clf()
|
10 |
+
x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
|
11 |
explainer = shap.KernelExplainer(model.predict, x)
|
12 |
shap_values = explainer(x)
|
13 |
|
|
|
21 |
|
22 |
def draw_waterfall(model, x, feature_names, number, paint_object):
|
23 |
plt.clf()
|
24 |
+
x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
|
25 |
explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
|
26 |
shap_values = explainer(x)
|
27 |
|
|
|
35 |
|
36 |
def draw_force(model, x, feature_names, number, paint_object):
|
37 |
plt.clf()
|
38 |
+
x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
|
39 |
explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
|
40 |
shap_values = explainer(x[number])
|
41 |
|
|
|
49 |
|
50 |
def draw_dependence(model, x, feature_names, col, paint_object):
|
51 |
plt.clf()
|
52 |
+
x = shap.sample(x, min(StaticValue.SAMPLE_NUM, len(x)), random_state=StaticValue.RANDOM_STATE)
|
53 |
explainer = shap.KernelExplainer(model.predict, x, feature_names=feature_names)
|
54 |
shap_values = explainer(x)
|
55 |
|
app.py
CHANGED
@@ -387,6 +387,9 @@ class Dataset:
|
|
387 |
|
388 |
cls.container_dict = get_container_dict()
|
389 |
|
|
|
|
|
|
|
390 |
@classmethod
|
391 |
def check_descriptive_indicators_df(cls):
|
392 |
return True if not cls.descriptive_indicators_df.empty else False
|
@@ -1205,7 +1208,7 @@ class Dataset:
|
|
1205 |
EACH_ROW_NUM = 6 - 1
|
1206 |
output_list = []
|
1207 |
|
1208 |
-
if cls.cur_model and cls.
|
1209 |
output_dict = ChooseModelParams.choose(cls.cur_model)
|
1210 |
row_unit_num_list = []
|
1211 |
row_len = len(output_dict.keys())
|
@@ -1531,12 +1534,12 @@ def get_return(is_visible, extra_gr_dict: dict = None):
|
|
1531 |
data_fit_button: gr.Button(LN.data_fit_button, visible=Dataset.check_before_train()),
|
1532 |
waterfall_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
|
1533 |
label=LN.waterfall_radio),
|
1534 |
-
waterfall_number: gr.Slider(0,
|
1535 |
visible=Dataset.check_before_train(), label=LN.waterfall_number),
|
1536 |
waterfall_button: gr.Button(LN.waterfall_button, visible=Dataset.check_before_train()),
|
1537 |
force_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
|
1538 |
label=LN.force_radio),
|
1539 |
-
force_number: gr.Slider(0,
|
1540 |
visible=Dataset.check_before_train(), label=LN.force_number),
|
1541 |
force_button: gr.Button(LN.force_button, visible=Dataset.check_before_train()),
|
1542 |
dependence_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
|
|
|
387 |
|
388 |
cls.container_dict = get_container_dict()
|
389 |
|
390 |
+
visualize = ""
|
391 |
+
choose_optimize = ""
|
392 |
+
|
393 |
@classmethod
|
394 |
def check_descriptive_indicators_df(cls):
|
395 |
return True if not cls.descriptive_indicators_df.empty else False
|
|
|
1208 |
EACH_ROW_NUM = 6 - 1
|
1209 |
output_list = []
|
1210 |
|
1211 |
+
if cls.cur_model and cls.check_model_optimize_radio():
|
1212 |
output_dict = ChooseModelParams.choose(cls.cur_model)
|
1213 |
row_unit_num_list = []
|
1214 |
row_len = len(output_dict.keys())
|
|
|
1534 |
data_fit_button: gr.Button(LN.data_fit_button, visible=Dataset.check_before_train()),
|
1535 |
waterfall_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
|
1536 |
label=LN.waterfall_radio),
|
1537 |
+
waterfall_number: gr.Slider(0, StaticValue.SAMPLE_NUM, value=0, step=1,
|
1538 |
visible=Dataset.check_before_train(), label=LN.waterfall_number),
|
1539 |
waterfall_button: gr.Button(LN.waterfall_button, visible=Dataset.check_before_train()),
|
1540 |
force_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
|
1541 |
label=LN.force_radio),
|
1542 |
+
force_number: gr.Slider(0, StaticValue.SAMPLE_NUM, value=0, step=1,
|
1543 |
visible=Dataset.check_before_train(), label=LN.force_number),
|
1544 |
force_button: gr.Button(LN.force_button, visible=Dataset.check_before_train()),
|
1545 |
dependence_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(),
|
classes/static_custom_class.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
# 全局静态变量值存储类
|
2 |
class StaticValue:
|
|
|
|
|
3 |
# 超参数文本框的最大组件数量
|
4 |
MAX_PARAMS_NUM = 60
|
5 |
# 颜色和标签显示的最大组件数量
|
|
|
1 |
# 全局静态变量值存储类
|
2 |
class StaticValue:
|
3 |
+
# SHAP抽样数量
|
4 |
+
SAMPLE_NUM = 20
|
5 |
# 超参数文本框的最大组件数量
|
6 |
MAX_PARAMS_NUM = 60
|
7 |
# 颜色和标签显示的最大组件数量
|