Spaces:
Sleeping
Sleeping
File size: 1,997 Bytes
bd39f54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import numpy as np
from matplotlib import pyplot as plt
from static.config import Config
def draw_learning_curve_total(input_dict, type):
plt.figure(figsize=(10, 6), dpi=300)
if type == "train":
i = 0
for label_name, values in input_dict.items():
train_sizes = values[0]
train_scores_mean = values[1]
train_scores_std = values[2]
test_scores_mean = values[3]
test_scores_std = values[4]
plt.fill_between(
train_sizes,
train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std,
alpha=0.1,
color=Config.COLORS[i]
)
plt.plot(
train_sizes,
train_scores_mean,
"o-",
color=Config.COLORS[i],
label=label_name
)
i += 1
title = "Training Learning curve"
# plt.title(title)
else:
i = 0
for label_name, values in input_dict.items():
train_sizes = values[0]
train_scores_mean = values[1]
train_scores_std = values[2]
test_scores_mean = values[3]
test_scores_std = values[4]
plt.fill_between(
train_sizes,
test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std,
alpha=0.1,
color=Config.COLORS[i]
)
plt.plot(
train_sizes,
test_scores_mean,
"o-",
color=Config.COLORS[i],
label=label_name
)
i += 1
title = "Cross-validation Learning curve"
# plt.title(title)
plt.xlabel("Sizes")
plt.ylabel("Adjusted R-square")
plt.legend()
# plt.savefig("./diagram/{}.png".format(title), dpi=300)
# plt.show()
return plt
|