Spaces:
Sleeping
Sleeping
LLH
commited on
Commit
·
11b81b9
1
Parent(s):
0136ac6
2024/02/16/14:00
Browse files- .idea/.gitignore +8 -0
- .idea/EasyMachineLearningDemo.iml +12 -0
- .idea/inspectionProfiles/Project_Default.xml +19 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- README.md +1 -1
- analysis/shap_model.py +3 -3
- app.py +346 -61
- {diagram → buffer}/__init__.py +0 -0
- static/config.py +3 -0
- static/paint.py +51 -0
- static/process.py +4 -1
- visualization/draw_learning_curve_total.py +18 -23
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
.idea/EasyMachineLearningDemo.iml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$" />
|
5 |
+
<orderEntry type="inheritedJdk" />
|
6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
+
</component>
|
8 |
+
<component name="PyDocumentationSettings">
|
9 |
+
<option name="format" value="PLAIN" />
|
10 |
+
<option name="myDocStringFormat" value="Plain" />
|
11 |
+
</component>
|
12 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
5 |
+
<option name="ignoredErrors">
|
6 |
+
<list>
|
7 |
+
<option value="E501" />
|
8 |
+
</list>
|
9 |
+
</option>
|
10 |
+
</inspection_tool>
|
11 |
+
<inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
12 |
+
<option name="ignoredIdentifiers">
|
13 |
+
<list>
|
14 |
+
<option value="object.pop" />
|
15 |
+
</list>
|
16 |
+
</option>
|
17 |
+
</inspection_tool>
|
18 |
+
</profile>
|
19 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/EasyMachineLearningDemo.iml" filepath="$PROJECT_DIR$/.idea/EasyMachineLearningDemo.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title: EasyMachineLearning
|
3 |
emoji: 🔥
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
|
|
1 |
---
|
2 |
+
title: EasyMachineLearning test
|
3 |
emoji: 🔥
|
4 |
colorFrom: red
|
5 |
colorTo: red
|
analysis/shap_model.py
CHANGED
@@ -3,15 +3,15 @@ import matplotlib.pyplot as plt
|
|
3 |
import shap
|
4 |
|
5 |
|
6 |
-
def shap_calculate(model, x, feature_names):
|
7 |
explainer = shap.Explainer(model.predict, x)
|
8 |
shap_values = explainer(x)
|
9 |
|
10 |
shap.summary_plot(shap_values, x, feature_names=feature_names, show=False)
|
11 |
|
12 |
-
|
13 |
|
14 |
-
|
15 |
|
16 |
|
17 |
|
|
|
3 |
import shap
|
4 |
|
5 |
|
6 |
+
def shap_calculate(model, x, feature_names, paint_object):
|
7 |
explainer = shap.Explainer(model.predict, x)
|
8 |
shap_values = explainer(x)
|
9 |
|
10 |
shap.summary_plot(shap_values, x, feature_names=feature_names, show=False)
|
11 |
|
12 |
+
plt.title(paint_object.get_name())
|
13 |
|
14 |
+
return plt, paint_object
|
15 |
|
16 |
|
17 |
|
app.py
CHANGED
@@ -11,8 +11,10 @@ from analysis.shap_model import shap_calculate
|
|
11 |
from static.process import *
|
12 |
from analysis.linear_model import *
|
13 |
from visualization.draw_learning_curve_total import draw_learning_curve_total
|
|
|
14 |
|
15 |
import warnings
|
|
|
16 |
warnings.filterwarnings("ignore")
|
17 |
|
18 |
|
@@ -68,18 +70,34 @@ class Container:
|
|
68 |
self.model = model
|
69 |
|
70 |
|
|
|
|
|
|
|
|
|
71 |
class FilePath:
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
shap_beeswarm_plot = "shap_beeswarm_plot"
|
74 |
|
75 |
|
76 |
class MN: # ModelName
|
77 |
classification = "classification"
|
78 |
regression = "regression"
|
|
|
79 |
linear_regression = "linear_regression"
|
80 |
polynomial_regression = "polynomial_regression"
|
81 |
logistic_regression = "logistic_regression"
|
82 |
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
class LN: # LabelName
|
85 |
choose_dataset_radio = "选择所需数据源 [必选]"
|
@@ -104,19 +122,54 @@ class LN: # LabelName
|
|
104 |
linear_regression_model_radio = "选择线性回归的模型"
|
105 |
model_optimize_radio = "选择超参数优化方法"
|
106 |
model_train_button = "训练"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
learning_curve_checkboxgroup = "选择所需绘制学习曲线的模型"
|
108 |
learning_curve_train_button = "绘制训练集学习曲线"
|
109 |
learning_curve_validation_button = "绘制验证集学习曲线"
|
110 |
-
learning_curve_train_plot = "绘制训练集学习曲线"
|
111 |
-
learning_curve_validation_plot = "绘制验证集学习曲线"
|
112 |
shap_beeswarm_radio = "选择所需绘制蜂群特征图的模型"
|
113 |
shap_beeswarm_button = "绘制蜂群特征图"
|
|
|
|
|
|
|
114 |
shap_beeswarm_plot = "蜂群特征图"
|
115 |
-
select_as_model_radio = "选择所需训练的模型"
|
116 |
|
117 |
|
118 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
gr_dict = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
choose_custom_dataset_file,
|
121 |
display_dataset_dataframe,
|
122 |
display_total_col_num_text,
|
@@ -141,26 +194,35 @@ def get_outputs():
|
|
141 |
model_optimize_radio,
|
142 |
model_train_button,
|
143 |
model_train_checkbox,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
learning_curve_checkboxgroup,
|
145 |
learning_curve_train_button,
|
146 |
learning_curve_validation_button,
|
147 |
-
learning_curve_train_plot,
|
148 |
-
learning_curve_validation_plot,
|
149 |
shap_beeswarm_radio,
|
150 |
shap_beeswarm_button,
|
151 |
-
shap_beeswarm_plot,
|
152 |
-
shap_beeswarm_plot_file,
|
153 |
-
select_as_model_radio,
|
154 |
-
choose_assign_radio,
|
155 |
}
|
156 |
|
157 |
-
|
|
|
|
|
|
|
|
|
158 |
|
159 |
|
160 |
def get_return(is_visible, extra_gr_dict: dict = None):
|
161 |
if is_visible:
|
162 |
gr_dict = {
|
163 |
display_dataset_dataframe: gr.Dataframe(add_index_into_df(Dataset.data), type="pandas", visible=True),
|
|
|
164 |
display_total_col_num_text: gr.Textbox(str(Dataset.get_total_col_num()), visible=True, label=LN.display_total_col_num_text),
|
165 |
display_total_row_num_text: gr.Textbox(str(Dataset.get_total_row_num()), visible=True, label=LN.display_total_row_num_text),
|
166 |
display_na_list_text: gr.Textbox(Dataset.get_na_list_str(), visible=True, label=LN.display_na_list_text),
|
@@ -188,14 +250,25 @@ def get_return(is_visible, extra_gr_dict: dict = None):
|
|
188 |
|
189 |
model_train_button: gr.Button(LN.model_train_button, visible=Dataset.check_before_train()),
|
190 |
model_train_checkbox: gr.Checkbox(Dataset.get_model_container_status(), visible=Dataset.check_select_model(), label=Dataset.get_model_label()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
learning_curve_checkboxgroup: gr.Checkboxgroup(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.learning_curve_checkboxgroup),
|
192 |
learning_curve_train_button: gr.Button(LN.learning_curve_train_button, visible=Dataset.check_before_train()),
|
193 |
learning_curve_validation_button: gr.Button(LN.learning_curve_validation_button, visible=Dataset.check_before_train()),
|
194 |
shap_beeswarm_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.shap_beeswarm_radio),
|
195 |
shap_beeswarm_button: gr.Button(LN.shap_beeswarm_button, visible=Dataset.check_before_train()),
|
196 |
-
shap_beeswarm_plot_file: gr.File(Dataset.after_get_shap_beeswarm_plot_file(), visible=Dataset.check_shap_beeswarm_plot_file()),
|
197 |
}
|
198 |
|
|
|
|
|
|
|
|
|
199 |
if extra_gr_dict:
|
200 |
gr_dict.update(extra_gr_dict)
|
201 |
|
@@ -204,6 +277,7 @@ def get_return(is_visible, extra_gr_dict: dict = None):
|
|
204 |
gr_dict = {
|
205 |
choose_custom_dataset_file: gr.File(None, visible=True),
|
206 |
display_dataset_dataframe: gr.Dataframe(visible=False),
|
|
|
207 |
display_total_col_num_text: gr.Textbox(visible=False),
|
208 |
display_total_row_num_text: gr.Textbox(visible=False),
|
209 |
display_na_list_text: gr.Textbox(visible=False),
|
@@ -225,19 +299,27 @@ def get_return(is_visible, extra_gr_dict: dict = None):
|
|
225 |
model_optimize_radio: gr.Radio(visible=False),
|
226 |
model_train_button: gr.Button(visible=False),
|
227 |
model_train_checkbox: gr.Checkbox(visible=False),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
learning_curve_checkboxgroup: gr.Checkboxgroup(visible=False),
|
229 |
learning_curve_train_button: gr.Button(visible=False),
|
230 |
learning_curve_validation_button: gr.Button(visible=False),
|
231 |
-
learning_curve_train_plot: gr.Plot(visible=False),
|
232 |
-
learning_curve_validation_plot: gr.Plot(visible=False),
|
233 |
shap_beeswarm_radio: gr.Radio(visible=False),
|
234 |
shap_beeswarm_button: gr.Button(visible=False),
|
235 |
-
shap_beeswarm_plot: gr.Plot(visible=False),
|
236 |
-
shap_beeswarm_plot_file: gr.File(visible=False),
|
237 |
-
select_as_model_radio: gr.Radio(visible=False),
|
238 |
-
choose_assign_radio: gr.Radio(visible=False),
|
239 |
}
|
240 |
|
|
|
|
|
|
|
|
|
241 |
return gr_dict
|
242 |
|
243 |
|
@@ -260,6 +342,8 @@ class Dataset:
|
|
260 |
MN.logistic_regression: Container(),
|
261 |
}
|
262 |
|
|
|
|
|
263 |
@classmethod
|
264 |
def get_dataset_list(cls):
|
265 |
return ["Iris Dataset", "Wine Dataset", "Breast Cancer Dataset", "自定义"]
|
@@ -309,6 +393,23 @@ class Dataset:
|
|
309 |
cls.file = ""
|
310 |
cls.data = pd.DataFrame()
|
311 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
@classmethod
|
313 |
def del_col(cls, col_list: list):
|
314 |
for col in col_list:
|
@@ -431,7 +532,8 @@ class Dataset:
|
|
431 |
|
432 |
for col in cls.data.columns.values:
|
433 |
if cls.data[col].dtype.name in ["int64", "float64"]:
|
434 |
-
if not np.array_equal(np.round(preprocessing.scale(cls.data[col]), decimals=2),
|
|
|
435 |
not_standardized_data_list.append(col)
|
436 |
|
437 |
return not_standardized_data_list
|
@@ -443,7 +545,8 @@ class Dataset:
|
|
443 |
|
444 |
for i, col in enumerate(cls.data.columns.values):
|
445 |
if i == 0:
|
446 |
-
if not (all(isinstance(x, str) for x in cls.data.iloc[:, 0]) or all(
|
|
|
447 |
return False
|
448 |
else:
|
449 |
if cls.data[col].dtype.name != "float64":
|
@@ -541,43 +644,98 @@ class Dataset:
|
|
541 |
return trained_model_list
|
542 |
|
543 |
@classmethod
|
544 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
learning_curve_dict = {}
|
546 |
|
547 |
for model_name in model_list:
|
548 |
model_name = cls.get_model_name_mapping_reverse()[model_name]
|
549 |
learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
|
550 |
|
551 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
552 |
|
553 |
@classmethod
|
554 |
-
def draw_learning_curve_validation_plot(cls, model_list: list
|
555 |
learning_curve_dict = {}
|
556 |
|
557 |
for model_name in model_list:
|
558 |
model_name = cls.get_model_name_mapping_reverse()[model_name]
|
559 |
learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
|
560 |
|
561 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
@classmethod
|
564 |
-
def draw_shap_beeswarm_plot(cls, model_name
|
565 |
model_name = cls.get_model_name_mapping_reverse()[model_name]
|
566 |
container = cls.container_dict[model_name]
|
567 |
|
568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
|
570 |
@classmethod
|
571 |
-
def
|
572 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
573 |
|
574 |
@classmethod
|
575 |
-
def
|
576 |
-
return os.path.exists(cls.
|
577 |
|
578 |
@classmethod
|
579 |
-
def
|
580 |
-
return cls.
|
581 |
|
582 |
@classmethod
|
583 |
def get_model_list(cls):
|
@@ -614,13 +772,37 @@ class Dataset:
|
|
614 |
data_copy = cls.data
|
615 |
|
616 |
if cls.assign == MN.classification:
|
617 |
-
data_copy.iloc[0
|
618 |
else:
|
619 |
-
data_copy.iloc[0
|
620 |
|
621 |
cls.data = data_copy
|
622 |
cls.change_data_type_to_float()
|
623 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
|
625 |
def choose_assign(assign: str):
|
626 |
Dataset.choose_assign(assign)
|
@@ -634,24 +816,85 @@ def select_as_model(model_name: str):
|
|
634 |
return get_return(True)
|
635 |
|
636 |
|
637 |
-
|
638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
|
640 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
641 |
|
642 |
-
|
643 |
|
|
|
644 |
|
645 |
-
def draw_learning_curve_validation_plot(model_list: list):
|
646 |
-
cur_plt = Dataset.draw_learning_curve_validation_plot(model_list)
|
647 |
|
648 |
-
|
|
|
649 |
|
650 |
|
651 |
-
def
|
652 |
-
|
|
|
|
|
|
|
|
|
|
|
653 |
|
654 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
655 |
|
656 |
|
657 |
def train_model(optimize, linear_regression_model_type):
|
@@ -681,7 +924,9 @@ def change_data_type_to_float():
|
|
681 |
def encode_label(col_list: list):
|
682 |
Dataset.encode_label(col_list)
|
683 |
|
684 |
-
return get_return(True, {
|
|
|
|
|
685 |
|
686 |
|
687 |
def del_duplicate():
|
@@ -737,7 +982,6 @@ def choose_custom_dataset(file: str):
|
|
737 |
|
738 |
|
739 |
with gr.Blocks() as demo:
|
740 |
-
|
741 |
'''
|
742 |
组件
|
743 |
'''
|
@@ -752,6 +996,7 @@ with gr.Blocks() as demo:
|
|
752 |
# 显示数据表信息
|
753 |
with gr.Accordion("当前数据信息"):
|
754 |
display_dataset_dataframe = gr.Dataframe(visible=False)
|
|
|
755 |
with gr.Row():
|
756 |
display_total_col_num_text = gr.Textbox(visible=False)
|
757 |
display_total_row_num_text = gr.Textbox(visible=False)
|
@@ -794,17 +1039,43 @@ with gr.Blocks() as demo:
|
|
794 |
|
795 |
# 可视化
|
796 |
with gr.Accordion("数据可视化"):
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
808 |
|
809 |
'''
|
810 |
监听事件
|
@@ -840,9 +1111,23 @@ with gr.Blocks() as demo:
|
|
840 |
model_train_button.click(fn=train_model, inputs=[model_optimize_radio, linear_regression_model_radio], outputs=get_outputs())
|
841 |
|
842 |
# 可视化
|
843 |
-
learning_curve_train_button.click(fn=
|
844 |
-
learning_curve_validation_button.click(fn=
|
845 |
-
shap_beeswarm_button.click(fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
846 |
|
847 |
if __name__ == "__main__":
|
848 |
demo.launch()
|
|
|
11 |
from static.process import *
|
12 |
from analysis.linear_model import *
|
13 |
from visualization.draw_learning_curve_total import draw_learning_curve_total
|
14 |
+
from static.paint import *
|
15 |
|
16 |
import warnings
|
17 |
+
|
18 |
warnings.filterwarnings("ignore")
|
19 |
|
20 |
|
|
|
70 |
self.model = model
|
71 |
|
72 |
|
73 |
+
class StaticValue:
|
74 |
+
max_num = 10
|
75 |
+
|
76 |
+
|
77 |
class FilePath:
|
78 |
+
png_base = "./buffer/{}.png"
|
79 |
+
excel_base = "./buffer/{}.xlsx"
|
80 |
+
|
81 |
+
# [绘图]
|
82 |
+
display_dataset = "current_excel_data"
|
83 |
+
learning_curve_train_plot = "learning_curve_train_plot"
|
84 |
+
learning_curve_validation_plot = "learning_curve_validation_plot"
|
85 |
shap_beeswarm_plot = "shap_beeswarm_plot"
|
86 |
|
87 |
|
88 |
class MN: # ModelName
|
89 |
classification = "classification"
|
90 |
regression = "regression"
|
91 |
+
|
92 |
linear_regression = "linear_regression"
|
93 |
polynomial_regression = "polynomial_regression"
|
94 |
logistic_regression = "logistic_regression"
|
95 |
|
96 |
+
# [绘图]
|
97 |
+
learning_curve_train = "learning_curve_train"
|
98 |
+
learning_curve_validation = "learning_curve_validation"
|
99 |
+
shap_beeswarm = "shap_beeswarm"
|
100 |
+
|
101 |
|
102 |
class LN: # LabelName
|
103 |
choose_dataset_radio = "选择所需数据源 [必选]"
|
|
|
122 |
linear_regression_model_radio = "选择线性回归的模型"
|
123 |
model_optimize_radio = "选择超参数优化方法"
|
124 |
model_train_button = "训练"
|
125 |
+
select_as_model_radio = "选择所需训练的模型"
|
126 |
+
|
127 |
+
title_name_textbox = "标题"
|
128 |
+
x_label_textbox = "x 轴名称"
|
129 |
+
y_label_textbox = "y 轴名称"
|
130 |
+
colors = ["颜色 {}".format(i) for i in range(StaticValue.max_num)]
|
131 |
+
labels = ["图例 {}".format(i) for i in range(StaticValue.max_num)]
|
132 |
+
|
133 |
+
# [绘图]
|
134 |
learning_curve_checkboxgroup = "选择所需绘制学习曲线的模型"
|
135 |
learning_curve_train_button = "绘制训练集学习曲线"
|
136 |
learning_curve_validation_button = "绘制验证集学习曲线"
|
|
|
|
|
137 |
shap_beeswarm_radio = "选择所需绘制蜂群特征图的模型"
|
138 |
shap_beeswarm_button = "绘制蜂群特征图"
|
139 |
+
|
140 |
+
learning_curve_train_plot = "训练集学习曲线"
|
141 |
+
learning_curve_validation_plot = "验证集学习曲线"
|
142 |
shap_beeswarm_plot = "蜂群特征图"
|
|
|
143 |
|
144 |
|
145 |
+
def get_return_extra(is_visible, extra_gr_dict: dict = None):
|
146 |
+
if is_visible:
|
147 |
+
gr_dict = {
|
148 |
+
draw_file: gr.File(Dataset.after_get_file(), visible=Dataset.check_file()),
|
149 |
+
}
|
150 |
+
|
151 |
+
if extra_gr_dict:
|
152 |
+
gr_dict.update(extra_gr_dict)
|
153 |
+
|
154 |
+
return gr_dict
|
155 |
+
|
156 |
gr_dict = {
|
157 |
+
draw_plot: gr.Plot(visible=False),
|
158 |
+
draw_file: gr.File(visible=False),
|
159 |
+
}
|
160 |
+
|
161 |
+
gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
|
162 |
+
gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
|
163 |
+
gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
|
164 |
+
gr_dict.update({title_name_textbox: gr.Textbox(visible=False)})
|
165 |
+
gr_dict.update({x_label_textbox: gr.Textbox(visible=False)})
|
166 |
+
gr_dict.update({y_label_textbox: gr.Textbox(visible=False)})
|
167 |
+
|
168 |
+
return gr_dict
|
169 |
+
|
170 |
+
|
171 |
+
def get_outputs():
|
172 |
+
gr_set = {
|
173 |
choose_custom_dataset_file,
|
174 |
display_dataset_dataframe,
|
175 |
display_total_col_num_text,
|
|
|
194 |
model_optimize_radio,
|
195 |
model_train_button,
|
196 |
model_train_checkbox,
|
197 |
+
select_as_model_radio,
|
198 |
+
choose_assign_radio,
|
199 |
+
display_dataset,
|
200 |
+
draw_plot,
|
201 |
+
draw_file,
|
202 |
+
title_name_textbox,
|
203 |
+
x_label_textbox,
|
204 |
+
y_label_textbox,
|
205 |
+
|
206 |
+
# [绘图]
|
207 |
learning_curve_checkboxgroup,
|
208 |
learning_curve_train_button,
|
209 |
learning_curve_validation_button,
|
|
|
|
|
210 |
shap_beeswarm_radio,
|
211 |
shap_beeswarm_button,
|
|
|
|
|
|
|
|
|
212 |
}
|
213 |
|
214 |
+
gr_set.update(set(colorpickers))
|
215 |
+
gr_set.update(set(color_textboxs))
|
216 |
+
gr_set.update(set(legend_labels_textboxs))
|
217 |
+
|
218 |
+
return gr_set
|
219 |
|
220 |
|
221 |
def get_return(is_visible, extra_gr_dict: dict = None):
|
222 |
if is_visible:
|
223 |
gr_dict = {
|
224 |
display_dataset_dataframe: gr.Dataframe(add_index_into_df(Dataset.data), type="pandas", visible=True),
|
225 |
+
display_dataset: gr.File(Dataset.after_get_display_dataset_file(), visible=Dataset.check_display_dataset_file()),
|
226 |
display_total_col_num_text: gr.Textbox(str(Dataset.get_total_col_num()), visible=True, label=LN.display_total_col_num_text),
|
227 |
display_total_row_num_text: gr.Textbox(str(Dataset.get_total_row_num()), visible=True, label=LN.display_total_row_num_text),
|
228 |
display_na_list_text: gr.Textbox(Dataset.get_na_list_str(), visible=True, label=LN.display_na_list_text),
|
|
|
250 |
|
251 |
model_train_button: gr.Button(LN.model_train_button, visible=Dataset.check_before_train()),
|
252 |
model_train_checkbox: gr.Checkbox(Dataset.get_model_container_status(), visible=Dataset.check_select_model(), label=Dataset.get_model_label()),
|
253 |
+
|
254 |
+
draw_plot: gr.Plot(visible=False),
|
255 |
+
draw_file: gr.File(visible=False),
|
256 |
+
title_name_textbox: gr.Textbox(visible=False),
|
257 |
+
x_label_textbox: gr.Textbox(visible=False),
|
258 |
+
y_label_textbox: gr.Textbox(visible=False),
|
259 |
+
|
260 |
+
# [绘图]
|
261 |
learning_curve_checkboxgroup: gr.Checkboxgroup(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.learning_curve_checkboxgroup),
|
262 |
learning_curve_train_button: gr.Button(LN.learning_curve_train_button, visible=Dataset.check_before_train()),
|
263 |
learning_curve_validation_button: gr.Button(LN.learning_curve_validation_button, visible=Dataset.check_before_train()),
|
264 |
shap_beeswarm_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.shap_beeswarm_radio),
|
265 |
shap_beeswarm_button: gr.Button(LN.shap_beeswarm_button, visible=Dataset.check_before_train()),
|
|
|
266 |
}
|
267 |
|
268 |
+
gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
|
269 |
+
gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
|
270 |
+
gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
|
271 |
+
|
272 |
if extra_gr_dict:
|
273 |
gr_dict.update(extra_gr_dict)
|
274 |
|
|
|
277 |
gr_dict = {
|
278 |
choose_custom_dataset_file: gr.File(None, visible=True),
|
279 |
display_dataset_dataframe: gr.Dataframe(visible=False),
|
280 |
+
display_dataset: gr.File(visible=False),
|
281 |
display_total_col_num_text: gr.Textbox(visible=False),
|
282 |
display_total_row_num_text: gr.Textbox(visible=False),
|
283 |
display_na_list_text: gr.Textbox(visible=False),
|
|
|
299 |
model_optimize_radio: gr.Radio(visible=False),
|
300 |
model_train_button: gr.Button(visible=False),
|
301 |
model_train_checkbox: gr.Checkbox(visible=False),
|
302 |
+
select_as_model_radio: gr.Radio(visible=False),
|
303 |
+
choose_assign_radio: gr.Radio(visible=False),
|
304 |
+
|
305 |
+
draw_plot: gr.Plot(visible=False),
|
306 |
+
draw_file: gr.File(visible=False),
|
307 |
+
title_name_textbox: gr.Textbox(visible=False),
|
308 |
+
x_label_textbox: gr.Textbox(visible=False),
|
309 |
+
y_label_textbox: gr.Textbox(visible=False),
|
310 |
+
|
311 |
+
# [绘图]
|
312 |
learning_curve_checkboxgroup: gr.Checkboxgroup(visible=False),
|
313 |
learning_curve_train_button: gr.Button(visible=False),
|
314 |
learning_curve_validation_button: gr.Button(visible=False),
|
|
|
|
|
315 |
shap_beeswarm_radio: gr.Radio(visible=False),
|
316 |
shap_beeswarm_button: gr.Button(visible=False),
|
|
|
|
|
|
|
|
|
317 |
}
|
318 |
|
319 |
+
gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
|
320 |
+
gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
|
321 |
+
gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
|
322 |
+
|
323 |
return gr_dict
|
324 |
|
325 |
|
|
|
342 |
MN.logistic_regression: Container(),
|
343 |
}
|
344 |
|
345 |
+
visualize = ""
|
346 |
+
|
347 |
@classmethod
|
348 |
def get_dataset_list(cls):
|
349 |
return ["Iris Dataset", "Wine Dataset", "Breast Cancer Dataset", "自定义"]
|
|
|
393 |
cls.file = ""
|
394 |
cls.data = pd.DataFrame()
|
395 |
|
396 |
+
@classmethod
|
397 |
+
def get_display_dataset_file(cls):
|
398 |
+
file_path = FilePath.excel_base.format(FilePath.display_dataset)
|
399 |
+
|
400 |
+
return file_path
|
401 |
+
|
402 |
+
@classmethod
|
403 |
+
def check_display_dataset_file(cls):
|
404 |
+
return os.path.exists(cls.get_display_dataset_file())
|
405 |
+
|
406 |
+
@classmethod
|
407 |
+
def after_get_display_dataset_file(cls):
|
408 |
+
if not cls.data.empty:
|
409 |
+
cls.data.to_excel(cls.get_display_dataset_file(), index=False)
|
410 |
+
|
411 |
+
return cls.get_display_dataset_file() if cls.check_display_dataset_file() else None
|
412 |
+
|
413 |
@classmethod
|
414 |
def del_col(cls, col_list: list):
|
415 |
for col in col_list:
|
|
|
532 |
|
533 |
for col in cls.data.columns.values:
|
534 |
if cls.data[col].dtype.name in ["int64", "float64"]:
|
535 |
+
if not np.array_equal(np.round(preprocessing.scale(cls.data[col]), decimals=2),
|
536 |
+
np.round(cls.data[col].values.round(2), decimals=2)):
|
537 |
not_standardized_data_list.append(col)
|
538 |
|
539 |
return not_standardized_data_list
|
|
|
545 |
|
546 |
for i, col in enumerate(cls.data.columns.values):
|
547 |
if i == 0:
|
548 |
+
if not (all(isinstance(x, str) for x in cls.data.iloc[:, 0]) or all(
|
549 |
+
isinstance(x, float) for x in cls.data.iloc[:, 0])):
|
550 |
return False
|
551 |
else:
|
552 |
if cls.data[col].dtype.name != "float64":
|
|
|
644 |
return trained_model_list
|
645 |
|
646 |
@classmethod
|
647 |
+
def draw_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
|
648 |
+
# [绘图]
|
649 |
+
if cls.visualize == MN.learning_curve_train:
|
650 |
+
return cls.draw_learning_curve_train_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
|
651 |
+
elif cls.visualize == MN.learning_curve_validation:
|
652 |
+
return cls.draw_learning_curve_validation_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
|
653 |
+
elif cls.visualize == MN.shap_beeswarm:
|
654 |
+
return cls.draw_shap_beeswarm_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
|
655 |
+
|
656 |
+
@classmethod
|
657 |
+
def draw_learning_curve_train_plot(cls, model_list, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
|
658 |
learning_curve_dict = {}
|
659 |
|
660 |
for model_name in model_list:
|
661 |
model_name = cls.get_model_name_mapping_reverse()[model_name]
|
662 |
learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
|
663 |
|
664 |
+
color_cur_list = Config.COLORS if is_default else color_list
|
665 |
+
label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
|
666 |
+
x_cur_label = "Train Sizes" if is_default else x_label
|
667 |
+
y_cur_label = "Accuracy" if is_default else y_label
|
668 |
+
cur_name = "" if is_default else name
|
669 |
+
|
670 |
+
paint_object = PaintObject()
|
671 |
+
paint_object.set_color_cur_list(color_cur_list)
|
672 |
+
paint_object.set_label_cur_list(label_cur_list)
|
673 |
+
paint_object.set_x_cur_label(x_cur_label)
|
674 |
+
paint_object.set_y_cur_label(y_cur_label)
|
675 |
+
paint_object.set_name(cur_name)
|
676 |
+
|
677 |
+
return draw_learning_curve_total(learning_curve_dict, "train", paint_object)
|
678 |
|
679 |
@classmethod
|
680 |
+
def draw_learning_curve_validation_plot(cls, model_list, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
|
681 |
learning_curve_dict = {}
|
682 |
|
683 |
for model_name in model_list:
|
684 |
model_name = cls.get_model_name_mapping_reverse()[model_name]
|
685 |
learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
|
686 |
|
687 |
+
color_cur_list = Config.COLORS if is_default else color_list
|
688 |
+
label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
|
689 |
+
x_cur_label = "Train Sizes" if is_default else x_label
|
690 |
+
y_cur_label = "Accuracy" if is_default else y_label
|
691 |
+
cur_name = "" if is_default else name
|
692 |
+
|
693 |
+
paint_object = PaintObject()
|
694 |
+
paint_object.set_color_cur_list(color_cur_list)
|
695 |
+
paint_object.set_label_cur_list(label_cur_list)
|
696 |
+
paint_object.set_x_cur_label(x_cur_label)
|
697 |
+
paint_object.set_y_cur_label(y_cur_label)
|
698 |
+
paint_object.set_name(cur_name)
|
699 |
+
|
700 |
+
return draw_learning_curve_total(learning_curve_dict, "validation", paint_object)
|
701 |
|
702 |
@classmethod
|
703 |
+
def draw_shap_beeswarm_plot(cls, model_name, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
|
704 |
model_name = cls.get_model_name_mapping_reverse()[model_name]
|
705 |
container = cls.container_dict[model_name]
|
706 |
|
707 |
+
# color_cur_list = Config.COLORS if is_default else color_list
|
708 |
+
# label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
|
709 |
+
# x_cur_label = "Train Sizes" if is_default else x_label
|
710 |
+
# y_cur_label = "Accuracy" if is_default else y_label
|
711 |
+
cur_name = "" if is_default else name
|
712 |
+
|
713 |
+
paint_object = PaintObject()
|
714 |
+
# paint_object.set_color_cur_list(color_cur_list)
|
715 |
+
# paint_object.set_label_cur_list(label_cur_list)
|
716 |
+
# paint_object.set_x_cur_label(x_cur_label)
|
717 |
+
# paint_object.set_y_cur_label(y_cur_label)
|
718 |
+
paint_object.set_name(cur_name)
|
719 |
+
|
720 |
+
return shap_calculate(container.get_model(), container.x_train, cls.data.columns.values, paint_object)
|
721 |
|
722 |
@classmethod
|
723 |
+
def get_file(cls):
|
724 |
+
# [绘图]
|
725 |
+
if cls.visualize == MN.learning_curve_train:
|
726 |
+
return FilePath.png_base.format(FilePath.learning_curve_train_plot)
|
727 |
+
elif cls.visualize == MN.learning_curve_validation:
|
728 |
+
return FilePath.png_base.format(FilePath.learning_curve_validation_plot)
|
729 |
+
elif cls.visualize == MN.shap_beeswarm:
|
730 |
+
return FilePath.png_base.format(FilePath.shap_beeswarm_plot)
|
731 |
|
732 |
@classmethod
|
733 |
+
def check_file(cls):
|
734 |
+
return os.path.exists(cls.get_file())
|
735 |
|
736 |
@classmethod
|
737 |
+
def after_get_file(cls):
|
738 |
+
return cls.get_file() if cls.check_file() else None
|
739 |
|
740 |
@classmethod
|
741 |
def get_model_list(cls):
|
|
|
772 |
data_copy = cls.data
|
773 |
|
774 |
if cls.assign == MN.classification:
|
775 |
+
data_copy.iloc[:, 0] = data_copy.iloc[:, 0].astype(str)
|
776 |
else:
|
777 |
+
data_copy.iloc[:, 0] = data_copy.iloc[:, 0].astype(float)
|
778 |
|
779 |
cls.data = data_copy
|
780 |
cls.change_data_type_to_float()
|
781 |
|
782 |
+
@classmethod
|
783 |
+
def colorpickers_change(cls, paint_object):
|
784 |
+
cur_num = paint_object.get_color_cur_num()
|
785 |
+
|
786 |
+
true_list = [gr.ColorPicker(paint_object.get_color_cur_list()[i], visible=True, label=LN.colors[i]) for i in range(cur_num)]
|
787 |
+
|
788 |
+
return true_list + [gr.ColorPicker(visible=False)] * (StaticValue.max_num - cur_num)
|
789 |
+
|
790 |
+
@classmethod
|
791 |
+
def color_textboxs_change(cls, paint_object):
|
792 |
+
cur_num = paint_object.get_color_cur_num()
|
793 |
+
|
794 |
+
true_list = [gr.Textbox(paint_object.get_color_cur_list()[i], visible=True, show_label=False) for i in range(cur_num)]
|
795 |
+
|
796 |
+
return true_list + [gr.Textbox(visible=False)] * (StaticValue.max_num - cur_num)
|
797 |
+
|
798 |
+
@classmethod
|
799 |
+
def labels_change(cls, paint_object):
|
800 |
+
cur_num = paint_object.get_label_cur_num()
|
801 |
+
|
802 |
+
true_list = [gr.Textbox(paint_object.get_label_cur_list()[i], visible=True, label=LN.labels[i]) for i in range(cur_num)]
|
803 |
+
|
804 |
+
return true_list + [gr.Textbox(visible=False)] * (StaticValue.max_num - cur_num)
|
805 |
+
|
806 |
|
807 |
def choose_assign(assign: str):
|
808 |
Dataset.choose_assign(assign)
|
|
|
816 |
return get_return(True)
|
817 |
|
818 |
|
819 |
+
# [绘图]
|
820 |
+
def shap_beeswarm_first_draw_plot(*inputs):
|
821 |
+
Dataset.visualize = MN.shap_beeswarm
|
822 |
+
return first_draw_plot(inputs)
|
823 |
+
|
824 |
+
|
825 |
+
def learning_curve_validation_first_draw_plot(*inputs):
|
826 |
+
Dataset.visualize = MN.learning_curve_validation
|
827 |
+
return first_draw_plot(inputs)
|
828 |
+
|
829 |
+
|
830 |
+
def learning_curve_train_first_draw_plot(*inputs):
|
831 |
+
Dataset.visualize = MN.learning_curve_train
|
832 |
+
return first_draw_plot(inputs)
|
833 |
+
|
834 |
|
835 |
+
def first_draw_plot(inputs):
|
836 |
+
select_model = inputs[0]
|
837 |
+
x_label = ""
|
838 |
+
y_label = ""
|
839 |
+
name = ""
|
840 |
+
color_list = []
|
841 |
+
label_list = []
|
842 |
|
843 |
+
cur_plt, paint_object = Dataset.draw_plot(select_model, color_list, label_list, name, x_label, y_label, True)
|
844 |
|
845 |
+
return first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object)
|
846 |
|
|
|
|
|
847 |
|
848 |
+
def out_non_first_draw_plot(*inputs):
|
849 |
+
return non_first_draw_plot(inputs)
|
850 |
|
851 |
|
852 |
+
def non_first_draw_plot(inputs):
|
853 |
+
name = inputs[0]
|
854 |
+
x_label = inputs[1]
|
855 |
+
y_label = inputs[2]
|
856 |
+
color_list = list(inputs[3: StaticValue.max_num+3])
|
857 |
+
label_list = list(inputs[StaticValue.max_num+3: 2*StaticValue.max_num+3])
|
858 |
+
start_index = 2*StaticValue.max_num+3
|
859 |
|
860 |
+
# 绘图
|
861 |
+
if Dataset.visualize == MN.learning_curve_train:
|
862 |
+
select_model = inputs[start_index]
|
863 |
+
elif Dataset.visualize == MN.learning_curve_validation:
|
864 |
+
select_model = inputs[start_index]
|
865 |
+
elif Dataset.visualize == MN.shap_beeswarm:
|
866 |
+
select_model = inputs[start_index+1]
|
867 |
+
|
868 |
+
else:
|
869 |
+
select_model = inputs[start_index: start_index+1]
|
870 |
+
|
871 |
+
cur_plt, paint_object = Dataset.draw_plot(select_model, color_list, label_list, name, x_label, y_label, False)
|
872 |
+
|
873 |
+
return first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object)
|
874 |
+
|
875 |
+
|
876 |
+
def first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object):
|
877 |
+
extra_gr_dict = {}
|
878 |
+
|
879 |
+
# [绘图]
|
880 |
+
if Dataset.visualize == MN.learning_curve_train:
|
881 |
+
cur_plt.savefig(FilePath.png_base.format(FilePath.learning_curve_train_plot), dpi=300)
|
882 |
+
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.learning_curve_train_plot)})
|
883 |
+
elif Dataset.visualize == MN.learning_curve_validation:
|
884 |
+
cur_plt.savefig(FilePath.png_base.format(FilePath.learning_curve_validation_plot), dpi=300)
|
885 |
+
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.learning_curve_validation_plot)})
|
886 |
+
elif Dataset.visualize == MN.shap_beeswarm:
|
887 |
+
cur_plt.savefig(FilePath.png_base.format(FilePath.shap_beeswarm_plot), dpi=300)
|
888 |
+
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.shap_beeswarm_plot)})
|
889 |
+
|
890 |
+
extra_gr_dict.update(dict(zip(colorpickers, Dataset.colorpickers_change(paint_object))))
|
891 |
+
extra_gr_dict.update(dict(zip(color_textboxs, Dataset.color_textboxs_change(paint_object))))
|
892 |
+
extra_gr_dict.update(dict(zip(legend_labels_textboxs, Dataset.labels_change(paint_object))))
|
893 |
+
extra_gr_dict.update({title_name_textbox: gr.Textbox(paint_object.get_name(), visible=True, label=LN.title_name_textbox)})
|
894 |
+
extra_gr_dict.update({x_label_textbox: gr.Textbox(paint_object.get_x_cur_label(), visible=True, label=LN.x_label_textbox)})
|
895 |
+
extra_gr_dict.update({y_label_textbox: gr.Textbox(paint_object.get_y_cur_label(), visible=True, label=LN.y_label_textbox)})
|
896 |
+
|
897 |
+
return get_return_extra(True, extra_gr_dict)
|
898 |
|
899 |
|
900 |
def train_model(optimize, linear_regression_model_type):
|
|
|
924 |
def encode_label(col_list: list):
|
925 |
Dataset.encode_label(col_list)
|
926 |
|
927 |
+
return get_return(True, {
|
928 |
+
display_encode_label_dataframe: gr.Dataframe(Dataset.get_str2int_mappings_df(), type="pandas", visible=True,
|
929 |
+
label=LN.display_encode_label_dataframe)})
|
930 |
|
931 |
|
932 |
def del_duplicate():
|
|
|
982 |
|
983 |
|
984 |
with gr.Blocks() as demo:
|
|
|
985 |
'''
|
986 |
组件
|
987 |
'''
|
|
|
996 |
# 显示数据表信息
|
997 |
with gr.Accordion("当前数据信息"):
|
998 |
display_dataset_dataframe = gr.Dataframe(visible=False)
|
999 |
+
display_dataset = gr.File(visible=False)
|
1000 |
with gr.Row():
|
1001 |
display_total_col_num_text = gr.Textbox(visible=False)
|
1002 |
display_total_row_num_text = gr.Textbox(visible=False)
|
|
|
1039 |
|
1040 |
# 可视化
|
1041 |
with gr.Accordion("数据可视化"):
|
1042 |
+
with gr.Tab("学习曲线图"):
|
1043 |
+
learning_curve_checkboxgroup = gr.Checkboxgroup(visible=False)
|
1044 |
+
with gr.Row():
|
1045 |
+
learning_curve_train_button = gr.Button(visible=False)
|
1046 |
+
learning_curve_validation_button = gr.Button(visible=False)
|
1047 |
+
|
1048 |
+
with gr.Tab("蜂群特征图"):
|
1049 |
+
shap_beeswarm_radio = gr.Radio(visible=False)
|
1050 |
+
shap_beeswarm_button = gr.Button(visible=False)
|
1051 |
+
|
1052 |
+
legend_labels_textboxs = []
|
1053 |
+
with gr.Accordion("图例"):
|
1054 |
+
with gr.Row():
|
1055 |
+
for i in range(StaticValue.max_num):
|
1056 |
+
with gr.Row():
|
1057 |
+
label = gr.Textbox(visible=False)
|
1058 |
+
legend_labels_textboxs.append(label)
|
1059 |
+
|
1060 |
+
with gr.Accordion("坐标轴"):
|
1061 |
+
with gr.Row():
|
1062 |
+
title_name_textbox = gr.Textbox(visible=False)
|
1063 |
+
x_label_textbox = gr.Textbox(visible=False)
|
1064 |
+
y_label_textbox = gr.Textbox(visible=False)
|
1065 |
+
|
1066 |
+
colorpickers = []
|
1067 |
+
color_textboxs = []
|
1068 |
+
with gr.Accordion("颜色"):
|
1069 |
+
with gr.Row():
|
1070 |
+
for i in range(StaticValue.max_num):
|
1071 |
+
with gr.Row():
|
1072 |
+
colorpicker = gr.ColorPicker(visible=False)
|
1073 |
+
colorpickers.append(colorpicker)
|
1074 |
+
color_textbox = gr.Textbox(visible=False)
|
1075 |
+
color_textboxs.append(color_textbox)
|
1076 |
+
|
1077 |
+
draw_plot = gr.Plot(visible=False)
|
1078 |
+
draw_file = gr.File(visible=False)
|
1079 |
|
1080 |
'''
|
1081 |
监听事件
|
|
|
1111 |
model_train_button.click(fn=train_model, inputs=[model_optimize_radio, linear_regression_model_radio], outputs=get_outputs())
|
1112 |
|
1113 |
# 可视化
|
1114 |
+
learning_curve_train_button.click(fn=learning_curve_train_first_draw_plot, inputs=[learning_curve_checkboxgroup], outputs=get_outputs())
|
1115 |
+
learning_curve_validation_button.click(fn=learning_curve_validation_first_draw_plot, inputs=[learning_curve_checkboxgroup], outputs=get_outputs())
|
1116 |
+
shap_beeswarm_button.click(fn=shap_beeswarm_first_draw_plot, inputs=[shap_beeswarm_radio], outputs=get_outputs())
|
1117 |
+
|
1118 |
+
title_name_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
|
1119 |
+
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
|
1120 |
+
x_label_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
|
1121 |
+
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
|
1122 |
+
y_label_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
|
1123 |
+
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
|
1124 |
+
for i in range(StaticValue.max_num):
|
1125 |
+
colorpickers[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
|
1126 |
+
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
|
1127 |
+
color_textboxs[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + color_textboxs + legend_labels_textboxs
|
1128 |
+
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
|
1129 |
+
legend_labels_textboxs[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
|
1130 |
+
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
|
1131 |
|
1132 |
if __name__ == "__main__":
|
1133 |
demo.launch()
|
{diagram → buffer}/__init__.py
RENAMED
File without changes
|
static/config.py
CHANGED
@@ -12,6 +12,9 @@ class Config:
|
|
12 |
"#EF8B67",
|
13 |
"#F0C284"
|
14 |
]
|
|
|
|
|
|
|
15 |
COLORS_1 = [
|
16 |
"#91CCC0",
|
17 |
"#7FABD1",
|
|
|
12 |
"#EF8B67",
|
13 |
"#F0C284"
|
14 |
]
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
COLORS_1 = [
|
19 |
"#91CCC0",
|
20 |
"#7FABD1",
|
static/paint.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class PaintObject:
|
2 |
+
def __init__(self):
|
3 |
+
self.color_cur_num = 0
|
4 |
+
self.color_cur_list = []
|
5 |
+
self.label_cur_num = 0
|
6 |
+
self.label_cur_list = []
|
7 |
+
self.x_cur_label = ""
|
8 |
+
self.y_cur_label = ""
|
9 |
+
self.name = ""
|
10 |
+
|
11 |
+
def get_color_cur_num(self):
|
12 |
+
return self.color_cur_num
|
13 |
+
|
14 |
+
def set_color_cur_num(self, color_cur_num):
|
15 |
+
self.color_cur_num = color_cur_num
|
16 |
+
|
17 |
+
def get_color_cur_list(self):
|
18 |
+
return self.color_cur_list
|
19 |
+
|
20 |
+
def set_color_cur_list(self, color_cur_list):
|
21 |
+
self.color_cur_list = color_cur_list
|
22 |
+
|
23 |
+
def get_label_cur_num(self):
|
24 |
+
return self.label_cur_num
|
25 |
+
|
26 |
+
def set_label_cur_num(self, label_cur_num):
|
27 |
+
self.label_cur_num = label_cur_num
|
28 |
+
|
29 |
+
def get_label_cur_list(self):
|
30 |
+
return self.label_cur_list
|
31 |
+
|
32 |
+
def set_label_cur_list(self, label_cur_list):
|
33 |
+
self.label_cur_list = label_cur_list
|
34 |
+
|
35 |
+
def get_x_cur_label(self):
|
36 |
+
return self.x_cur_label
|
37 |
+
|
38 |
+
def set_x_cur_label(self, x_cur_label):
|
39 |
+
self.x_cur_label = x_cur_label
|
40 |
+
|
41 |
+
def get_y_cur_label(self):
|
42 |
+
return self.y_cur_label
|
43 |
+
|
44 |
+
def set_y_cur_label(self, y_cur_label):
|
45 |
+
self.y_cur_label = y_cur_label
|
46 |
+
|
47 |
+
def get_name(self):
|
48 |
+
return self.name
|
49 |
+
|
50 |
+
def set_name(self, name):
|
51 |
+
self.name = name
|
static/process.py
CHANGED
@@ -196,7 +196,10 @@ def load_data(sort):
|
|
196 |
|
197 |
|
198 |
def load_custom_data(file):
|
199 |
-
|
|
|
|
|
|
|
200 |
|
201 |
|
202 |
def preprocess_raw_data_filtering(df):
|
|
|
196 |
|
197 |
|
198 |
def load_custom_data(file):
|
199 |
+
if "xlsx" in file or "xls" in file:
|
200 |
+
return pd.read_excel(file)
|
201 |
+
elif "csv" in file:
|
202 |
+
return pd.read_csv(file)
|
203 |
|
204 |
|
205 |
def preprocess_raw_data_filtering(df):
|
visualization/draw_learning_curve_total.py
CHANGED
@@ -1,15 +1,15 @@
|
|
1 |
import numpy as np
|
2 |
from matplotlib import pyplot as plt
|
3 |
|
|
|
4 |
from static.config import Config
|
5 |
|
6 |
|
7 |
-
def draw_learning_curve_total(input_dict, type):
|
8 |
plt.figure(figsize=(10, 6), dpi=300)
|
9 |
|
10 |
if type == "train":
|
11 |
-
i
|
12 |
-
for label_name, values in input_dict.items():
|
13 |
train_sizes = values[0]
|
14 |
train_scores_mean = values[1]
|
15 |
train_scores_std = values[2]
|
@@ -21,25 +21,19 @@ def draw_learning_curve_total(input_dict, type):
|
|
21 |
train_scores_mean - train_scores_std,
|
22 |
train_scores_mean + train_scores_std,
|
23 |
alpha=0.1,
|
24 |
-
color=
|
25 |
)
|
26 |
|
27 |
plt.plot(
|
28 |
train_sizes,
|
29 |
train_scores_mean,
|
30 |
"o-",
|
31 |
-
color=
|
32 |
-
label=
|
33 |
)
|
34 |
|
35 |
-
i += 1
|
36 |
-
|
37 |
-
title = "Training Learning curve"
|
38 |
-
# plt.title(title)
|
39 |
-
|
40 |
else:
|
41 |
-
i
|
42 |
-
for label_name, values in input_dict.items():
|
43 |
train_sizes = values[0]
|
44 |
train_scores_mean = values[1]
|
45 |
train_scores_std = values[2]
|
@@ -51,26 +45,27 @@ def draw_learning_curve_total(input_dict, type):
|
|
51 |
test_scores_mean - test_scores_std,
|
52 |
test_scores_mean + test_scores_std,
|
53 |
alpha=0.1,
|
54 |
-
color=
|
55 |
)
|
56 |
plt.plot(
|
57 |
train_sizes,
|
58 |
test_scores_mean,
|
59 |
"o-",
|
60 |
-
color=
|
61 |
-
label=
|
62 |
)
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
title = "Cross-validation Learning curve"
|
67 |
-
# plt.title(title)
|
68 |
|
69 |
-
plt.xlabel(
|
70 |
-
plt.ylabel(
|
71 |
plt.legend()
|
72 |
|
73 |
# plt.savefig("./diagram/{}.png".format(title), dpi=300)
|
74 |
# plt.show()
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
|
|
|
1 |
import numpy as np
|
2 |
from matplotlib import pyplot as plt
|
3 |
|
4 |
+
from static.paint import PaintObject
|
5 |
from static.config import Config
|
6 |
|
7 |
|
8 |
+
def draw_learning_curve_total(input_dict, type, paint_object: PaintObject):
|
9 |
plt.figure(figsize=(10, 6), dpi=300)
|
10 |
|
11 |
if type == "train":
|
12 |
+
for i, values in enumerate(input_dict.values()):
|
|
|
13 |
train_sizes = values[0]
|
14 |
train_scores_mean = values[1]
|
15 |
train_scores_std = values[2]
|
|
|
21 |
train_scores_mean - train_scores_std,
|
22 |
train_scores_mean + train_scores_std,
|
23 |
alpha=0.1,
|
24 |
+
color=paint_object.get_color_cur_list()[i]
|
25 |
)
|
26 |
|
27 |
plt.plot(
|
28 |
train_sizes,
|
29 |
train_scores_mean,
|
30 |
"o-",
|
31 |
+
color=paint_object.get_color_cur_list()[i],
|
32 |
+
label=paint_object.get_label_cur_list()[i]
|
33 |
)
|
34 |
|
|
|
|
|
|
|
|
|
|
|
35 |
else:
|
36 |
+
for i, values in enumerate(input_dict.values()):
|
|
|
37 |
train_sizes = values[0]
|
38 |
train_scores_mean = values[1]
|
39 |
train_scores_std = values[2]
|
|
|
45 |
test_scores_mean - test_scores_std,
|
46 |
test_scores_mean + test_scores_std,
|
47 |
alpha=0.1,
|
48 |
+
color=paint_object.get_color_cur_list()[i]
|
49 |
)
|
50 |
plt.plot(
|
51 |
train_sizes,
|
52 |
test_scores_mean,
|
53 |
"o-",
|
54 |
+
color=paint_object.get_color_cur_list()[i],
|
55 |
+
label=paint_object.get_label_cur_list()[i]
|
56 |
)
|
57 |
|
58 |
+
plt.title(paint_object.get_name())
|
|
|
|
|
|
|
59 |
|
60 |
+
plt.xlabel(paint_object.get_x_cur_label())
|
61 |
+
plt.ylabel(paint_object.get_y_cur_label())
|
62 |
plt.legend()
|
63 |
|
64 |
# plt.savefig("./diagram/{}.png".format(title), dpi=300)
|
65 |
# plt.show()
|
66 |
+
|
67 |
+
paint_object.set_color_cur_num(len(input_dict.keys()))
|
68 |
+
paint_object.set_label_cur_num(len(input_dict.keys()))
|
69 |
+
|
70 |
+
return plt, paint_object
|
71 |
|