from matplotlib import pyplot as plt import json if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Plot loss for the training log") parser.add_argument("-t", "--trainlog", required=True, help="Training log file") parser.add_argument("-l", "--loss", help="Loss plot to save (optional)") parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale") parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure") args = parser.parse_args() train_inner_upd, train_inner_loss = [], [] train_upd, train_loss = [], [] val_upd, val_loss = [], [] with open(args.trainlog, "r") as tl: for line in tl: # Filter out json if line[0] != "{": continue try: data = json.loads(line.strip()) except: continue if "loss" in data: loss = float(data["loss"]) upd = int(data["num_updates"]) if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd: train_inner_upd.append(upd) train_inner_loss.append(loss) if "valid_loss" in data: loss = float(data["valid_loss"]) upd = int(data["valid_num_updates"]) if len(val_upd) == 0 or val_upd[-1] < upd: val_upd.append(upd) val_loss.append(loss) if "train_loss" in data: loss = float(data["train_loss"]) upd = int(data["train_num_updates"]) if len(train_upd) == 0 or train_upd[-1] < upd: train_upd.append(upd) train_loss.append(loss) plt.figure() plt.plot(train_upd, train_loss, "r") plt.plot(val_upd, val_loss, "b") if len(train_inner_upd) > 0: plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3) plt.legend(["train", "valid"]) if args.log_scale: plt.yscale("log") elif min(min(train_loss), min(val_loss)) < 1: plt.ylim((0, 1)) plt.xlabel("Updates") plt.ylabel("Loss") if args.loss: plt.savefig(args.loss) if args.no_plot: plt.show()