LLH commited on
Commit
11b81b9
·
1 Parent(s): 0136ac6

2024/02/16/14:00

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/EasyMachineLearningDemo.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <option name="ignoredErrors">
6
+ <list>
7
+ <option value="E501" />
8
+ </list>
9
+ </option>
10
+ </inspection_tool>
11
+ <inspection_tool class="PyUnresolvedReferencesInspection" enabled="true" level="WARNING" enabled_by_default="true">
12
+ <option name="ignoredIdentifiers">
13
+ <list>
14
+ <option value="object.pop" />
15
+ </list>
16
+ </option>
17
+ </inspection_tool>
18
+ </profile>
19
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/EasyMachineLearningDemo.iml" filepath="$PROJECT_DIR$/.idea/EasyMachineLearningDemo.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: EasyMachineLearning v0.0
3
  emoji: 🔥
4
  colorFrom: red
5
  colorTo: red
 
1
  ---
2
+ title: EasyMachineLearning test
3
  emoji: 🔥
4
  colorFrom: red
5
  colorTo: red
analysis/shap_model.py CHANGED
@@ -3,15 +3,15 @@ import matplotlib.pyplot as plt
3
  import shap
4
 
5
 
6
- def shap_calculate(model, x, feature_names):
7
  explainer = shap.Explainer(model.predict, x)
8
  shap_values = explainer(x)
9
 
10
  shap.summary_plot(shap_values, x, feature_names=feature_names, show=False)
11
 
12
- return plt
13
 
14
- # title = "shap"
15
 
16
 
17
 
 
3
  import shap
4
 
5
 
6
+ def shap_calculate(model, x, feature_names, paint_object):
7
  explainer = shap.Explainer(model.predict, x)
8
  shap_values = explainer(x)
9
 
10
  shap.summary_plot(shap_values, x, feature_names=feature_names, show=False)
11
 
12
+ plt.title(paint_object.get_name())
13
 
14
+ return plt, paint_object
15
 
16
 
17
 
app.py CHANGED
@@ -11,8 +11,10 @@ from analysis.shap_model import shap_calculate
11
  from static.process import *
12
  from analysis.linear_model import *
13
  from visualization.draw_learning_curve_total import draw_learning_curve_total
 
14
 
15
  import warnings
 
16
  warnings.filterwarnings("ignore")
17
 
18
 
@@ -68,18 +70,34 @@ class Container:
68
  self.model = model
69
 
70
 
 
 
 
 
71
  class FilePath:
72
- base = "./diagram/{}.png"
 
 
 
 
 
 
73
  shap_beeswarm_plot = "shap_beeswarm_plot"
74
 
75
 
76
  class MN: # ModelName
77
  classification = "classification"
78
  regression = "regression"
 
79
  linear_regression = "linear_regression"
80
  polynomial_regression = "polynomial_regression"
81
  logistic_regression = "logistic_regression"
82
 
 
 
 
 
 
83
 
84
  class LN: # LabelName
85
  choose_dataset_radio = "选择所需数据源 [必选]"
@@ -104,19 +122,54 @@ class LN: # LabelName
104
  linear_regression_model_radio = "选择线性回归的模型"
105
  model_optimize_radio = "选择超参数优化方法"
106
  model_train_button = "训练"
 
 
 
 
 
 
 
 
 
107
  learning_curve_checkboxgroup = "选择所需绘制学习曲线的模型"
108
  learning_curve_train_button = "绘制训练集学习曲线"
109
  learning_curve_validation_button = "绘制验证集学习曲线"
110
- learning_curve_train_plot = "绘制训练集学习曲线"
111
- learning_curve_validation_plot = "绘制验证集学习曲线"
112
  shap_beeswarm_radio = "选择所需绘制蜂群特征图的模型"
113
  shap_beeswarm_button = "绘制蜂群特征图"
 
 
 
114
  shap_beeswarm_plot = "蜂群特征图"
115
- select_as_model_radio = "选择所需训练的模型"
116
 
117
 
118
- def get_outputs():
 
 
 
 
 
 
 
 
 
 
119
  gr_dict = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  choose_custom_dataset_file,
121
  display_dataset_dataframe,
122
  display_total_col_num_text,
@@ -141,26 +194,35 @@ def get_outputs():
141
  model_optimize_radio,
142
  model_train_button,
143
  model_train_checkbox,
 
 
 
 
 
 
 
 
 
 
144
  learning_curve_checkboxgroup,
145
  learning_curve_train_button,
146
  learning_curve_validation_button,
147
- learning_curve_train_plot,
148
- learning_curve_validation_plot,
149
  shap_beeswarm_radio,
150
  shap_beeswarm_button,
151
- shap_beeswarm_plot,
152
- shap_beeswarm_plot_file,
153
- select_as_model_radio,
154
- choose_assign_radio,
155
  }
156
 
157
- return gr_dict
 
 
 
 
158
 
159
 
160
  def get_return(is_visible, extra_gr_dict: dict = None):
161
  if is_visible:
162
  gr_dict = {
163
  display_dataset_dataframe: gr.Dataframe(add_index_into_df(Dataset.data), type="pandas", visible=True),
 
164
  display_total_col_num_text: gr.Textbox(str(Dataset.get_total_col_num()), visible=True, label=LN.display_total_col_num_text),
165
  display_total_row_num_text: gr.Textbox(str(Dataset.get_total_row_num()), visible=True, label=LN.display_total_row_num_text),
166
  display_na_list_text: gr.Textbox(Dataset.get_na_list_str(), visible=True, label=LN.display_na_list_text),
@@ -188,14 +250,25 @@ def get_return(is_visible, extra_gr_dict: dict = None):
188
 
189
  model_train_button: gr.Button(LN.model_train_button, visible=Dataset.check_before_train()),
190
  model_train_checkbox: gr.Checkbox(Dataset.get_model_container_status(), visible=Dataset.check_select_model(), label=Dataset.get_model_label()),
 
 
 
 
 
 
 
 
191
  learning_curve_checkboxgroup: gr.Checkboxgroup(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.learning_curve_checkboxgroup),
192
  learning_curve_train_button: gr.Button(LN.learning_curve_train_button, visible=Dataset.check_before_train()),
193
  learning_curve_validation_button: gr.Button(LN.learning_curve_validation_button, visible=Dataset.check_before_train()),
194
  shap_beeswarm_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.shap_beeswarm_radio),
