-
Notifications
You must be signed in to change notification settings - Fork 46
/
loss_recorder.py
82 lines (68 loc) · 2.57 KB
/
loss_recorder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
import os
class SingleLoss:
def __init__(self, name: str, writer: SummaryWriter, base=0):
self.name = name
self.loss_step = []
self.loss_epoch = []
self.loss_epoch_tmp = []
self.writer = writer
if base:
self.loss_epoch = [0] * base
self.loss_step = [0] * base * 10
def add_event(self, val, step=None, name='scalar'):
if step is None: step = len(self.loss_step)
if val is None:
val = 0
else:
callee = getattr(self.writer, 'add_' + name)
callee(self.name + '_step', val, step)
self.loss_step.append(val)
self.loss_epoch_tmp.append(val)
def epoch(self, step=None):
if step is None: step = len(self.loss_epoch)
loss_avg = sum(self.loss_epoch_tmp) / len(self.loss_epoch_tmp)
self.loss_epoch_tmp = []
self.loss_epoch.append(loss_avg)
self.writer.add_scalar('Train/epoch_' + self.name, loss_avg, step)
def save(self, path):
os.makedirs(path, exist_ok=True)
loss_step = np.array(self.loss_step)
loss_epoch = np.array(self.loss_epoch)
np.save(path + self.name + '_step.npy', loss_step)
np.save(path + self.name + '_epoch.npy', loss_epoch)
def last_epoch(self):
return self.loss_epoch[-1]
class LossRecorder:
def __init__(self, writer: SummaryWriter, base=0):
self.losses = {}
self.writer = writer
self.base = base
def add_scalar(self, name, val=None, step=None):
if isinstance(val, torch.Tensor): val = val.item()
if name not in self.losses:
self.losses[name] = SingleLoss(name, self.writer, self.base)
self.losses[name].add_event(val, step, 'scalar')
def add_figure(self, name, val, step=None):
if name not in self.losses:
self.losses[name] = SingleLoss(name, self.writer, self.base)
self.losses[name].add_event(val, step, 'figure')
def verbose(self):
lst = {}
for key in self.losses.keys():
lst[key] = self.losses[key].loss_step[-1]
lst = sorted(lst.items(), key=lambda x: x[0])
return str(lst)
def epoch(self, step=None):
for loss in self.losses.values():
loss.epoch(step)
def save(self, path):
for loss in self.losses.values():
loss.save(path)
def last_epoch(self):
res = []
for loss in self.losses.values():
res.append(loss.last_epoch())
return res