LLH commited on
Commit
4b04dd5
·
1 Parent(s): 9636812

2024/03/09/18:30

Browse files
Files changed (2) hide show
  1. app.py +30 -26
  2. classes/static_custom_class.py +5 -0
app.py CHANGED
@@ -356,7 +356,7 @@ class Dataset:
356
 
357
  @classmethod
358
  def check_model_optimize_radio(cls):
359
- if cls.cur_model and cls.choose_optimize != "None" and cls.choose_optimize:
360
  if cls.cur_model == MN.linear_regressor:
361
  if cls.linear_regression_model_type:
362
  return True
@@ -640,8 +640,8 @@ class Dataset:
640
 
641
  @classmethod
642
  def get_optimize_name_mapping(cls):
643
- # return dict(zip(cls.get_optimize_list(), ["None", "grid_search", "bayes_search"]))
644
- return dict(zip(cls.get_optimize_list(), ["None", "grid_search"]))
645
 
646
  @classmethod
647
  def get_linear_regression_model_list(cls):
@@ -1214,10 +1214,9 @@ class Dataset:
1214
  def get_model_train_input_params(cls):
1215
  EACH_ROW_NUM = 6
1216
  output_list = []
1217
- print(cls.cur_model)
1218
  if cls.check_model_optimize_radio():
1219
  output_dict = ChooseModelParams.choose(cls.cur_model)
1220
- print("output_dict: {}".format(str(output_dict)))
1221
  row_unit_num_list = []
1222
  row_len = len(output_dict.keys())
1223
  dict_keys_list = [x for x in output_dict.keys()]
@@ -1226,9 +1225,7 @@ class Dataset:
1226
  row_unit_num_list.append(len(v))
1227
  for x in v:
1228
  output_list.append(x)
1229
- print("output_list: {}".format(str(output_list)))
1230
- print("row_len: {}".format(str(row_len)))
1231
- print("dict_keys_list: {}".format(str(dict_keys_list)))
1232
  return_list = []
1233
  cumulative_sum = 0
1234
  for j in range(row_len):
@@ -1240,7 +1237,7 @@ class Dataset:
1240
  return_list.extend(
1241
  [gr.Textbox("", visible=False)] * (EACH_ROW_NUM - 1 - row_unit_num_list[j])
1242
  )
1243
- print("return_list: {}".format(str(return_list)))
1244
  cumulative_sum += row_unit_num_list[j]
1245
 
1246
  return_list.extend(["", gr.Textbox(visible=False)] * (StaticValue.MAX_PARAMS_NUM - EACH_ROW_NUM * row_len))
@@ -2005,12 +2002,12 @@ def choose_custom_dataset(file: str):
2005
  def select_model_optimize_radio(optimize):
2006
  optimize = Dataset.get_optimize_name_mapping()[optimize]
2007
 
2008
- if optimize == "grid_search":
2009
- Dataset.choose_optimize = "grid_search"
2010
- elif optimize == "bayes_search":
2011
- Dataset.choose_optimize = "bayes_search"
2012
- elif optimize == "None":
2013
- Dataset.choose_optimize = "None"
2014
 
2015
  return get_return(True)
2016
 
@@ -2446,17 +2443,24 @@ with gr.Blocks(js="./design/welcome.js", css="./design/custom.css") as demo:
2446
 
2447
  # 数据模型
2448
  # !有BUG
2449
- select_as_model_radio.change(
2450
- fn=select_as_model,
2451
- inputs=[select_as_model_radio],
2452
- outputs=get_outputs()
2453
- ).then(
2454
- fn=reset_select_model_optimize_radio_part_1,
2455
- outputs=get_outputs()
2456
- ).then(
2457
- fn=reset_select_model_optimize_radio_part_2,
2458
- outputs=get_outputs()
2459
- )
 
 
 
 
 
 
 
2460
 
2461
  # [模型]
