|
from matplotlib import pyplot as plt |
|
import json |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Plot loss for the training log") |
|
parser.add_argument("-t", "--trainlog", required=True, help="Training log file") |
|
parser.add_argument("-l", "--loss", help="Loss plot to save (optional)") |
|
parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale") |
|
parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure") |
|
args = parser.parse_args() |
|
|
|
train_inner_upd, train_inner_loss = [], [] |
|
train_upd, train_loss = [], [] |
|
val_upd, val_loss = [], [] |
|
|
|
with open(args.trainlog, "r") as tl: |
|
for line in tl: |
|
|
|
if line[0] != "{": |
|
continue |
|
try: |
|
data = json.loads(line.strip()) |
|
except: |
|
continue |
|
if "loss" in data: |
|
loss = float(data["loss"]) |
|
upd = int(data["num_updates"]) |
|
if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd: |
|
train_inner_upd.append(upd) |
|
train_inner_loss.append(loss) |
|
if "valid_loss" in data: |
|
loss = float(data["valid_loss"]) |
|
upd = int(data["valid_num_updates"]) |
|
if len(val_upd) == 0 or val_upd[-1] < upd: |
|
val_upd.append(upd) |
|
val_loss.append(loss) |
|
if "train_loss" in data: |
|
loss = float(data["train_loss"]) |
|
upd = int(data["train_num_updates"]) |
|
if len(train_upd) == 0 or train_upd[-1] < upd: |
|
train_upd.append(upd) |
|
train_loss.append(loss) |
|
|
|
plt.figure() |
|
plt.plot(train_upd, train_loss, "r") |
|
plt.plot(val_upd, val_loss, "b") |
|
if len(train_inner_upd) > 0: |
|
plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3) |
|
plt.legend(["train", "valid"]) |
|
if args.log_scale: |
|
plt.yscale("log") |
|
elif min(min(train_loss), min(val_loss)) < 1: |
|
plt.ylim((0, 1)) |
|
plt.xlabel("Updates") |
|
plt.ylabel("Loss") |
|
if args.loss: |
|
plt.savefig(args.loss) |
|
if args.no_plot: |
|
plt.show() |
|
|