File size: 2,302 Bytes
7145fd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
from matplotlib import pyplot as plt
import json
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Plot loss for the training log")
parser.add_argument("-t", "--trainlog", required=True, help="Training log file")
parser.add_argument("-l", "--loss", help="Loss plot to save (optional)")
parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale")
parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure")
args = parser.parse_args()
train_inner_upd, train_inner_loss = [], []
train_upd, train_loss = [], []
val_upd, val_loss = [], []
with open(args.trainlog, "r") as tl:
for line in tl:
# Filter out json
if line[0] != "{":
continue
try:
data = json.loads(line.strip())
except:
continue
if "loss" in data:
loss = float(data["loss"])
upd = int(data["num_updates"])
if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd:
train_inner_upd.append(upd)
train_inner_loss.append(loss)
if "valid_loss" in data:
loss = float(data["valid_loss"])
upd = int(data["valid_num_updates"])
if len(val_upd) == 0 or val_upd[-1] < upd:
val_upd.append(upd)
val_loss.append(loss)
if "train_loss" in data:
loss = float(data["train_loss"])
upd = int(data["train_num_updates"])
if len(train_upd) == 0 or train_upd[-1] < upd:
train_upd.append(upd)
train_loss.append(loss)
plt.figure()
plt.plot(train_upd, train_loss, "r")
plt.plot(val_upd, val_loss, "b")
if len(train_inner_upd) > 0:
plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3)
plt.legend(["train", "valid"])
if args.log_scale:
plt.yscale("log")
elif min(min(train_loss), min(val_loss)) < 1:
plt.ylim((0, 1))
plt.xlabel("Updates")
plt.ylabel("Loss")
if args.loss:
plt.savefig(args.loss)
if args.no_plot:
plt.show()
|