195
  shap_beeswarm_button: gr.Button(LN.shap_beeswarm_button, visible=Dataset.check_before_train()),
196
- shap_beeswarm_plot_file: gr.File(Dataset.after_get_shap_beeswarm_plot_file(), visible=Dataset.check_shap_beeswarm_plot_file()),
197
  }
198
 
 
 
 
 
199
  if extra_gr_dict:
200
  gr_dict.update(extra_gr_dict)
201
 
@@ -204,6 +277,7 @@ def get_return(is_visible, extra_gr_dict: dict = None):
204
  gr_dict = {
205
  choose_custom_dataset_file: gr.File(None, visible=True),
206
  display_dataset_dataframe: gr.Dataframe(visible=False),
 
207
  display_total_col_num_text: gr.Textbox(visible=False),
208
  display_total_row_num_text: gr.Textbox(visible=False),
209
  display_na_list_text: gr.Textbox(visible=False),
@@ -225,19 +299,27 @@ def get_return(is_visible, extra_gr_dict: dict = None):
225
  model_optimize_radio: gr.Radio(visible=False),
226
  model_train_button: gr.Button(visible=False),
227
  model_train_checkbox: gr.Checkbox(visible=False),
 
 
 
 
 
 
 
 
 
 
228
  learning_curve_checkboxgroup: gr.Checkboxgroup(visible=False),
229
  learning_curve_train_button: gr.Button(visible=False),
230
  learning_curve_validation_button: gr.Button(visible=False),
231
- learning_curve_train_plot: gr.Plot(visible=False),
232
- learning_curve_validation_plot: gr.Plot(visible=False),
233
  shap_beeswarm_radio: gr.Radio(visible=False),
234
  shap_beeswarm_button: gr.Button(visible=False),
235
- shap_beeswarm_plot: gr.Plot(visible=False),
236
- shap_beeswarm_plot_file: gr.File(visible=False),
237
- select_as_model_radio: gr.Radio(visible=False),
238
- choose_assign_radio: gr.Radio(visible=False),
239
  }
240
 
 
 
 
 
241
  return gr_dict
242
 
243
 
@@ -260,6 +342,8 @@ class Dataset:
260
  MN.logistic_regression: Container(),
261
  }
262
 
 
 
263
  @classmethod
264
  def get_dataset_list(cls):
265
  return ["Iris Dataset", "Wine Dataset", "Breast Cancer Dataset", "自定义"]
@@ -309,6 +393,23 @@ class Dataset:
309
  cls.file = ""
310
  cls.data = pd.DataFrame()
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  @classmethod
313
  def del_col(cls, col_list: list):
314
  for col in col_list:
@@ -431,7 +532,8 @@ class Dataset:
431
 
432
  for col in cls.data.columns.values:
433
  if cls.data[col].dtype.name in ["int64", "float64"]:
434
- if not np.array_equal(np.round(preprocessing.scale(cls.data[col]), decimals=2), np.round(cls.data[col].values.round(2), decimals=2)):
 
435
  not_standardized_data_list.append(col)
436
 
437
  return not_standardized_data_list
@@ -443,7 +545,8 @@ class Dataset:
443
 
444
  for i, col in enumerate(cls.data.columns.values):
445
  if i == 0:
446
- if not (all(isinstance(x, str) for x in cls.data.iloc[:, 0]) or all(isinstance(x, float) for x in cls.data.iloc[:, 0])):
 
447
  return False
448
  else:
449
  if cls.data[col].dtype.name != "float64":
@@ -541,43 +644,98 @@ class Dataset:
541
  return trained_model_list
542
 
543
  @classmethod
544
- def draw_learning_curve_train_plot(cls, model_list: list) -> plt.Figure:
 
 
 
 
 
 
 
 
 
 
545
  learning_curve_dict = {}
546
 
547
  for model_name in model_list:
548
  model_name = cls.get_model_name_mapping_reverse()[model_name]
549
  learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
550
 
551
- return draw_learning_curve_total(learning_curve_dict, "train")
 
 
 
 
 
 
 
 
 
 
 
 
 
552
 
553
  @classmethod
554
- def draw_learning_curve_validation_plot(cls, model_list: list) -> plt.Figure:
555
  learning_curve_dict = {}
556
 
557
  for model_name in model_list:
558
  model_name = cls.get_model_name_mapping_reverse()[model_name]
559
  learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
560
 
561
- return draw_learning_curve_total(learning_curve_dict, "validation")
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
563
  @classmethod
564
- def draw_shap_beeswarm_plot(cls, model_name) -> plt.Figure:
565
  model_name = cls.get_model_name_mapping_reverse()[model_name]
566
  container = cls.container_dict[model_name]
567
 
568
- return shap_calculate(container.get_model(), container.x_train, cls.data.columns.values)
 
 
 
 
 
 
 
 
 
 
 
 
 
569
 
570
  @classmethod
571
- def get_shap_beeswarm_plot_file(cls):
572
- return FilePath.base.format(FilePath.shap_beeswarm_plot)
 
 
 
 
 
 
573
 
574
  @classmethod
575
- def check_shap_beeswarm_plot_file(cls):
576
- return os.path.exists(cls.get_shap_beeswarm_plot_file())
577
 
578
  @classmethod
579
- def after_get_shap_beeswarm_plot_file(cls):
580
- return cls.get_shap_beeswarm_plot_file() if cls.check_shap_beeswarm_plot_file() else None
581
 
582
  @classmethod
583
  def get_model_list(cls):
@@ -614,13 +772,37 @@ class Dataset:
614
  data_copy = cls.data
615
 
616
  if cls.assign == MN.classification:
