LLH
2024/02/20/14:15
10c7c36
raw
history blame
84.1 kB
import copy
import math
import os.path
import random
import gradio as gr
import matplotlib.pyplot as plt
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
import pandas as pd
from analysis.bayes_model import *
from analysis.distance_model import *
from analysis.gradient_model import *
from analysis.kernel_model import *
from analysis.shap_model import *
from analysis.tree_model import *
from metrics.calculate_classification_metrics import ClassificationMetrics
from metrics.calculate_regression_metrics import RegressionMetrics
from static.process import *
from analysis.linear_model import *
from visualization.draw_boxplot import draw_boxplot
from visualization.draw_data_fit_total import draw_data_fit_total
from visualization.draw_heat_map import draw_heat_map
from visualization.draw_histogram import draw_histogram
from visualization.draw_learning_curve_total import draw_learning_curve_total
from static.new_class import *
import warnings
warnings.filterwarnings("ignore")
# [模型]
class ChooseModelMetrics:
@classmethod
def choose(cls, cur_model):
if cur_model == MN.linear_regression:
return RegressionMetrics.get_metrics()
elif cur_model == MN.polynomial_regression:
return RegressionMetrics.get_metrics()
elif cur_model == MN.logistic_regression:
return ClassificationMetrics.get_metrics()
elif cur_model == MN.decision_tree_classifier:
return ClassificationMetrics.get_metrics()
elif cur_model == MN.random_forest_classifier:
return ClassificationMetrics.get_metrics()
elif cur_model == MN.random_forest_regression:
return RegressionMetrics.get_metrics()
elif cur_model == MN.xgboost_classifier:
return ClassificationMetrics.get_metrics()
elif cur_model == MN.lightGBM_classifier:
return ClassificationMetrics.get_metrics()
elif cur_model == MN.gradient_boosting_regression:
return RegressionMetrics.get_metrics()
elif cur_model == MN.svm_classifier:
return ClassificationMetrics.get_metrics()
elif cur_model == MN.svm_regression:
return RegressionMetrics.get_metrics()
elif cur_model == MN.knn_classifier:
return ClassificationMetrics.get_metrics()
elif cur_model == MN.knn_regression:
return RegressionMetrics.get_metrics()
elif cur_model == MN.naive_bayes_classification:
return ClassificationMetrics.get_metrics()
# [模型]
class ChooseModelParams:
@classmethod
def choose(cls, cur_model):
if cur_model == MN.linear_regression:
return LinearRegressionParams.get_params(Dataset.linear_regression_model_type)
elif cur_model == MN.polynomial_regression:
return PolynomialRegressionParams.get_params()
elif cur_model == MN.logistic_regression:
return LogisticRegressionParams.get_params()
elif cur_model == MN.decision_tree_classifier:
return DecisionTreeClassifierParams.get_params()
elif cur_model == MN.random_forest_classifier:
return RandomForestClassifierParams.get_params()
elif cur_model == MN.random_forest_regression:
return RandomForestRegressionParams.get_params()
elif cur_model == MN.xgboost_classifier:
return XgboostClassifierParams.get_params()
elif cur_model == MN.lightGBM_classifier:
return LightGBMClassifierParams.get_params()
elif cur_model == MN.gradient_boosting_regression:
return GradientBoostingParams.get_params()
elif cur_model == MN.svm_classifier:
return SVMClassifierParams.get_params()
elif cur_model == MN.svm_regression:
return SVMRegressionParams.get_params()
elif cur_model == MN.knn_classifier:
return KNNClassifierParams.get_params()
elif cur_model == MN.knn_regression:
return KNNRegressionParams.get_params()
elif cur_model == MN.naive_bayes_classification:
return NaiveBayesClassifierParams.get_params(Dataset.naive_bayes_classifier_model_type)
class StaticValue:
max_num = 20
class FilePath:
png_base = "./buffer/{}.png"
excel_base = "./buffer/{}.xlsx"
# [绘图]
display_dataset = "current_excel_data"
data_distribution_plot = "data_distribution_plot"
descriptive_indicators_plot = "descriptive_indicators_plot"
heatmap_plot = "heatmap_plot"
learning_curve_plot = "learning_curve_plot"
shap_beeswarm_plot = "shap_beeswarm_plot"
data_fit_plot = "data_fit_plot"
waterfall_plot = "waterfall_plot"
force_plot = "force_plot"
dependence_plot = "dependence_plot"
class MN: # ModelName
classification = "classification"
regression = "regression"
# [模型]
linear_regression = "linear regressor"
polynomial_regression = "polynomial regressor"
logistic_regression = "logistic regressor"
decision_tree_classifier = "decision tree classifier"
random_forest_classifier = "random forest classifier"
random_forest_regression = "random forest regressor"
xgboost_classifier = "xgboost classifier"
lightGBM_classifier = "lightGBM classifier"
gradient_boosting_regression = "gradient boosting regressor"
svm_classifier = "svm classifier"
svm_regression = "svm regressor"
knn_classifier = "knn classifier"
knn_regression = "knn regressor"
naive_bayes_classification = "naive bayes classification"
# [绘图]
data_distribution = "data_distribution"
descriptive_indicators = "descriptive_indicators"
heatmap = "heatmap"
learning_curve = "learning_curve"
shap_beeswarm = "shap_beeswarm"
data_fit = "data_fit"
waterfall = "waterfall"
force = "force"
dependence = "dependence"
class LN: # LabelName
choose_dataset_radio = "选择所需数据源 [必选]"
display_total_col_num_text = "总列数"
display_total_row_num_text = "总行数"
display_na_list_text = "存在缺失值的列"
del_all_na_col_button = "删除所有存在缺失值的列 [可选]"
display_duplicate_num_text = "重复的行数"
del_col_checkboxgroup = "选择所需删除的列"
del_col_button = "删除 [可选]"
remain_row_slider = "保留的行数"
remain_row_button = "保留 [可选]"
del_duplicate_button = "删除所有重复行 [可选]"
encode_label_checkboxgroup = "选择所需标签编码的字符型数值列"
display_encode_label_dataframe = "标签编码信息"
encode_label_button = "字符型转数值型 [可选]"
change_data_type_to_float_button = "将所有数据强制转换为浮点型(除第1列以外)[必选]"
standardize_data_checkboxgroup = "选择所需标准化的列"
standardize_data_button = "标准化 [可选]"
select_as_y_radio = "选择因变量 [必选]"
choose_assign_radio = "选择任务类型(同时会根据任务类型将第1列数据强制转换)[必选]"
model_optimize_radio = "选择超参数优化方法"
model_train_button = "训练"
model_train_params_dataframe = "训练后的模型参数"
model_train_metrics_dataframe = "训练后的模型指标"
select_as_model_radio = "选择所需训练的模型"
# [模型]
linear_regression_model_radio = "选择线性回归的模型"
naive_bayes_classification_model_radio = "选择朴素贝叶斯分类的模型"
title_name_textbox = "标题"
x_label_textbox = "x 轴名称"
y_label_textbox = "y 轴名称"
colors = ["颜色 {}".format(i) for i in range(StaticValue.max_num)]
labels = ["图例 {}".format(i) for i in range(StaticValue.max_num)]
# [绘图]
heatmap_is_rotate = "x轴标签是否旋转"
heatmap_checkboxgroup = "选择所需绘制系数热力图的列"
heatmap_button = "绘制系数热力图"
data_distribution_radio = "选择所需绘制数据分布图的列"
data_distribution_is_rotate = "x轴标签是否旋转"
data_distribution_button = "绘制数据分布图"
descriptive_indicators_checkboxgroup = "选择所需绘制箱线统计图的列"
descriptive_indicators_is_rotate = "x轴标签是否旋转"
descriptive_indicators_button = "绘制箱线统计图"
learning_curve_checkboxgroup = "选择所需绘制学习曲线图的模型"
learning_curve_button = "绘制学习曲线图"
shap_beeswarm_radio = "选择所需绘制特征蜂群图的模型"
shap_beeswarm_type = "选择图像类型"
shap_beeswarm_button = "绘制特征蜂群图"
data_fit_checkboxgroup = "选择所需绘制数据拟合图的模型"
data_fit_button = "绘制数据拟合图"
waterfall_radio = "选择所需绘制特征瀑布图的模型"
waterfall_number = "输入相关特征的变量索引"
waterfall_button = "绘制特征瀑布图"
force_radio = "选择所需绘制特征力图的模型"
force_number = "输入相关特征的变量索引"
force_button = "绘制特征力图"
dependence_radio = "选择所需绘制特征依赖图的模型"
dependence_col = "选择相应的列"
dependence_button = "绘制特征依赖图"
data_distribution_plot = "数据分布图"
descriptive_indicators_plot = "箱线统计图"
heatmap_plot = "系数热力图"
learning_curve_plot = "学习曲线图"
shap_beeswarm_plot = "特征蜂群图"
data_fit_plot = "数据拟合图"
waterfall_plot = "特征瀑布图"
force_plot = "特征力图"
dependence_plot = "特征依赖图"
def get_return_extra(is_visible, extra_gr_dict: dict = None):
if is_visible:
gr_dict = {
draw_file: gr.File(Dataset.after_get_file(), visible=Dataset.check_file()),
}
if extra_gr_dict:
gr_dict.update(extra_gr_dict)
return gr_dict
gr_dict = {
draw_plot: gr.Plot(visible=False),
draw_file: gr.File(visible=False),
}
gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
gr_dict.update({title_name_textbox: gr.Textbox(visible=False)})
gr_dict.update({x_label_textbox: gr.Textbox(visible=False)})
gr_dict.update({y_label_textbox: gr.Textbox(visible=False)})
return gr_dict
def get_outputs():
gr_set = {
choose_custom_dataset_file,
display_dataset_dataframe,
display_total_col_num_text,
display_total_row_num_text,
display_na_list_text,
del_all_na_col_button,
display_duplicate_num_text,
del_duplicate_button,
del_col_checkboxgroup,
del_col_button,
remain_row_slider,
remain_row_button,
encode_label_button,
display_encode_label_dataframe,
encode_label_checkboxgroup,
data_type_dataframe,
change_data_type_to_float_button,
standardize_data_checkboxgroup,
standardize_data_button,
select_as_y_radio,
model_optimize_radio,
model_train_button,
model_train_checkbox,
model_train_params_dataframe,
model_train_metrics_dataframe,
select_as_model_radio,
choose_assign_radio,
display_dataset,
draw_plot,
draw_file,
title_name_textbox,
x_label_textbox,
y_label_textbox,
# [模型]
linear_regression_model_radio,
naive_bayes_classification_model_radio,
# [绘图]
heatmap_is_rotate,
heatmap_checkboxgroup,
heatmap_button,
data_distribution_radio,
data_distribution_is_rotate,
data_distribution_button,
descriptive_indicators_checkboxgroup,
descriptive_indicators_is_rotate,
descriptive_indicators_dataframe,
descriptive_indicators_button,
learning_curve_checkboxgroup,
learning_curve_button,
shap_beeswarm_radio,
shap_beeswarm_type,
shap_beeswarm_button,
data_fit_checkboxgroup,
data_fit_button,
waterfall_radio,
waterfall_number,
waterfall_button,
force_radio,
force_number,
force_button,
dependence_radio,
dependence_col,
dependence_button,
}
gr_set.update(set(colorpickers))
gr_set.update(set(color_textboxs))
gr_set.update(set(legend_labels_textboxs))
return gr_set
def get_return(is_visible, extra_gr_dict: dict = None):
if is_visible:
gr_dict = {
display_dataset_dataframe: gr.Dataframe(add_index_into_df(Dataset.data), type="pandas", visible=True),
display_dataset: gr.File(Dataset.after_get_display_dataset_file(), visible=Dataset.check_display_dataset_file()),
display_total_col_num_text: gr.Textbox(str(Dataset.get_total_col_num()), visible=True, label=LN.display_total_col_num_text),
display_total_row_num_text: gr.Textbox(str(Dataset.get_total_row_num()), visible=True, label=LN.display_total_row_num_text),
display_na_list_text: gr.Textbox(Dataset.get_na_list_str(), visible=True, label=LN.display_na_list_text),
del_all_na_col_button: gr.Button(LN.del_all_na_col_button, visible=True),
display_duplicate_num_text: gr.Textbox(str(Dataset.get_duplicate_num()), visible=True, label=LN.display_duplicate_num_text),
del_duplicate_button: gr.Button(LN.del_duplicate_button, visible=True),
del_col_checkboxgroup: gr.Checkboxgroup(Dataset.get_col_list(), visible=True, label=LN.del_col_checkboxgroup),
del_col_button: gr.Button(LN.del_col_button, visible=True),
remain_row_slider: gr.Slider(0, Dataset.get_max_num(), value=Dataset.get_total_row_num(), step=1, visible=True, label=LN.remain_row_slider),
remain_row_button: gr.Button(LN.remain_row_button, visible=True),
encode_label_button: gr.Button(LN.encode_label_button, visible=True),
encode_label_checkboxgroup: gr.Checkboxgroup(Dataset.get_non_numeric_list(), visible=True, label=LN.encode_label_checkboxgroup),
display_encode_label_dataframe: gr.Dataframe(visible=False),
data_type_dataframe: gr.Dataframe(Dataset.get_data_type(), visible=True),
change_data_type_to_float_button: gr.Button(LN.change_data_type_to_float_button, visible=True),
select_as_y_radio: gr.Radio(Dataset.get_col_list(), visible=True, label=LN.select_as_y_radio),
standardize_data_checkboxgroup: gr.Checkboxgroup(Dataset.get_non_standardized_data(), visible=True, label=LN.standardize_data_checkboxgroup),
standardize_data_button: gr.Button(LN.standardize_data_button, visible=True),
choose_assign_radio: gr.Radio(Dataset.get_assign_list(), visible=True, label=LN.choose_assign_radio),
select_as_model_radio: gr.Radio(Dataset.get_model_list(), visible=Dataset.check_before_train(), label=LN.select_as_model_radio),
model_optimize_radio: gr.Radio(Dataset.get_optimize_list(), visible=Dataset.check_before_train(), label=LN.model_optimize_radio),
model_train_button: gr.Button(LN.model_train_button, visible=Dataset.check_before_train()),
model_train_checkbox: gr.Checkbox(Dataset.get_model_container_status(), visible=Dataset.check_select_model(), label=Dataset.get_model_label()),
model_train_params_dataframe: gr.Dataframe(Dataset.get_model_train_params_dataframe(), type="pandas", visible=Dataset.get_model_container_status()),
model_train_metrics_dataframe: gr.Dataframe(Dataset.get_model_train_metrics_dataframe(), type="pandas", visible=Dataset.get_model_container_status()),
draw_plot: gr.Plot(visible=False),
draw_file: gr.File(visible=False),
title_name_textbox: gr.Textbox(visible=False),
x_label_textbox: gr.Textbox(visible=False),
y_label_textbox: gr.Textbox(visible=False),
# [模型]
linear_regression_model_radio: gr.Radio(Dataset.get_linear_regression_model_list(), visible=Dataset.get_linear_regression_mark(), label=LN.linear_regression_model_radio),
naive_bayes_classification_model_radio: gr.Radio(Dataset.get_naive_bayes_classifier_model_list(), visible=Dataset.get_naive_bayes_classifier_mark(), label=LN.naive_bayes_classification_model_radio),
# [绘图]
heatmap_checkboxgroup: gr.Checkboxgroup(Dataset.get_float_col_list(), visible=True, label=LN.heatmap_checkboxgroup),
heatmap_is_rotate: gr.Checkbox(visible=True, label=LN.heatmap_is_rotate),
heatmap_button: gr.Button(LN.heatmap_button, visible=True),
descriptive_indicators_checkboxgroup: gr.Checkboxgroup(Dataset.get_float_col_list(), visible=True, label=LN.descriptive_indicators_checkboxgroup),
data_distribution_radio: gr.Radio(Dataset.get_str_col_list(), visible=True, label=LN.data_distribution_radio),
data_distribution_is_rotate: gr.Checkbox(visible=True, label=LN.data_distribution_is_rotate),
data_distribution_button: gr.Button(LN.data_distribution_button, visible=True),
descriptive_indicators_is_rotate: gr.Checkbox(visible=True, label=LN.descriptive_indicators_is_rotate),
descriptive_indicators_dataframe: gr.Dataframe(Dataset.get_descriptive_indicators_df(), type="pandas", visible=Dataset.check_descriptive_indicators_df()),
descriptive_indicators_button: gr.Button(LN.descriptive_indicators_button, visible=True),
learning_curve_checkboxgroup: gr.Checkboxgroup(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.learning_curve_checkboxgroup),
learning_curve_button: gr.Button(LN.learning_curve_button, visible=Dataset.check_before_train()),
shap_beeswarm_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.shap_beeswarm_radio),
shap_beeswarm_type: gr.Radio(Dataset.get_shap_beeswarm_plot_type(), visible=Dataset.check_before_train(), label=LN.shap_beeswarm_type),
shap_beeswarm_button: gr.Button(LN.shap_beeswarm_button, visible=Dataset.check_before_train()),
data_fit_checkboxgroup: gr.Checkboxgroup(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.data_fit_checkboxgroup),
data_fit_button: gr.Button(LN.data_fit_button, visible=Dataset.check_before_train()),
waterfall_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.waterfall_radio),
waterfall_number: gr.Slider(0, Dataset.get_total_row_num(), value=0, step=1, visible=Dataset.check_before_train(), label=LN.waterfall_number),
waterfall_button: gr.Button(LN.waterfall_button, visible=Dataset.check_before_train()),
force_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.force_radio),
force_number: gr.Slider(0, Dataset.get_total_row_num(), value=0, step=1, visible=Dataset.check_before_train(), label=LN.force_number),
force_button: gr.Button(LN.force_button, visible=Dataset.check_before_train()),
dependence_radio: gr.Radio(Dataset.get_trained_model_list(), visible=Dataset.check_before_train(), label=LN.dependence_radio),
dependence_col: gr.Radio(Dataset.get_col_list(), visible=Dataset.check_before_train(), label=LN.dependence_col),
dependence_button: gr.Button(LN.dependence_button, visible=Dataset.check_before_train()),
}
gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
if extra_gr_dict:
gr_dict.update(extra_gr_dict)
return gr_dict
gr_dict = {
choose_custom_dataset_file: gr.File(None, visible=True),
display_dataset_dataframe: gr.Dataframe(visible=False),
display_dataset: gr.File(visible=False),
display_total_col_num_text: gr.Textbox(visible=False),
display_total_row_num_text: gr.Textbox(visible=False),
display_na_list_text: gr.Textbox(visible=False),
del_all_na_col_button: gr.Button(visible=False),
display_duplicate_num_text: gr.Textbox(visible=False),
del_duplicate_button: gr.Button(visible=False),
del_col_checkboxgroup: gr.Checkboxgroup(visible=False),
del_col_button: gr.Button(visible=False),
remain_row_slider: gr.Slider(visible=False),
encode_label_button: gr.Button(visible=False),
display_encode_label_dataframe: gr.Dataframe(visible=False),
encode_label_checkboxgroup: gr.Checkboxgroup(visible=False),
data_type_dataframe: gr.Dataframe(visible=False),
change_data_type_to_float_button: gr.Button(visible=False),
standardize_data_checkboxgroup: gr.Checkboxgroup(visible=False),
standardize_data_button: gr.Button(visible=False),
select_as_y_radio: gr.Radio(visible=False),
model_optimize_radio: gr.Radio(visible=False),
model_train_button: gr.Button(visible=False),
model_train_checkbox: gr.Checkbox(visible=False),
model_train_metrics_dataframe: gr.Dataframe(visible=False),
model_train_params_dataframe: gr.Dataframe(visible=False),
select_as_model_radio: gr.Radio(visible=False),
choose_assign_radio: gr.Radio(visible=False),
draw_plot: gr.Plot(visible=False),
draw_file: gr.File(visible=False),
title_name_textbox: gr.Textbox(visible=False),
x_label_textbox: gr.Textbox(visible=False),
y_label_textbox: gr.Textbox(visible=False),
# [模型]
linear_regression_model_radio: gr.Radio(visible=False),
naive_bayes_classification_model_radio: gr.Radio(visible=False),
# [绘图]
heatmap_checkboxgroup: gr.Checkboxgroup(visible=False),
heatmap_is_rotate: gr.Checkbox(visible=False),
heatmap_button: gr.Button(visible=False),
data_distribution_radio: gr.Radio(visible=False),
data_distribution_is_rotate: gr.Checkbox(visible=False),
data_distribution_button: gr.Button(visible=False),
descriptive_indicators_checkboxgroup: gr.Checkboxgroup(visible=False),
descriptive_indicators_is_rotate: gr.Checkbox(visible=False),
descriptive_indicators_dataframe: gr.Dataframe(visible=False),
descriptive_indicators_button: gr.Button(visible=False),
learning_curve_checkboxgroup: gr.Checkboxgroup(visible=False),
learning_curve_button: gr.Button(visible=False),
shap_beeswarm_radio: gr.Radio(visible=False),
shap_beeswarm_type: gr.Radio(visible=False),
shap_beeswarm_button: gr.Button(visible=False),
data_fit_checkboxgroup: gr.Checkboxgroup(visible=False),
data_fit_button: gr.Button(visible=False),
waterfall_radio: gr.Radio(visible=False),
waterfall_number: gr.Slider(visible=False),
waterfall_button: gr.Button(visible=False),
force_radio: gr.Radio(visible=False),
force_number: gr.Slider(visible=False),
force_button: gr.Button(visible=False),
dependence_radio: gr.Radio(visible=False),
dependence_col: gr.Radio(visible=False),
dependence_button: gr.Button(visible=False),
}
gr_dict.update(dict(zip(colorpickers, [gr.ColorPicker(visible=False)] * StaticValue.max_num)))
gr_dict.update(dict(zip(color_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
gr_dict.update(dict(zip(legend_labels_textboxs, [gr.Textbox(visible=False)] * StaticValue.max_num)))
return gr_dict
class Dataset:
file = ""
data = pd.DataFrame()
na_list = []
non_numeric_list = []
str2int_mappings = {}
max_num = 0
data_copy = pd.DataFrame()
assign = ""
cur_model = ""
select_y_mark = False
descriptive_indicators_df = pd.DataFrame()
linear_regression_model_type = ""
naive_bayes_classifier_model_type = ""
container_dict = {
# [模型]
MN.linear_regression: Container(),
MN.polynomial_regression: Container(),
MN.logistic_regression: Container(),
MN.decision_tree_classifier: Container(),
MN.random_forest_classifier: Container(),
MN.random_forest_regression: Container(),
MN.xgboost_classifier: Container(),
MN.lightGBM_classifier: Container(),
MN.gradient_boosting_regression: Container(),
MN.svm_classifier: Container(),
MN.svm_regression: Container(),
MN.knn_classifier: Container(),
MN.knn_regression: Container(),
MN.naive_bayes_classification: Container(),
}
visualize = ""
@classmethod
def check_descriptive_indicators_df(cls):
return True if not cls.descriptive_indicators_df.empty else False
@classmethod
def get_descriptive_indicators_df(cls):
return cls.descriptive_indicators_df
@classmethod
def get_notes(cls):
notes = ""
with open("./data/notes.md", "r", encoding="utf-8") as f:
notes = str(f.read())
return notes
@classmethod
def get_dataset_list(cls):
return ["自定义", "Iris Dataset", "Wine Dataset", "Breast Cancer Dataset", "Diabetes Dataset", "California Housing Dataset"]
@classmethod
def get_col_list(cls):
return [x for x in cls.data.columns.values]
@classmethod
def get_na_list_str(cls) -> str:
na_series = cls.data.isna().any(axis=0)
na_list = []
na_list_str = ""
for i in range(len(na_series)):
cur_value = na_series[i]
cur_index = na_series.index[i]
if cur_value:
na_list_str += cur_index + ", "
na_list.append(cur_index)
na_list_str = na_list_str.rstrip(", ")
cls.na_list = na_list
if not na_list:
return "无"
return na_list_str
@classmethod
def get_total_col_num(cls) -> int:
return len(cls.data.columns)
@classmethod
def get_total_row_num(cls) -> int:
return len(cls.data)
@classmethod
def update(cls, file: str, data: pd.DataFrame):
cls.file = file
cls.data = data
cls.max_num = len(data)
cls.data_copy = data
@classmethod
def clear(cls):
cls.file = ""
cls.data = pd.DataFrame()
@classmethod
def get_display_dataset_file(cls):
file_path = FilePath.excel_base.format(FilePath.display_dataset)
return file_path
@classmethod
def check_display_dataset_file(cls):
return os.path.exists(cls.get_display_dataset_file())
@classmethod
def after_get_display_dataset_file(cls):
if not cls.data.empty:
cls.data.to_excel(cls.get_display_dataset_file(), index=False)
return cls.get_display_dataset_file() if cls.check_display_dataset_file() else None
@classmethod
def del_col(cls, col_list: list):
for col in col_list:
if col in cls.data.columns.values:
cls.data.drop(col, axis=1, inplace=True)
@classmethod
def get_max_num(cls):
return cls.max_num
@classmethod
def remain_row(cls, num):
cls.data = cls.data_copy.iloc[:num, :]
@classmethod
def del_all_na_col(cls):
for col in cls.na_list:
if col in cls.data.columns.values:
cls.data.drop(col, axis=1, inplace=True)
@classmethod
def get_duplicate_num(cls):
data_copy = copy.deepcopy(cls.data)
return len(cls.data) - len(data_copy.drop_duplicates())
@classmethod
def del_duplicate(cls):
cls.data = cls.data.drop_duplicates().reset_index().drop("index", axis=1)
@classmethod
def encode_label(cls, col_list: list, extra_mark=False):
data_copy = copy.deepcopy(cls.data)
str2int_mappings = dict(zip(col_list, [{} for _ in range(len(col_list))]))
for col in str2int_mappings.keys():
keys = np.array(data_copy[col].drop_duplicates())
values = [x for x in range(len(keys))]
str2int_mappings[col] = dict(zip(keys, values))
for col, mapping in str2int_mappings.items():
series = data_copy[col]
for k, v in mapping.items():
series.replace(k, v, inplace=True)
data_copy[col] = series
for k, v in str2int_mappings.items():
if np.nan in v.keys():
v.update({"nan": v.pop(np.nan)})
str2int_mappings[k] = v
if extra_mark:
return data_copy
else:
cls.data = data_copy
cls.str2int_mappings = str2int_mappings
@classmethod
def get_str2int_mappings_df(cls):
columns_list = ["列名", "字符型", "数值型"]
str2int_mappings_df = pd.DataFrame(columns=columns_list)
for k, v in cls.str2int_mappings.items():
cur_df = pd.DataFrame(columns=columns_list)
cur_df["列名"] = pd.DataFrame([k] * len(v.keys()))
cur_df["字符型"] = pd.DataFrame([x for x in v.keys()])
cur_df["数值型"] = pd.DataFrame([x for x in v.values()])
str2int_mappings_df = pd.concat([str2int_mappings_df, cur_df], axis=0)
blank_df = pd.DataFrame(columns=columns_list)
blank_df.loc[0] = ["", "", ""]
str2int_mappings_df = pd.concat([str2int_mappings_df, blank_df], axis=0)
return str2int_mappings_df.iloc[:-1, :]
@classmethod
def get_non_numeric_list(cls):
data_copy = copy.deepcopy(cls.data)
data_copy = data_copy.astype(str)
non_numeric_list = []
for col in data_copy.columns.values:
if pd.to_numeric(data_copy[col], errors="coerce").isnull().values.any():
non_numeric_list.append(col)
cls.non_numeric_list = non_numeric_list
return non_numeric_list
@classmethod
def get_data_type(cls):
columns_list = ["列名", "数据类型"]
data_type_dict = {}
for col in cls.data.columns.values:
data_type_dict[col] = cls.data[col].dtype.name
data_type_df = pd.DataFrame(columns=columns_list)
data_type_df["列名"] = [x for x in data_type_dict.keys()]
data_type_df["数据类型"] = [x for x in data_type_dict.values()]
return data_type_df
@classmethod
def change_data_type_to_float(cls):
data_copy = cls.data
for i, col in enumerate(data_copy.columns.values):
if i != 0:
data_copy[col] = data_copy[col].astype(float)
cls.data = data_copy
@classmethod
def get_non_standardized_data(cls):
not_standardized_data_list = []
for col in cls.data.columns.values:
if cls.data[col].dtype.name in ["int64", "float64"]:
if not np.array_equal(np.round(preprocessing.scale(cls.data[col]), decimals=2),
np.round(cls.data[col].values.round(2), decimals=2)):
not_standardized_data_list.append(col)
return not_standardized_data_list
@classmethod
def check_before_train(cls):
if cls.assign == "" or not cls.select_y_mark:
return False
for i, col in enumerate(cls.data.columns.values):
if i == 0:
if not (all(isinstance(x, str) for x in cls.data.iloc[:, 0]) or all(isinstance(x, float) for x in cls.data.iloc[:, 0])):
return False
else:
if cls.data[col].dtype.name != "float64":
return False
return True
@classmethod
def standardize_data(cls, col_list: list):
for col in col_list:
cls.data[col] = preprocessing.scale(cls.data[col])
@classmethod
def select_as_y(cls, col: str):
cls.data = pd.concat([cls.data[col], cls.data.drop(col, axis=1)], axis=1)
cls.select_y_mark = True
@classmethod
def get_optimize_list(cls):
return ["无", "网格搜索", "贝叶斯优化"]
@classmethod
def get_optimize_name_mapping(cls):
return dict(zip(cls.get_optimize_list(), ["None", "grid_search", "bayes_search"]))
@classmethod
def get_linear_regression_model_list(cls):
return ["线性回归", "Lasso回归", "Ridge回归", "弹性网络回归"]
@classmethod
def get_naive_bayes_classifier_model_list(cls):
return ["多项式朴素贝叶斯分类", "高斯朴素贝叶斯分类", "补充朴素贝叶斯分类"]
@classmethod
def get_linear_regression_model_name_mapping(cls):
return dict(zip(cls.get_linear_regression_model_list(), ["LinearRegression", "Lasso", "Ridge", "ElasticNet"]))
@classmethod
def get_naive_bayes_classifier_model_name_mapping(cls):
return dict(zip(cls.get_naive_bayes_classifier_model_list(), ["MultinomialNB", "GaussianNB", "ComplementNB"]))
@classmethod
def train_model(cls, optimize, linear_regression_model_type=None, naive_bayes_classifier_model_type=None):
optimize = cls.get_optimize_name_mapping()[optimize]
data_copy = cls.data
if cls.assign == MN.classification:
data_copy = cls.encode_label([cls.data.columns.values[0]], True)
x_train, x_test, y_train, y_test = train_test_split(
data_copy.values[:, 1:],
data_copy.values[:, :1],
random_state=Config.RANDOM_STATE,
train_size=0.8
)
container = Container(x_train, y_train, x_test, y_test, optimize)
# [模型]
if cls.cur_model == MN.linear_regression:
cls.linear_regression_model_type = cls.get_linear_regression_model_name_mapping()[linear_regression_model_type]
container = linear_regression(container, cls.linear_regression_model_type)
elif cls.cur_model == MN.polynomial_regression:
container = polynomial_regression(container)
elif cls.cur_model == MN.logistic_regression:
container = logistic_regression(container)
elif cls.cur_model == MN.decision_tree_classifier:
container = decision_tree_classifier(container)
elif cls.cur_model == MN.random_forest_classifier:
container = random_forest_classifier(container)
elif cls.cur_model == MN.random_forest_regression:
container = random_forest_regression(container)
elif cls.cur_model == MN.xgboost_classifier:
container = xgboost_classifier(container)
elif cls.cur_model == MN.lightGBM_classifier:
container = lightGBM_classifier(container)
elif cls.cur_model == MN.gradient_boosting_regression:
container = gradient_boosting_regression(container)
elif cls.cur_model == MN.svm_classifier:
container = svm_classifier(container)
elif cls.cur_model == MN.svm_regression:
container = svm_regression(container)
elif cls.cur_model == MN.knn_classifier:
container = knn_classifier(container)
elif cls.cur_model == MN.knn_regression:
container = knn_regression(container)
elif cls.cur_model == MN.naive_bayes_classification:
cls.naive_bayes_classifier_model_type = cls.get_naive_bayes_classifier_model_name_mapping()[naive_bayes_classifier_model_type]
container = naive_bayes_classification(container, cls.naive_bayes_classifier_model_type)
cls.container_dict[cls.cur_model] = container
@classmethod
def get_model_container_status(cls):
return True if cls.cur_model != "" and cls.container_dict[cls.cur_model].get_status() == "trained" else False
@classmethod
def get_model_label(cls):
return str(cls.get_model_name_mapping()[cls.cur_model]) + "模型是否完成训练" if cls.cur_model != "" else ""
@classmethod
def check_select_model(cls):
return True if cls.cur_model != "" and cls.check_before_train() else False
@classmethod
def get_model_name(cls):
return [x for x in cls.container_dict.keys()]
# [模型]
@classmethod
def get_model_chinese_name(cls):
return ["线性回归", "多项式回归", "逻辑斯谛分类", "决策树分类", "随机森林分类", "随机森林回归", "XGBoost分类", "LightGBM分类",
"梯度提升回归", "支持向量机分类", "支持向量机回归", "K-最近邻分类", "K-最近邻回归", "朴素贝叶斯分类"]
@classmethod
def get_model_name_mapping(cls):
return dict(zip(cls.get_model_name(), cls.get_model_chinese_name()))
@classmethod
def get_model_name_mapping_reverse(cls):
return dict(zip(cls.get_model_chinese_name(), cls.get_model_name()))
@classmethod
def get_trained_model_list(cls):
trained_model_list = []
for model_name, container in cls.container_dict.items():
if container.get_status() == "trained":
trained_model_list.append(cls.get_model_name_mapping()[model_name])
return trained_model_list
@classmethod
def draw_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
# [绘图]
if cls.visualize == MN.learning_curve:
return cls.draw_learning_curve_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.shap_beeswarm:
return cls.draw_shap_beeswarm_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.data_fit:
return cls.draw_data_fit_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.waterfall:
return cls.draw_waterfall_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.force:
return cls.draw_force_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.dependence:
return cls.draw_dependence_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.data_distribution:
return cls.draw_data_distribution_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.descriptive_indicators:
return cls.draw_descriptive_indicators_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
elif cls.visualize == MN.heatmap:
return cls.draw_heatmap_plot(select_model, color_list, label_list, name, x_label, y_label, is_default)
@classmethod
def draw_heatmap_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
color_cur_list = [] if is_default else color_list
x_cur_label = "Indicators" if is_default else x_label
y_cur_label = "Value" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
paint_object.set_color_cur_list(color_cur_list)
paint_object.set_x_cur_label(x_cur_label)
paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
if cls.check_col_list(select_model.get_heatmap_col()):
return cls.error_return_draw(paint_object)
df = Dataset.data
heatmap_col = select_model.get_heatmap_col()
covX = np.around(np.corrcoef(df[heatmap_col].T), decimals=3)
std_dev = np.sqrt(np.diag(covX))
pearson_matrix = covX / np.outer(std_dev, std_dev)
return draw_heat_map(pearson_matrix, heatmap_col, paint_object, select_model.get_heatmap_is_rotate())
@classmethod
def draw_descriptive_indicators_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
color_cur_list = [Config.COLORS[random.randint(0, 11)]]*3 if is_default else color_list
x_cur_label = "Indicators" if is_default else x_label
y_cur_label = "Value" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
paint_object.set_color_cur_list(color_cur_list)
paint_object.set_x_cur_label(x_cur_label)
paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
if cls.check_col_list(select_model.get_descriptive_indicators_col()):
return cls.error_return_draw(paint_object)
df = Dataset.data
descriptive_indicators_col = select_model.get_descriptive_indicators_col()
descriptive_indicators_df = pd.DataFrame(
index=list(descriptive_indicators_col),
columns=[
"Name",
"Min",
"Max",
"Avg",
"Standard Deviation",
"Standard Error",
"Upper Quartile",
"Median",
"Lower Quartile",
"Interquartile Distance",
"Kurtosis",
"Skewness",
"Coefficient of Variation"
]
)
for col in descriptive_indicators_col:
descriptive_indicators_df["Name"][col] = col
descriptive_indicators_df["Min"][col] = df[col].min()
descriptive_indicators_df["Max"][col] = df[col].max()
descriptive_indicators_df["Avg"][col] = df[col].mean()
descriptive_indicators_df["Standard Deviation"][col] = df[col].std()
descriptive_indicators_df["Standard Error"][col] = descriptive_indicators_df["Standard Deviation"][
col] / math.sqrt(len(df[col]))
descriptive_indicators_df["Upper Quartile"][col] = df[col].quantile(0.75)
descriptive_indicators_df["Median"][col] = df[col].quantile(0.5)
descriptive_indicators_df["Lower Quartile"][col] = df[col].quantile(0.25)
descriptive_indicators_df["Interquartile Distance"][col] = descriptive_indicators_df["Lower Quartile"][
col] - \
descriptive_indicators_df["Upper Quartile"][col]
descriptive_indicators_df["Kurtosis"][col] = df[col].kurt()
descriptive_indicators_df["Skewness"][col] = df[col].skew()
descriptive_indicators_df["Coefficient of Variation"][col] = \
descriptive_indicators_df["Standard Deviation"][col] / descriptive_indicators_df["Avg"][col]
cls.descriptive_indicators_df = descriptive_indicators_df
cur_df = df[descriptive_indicators_col].astype(float)
return draw_boxplot(cur_df, paint_object, select_model.get_descriptive_indicators_is_rotate())
@classmethod
def error_return_draw(cls, paint_object):
cur_plt = plt.Figure(figsize=(10, 8))
return cur_plt, paint_object
@classmethod
def draw_data_distribution_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
cur_col = select_model.get_data_distribution_col()
color_cur_list = [Config.COLORS[random.randint(0, 11)]] if is_default else color_list
x_cur_label = cur_col if is_default else x_label
y_cur_label = "Num" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
paint_object.set_color_cur_list(color_cur_list)
paint_object.set_x_cur_label(x_cur_label)
paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
if cls.check_col_list(select_model.get_data_distribution_col()):
return cls.error_return_draw(paint_object)
counts_mapping = {}
for x in Dataset.data.loc[:, cur_col].values:
if x in counts_mapping.keys():
counts_mapping[x] += 1
else:
counts_mapping[x] = 1
sorting = sorted(counts_mapping.items(), reverse=True, key=lambda m: m[1])
nums = [x[1] for x in sorting]
labels = [x[0] for x in sorting]
if Dataset.check_data_distribution_type(cur_col) == "histogram":
return draw_histogram(nums, labels, paint_object, select_model.get_data_distribution_is_rotate())
else:
return cls.error_return_draw(paint_object)
@classmethod
def draw_dependence_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
model_name = select_model.get_models()
model_name = cls.get_model_name_mapping_reverse()[model_name]
container = cls.container_dict[model_name]
# color_cur_list = Config.COLORS if is_default else color_list
# label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
# x_cur_label = "Train Sizes" if is_default else x_label
# y_cur_label = "Accuracy" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
# paint_object.set_color_cur_list(color_cur_list)
# paint_object.set_label_cur_list(label_cur_list)
# paint_object.set_x_cur_label(x_cur_label)
# paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
return draw_dependence(container.get_model(), container.x_train, cls.data.columns.values.tolist()[1:], select_model.get_dependence_col(), paint_object)
@classmethod
def draw_force_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
model_name = select_model.get_models()
model_name = cls.get_model_name_mapping_reverse()[model_name]
container = cls.container_dict[model_name]
# color_cur_list = Config.COLORS if is_default else color_list
# label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
# x_cur_label = "Train Sizes" if is_default else x_label
# y_cur_label = "Accuracy" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
# paint_object.set_color_cur_list(color_cur_list)
# paint_object.set_label_cur_list(label_cur_list)
# paint_object.set_x_cur_label(x_cur_label)
# paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
return draw_force(container.get_model(), container.x_train, cls.data.columns.values.tolist()[1:], select_model.get_force_number(), paint_object)
@classmethod
def draw_waterfall_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
model_name = select_model.get_models()
model_name = cls.get_model_name_mapping_reverse()[model_name]
container = cls.container_dict[model_name]
# color_cur_list = Config.COLORS if is_default else color_list
# label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
# x_cur_label = "Train Sizes" if is_default else x_label
# y_cur_label = "Accuracy" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
# paint_object.set_color_cur_list(color_cur_list)
# paint_object.set_label_cur_list(label_cur_list)
# paint_object.set_x_cur_label(x_cur_label)
# paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
return draw_waterfall(container.get_model(), container.x_train, cls.data.columns.values.tolist()[1:], select_model.get_waterfall_number(), paint_object)
@classmethod
def draw_learning_curve_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
cur_dict = {}
model_list = select_model.get_models()
for model_name in model_list:
model_name = cls.get_model_name_mapping_reverse()[model_name]
cur_dict[model_name] = cls.container_dict[model_name].get_learning_curve_values()
color_cur_list = Config.COLORS if is_default else color_list
if is_default:
label_cur_list = []
for x in cur_dict.keys():
label_cur_list.append("train " + str(x))
label_cur_list.append("validation " + str(x))
else:
label_cur_list = label_list
x_cur_label = "Train Sizes" if is_default else x_label
y_cur_label = "Accuracy" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
paint_object.set_color_cur_list(color_cur_list)
paint_object.set_label_cur_list(label_cur_list)
paint_object.set_x_cur_label(x_cur_label)
paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
if cls.check_cur_dict(cur_dict):
return cls.error_return_draw(paint_object)
return draw_learning_curve_total(cur_dict, paint_object)
@classmethod
def draw_shap_beeswarm_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
model_name = select_model.get_models()
model_name = cls.get_model_name_mapping_reverse()[model_name]
container = cls.container_dict[model_name]
# color_cur_list = Config.COLORS if is_default else color_list
# label_cur_list = [x for x in learning_curve_dict.keys()] if is_default else label_list
# x_cur_label = "Train Sizes" if is_default else x_label
# y_cur_label = "Accuracy" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
# paint_object.set_color_cur_list(color_cur_list)
# paint_object.set_label_cur_list(label_cur_list)
# paint_object.set_x_cur_label(x_cur_label)
# paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
return draw_shap_beeswarm(container.get_model(), container.x_train, cls.data.columns.values.tolist()[1:], select_model.get_beeswarm_plot_type(), paint_object)
@classmethod
def draw_data_fit_plot(cls, select_model, color_list: list, label_list: list, name: str, x_label: str, y_label: str, is_default: bool):
cur_dict = {}
model_list = select_model.get_models()
for model_name in model_list:
model_name = cls.get_model_name_mapping_reverse()[model_name]
cur_dict[model_name] = cls.container_dict[model_name].get_data_fit_values()
color_cur_list = Config.COLORS if is_default else color_list
if is_default:
label_cur_list = []
for x in cur_dict.keys():
label_cur_list.append("pred " + str(x))
label_cur_list.append("real data")
else:
label_cur_list = label_list
x_cur_label = "n value" if is_default else x_label
y_cur_label = "y value" if is_default else y_label
cur_name = "" if is_default else name
paint_object = PaintObject()
paint_object.set_color_cur_list(color_cur_list)
paint_object.set_label_cur_list(label_cur_list)
paint_object.set_x_cur_label(x_cur_label)
paint_object.set_y_cur_label(y_cur_label)
paint_object.set_name(cur_name)
return draw_data_fit_total(cur_dict, paint_object)
@classmethod
def get_shap_beeswarm_plot_type(cls):
return ["bar", "violin"]
@classmethod
def get_file(cls):
# [绘图]
if cls.visualize == MN.learning_curve:
return FilePath.png_base.format(FilePath.learning_curve_plot)
elif cls.visualize == MN.shap_beeswarm:
return FilePath.png_base.format(FilePath.shap_beeswarm_plot)
elif cls.visualize == MN.data_fit:
return FilePath.png_base.format(FilePath.data_fit_plot)
elif cls.visualize == MN.waterfall:
return FilePath.png_base.format(FilePath.waterfall_plot)
elif cls.visualize == MN.force:
return FilePath.png_base.format(FilePath.force_plot)
elif cls.visualize == MN.dependence:
return FilePath.png_base.format(FilePath.dependence_plot)
elif cls.visualize == MN.data_distribution:
return FilePath.png_base.format(FilePath.data_distribution_plot)
elif cls.visualize == MN.descriptive_indicators:
return FilePath.png_base.format(FilePath.descriptive_indicators_plot)
elif cls.visualize == MN.heatmap:
return FilePath.png_base.format(FilePath.heatmap_plot)
@classmethod
def check_file(cls):
return os.path.exists(cls.get_file())
@classmethod
def after_get_file(cls):
return cls.get_file() if cls.check_file() else None
@classmethod
def get_model_list(cls):
model_list = []
for model_name in cls.container_dict.keys():
model_list.append(cls.get_model_name_mapping()[model_name])
return model_list
@classmethod
def select_as_model(cls, model_name: str):
cls.cur_model = cls.get_model_name_mapping_reverse()[model_name]
@classmethod
def get_model_mark(cls):
return True if cls.cur_model != "" else False
@classmethod
def get_linear_regression_mark(cls):
return True if cls.cur_model == MN.linear_regression else False
@classmethod
def get_naive_bayes_classifier_mark(cls):
return True if cls.cur_model == MN.naive_bayes_classification else False
@classmethod
def get_assign_list(cls):
return ["分类", "回归"]
@classmethod
def get_assign_mapping_reverse(cls):
return dict(zip(cls.get_assign_list(), [MN.classification, MN.regression]))
@classmethod
def choose_assign(cls, assign: str):
cls.assign = cls.get_assign_mapping_reverse()[assign]
data_copy = cls.data
if cls.assign == MN.classification:
data_copy.iloc[:, 0] = data_copy.iloc[:, 0].astype(str)
else:
data_copy.iloc[:, 0] = data_copy.iloc[:, 0].astype(float)
cls.data = data_copy
cls.change_data_type_to_float()
@classmethod
def colorpickers_change(cls, paint_object):
cur_num = paint_object.get_color_cur_num()
true_list = [gr.ColorPicker(paint_object.get_color_cur_list()[i], visible=True, label=LN.colors[i]) for i in range(cur_num)]
return true_list + [gr.ColorPicker(visible=False)] * (StaticValue.max_num - cur_num)
@classmethod
def color_textboxs_change(cls, paint_object):
cur_num = paint_object.get_color_cur_num()
true_list = [gr.Textbox(paint_object.get_color_cur_list()[i], visible=True, show_label=False) for i in range(cur_num)]
return true_list + [gr.Textbox(visible=False)] * (StaticValue.max_num - cur_num)
@classmethod
def labels_change(cls, paint_object):
cur_num = paint_object.get_label_cur_num()
true_list = [gr.Textbox(paint_object.get_label_cur_list()[i], visible=True, label=LN.labels[i]) for i in range(cur_num)]
return true_list + [gr.Textbox(visible=False)] * (StaticValue.max_num - cur_num)
@classmethod
def get_model_train_metrics_dataframe(cls):
if cls.cur_model != "" and cls.get_model_container_status():
columns_list = ["指标", "数值"]
output_dict = cls.container_dict[cls.cur_model].get_info()["指标"]
output_df = pd.DataFrame(columns=columns_list)
output_df["指标"] = [x for x in output_dict.keys() if x in ChooseModelMetrics.choose(cls.cur_model)]
output_df["数值"] = [output_dict[x] for x in output_df["指标"]]
return output_df
@classmethod
def get_model_train_params_dataframe(cls):
if cls.cur_model != "" and cls.get_model_container_status():
columns_list = ["参数", "数值"]
output_dict = cls.container_dict[cls.cur_model].get_info()["参数"]
output_df = pd.DataFrame(columns=columns_list)
output_df["参数"] = [x for x in output_dict.keys() if x in ChooseModelParams.choose(cls.cur_model).keys()]
output_df["数值"] = [output_dict[x] for x in output_df["参数"]]
return output_df
@classmethod
def get_str_col_list(cls):
str_col_list = []
for col in cls.get_col_list():
if all(isinstance(x, str) for x in cls.data.loc[:, col]):
str_col_list.append(col)
return str_col_list
@classmethod
def get_float_col_list(cls):
float_col_list = []
for col in cls.get_col_list():
if all(isinstance(x, float) for x in cls.data.loc[:, col]):
float_col_list.append(col)
return float_col_list
@classmethod
def check_data_distribution_type(cls, col):
if all(isinstance(x, str) for x in cls.data.loc[:, col]):
return "histogram"
# elif all(isinstance(x, float) for x in cls.data.loc[:, col]):
# return "line_graph"
else:
gr.Warning("所选列的所有数据必须为字符型或浮点型")
@classmethod
def check_col_list(cls, col):
if not col:
gr.Warning("请选择所需列")
return True
return False
@classmethod
def check_train_model(cls, optimize):
if cls.cur_model == "":
gr.Warning("请选择所需训练的模型")
return True
if not optimize:
gr.Warning("请选择超参数优化方法")
return True
return False
@classmethod
def error_return_train(cls):
return get_return(True)
@classmethod
def check_train_model_other_related(cls, linear_regression_model_type, naive_bayes_classifier_model_type):
if cls.cur_model == MN.linear_regression:
if not linear_regression_model_type:
gr.Warning("请选择线性回归对应的模型")
return True
elif cls.cur_model == MN.naive_bayes_classification:
if not naive_bayes_classifier_model_type:
gr.Warning("请选择朴素贝叶斯对应的模型")
return True
return False
@classmethod
def check_cur_dict(cls, cur_dict):
if not cur_dict:
gr.Warning("请选择绘图所需的模型")
return True
return False
def choose_assign(assign: str):
Dataset.choose_assign(assign)
return get_return(True)
def select_as_model(model_name: str):
Dataset.select_as_model(model_name)
return get_return(True)
# [绘图]
def heatmap_first_draw_plot(*inputs):
Dataset.visualize = MN.heatmap
return before_train_first_draw_plot(inputs)
def descriptive_indicators_first_draw_plot(*inputs):
Dataset.visualize = MN.descriptive_indicators
return before_train_first_draw_plot(inputs)
def data_distribution_first_draw_plot(*inputs):
Dataset.visualize = MN.data_distribution
return before_train_first_draw_plot(inputs)
def dependence_first_draw_plot(*inputs):
Dataset.visualize = MN.dependence
return first_draw_plot(inputs)
def force_first_draw_plot(*inputs):
Dataset.visualize = MN.force
return first_draw_plot(inputs)
def waterfall_first_draw_plot(*inputs):
Dataset.visualize = MN.waterfall
return first_draw_plot(inputs)
def data_fit_first_draw_plot(*inputs):
Dataset.visualize = MN.data_fit
return first_draw_plot(inputs)
def shap_beeswarm_first_draw_plot(*inputs):
Dataset.visualize = MN.shap_beeswarm
return first_draw_plot(inputs)
def learning_curve_first_draw_plot(*inputs):
Dataset.visualize = MN.learning_curve
return first_draw_plot(inputs)
def before_train_first_draw_plot(inputs):
select_model = SelectModel()
x_label = ""
y_label = ""
name = ""
color_list = []
label_list = []
# [绘图][无训练模型]
if Dataset.visualize == MN.data_distribution:
select_model.set_data_distribution_col(inputs[0])
select_model.set_data_distribution_is_rotate(inputs[1])
elif Dataset.visualize == MN.descriptive_indicators:
select_model.set_descriptive_indicators_is_rotate(inputs[0])
select_model.set_descriptive_indicators_col(inputs[1])
elif Dataset.visualize == MN.heatmap:
select_model.set_heatmap_col(inputs[0])
select_model.set_heatmap_is_rotate(inputs[1])
cur_plt, paint_object = Dataset.draw_plot(select_model, color_list, label_list, name, x_label, y_label, True)
return first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object)
def first_draw_plot(inputs):
select_model = SelectModel()
select_model.set_models(inputs[0])
x_label = ""
y_label = ""
name = ""
color_list = []
label_list = []
# [绘图][有训练模型]
if Dataset.visualize == MN.shap_beeswarm:
select_model.set_beeswarm_plot_type(inputs[1])
elif Dataset.visualize == MN.waterfall:
select_model.set_waterfall_number(inputs[1])
elif Dataset.visualize == MN.force:
select_model.set_force_number(inputs[1])
elif Dataset.visualize == MN.dependence:
select_model.set_dependence_col(inputs[1])
cur_plt, paint_object = Dataset.draw_plot(select_model, color_list, label_list, name, x_label, y_label, True)
return first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object)
def out_non_first_draw_plot(*inputs):
return non_first_draw_plot(inputs)
def non_first_draw_plot(inputs):
name = inputs[0]
x_label = inputs[1]
y_label = inputs[2]
color_list = list(inputs[3: StaticValue.max_num+3])
label_list = list(inputs[StaticValue.max_num+3: 2*StaticValue.max_num+3])
start_index = 2*StaticValue.max_num+3
select_model = SelectModel()
# 绘图
if Dataset.visualize == MN.learning_curve:
select_model.set_models(inputs[start_index+0])
select_model.set_beeswarm_plot_type(inputs[start_index+1])
elif Dataset.visualize == MN.shap_beeswarm:
select_model.set_models(inputs[start_index+2])
elif Dataset.visualize == MN.data_fit:
select_model.set_models(inputs[start_index+3])
elif Dataset.visualize == MN.waterfall:
select_model.set_models(inputs[start_index+4])
select_model.set_waterfall_number(inputs[start_index+5])
elif Dataset.visualize == MN.force:
select_model.set_models(inputs[start_index+6])
select_model.set_force_number(inputs[start_index+7])
elif Dataset.visualize == MN.dependence:
select_model.set_models(inputs[start_index+8])
select_model.set_dependence_col(inputs[start_index+9])
elif Dataset.visualize == MN.data_distribution:
select_model.set_data_distribution_col(inputs[start_index+10])
select_model.set_data_distribution_is_rotate(inputs[start_index+11])
elif Dataset.visualize == MN.descriptive_indicators:
select_model.set_descriptive_indicators_is_rotate(inputs[start_index+12])
select_model.set_descriptive_indicators_col(inputs[start_index+13])
elif Dataset.visualize == MN.descriptive_indicators:
select_model.set_heatmap_col(inputs[start_index+14])
select_model.set_heatmap_is_rotate(inputs[start_index+15])
else:
select_model.set_models(inputs[start_index])
cur_plt, paint_object = Dataset.draw_plot(select_model, color_list, label_list, name, x_label, y_label, False)
return first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object)
def first_draw_plot_with_non_first_draw_plot(cur_plt, paint_object):
extra_gr_dict = {}
# [绘图]
if Dataset.visualize == MN.learning_curve:
cur_plt.savefig(FilePath.png_base.format(FilePath.learning_curve_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.learning_curve_plot)})
elif Dataset.visualize == MN.shap_beeswarm:
cur_plt.savefig(FilePath.png_base.format(FilePath.shap_beeswarm_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.shap_beeswarm_plot)})
elif Dataset.visualize == MN.data_fit:
cur_plt.savefig(FilePath.png_base.format(FilePath.data_fit_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.data_fit_plot)})
elif Dataset.visualize == MN.waterfall:
cur_plt.savefig(FilePath.png_base.format(FilePath.waterfall_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.waterfall_plot)})
elif Dataset.visualize == MN.force:
cur_plt.savefig(FilePath.png_base.format(FilePath.force_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.force_plot)})
elif Dataset.visualize == MN.dependence:
cur_plt.savefig(FilePath.png_base.format(FilePath.dependence_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.dependence_plot)})
elif Dataset.visualize == MN.data_distribution:
cur_plt.savefig(FilePath.png_base.format(FilePath.data_distribution_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.data_distribution_plot)})
elif Dataset.visualize == MN.descriptive_indicators:
cur_plt.savefig(FilePath.png_base.format(FilePath.descriptive_indicators_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.descriptive_indicators_plot)})
extra_gr_dict.update({descriptive_indicators_dataframe: gr.Dataframe(Dataset.get_descriptive_indicators_df(), type="pandas", visible=Dataset.check_descriptive_indicators_df())})
elif Dataset.visualize == MN.heatmap:
cur_plt.savefig(FilePath.png_base.format(FilePath.heatmap_plot), dpi=300)
extra_gr_dict.update({draw_plot: gr.Plot(cur_plt, visible=True, label=LN.heatmap_plot)})
extra_gr_dict.update(dict(zip(colorpickers, Dataset.colorpickers_change(paint_object))))
extra_gr_dict.update(dict(zip(color_textboxs, Dataset.color_textboxs_change(paint_object))))
extra_gr_dict.update(dict(zip(legend_labels_textboxs, Dataset.labels_change(paint_object))))
extra_gr_dict.update({title_name_textbox: gr.Textbox(paint_object.get_name(), visible=True, label=LN.title_name_textbox)})
extra_gr_dict.update({x_label_textbox: gr.Textbox(paint_object.get_x_cur_label(), visible=True, label=LN.x_label_textbox)})
extra_gr_dict.update({y_label_textbox: gr.Textbox(paint_object.get_y_cur_label(), visible=True, label=LN.y_label_textbox)})
return get_return_extra(True, extra_gr_dict)
# [模型]
def train_model(optimize, linear_regression_model_type, naive_bayes_classifier_model_type):
if Dataset.check_train_model(optimize):
return Dataset.error_return_train()
if Dataset.check_train_model_other_related(linear_regression_model_type, naive_bayes_classifier_model_type):
return Dataset.error_return_train()
Dataset.train_model(optimize, linear_regression_model_type, naive_bayes_classifier_model_type)
return get_return(True)
def select_as_y(col: str):
Dataset.select_as_y(col)
return get_return(True)
def standardize_data(col_list: list):
Dataset.standardize_data(col_list)
return get_return(True)
def change_data_type_to_float():
Dataset.change_data_type_to_float()
return get_return(True)
def encode_label(col_list: list):
Dataset.encode_label(col_list)
return get_return(True, {display_encode_label_dataframe: gr.Dataframe(Dataset.get_str2int_mappings_df(), type="pandas", visible=True, label=LN.display_encode_label_dataframe)})
def del_duplicate():
Dataset.del_duplicate()
return get_return(True)
def del_all_na_col():
Dataset.del_all_na_col()
return get_return(True)
def remain_row(num):
Dataset.remain_row(num)
return get_return(True)
def del_col(col_list: list):
Dataset.del_col(col_list)
return get_return(True)
def add_index_into_df(df: pd.DataFrame) -> pd.DataFrame:
if df.empty:
return df
index_df = pd.DataFrame([x for x in range(len(df))], columns=["[*index]"])
return pd.concat([index_df, df], axis=1)
def choose_dataset(file: str):
if file == "自定义":
Dataset.clear()
return get_return(False)
df = load_data(file)
Dataset.update(file, df)
return get_return(True, {choose_custom_dataset_file: gr.File(visible=False)})
def choose_custom_dataset(file: str):
df = load_custom_data(file)
Dataset.update(file, df)
return get_return(True, {choose_custom_dataset_file: gr.File(Dataset.file, visible=True)})
with gr.Blocks(js=Config.JS_0) as demo:
'''
组件
'''
with gr.Tab("机器学习"):
# 选择数据源
with gr.Accordion("数据源"):
with gr.Group():
choose_dataset_radio = gr.Radio(Dataset.get_dataset_list(), label=LN.choose_dataset_radio)
choose_custom_dataset_file = gr.File(visible=False)
# 显示数据表信息
with gr.Accordion("当前数据信息"):
display_dataset_dataframe = gr.Dataframe(visible=False)
display_dataset = gr.File(visible=False)
with gr.Row():
display_total_col_num_text = gr.Textbox(visible=False)
display_total_row_num_text = gr.Textbox(visible=False)
with gr.Column():
remain_row_slider = gr.Slider(visible=False)
remain_row_button = gr.Button(visible=False)
with gr.Row():
with gr.Column():
with gr.Row():
display_na_list_text = gr.Textbox(visible=False)
display_duplicate_num_text = gr.Textbox(visible=False)
with gr.Row():
del_all_na_col_button = gr.Button(visible=False)
del_duplicate_button = gr.Button(visible=False)
# 操作数据表
with gr.Accordion("数据处理"):
select_as_y_radio = gr.Radio(visible=False)
with gr.Row():
with gr.Column():
data_type_dataframe = gr.Dataframe(visible=False)
change_data_type_to_float_button = gr.Button(visible=False)
choose_assign_radio = gr.Radio(visible=False)
with gr.Column():
del_col_checkboxgroup = gr.Checkboxgroup(visible=False)
del_col_button = gr.Button(visible=False)
encode_label_checkboxgroup = gr.Checkboxgroup(visible=False)
encode_label_button = gr.Button(visible=False)
display_encode_label_dataframe = gr.Dataframe(visible=False)
standardize_data_checkboxgroup = gr.Checkboxgroup(visible=False)
standardize_data_button = gr.Button(visible=False)
# 数据模型
with gr.Accordion("数据模型"):
# [模型]
select_as_model_radio = gr.Radio(visible=False)
linear_regression_model_radio = gr.Radio(visible=False)
naive_bayes_classification_model_radio = gr.Radio(visible=False)
model_optimize_radio = gr.Radio(visible=False)
model_train_button = gr.Button(visible=False)
model_train_checkbox = gr.Checkbox(visible=False)
model_train_params_dataframe = gr.Dataframe(visible=False)
model_train_metrics_dataframe = gr.Dataframe(visible=False)
# 可视化
with gr.Accordion("数据可视化"):
with gr.Tab("数据分布图"):
data_distribution_radio = gr.Radio(visible=False)
data_distribution_is_rotate = gr.Checkbox(visible=False)
data_distribution_button = gr.Button(visible=False)
with gr.Tab("箱线统计图"):
descriptive_indicators_checkboxgroup = gr.Checkboxgroup(visible=False)
descriptive_indicators_is_rotate = gr.Checkbox(visible=False)
descriptive_indicators_button = gr.Button(visible=False)
descriptive_indicators_dataframe = gr.Dataframe(visible=False)
with gr.Tab("系数热力图"):
heatmap_checkboxgroup = gr.Checkboxgroup(visible=False)
heatmap_is_rotate = gr.Checkbox(visible=False)
heatmap_button = gr.Button(visible=False)
# with gr.Tab("主成分分析"):
# pca_button = gr.Button(visible=False)
# pca_replace_data_button = gr.Button(visible=False)
with gr.Tab("学习曲线图"):
learning_curve_checkboxgroup = gr.Checkboxgroup(visible=False)
learning_curve_button = gr.Button(visible=False)
with gr.Tab("数据拟合图"):
data_fit_checkboxgroup = gr.Checkboxgroup(visible=False)
data_fit_button = gr.Button(visible=False)
with gr.Tab("特征蜂群图"):
shap_beeswarm_radio = gr.Radio(visible=False)
shap_beeswarm_type = gr.Radio(visible=False)
shap_beeswarm_button = gr.Button(visible=False)
with gr.Tab("特征瀑布图"):
waterfall_radio = gr.Radio(visible=False)
waterfall_number = gr.Slider(visible=False)
waterfall_button = gr.Button(visible=False)
with gr.Tab("特征力图"):
force_radio = gr.Radio(visible=False)
force_number = gr.Slider(visible=False)
force_button = gr.Button(visible=False)
with gr.Tab("特征依赖图"):
dependence_radio = gr.Radio(visible=False)
dependence_col = gr.Radio(visible=False)
dependence_button = gr.Button(visible=False)
legend_labels_textboxs = []
with gr.Accordion("图例"):
with gr.Row():
for i in range(StaticValue.max_num):
with gr.Row():
label = gr.Textbox(visible=False)
legend_labels_textboxs.append(label)
with gr.Accordion("坐标轴"):
with gr.Row():
title_name_textbox = gr.Textbox(visible=False)
x_label_textbox = gr.Textbox(visible=False)
y_label_textbox = gr.Textbox(visible=False)
colorpickers = []
color_textboxs = []
with gr.Accordion("颜色"):
with gr.Row():
for i in range(StaticValue.max_num):
with gr.Row():
colorpicker = gr.ColorPicker(visible=False)
colorpickers.append(colorpicker)
color_textbox = gr.Textbox(visible=False)
color_textboxs.append(color_textbox)
draw_plot = gr.Plot(visible=False)
draw_file = gr.File(visible=False)
with gr.Tab("文字说明"):
notes = gr.Markdown(Dataset.get_notes(), visible=True)
'''
监听事件
'''
# 选择数据源
choose_dataset_radio.change(fn=choose_dataset, inputs=[choose_dataset_radio], outputs=get_outputs())
choose_custom_dataset_file.upload(fn=choose_custom_dataset, inputs=[choose_custom_dataset_file], outputs=get_outputs())
# 操作数据表
# 删除所选列
del_col_button.click(fn=del_col, inputs=[del_col_checkboxgroup], outputs=get_outputs())
# 保留行
remain_row_button.click(fn=remain_row, inputs=[remain_row_slider], outputs=get_outputs())
# 删除所有存在缺失值的列
del_all_na_col_button.click(fn=del_all_na_col, outputs=get_outputs())
# 删除所有重复的行
del_duplicate_button.click(fn=del_duplicate, outputs=get_outputs())
# 字符型列转数值型列
encode_label_button.click(fn=encode_label, inputs=[encode_label_checkboxgroup], outputs=get_outputs())
# 将所有数据强制转换为浮点型(除第1列之外)
change_data_type_to_float_button.click(fn=change_data_type_to_float, outputs=get_outputs())
# 标准化数据
standardize_data_button.click(fn=standardize_data, inputs=[standardize_data_checkboxgroup], outputs=get_outputs())
# 选择因变量
select_as_y_radio.change(fn=select_as_y, inputs=[select_as_y_radio], outputs=get_outputs())
# 选择任务类型(强制转换第1列)
choose_assign_radio.change(fn=choose_assign, inputs=[choose_assign_radio], outputs=get_outputs())
# 数据模型
select_as_model_radio.change(fn=select_as_model, inputs=[select_as_model_radio], outputs=get_outputs())
# [模型]
model_train_button.click(fn=train_model, inputs=[model_optimize_radio, linear_regression_model_radio, naive_bayes_classification_model_radio], outputs=get_outputs())
# [绘图]
# 可视化
data_distribution_button.click(fn=data_distribution_first_draw_plot, inputs=[data_distribution_radio] + [data_distribution_is_rotate], outputs=get_outputs())
descriptive_indicators_button.click(fn=descriptive_indicators_first_draw_plot, inputs=[descriptive_indicators_is_rotate] + [descriptive_indicators_checkboxgroup], outputs=get_outputs())
heatmap_button.click(fn=heatmap_first_draw_plot, inputs=[heatmap_checkboxgroup] + [heatmap_is_rotate], outputs=get_outputs())
learning_curve_button.click(fn=learning_curve_first_draw_plot, inputs=[learning_curve_checkboxgroup], outputs=get_outputs())
shap_beeswarm_button.click(fn=shap_beeswarm_first_draw_plot, inputs=[shap_beeswarm_radio] + [shap_beeswarm_type], outputs=get_outputs())
data_fit_button.click(fn=data_fit_first_draw_plot, inputs=[data_fit_checkboxgroup], outputs=get_outputs())
waterfall_button.click(fn=waterfall_first_draw_plot, inputs=[waterfall_radio] + [waterfall_number], outputs=get_outputs())
force_button.click(fn=force_first_draw_plot, inputs=[force_radio] + [force_number], outputs=get_outputs())
dependence_button.click(fn=dependence_first_draw_plot, inputs=[dependence_radio] + [dependence_col], outputs=get_outputs())
title_name_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio] + [shap_beeswarm_type] + [data_fit_checkboxgroup] + [waterfall_radio] + [waterfall_number]
+ [force_radio] + [force_number] + [dependence_radio] + [dependence_col] + [data_distribution_radio] + [data_distribution_is_rotate]
+ [descriptive_indicators_is_rotate] + [descriptive_indicators_checkboxgroup] + [heatmap_checkboxgroup] + [heatmap_is_rotate], outputs=get_outputs())
x_label_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio] + [shap_beeswarm_type] + [data_fit_checkboxgroup] + [waterfall_radio] + [waterfall_number]
+ [force_radio] + [force_number] + [dependence_radio] + [dependence_col] + [data_distribution_radio] + [data_distribution_is_rotate]
+ [descriptive_indicators_is_rotate] + [descriptive_indicators_checkboxgroup] + [heatmap_checkboxgroup] + [heatmap_is_rotate], outputs=get_outputs())
y_label_textbox.blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio] + [shap_beeswarm_type] + [data_fit_checkboxgroup] + [waterfall_radio] + [waterfall_number]
+ [force_radio] + [force_number] + [dependence_radio] + [dependence_col] + [data_distribution_radio] + [data_distribution_is_rotate]
+ [descriptive_indicators_is_rotate] + [descriptive_indicators_checkboxgroup] + [heatmap_checkboxgroup] + [heatmap_is_rotate], outputs=get_outputs())
for i in range(StaticValue.max_num):
colorpickers[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio] + [shap_beeswarm_type] + [data_fit_checkboxgroup] + [waterfall_radio] + [waterfall_number]
+ [force_radio] + [force_number] + [dependence_radio] + [dependence_col] + [data_distribution_radio] + [data_distribution_is_rotate]
+ [descriptive_indicators_is_rotate] + [descriptive_indicators_checkboxgroup] + [heatmap_checkboxgroup] + [heatmap_is_rotate], outputs=get_outputs())
color_textboxs[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + color_textboxs + legend_labels_textboxs
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio] + [shap_beeswarm_type] + [data_fit_checkboxgroup] + [waterfall_radio] + [waterfall_number]
+ [force_radio] + [force_number] + [dependence_radio] + [dependence_col] + [data_distribution_radio] + [data_distribution_is_rotate]
+ [descriptive_indicators_is_rotate] + [descriptive_indicators_checkboxgroup] + [heatmap_checkboxgroup] + [heatmap_is_rotate], outputs=get_outputs())
legend_labels_textboxs[i].blur(fn=out_non_first_draw_plot, inputs=[title_name_textbox] + [x_label_textbox] + [y_label_textbox] + colorpickers + legend_labels_textboxs
+ [learning_curve_checkboxgroup] + [shap_beeswarm_radio] + [shap_beeswarm_type] + [data_fit_checkboxgroup] + [waterfall_radio] + [waterfall_number]
+ [force_radio] + [force_number] + [dependence_radio] + [dependence_col] + [data_distribution_radio] + [data_distribution_is_rotate]
+ [descriptive_indicators_is_rotate] + [descriptive_indicators_checkboxgroup] + [heatmap_checkboxgroup] + [heatmap_is_rotate], outputs=get_outputs())
if __name__ == "__main__":
demo.launch()