Spaces:
Sleeping
Sleeping
import numpy as np | |
from matplotlib import pyplot as plt | |
from static.config import Config | |
def draw_learning_curve_total(input_dict, type): | |
plt.figure(figsize=(10, 6), dpi=300) | |
if type == "train": | |
i = 0 | |
for label_name, values in input_dict.items(): | |
train_sizes = values[0] | |
train_scores_mean = values[1] | |
train_scores_std = values[2] | |
test_scores_mean = values[3] | |
test_scores_std = values[4] | |
plt.fill_between( | |
train_sizes, | |
train_scores_mean - train_scores_std, | |
train_scores_mean + train_scores_std, | |
alpha=0.1, | |
color=Config.COLORS[i] | |
) | |
plt.plot( | |
train_sizes, | |
train_scores_mean, | |
"o-", | |
color=Config.COLORS[i], | |
label=label_name | |
) | |
i += 1 | |
title = "Training Learning curve" | |
# plt.title(title) | |
else: | |
i = 0 | |
for label_name, values in input_dict.items(): | |
train_sizes = values[0] | |
train_scores_mean = values[1] | |
train_scores_std = values[2] | |
test_scores_mean = values[3] | |
test_scores_std = values[4] | |
plt.fill_between( | |
train_sizes, | |
test_scores_mean - test_scores_std, | |
test_scores_mean + test_scores_std, | |
alpha=0.1, | |
color=Config.COLORS[i] | |
) | |
plt.plot( | |
train_sizes, | |
test_scores_mean, | |
"o-", | |
color=Config.COLORS[i], | |
label=label_name | |
) | |
i += 1 | |
title = "Cross-validation Learning curve" | |
# plt.title(title) | |
plt.xlabel("Sizes") | |
plt.ylabel("Adjusted R-square") | |
plt.legend() | |
# plt.savefig("./diagram/{}.png".format(title), dpi=300) | |
# plt.show() | |
return plt | |