617
- data_copy.iloc[0, :] = data_copy.iloc[0, :].astype(str)
618
  else:
619
- data_copy.iloc[0, :] = data_copy.iloc[0, :].astype(float)
620
 
621
  cls.data = data_copy
622
  cls.change_data_type_to_float()
623
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
 
625
  def choose_assign(assign: str):
626
  Dataset.choose_assign(assign)
@@ -634,24 +816,85 @@ def select_as_model(model_name: str):
634
  return get_return(True)
635
 
636
 
637
- def draw_shap_beeswarm_plot(model_name):
638
- cur_plt = Dataset.draw_shap_beeswarm_plot(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
639
 
640
- cur_plt.savefig(FilePath.base.format(FilePath.shap_beeswarm_plot), dpi=300)
 
 
 
 
 
 
641
 
642
- return get_return(True, {shap_beeswarm_plot: gr.Plot(cur_plt, visible=True, label=LN.shap_beeswarm_plot)})
643
 
 
644
 
645
- def draw_learning_curve_validation_plot(model_list: list):
646
- cur_plt = Dataset.draw_learning_curve_validation_plot(model_list)
647
 
648
- return get_return(True, {learning_curve_validation_plot: gr.Plot(cur_plt, visible=True, label=LN.learning_curve_validation_plot)})
 
649
 
650
 
651
- def draw_learning_curve_train_plot(model_list: list):
652
- cur_plt = Dataset.draw_learning_curve_train_plot(model_list)
 
 
 
 
 
653
 
654
- return get_return(True, {learning_curve_train_plot: gr.Plot(cur_plt, visible=True, label=LN.learning_curve_train_plot)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
655
 
656
 
657
  def train_model(optimize, linear_regression_model_type):
@@ -681,7 +924,9 @@ def change_data_type_to_float():
681
  def encode_label(col_list: list):
682
  Dataset.encode_label(col_list)
683
 
684
- return get_return(True, {display_encode_label_dataframe: gr.Dataframe(Dataset.get_str2int_mappings_df(), type="pandas", visible=True, label=LN.display_encode_label_dataframe)})
 
 
685
 
686
 
687
  def del_duplicate():
@@ -737,7 +982,6 @@ def choose_custom_dataset(file: str):
737
 
738
 
739
  with gr.Blocks() as demo:
740
-
741
  '''
742
  组件
743
  '''
@@ -752,6 +996,7 @@ with gr.Blocks() as demo:
752
  # 显示数据表信息
753
  with gr.Accordion("当前数据信息"):
754
  display_dataset_dataframe = gr.Dataframe(visible=False)
 
755
  with gr.Row():
756
  display_total_col_num_text = gr.Textbox(visible=False)
757
  display_total_row_num_text = gr.Textbox(visible=False)
@@ -794,17 +1039,43 @@ with gr.Blocks() as demo:
794
 
795
  # 可视化
796
  with gr.Accordion("数据可视化"):
797
- learning_curve_checkboxgroup = gr.Checkboxgroup(visible=False)
798
- with gr.Row():
799
- learning_curve_train_button = gr.Button(visible=False)
800
- learning_curve_validation_button = gr.Button(visible=False)
801
- learning_curve_train_plot = gr.Plot(visible=False)
802
- learning_curve_validation_plot = gr.Plot(visible=False)
803
- shap_beeswarm_radio = gr.Radio(visible=False)
804
- shap_beeswarm_button = gr.Button(visible=False)
805
- with gr.Group():
806
- shap_beeswarm_plot = gr.Plot(visible=False)
807
- shap_beeswarm_plot_file = gr.File(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
808
 
809
  '''
810
  监听事件
@@ -840,9 +1111,23 @@ with gr.Blocks() as demo:
840
  model_train_button.click(fn=train_model, inputs=[model_optimize_radio, linear_regression_model_radio], outputs=get_outputs())
841
 
842
  # 可视化
843
- learning_curve_train_button.click(fn=draw_learning_curve_train_plot, inputs=[learning_curve_checkboxgroup], outputs=get_outputs())
844
- learning_curve_validation_button.click(fn=draw_learning_curve_validation_plot, inputs=[learning_curve_checkboxgroup], outputs=get_outputs())
845
- shap_beeswarm_button.click(fn=draw_shap_beeswarm_plot, inputs=[shap_beeswarm_radio], outputs=get_outputs())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846
 
847
  if __name__ == "__main__":
848
  demo.launch()
 
11
  from static.process import *
12
  from analysis.linear_model import *
13
  from visualization.draw_learning_curve_total import draw_learning_curve_total
14
+ from static.paint import *
15
 
16
  import warnings
17
+
18
  warnings.filterwarnings("ignore")
19
 
20
 
 
70
  self.model = model
71
 
72
 
73
+ class StaticValue:
74
+ max_num = 10
75
+
76
+
77
  class FilePath:
78
+ png_base = "./buffer/{}.png"
79
+ excel_base = "./buffer/{}.xlsx"
80
+
81
+ # [绘图]
82
+ display_dataset = "current_excel_data"
83
+ learning_curve_train_plot = "learning_curve_train_plot"
84
+ learning_curve_validation_plot = "learning_curve_validation_plot"
85
  shap_beeswarm_plot = "shap_beeswarm_plot"
86
 
87
 
88
  class MN: # ModelName
89
  classification = "classification"
90
  regression = "regression"
91
+
92
  linear_regression = "linear_regression"
93
  polynomial_regression = "polynomial_regression"
94
  logistic_regression = "logistic_regression"
95
 
96
+ # [绘图]
97
+ learning_curve_train = "learning_curve_train"
98
+ learning_curve_validation = "learning_curve_validation"
99
+ shap_beeswarm = "shap_beeswarm"
100
+
101
 
102
  class LN: # LabelName
103
  choose_dataset_radio = "选择所需数据源 [必选]"
 
122
  linear_regression_model_radio = "选择线性回归的模型"
123
  model_optimize_radio = "选择超参数优化方法"
124
  model_train_button = "训练"
125
+ select_as_model_radio = "选择所需训练的模型"
126
+
127
+ title_name_textbox = "标题"
128
+ x_label_textbox = "x 轴名称"
129
+ y_label_textbox = "y 轴名称"
130
+ colors = ["颜色 {}".format(i) for i in range(StaticValue.max_num)]
131
+ labels = ["图例 {}".format(i) for i in range(StaticValue.max_num)]
132
+
133
+ # [绘图]
134
  learning_curve_checkboxgroup = "选择所需绘制学习曲线的模型"
135
  learning_curve_train_button = "绘制训练集学习曲线"
136
  learning_curve_validation_button = "绘制验证集学习曲线"
 
 
137
  shap_beeswarm_radio = "选择所需绘制蜂群特征图的模型"
138
  shap_beeswarm_button = "绘制蜂群特征图"
139
+
140
+ learning_curve_train_plot = "训练集学习曲线"
141
+ learning_curve_validation_plot = "验证集学习曲线"
142
  shap_beeswarm_plot = "蜂群特征图"
 
143
 
144
 
145
+ def get_return_extra(is_visible, extra_gr_dict: dict = None):
146
+ if is_visible:
147
+ gr_dict = {
148
+ draw_file: gr.File(Dataset.after_get_file(), visible=Dataset.check_file()),
149
+ }
150
+
151
+ if extra_gr_dict:
152
+ gr_dict.update(extra_gr_dict)
153
+
154
+ return gr_dict
155
+
156
  gr_dict = {
157
+ draw_plot: gr.Plot(visible=False),
158
+ draw_file: gr.File(visible=False),
159
+ }
160
+
161
+ gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
162
+ gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
163
+ gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
164
+ gr_dict.update({title_name_textbox: gr.Textbox(visible=False)})
165
+ gr_dict.update({x_label_textbox: gr.Textbox(visible=False)})
166
+ gr_dict.update({y_label_textbox: gr.Textbox(visible=False)})
167
+
168
+ return gr_dict
169
+
170
+
171
+ def get_outputs():
172
+ gr_set = {
173
  choose_custom_dataset_file,
174
  display_dataset_dataframe,
175
  display_total_col_num_text,
 
194
  model_optimize_radio,
195
  model_train_button,
196
  model_train_checkbox,
197
+ select_as_model_radio,
198
+ choose_assign_radio,
199
+ display_dataset,
200
+ draw_plot,
201
+ draw_file,
202
+ title_name_textbox,
203
+ x_label_textbox,
204
+ y_label_textbox,
205
+
206
+ # [绘图]
207
  learning_curve_checkboxgroup,
208
  learning_curve_train_button,
209
  learning_curve_validation_button,
 
 
210
  shap_beeswarm_radio,
211
  shap_beeswarm_button,
 
 
 
 
212
  }
213
 
214
+ gr_set.update(set(colorpickers))
215
+ gr_set.update(set(color_textboxs))
216
+ gr_set.update(set(legend_labels_textboxs))
217
+
218
+ return gr_set
219
 
220
 
221
  def get_return(is_visible, extra_gr_dict: dict = None):
222
  if is_visible:
223
  gr_dict = {
224
  display_dataset_dataframe: gr.Dataframe(add_index_into_df(Dataset.data), type="pandas", visible=True),
225
+ display_dataset: gr.File(Dataset.after_get_display_dataset_file(), visible=Dataset.check_display_dataset_file()),
226
  display_total_col_num_text: gr.Textbox(str(Dataset.get_total_col_num()), visible=True, label=LN.display_total_col_num_text),
227
  display_total_row_num_text: gr.Textbox(str(Dataset.get_total_row_num()), visible=True, label=LN.display_total_row_num_text),
228
  display_na_list_text: gr.Textbox(Dataset.get_na_list_str(), visible=True, label=LN.display_na_list_text),
 
250
 
251
  model_train_button: gr.Button(LN.model_train_button, visible=Dataset.check_before_train()),
252
  model_train_checkbox: gr.Checkbox(Dataset.get_model_container_status(), visible=Dataset.check_select_model(), label=Dataset.get_model_label()),
253
+
254
+ draw_plot: gr.Plot(visible=False),
255
+ draw_file: gr.File(visible=False),
256
+ title_name_textbox: gr.Textbox(visible=False),
257
+ x_label_textbox: gr.Textbox(visible=False),
258
+ y_label_textbox: gr.Textbox(visible=False),
259
+
260
+ # [绘图]
261
  learning_curve_checkboxgroup: gr.Checkboxgroup(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.learning_curve_checkboxgroup),
262
  learning_curve_train_button: gr.Button(LN.learning_curve_train_button, visible=Dataset.check_before_train()),
263
  learning_curve_validation_button: gr.Button(LN.learning_curve_validation_button, visible=Dataset.check_before_train()),
264
  shap_beeswarm_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.shap_beeswarm_radio),
265
  shap_beeswarm_button: gr.Button(LN.shap_beeswarm_button, visible=Dataset.check_before_train()),
 
266
  }
267
 
268
+ gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
269
+ gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
270
+ gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
271
+
272
  if extra_gr_dict:
273
  gr_dict.update(extra_gr_dict)
274
 
 
277
  gr_dict = {
278
  choose_custom_dataset_file: gr.File(None, visible=True),
279
  display_dataset_dataframe: gr.Dataframe(visible=False),
280
+ display_dataset: gr.File(visible=False),
281
  display_total_col_num_text: gr.Textbox(visible=False),
282
  display_total_row_num_text: gr.Textbox(visible=False),
283
  display_na_list_text: gr.Textbox(visible=False),
 
299
  model_optimize_radio: gr.Radio(visible=False),
300
  model_train_button: gr.Button(visible=False),
301
  model_train_checkbox: gr.Checkbox(visible=False),
302
+ select_as_model_radio: gr.Radio(visible=False),
303
+ choose_assign_radio: gr.Radio(visible=False),
304
+
305
+ draw_plot: gr.Plot(visible=False),
306
+ draw_file: gr.File(visible=False),
307
+ title_name_textbox: gr.Textbox(visible=False),
308
+ x_label_textbox: gr.Textbox(visible=False),
309
+ y_label_textbox: gr.Textbox(visible=False),
310
+
311
+ # [绘图]
312
  learning_curve_checkboxgroup: gr.Checkboxgroup(visible=False),
313
  learning_curve_train_button: gr.Button(visible=False),
314
  learning_curve_validation_button: gr.Button(visible=False),
 
 
315
  shap_beeswarm_radio: gr.Radio(visible=False),
316
  shap_beeswarm_button: gr.Button(visible=False),
 
 
 
 
317
  }
318
 
319
+ gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
320
+ gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
321
+ gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
322
+
323
  return gr_dict
324
 
325
 
 
342
  MN.logistic_regression: Container(),
343
  }
344
 
345
+ visualize = ""
346
+
347
  @classmethod
348
  def get_dataset_list(cls):
349
  return ["Iris Dataset", "Wine Dataset", "Breast Cancer Dataset", "自定义"]
 
393
  cls.file = ""
394
  cls.data = pd.DataFrame()
395
 
396
+ @classmethod
397
+ def get_display_dataset_file(cls):
398
+ file_path = FilePath.excel_base.format(FilePath.display_dataset)
399
+
400
+ return file_path
401
+
402
+ @classmethod
403
+ def check_display_dataset_file(cls):
404
+ return os.path.exists(cls.get_display_dataset_file())
405
+
406
+ @classmethod
407
+ def after_get_display_dataset_file(cls):
408
+ if not cls.data.empty:
409
+ cls.data.to_excel(cls.get_display_dataset_file(), index=False)
410
+
411
+ return cls.get_display_dataset_file() if cls.check_display_dataset_file() else None
412
+
413
  @classmethod
414
  def del_col(cls, col_list: list):
415
  for col in col_list:
 
532
 
533
  for col in cls.data.columns.values:
534
  if cls.data[col].dtype.name in ["int64", "float64"]:
535
+ if not np.array_equal(np.round(preprocessing.scale(cls.data[col]), decimals=2),
536
+ np.round(cls.data[col].values.round(2), decimals=2)):
537
  not_standardized_data_list.append(col)
538
 
539
  return not_standardized_data_list
 
545
 
546
  for i, col in enumerate(cls.data.columns.values):
547
  if i == 0:
548
+ if not (all(isinstance(x, str) for x in cls.data.iloc[:, 0]) or all(
549
+ isinstance(x, float) for x in cls.data.iloc[:, 0])):
550
  return False
551
  else:
552
  if cls.data[col].dtype.name != "float64":
 
644
  return trained_model_list
645
 
646
  @classmethod
647
+ def draw_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
648
+ # [绘图]
649
+ if cls.visualize == MN.learning_curve_train:
650
+ return cls.draw_learning_curve_train_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
651
+ elif cls.visualize == MN.learning_curve_validation:
652
+ return cls.draw_learning_curve_validation_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
653
+ elif cls.visualize == MN.shap_beeswarm:
654
+ return cls.draw_shap_beeswarm_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
655
+
656
+ @classmethod
657
+ def draw_learning_curve_train_plot(cls, model_list, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
658
  learning_curve_dict = {}
659
 
660
  for model_name in model_list:
661
  model_name = cls.get_model_name_mapping_reverse()[model_name]
662
  learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
663
 
664
+ color_cur_list = Config.COLORS if is_default else color_list
665
+ label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
666
+ x_cur_label = "Train Sizes" if is_default else x_label
667
+ y_cur_label = "Accuracy" if is_default else y_label
668
+ cur_name = "" if is_default else name
669
+
670
+ paint_object = PaintObject()
671
+ paint_object.set_color_cur_list(color_cur_list)
672
+ paint_object.set_label_cur_list(label_cur_list)
673
+ paint_object.set_x_cur_label(x_cur_label)
674
+ paint_object.set_y_cur_label(y_cur_label)
675
+ paint_object.set_name(cur_name)
676
+
677
+ return draw_learning_curve_total(learning_curve_dict, "train", paint_object)
678
 
679
  @classmethod
680
+ def draw_learning_curve_validation_plot(cls, model_list, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
681
  learning_curve_dict = {}
682
 
683
  for model_name in model_list:
684
  model_name = cls.get_model_name_mapping_reverse()[model_name]
685
  learning_curve_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
686
 
687
+ color_cur_list = Config.COLORS if is_default else color_list
688
+ label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
689
+ x_cur_label = "Train Sizes" if is_default else x_label
690
+ y_cur_label = "Accuracy" if is_default else y_label
691
+ cur_name = "" if is_default else name
692
+
693
+ paint_object = PaintObject()
694
+ paint_object.set_color_cur_list(color_cur_list)
695
+ paint_object.set_label_cur_list(label_cur_list)
696
+ paint_object.set_x_cur_label(x_cur_label)
697
+ paint_object.set_y_cur_label(y_cur_label)
698
+ paint_object.set_name(cur_name)
699
+
700
+ return draw_learning_curve_total(learning_curve_dict, "validation", paint_object)
701
 
702
  @classmethod
703
+ def draw_shap_beeswarm_plot(cls, model_name, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
704
  model_name = cls.get_model_name_mapping_reverse()[model_name]
705
  container = cls.container_dict[model_name]
706
 
707
+ # color_cur_list = Config.COLORS if is_default else color_list
708
+ # label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
709
+ # x_cur_label = "Train Sizes" if is_default else x_label
710
+ # y_cur_label = "Accuracy" if is_default else y_label
711
+ cur_name = "" if is_default else name
712
+
713
+ paint_object = PaintObject()
714
+ # paint_object.set_color_cur_list(color_cur_list)
715
+ # paint_object.set_label_cur_list(label_cur_list)
716
+ # paint_object.set_x_cur_label(x_cur_label)
717
+ # paint_object.set_y_cur_label(y_cur_label)
718
+ paint_object.set_name(cur_name)
719
+
720
+ return shap_calculate(container.get_model(), container.x_train, cls.data.columns.values, paint_object)
721
 
722
  @classmethod
723
+ def get_file(cls):
724
+ # [绘图]
725
+ if cls.visualize == MN.learning_curve_train:
726
+ return FilePath.png_base.format(FilePath.learning_curve_train_plot)
727
+ elif cls.visualize == MN.learning_curve_validation:
728
+ return FilePath.png_base.format(FilePath.learning_curve_validation_plot)
729
+ elif cls.visualize == MN.shap_beeswarm:
730
+ return FilePath.png_base.format(FilePath.shap_beeswarm_plot)
731
 
732
  @classmethod
733
+ def check_file(cls):
734
+ return os.path.exists(cls.get_file())
735
 
736
  @classmethod
737
+ def after_get_file(cls):
738
+ return cls.get_file() if cls.check_file() else None
739
 
740
  @classmethod
741
  def get_model_list(cls):
 
772
  data_copy = cls.data
773
 
774
  if cls.assign == MN.classification:
775
+ data_copy.iloc[:, 0] = data_copy.iloc[:, 0].astype(str)
776
  else:
777
+ data_copy.iloc[:, 0] = data_copy.iloc[:, 0].astype(float)
778
 
779
  cls.data = data_copy
780
  cls.change_data_type_to_float()
781
 
782
+ @classmethod
783
+ def colorpickers_change(cls, paint_object):
784
+ cur_num = paint_object.get_color_cur_num()
785
+
786
+ true_list = [gr.ColorPicker(paint_object.get_color_cur_list()[i], visible=True, label=LN.colors[i]) for i in range(cur_num)]
787
+
788
+ return true_list + [gr.ColorPicker(visible=False)] * (StaticValue.max_num - cur_num)
789
+
790
+ @classmethod
791
+ def color_textboxs_change(cls, paint_object):
792
+ cur_num = paint_object.get_color_cur_num()
793
+
794
+ true_list = [gr.Textbox(paint_object.get_color_cur_list()[i], visible=True, show_label=False) for i in range(cur_num)]
795
+
796
+ return true_list + [gr.Textbox(visible=False)] * (StaticValue.max_num - cur_num)
797
+
798
+ @classmethod
799
+ def labels_change(cls, paint_object):
800
+ cur_num = paint_object.get_label_cur_num()
801
+
802
+ true_list = [gr.Textbox(paint_object.get_label_cur_list()[i], visible=True, label=LN.labels[i]) for i in range(cur_num)]
803
+
804
+ return true_list + [gr.Textbox(visible=False)] * (StaticValue.max_num - cur_num)
805
+
806
 
807
  def choose_assign(assign: str):
808
  Dataset.choose_assign(assign)
 
816
  return get_return(True)
817
 
818
 
819
+ # [绘图]
820
+ def shap_beeswarm_first_draw_plot(*inputs):
821
+ Dataset.visualize = MN.shap_beeswarm
822
+ return first_draw_plot(inputs)
823
+
824
+
825
+ def learning_curve_validation_first_draw_plot(*inputs):
826
+ Dataset.visualize = MN.learning_curve_validation
827
+ return first_draw_plot(inputs)
828
+
829
+
830
+ def learning_curve_train_first_draw_plot(*inputs):
831
+ Dataset.visualize = MN.learning_curve_train
832
+ return first_draw_plot(inputs)
833
+
834
 
835
+ def first_draw_plot(inputs):
836
+ select_model = inputs[0]
837
+ x_label = ""
838
+ y_label = ""
839
+ name = ""
840
+ color_list = []
841
+ label_list = []
842
 
843
+ cur_plt, paint_object = Dataset.draw_plot(select_model, color_list, label_list, name, x_label, y_label, True)
844
 
845
+ return first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object)
846
 
 
 
847
 
848
+ def out_non_first_draw_plot(*inputs):
849
+ return non_first_draw_plot(inputs)
850
 
851
 
852
+ def non_first_draw_plot(inputs):
853
+ name = inputs[0]
854
+ x_label = inputs[1]
855
+ y_label = inputs[2]
856
+ color_list = list(inputs[3: StaticValue.max_num+3])
857
+ label_list = list(inputs[StaticValue.max_num+3: 2*StaticValue.max_num+3])
858
+ start_index = 2*StaticValue.max_num+3
859
 
860
+ # 绘图
861
+ if Dataset.visualize == MN.learning_curve_train:
862
+ select_model = inputs[start_index]
863
+ elif Dataset.visualize == MN.learning_curve_validation:
864
+ select_model = inputs[start_index]
865
+ elif Dataset.visualize == MN.shap_beeswarm:
866
+ select_model = inputs[start_index+1]
867
+
868
+ else:
869
+ select_model = inputs[start_index: start_index+1]
870
+
871
+ cur_plt, paint_object = Dataset.draw_plot(select_model, color_list, label_list, name, x_label, y_label, False)
872
+
873
+ return first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object)
874
+
875
+
876
+ def first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object):
877
+ extra_gr_dict = {}
878
+
879
+ # [绘图]
880
+ if Dataset.visualize == MN.learning_curve_train:
881
+ cur_plt.savefig(FilePath.png_base.format(FilePath.learning_curve_train_plot), dpi=300)
882
+ extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.learning_curve_train_plot)})
883
+ elif Dataset.visualize == MN.learning_curve_validation:
884
+ cur_plt.savefig(FilePath.png_base.format(FilePath.learning_curve_validation_plot), dpi=300)
885
+ extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.learning_curve_validation_plot)})
886
+ elif Dataset.visualize == MN.shap_beeswarm:
887
+ cur_plt.savefig(FilePath.png_base.format(FilePath.shap_beeswarm_plot), dpi=300)
888
+ extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.shap_beeswarm_plot)})
889
+
890
+ extra_gr_dict.update(dict(zip(colorpickers, Dataset.colorpickers_change(paint_object))))
891
+ extra_gr_dict.update(dict(zip(color_textboxs, Dataset.color_textboxs_change(paint_object))))
892
+ extra_gr_dict.update(dict(zip(legend_labels_textboxs, Dataset.labels_change(paint_object))))
893
+ extra_gr_dict.update({title_name_textbox: gr.Textbox(paint_object.get_name(), visible=True, label=LN.title_name_textbox)})
894
+ extra_gr_dict.update({x_label_textbox: gr.Textbox(paint_object.get_x_cur_label(), visible=True, label=LN.x_label_textbox)})
895
+ extra_gr_dict.update({y_label_textbox: gr.Textbox(paint_object.get_y_cur_label(), visible=True, label=LN.y_label_textbox)})
896
+
897
+ return get_return_extra(True, extra_gr_dict)
898
 
899
 
900
  def train_model(optimize, linear_regression_model_type):
 
924
  def encode_label(col_list: list):
925
  Dataset.encode_label(col_list)
926
 
927
+ return get_return(True, {
928
+ display_encode_label_dataframe: gr.Dataframe(Dataset.get_str2int_mappings_df(), type="pandas", visible=True,
929
+ label=LN.display_encode_label_dataframe)})
930
 
931
 
932
  def del_duplicate():
 
982
 
983
 
984
  with gr.Blocks() as demo:
 
985
  '''
986
  组件
987
  '''
 
996
  # 显示数据表信息
997
  with gr.Accordion("当前数据信息"):
998
  display_dataset_dataframe = gr.Dataframe(visible=False)
999
+ display_dataset = gr.File(visible=False)
1000
  with gr.Row():
1001
  display_total_col_num_text = gr.Textbox(visible=False)
1002
  display_total_row_num_text = gr.Textbox(visible=False)
 
1039
 
1040
  # 可视化
1041
  with gr.Accordion("数据可视化"):
1042
+ with gr.Tab("学习曲线图"):
1043
+ learning_curve_checkboxgroup = gr.Checkboxgroup(visible=False)
1044
+ with gr.Row():
1045
+ learning_curve_train_button = gr.Button(visible=False)
1046
+ learning_curve_validation_button = gr.Button(visible=False)
1047
+
1048
+ with gr.Tab("蜂群特征图"):
1049
+ shap_beeswarm_radio = gr.Radio(visible=False)
1050
+ shap_beeswarm_button = gr.Button(visible=False)
1051
+
1052
+ legend_labels_textboxs = []
1053
+ with gr.Accordion("图例"):
1054
+ with gr.Row():
1055
+ for i in range(StaticValue.max_num):
1056
+ with gr.Row():
1057
+ label = gr.Textbox(visible=False)
1058
+ legend_labels_textboxs.append(label)
1059
+
1060
+ with gr.Accordion("坐标轴"):
1061
+ with gr.Row():
1062
+ title_name_textbox = gr.Textbox(visible=False)
1063
+ x_label_textbox = gr.Textbox(visible=False)
1064
+ y_label_textbox = gr.Textbox(visible=False)
1065
+
1066
+ colorpickers = []
1067
+ color_textboxs = []
1068
+ with gr.Accordion("颜色"):
1069
+ with gr.Row():
1070
+ for i in range(StaticValue.max_num):
1071
+ with gr.Row():
1072
+ colorpicker = gr.ColorPicker(visible=False)
1073
+ colorpickers.append(colorpicker)
1074
+ color_textbox = gr.Textbox(visible=False)
1075
+ color_textboxs.append(color_textbox)
1076
+
1077
+ draw_plot = gr.Plot(visible=False)
1078
+ draw_file = gr.File(visible=False)
1079
 
1080
  '''
1081
  监听事件
 
1111
  model_train_button.click(fn=train_model, inputs=[model_optimize_radio, linear_regression_model_radio], outputs=get_outputs())
1112
 
1113
  # 可视化
1114
+ learning_curve_train_button.click(fn=learning_curve_train_first_draw_plot, inputs=[learning_curve_checkboxgroup], outputs=get_outputs())
1115
+ learning_curve_validation_button.click(fn=learning_curve_validation_first_draw_plot, inputs=[learning_curve_checkboxgroup], outputs=get_outputs())
1116
+ shap_beeswarm_button.click(fn=shap_beeswarm_first_draw_plot, inputs=[shap_beeswarm_radio], outputs=get_outputs())
1117
+
1118
+ title_name_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
1119
+ + [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
1120
+ x_label_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
1121
+ + [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
1122
+ y_label_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
1123
+ + [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
1124
+ for i in range(StaticValue.max_num):
1125
+ colorpickers[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
1126
+ + [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
1127
+ color_textboxs[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + color_textboxs + legend_labels_textboxs
1128
+ + [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
1129
+ legend_labels_textboxs[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
1130
+ + [learning_curve_checkboxgroup] + [shap_beeswarm_radio], outputs=get_outputs())
1131
 
1132
  if __name__ == "__main__":
1133
  demo.launch()
{diagram → buffer}/__init__.py RENAMED
File without changes
static/config.py CHANGED
@@ -12,6 +12,9 @@ class Config:
12
  "#EF8B67",
13
  "#F0C284"
14
  ]
 
 
 
15
  COLORS_1 = [
16
  "#91CCC0",
17
  "#7FABD1",
 
12
  "#EF8B67",
13
  "#F0C284"
14
  ]
15
+
16
+
17
+
18
  COLORS_1 = [
19
  "#91CCC0",
20
  "#7FABD1",
static/paint.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class PaintObject:
2
+ def __init__(self):
3
+ self.color_cur_num = 0
4
+ self.color_cur_list = []
5
+ self.label_cur_num = 0
6
+ self.label_cur_list = []
7
+ self.x_cur_label = ""
8
+ self.y_cur_label = ""
9
+ self.name = ""
10
+
11
+ def get_color_cur_num(self):
12
+ return self.color_cur_num
13
+
14
+ def set_color_cur_num(self, color_cur_num):
15
+ self.color_cur_num = color_cur_num
16
+
17
+ def get_color_cur_list(self):
18
+ return self.color_cur_list
19
+
20
+ def set_color_cur_list(self, color_cur_list):
21
+ self.color_cur_list = color_cur_list
22
+
23
+ def get_label_cur_num(self):
24
+ return self.label_cur_num
25
+
26
+ def set_label_cur_num(self, label_cur_num):
27
+ self.label_cur_num = label_cur_num
28
+
29
+ def get_label_cur_list(self):
30
+ return self.label_cur_list
31
+
32
+ def set_label_cur_list(self, label_cur_list):
33
+ self.label_cur_list = label_cur_list
34
+
35
+ def get_x_cur_label(self):
36
+ return self.x_cur_label
37
+
38
+ def set_x_cur_label(self, x_cur_label):
39
+ self.x_cur_label = x_cur_label
40
+
41
+ def get_y_cur_label(self):
42
+ return self.y_cur_label
43
+
44
+ def set_y_cur_label(self, y_cur_label):
45
+ self.y_cur_label = y_cur_label
46
+
47
+ def get_name(self):
48
+ return self.name
49
+
50
+ def set_name(self, name):
51
+ self.name = name
static/process.py CHANGED
@@ -196,7 +196,10 @@ def load_data(sort):
196
 
197
 
198
  def load_custom_data(file):
199
- return pd.read_csv(file)
 
 
 
200
 
201
 
202
  def preprocess_raw_data_filtering(df):
 
196
 
197
 
198
  def load_custom_data(file):
199
+ if "xlsx" in file or "xls" in file:
200
+ return pd.read_excel(file)
201
+ elif "csv" in file:
202
+ return pd.read_csv(file)
203
 
204
 
205
  def preprocess_raw_data_filtering(df):
visualization/draw_learning_curve_total.py CHANGED
@@ -1,15 +1,15 @@
1
  import numpy as np
2
  from matplotlib import pyplot as plt
3
 
 
4
  from static.config import Config
5
 
6
 
7
- def draw_learning_curve_total(input_dict, type):
8
  plt.figure(figsize=(10, 6), dpi=300)
9
 
10
  if type == "train":
11
- i = 0
12
- for label_name, values in input_dict.items():
13
  train_sizes = values[0]
14
  train_scores_mean = values[1]
15
  train_scores_std = values[2]
@@ -21,25 +21,19 @@ def draw_learning_curve_total(input_dict, type):
21
  train_scores_mean - train_scores_std,
22
  train_scores_mean + train_scores_std,
23
  alpha=0.1,
24
- color=Config.COLORS[i]
25
  )
26
 
27
  plt.plot(
28
  train_sizes,
29
  train_scores_mean,
30
  "o-",
31
- color=Config.COLORS[i],
32
- label=label_name
33
  )
34
 
35
- i += 1
36
-
37
- title = "Training Learning curve"
38
- # plt.title(title)
39
-
40
  else:
41
- i = 0
42
- for label_name, values in input_dict.items():
43
  train_sizes = values[0]
44
  train_scores_mean = values[1]
45
  train_scores_std = values[2]
@@ -51,26 +45,27 @@ def draw_learning_curve_total(input_dict, type):
51
  test_scores_mean - test_scores_std,
52
  test_scores_mean + test_scores_std,
53
  alpha=0.1,
54
- color=Config.COLORS[i]
55
  )
56
  plt.plot(
57
  train_sizes,
58
  test_scores_mean,
59
  "o-",
60
- color=Config.COLORS[i],
61
- label=label_name
62
  )
63
 
64
- i += 1
65
-
66
- title = "Cross-validation Learning curve"
67
- # plt.title(title)
68
 
69
- plt.xlabel("Sizes")
70
- plt.ylabel("Adjusted R-square")
71
  plt.legend()
72
 
73
  # plt.savefig("./diagram/{}.png".format(title), dpi=300)
74
  # plt.show()
75
- return plt
 
 
 
 
76
 
 
1
  import numpy as np
2
  from matplotlib import pyplot as plt
3
 
4
+ from static.paint import PaintObject
5
  from static.config import Config
6
 
7
 
8
+ def draw_learning_curve_total(input_dict, type, paint_object: PaintObject):
9
  plt.figure(figsize=(10, 6), dpi=300)
10
 
11
  if type == "train":
12
+ for i, values in enumerate(input_dict.values()):
 
13
  train_sizes = values[0]
14
  train_scores_mean = values[1]
15
  train_scores_std = values[2]
 
21
  train_scores_mean - train_scores_std,
22
  train_scores_mean + train_scores_std,
23
  alpha=0.1,
24
+ color=paint_object.get_color_cur_list()[i]
25
  )
26
 
27
  plt.plot(
28
  train_sizes,
29
  train_scores_mean,
30
  "o-",
31
+ color=paint_object.get_color_cur_list()[i],
32
+ label=paint_object.get_label_cur_list()[i]
33
  )
34
 
 
 
 
 
 
35
  else:
36
+ for i, values in enumerate(input_dict.values()):
 
37
  train_sizes = values[0]
38
  train_scores_mean = values[1]
39
  train_scores_std = values[2]
 
45
  test_scores_mean - test_scores_std,
46
  test_scores_mean + test_scores_std,
47
  alpha=0.1,
48
+ color=paint_object.get_color_cur_list()[i]
49
  )
50
  plt.plot(
51
  train_sizes,
52
  test_scores_mean,
53
  "o-",
54
+ color=paint_object.get_color_cur_list()[i],
55
+ label=paint_object.get_label_cur_list()[i]
56
  )
57
 
58
+ plt.title(paint_object.get_name())
 
 
 
59
 
60
+ plt.xlabel(paint_object.get_x_cur_label())
61
+ plt.ylabel(paint_object.get_y_cur_label())
62
  plt.legend()
63
 
64
  # plt.savefig("./diagram/{}.png".format(title), dpi=300)
65
  # plt.show()
66
+
67
+ paint_object.set_color_cur_num(len(input_dict.keys()))
68
+ paint_object.set_label_cur_num(len(input_dict.keys()))
69
+
70
+ return plt, paint_object
71