File size: 2,302 Bytes
7145fd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from matplotlib import pyplot as plt
import json

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Plot loss for the training log")
    parser.add_argument("-t", "--trainlog", required=True, help="Training log file")
    parser.add_argument("-l", "--loss", help="Loss plot to save (optional)")
    parser.add_argument("--log-scale", default=False, action="store_true", help="Log scale")
    parser.add_argument("-P", "--no-plot", default=True, action="store_false", help="Don't open matplotlib figure")
    args = parser.parse_args()

    train_inner_upd, train_inner_loss = [], []
    train_upd, train_loss = [], []
    val_upd, val_loss = [], []

    with open(args.trainlog, "r") as tl:
        for line in tl:
            # Filter out json
            if line[0] != "{":
                continue
            try:
                data = json.loads(line.strip())
            except:
                continue
            if "loss" in data:
                loss = float(data["loss"])
                upd = int(data["num_updates"])
                if len(train_inner_upd) == 0 or train_inner_upd[-1] < upd:
                    train_inner_upd.append(upd)
                    train_inner_loss.append(loss)
            if "valid_loss" in data:
                loss = float(data["valid_loss"])
                upd = int(data["valid_num_updates"])
                if len(val_upd) == 0 or val_upd[-1] < upd:
                    val_upd.append(upd)
                    val_loss.append(loss)
            if "train_loss" in data:
                loss = float(data["train_loss"])
                upd = int(data["train_num_updates"])
                if len(train_upd) == 0 or train_upd[-1] < upd:
                    train_upd.append(upd)
                    train_loss.append(loss)

    plt.figure()
    plt.plot(train_upd, train_loss, "r")
    plt.plot(val_upd, val_loss, "b")
    if len(train_inner_upd) > 0:
        plt.plot(train_inner_upd, train_inner_loss, "r", alpha=0.3)
    plt.legend(["train", "valid"])
    if args.log_scale:
        plt.yscale("log")
    elif min(min(train_loss), min(val_loss)) < 1:
        plt.ylim((0, 1))
    plt.xlabel("Updates")
    plt.ylabel("Loss")
    if args.loss:
        plt.savefig(args.loss)
    if args.no_plot:
        plt.show()