2462
  linear_regression_model_radio.change(
 
356
 
357
  @classmethod
358
  def check_model_optimize_radio(cls):
359
+ if cls.cur_model and cls.choose_optimize != MN.none and cls.choose_optimize:
360
  if cls.cur_model == MN.linear_regressor:
361
  if cls.linear_regression_model_type:
362
  return True
 
640
 
641
  @classmethod
642
  def get_optimize_name_mapping(cls):
643
+ # return dict(zip(cls.get_optimize_list(), [MN.none, MN.grid_search, MN.bayes_search]))
644
+ return dict(zip(cls.get_optimize_list(), [MN.none, MN.grid_search]))
645
 
646
  @classmethod
647
  def get_linear_regression_model_list(cls):
 
1214
  def get_model_train_input_params(cls):
1215
  EACH_ROW_NUM = 6
1216
  output_list = []
1217
+
1218
  if cls.check_model_optimize_radio():
1219
  output_dict = ChooseModelParams.choose(cls.cur_model)
 
1220
  row_unit_num_list = []
1221
  row_len = len(output_dict.keys())
1222
  dict_keys_list = [x for x in output_dict.keys()]
 
1225
  row_unit_num_list.append(len(v))
1226
  for x in v:
1227
  output_list.append(x)
1228
+
 
 
1229
  return_list = []
1230
  cumulative_sum = 0
1231
  for j in range(row_len):
 
1237
  return_list.extend(
1238
  [gr.Textbox("", visible=False)] * (EACH_ROW_NUM - 1 - row_unit_num_list[j])
1239
  )
1240
+
1241
  cumulative_sum += row_unit_num_list[j]
1242
 
1243
  return_list.extend(["", gr.Textbox(visible=False)] * (StaticValue.MAX_PARAMS_NUM - EACH_ROW_NUM * row_len))
 
2002
  def select_model_optimize_radio(optimize):
2003
  optimize = Dataset.get_optimize_name_mapping()[optimize]
2004
 
2005
+ if optimize == MN.grid_search:
2006
+ Dataset.choose_optimize = MN.grid_search
2007
+ elif optimize == MN.bayes_search:
2008
+ Dataset.choose_optimize = MN.bayes_search
2009
+ elif optimize == MN.none:
2010
+ Dataset.choose_optimize = MN.none
2011
 
2012
  return get_return(True)
2013
 
 
2443
 
2444
  # 数据模型
2445
  # !有BUG
2446
+ if Dataset.choose_optimize and Dataset.choose_optimize != MN.none:
2447
+ select_as_model_radio.change(
2448
+ fn=select_as_model,
2449
+ inputs=[select_as_model_radio],
2450
+ outputs=get_outputs()
2451
+ ).then(
2452
+ fn=reset_select_model_optimize_radio_part_1,
2453
+ outputs=get_outputs()
2454
+ ).then(
2455
+ fn=reset_select_model_optimize_radio_part_2,
2456
+ outputs=get_outputs()
2457
+ )
2458
+ else:
2459
+ select_as_model_radio.change(
2460
+ fn=select_as_model,
2461
+ inputs=[select_as_model_radio],
2462
+ outputs=get_outputs()
2463
+ )
2464
 
2465
  # [模型]
2466
  linear_regression_model_radio.change(
classes/static_custom_class.py CHANGED
@@ -148,6 +148,11 @@ class MN: # ModelName
148
  naive_bayes_classifier = "naive bayes classifier"
149
  # 模型Step 4:在这里添加新的模型名称
150
 
 
 
 
 
 
151
  # [绘图]
152
  data_distribution = "data_distribution"
153
  descriptive_indicators = "descriptive_indicators"
 
148
  naive_bayes_classifier = "naive bayes classifier"
149
  # 模型Step 4:在这里添加新的模型名称
150
 
151
+ none = "None"
152
+ grid_search = "grid_search"
153
+ bayes_search = "bayes_search"
154
+
155
+
156
  # [绘图]
157
  data_distribution = "data_distribution"
158
  descriptive_indicators = "descriptive_indicators"