REMEND / remend /plot_loss.py
udiboy1209's picture
Add REMEND python module
7145fd6
from matplotlib import pyplot as plt
import json
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Plot loss for the training log")
parser.add_argument("-t", "--trainlog", required=True, help="Training log file")
parser.add_argument("-l", "--loss", help="Loss plot to save (optional)")
parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale")
parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure")
args = parser.parse_args()
train_inner_upd, train_inner_loss = [], []
train_upd, train_loss = [], []
val_upd, val_loss = [], []
with open(args.trainlog, "r") as tl:
for line in tl:
# Filter out json
if line[0] != "{":
continue
try:
data = json.loads(line.strip())
except:
continue
if "loss" in data:
loss = float(data["loss"])
upd = int(data["num_updates"])
if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd:
train_inner_upd.append(upd)
train_inner_loss.append(loss)
if "valid_loss" in data:
loss = float(data["valid_loss"])
upd = int(data["valid_num_updates"])
if len(val_upd) == 0 or val_upd[-1] < upd:
val_upd.append(upd)
val_loss.append(loss)
if "train_loss" in data:
loss = float(data["train_loss"])
upd = int(data["train_num_updates"])
if len(train_upd) == 0 or train_upd[-1] < upd:
train_upd.append(upd)
train_loss.append(loss)
plt.figure()
plt.plot(train_upd, train_loss, "r")
plt.plot(val_upd, val_loss, "b")
if len(train_inner_upd) > 0:
plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3)
plt.legend(["train", "valid"])
if args.log_scale:
plt.yscale("log")
elif min(min(train_loss), min(val_loss)) < 1:
plt.ylim((0, 1))
plt.xlabel("Updates")
plt.ylabel("Loss")
if args.loss:
plt.savefig(args.loss)
if args.no_plot:
plt.show()