-
Notifications
You must be signed in to change notification settings - Fork 1
/
regularizations.py
85 lines (69 loc) · 3.57 KB
/
regularizations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
import copy
import time
import torch
class EWC:
def __init__(self,params):
self.cfg = params
def run(self,action_agent, q_agent, logger, info):
logger = logger.get_logger(type(self).__name__+str("/"))
logger.message("Starting EWC procedure: computing Fisher matrix")
_training_start_time=time.time()
## We want to compute the fisher matrix with fisher_nb samples of fisher_batch element each times
action_agent.train()
action_agent = action_agent.to(self.cfg.device)
policy = action_agent[-1].model[min(len(action_agent[-1].model)-1,info["task_id"])]
output_dim = action_agent[-1].output_dimension
policy.zero_grad()
reg_weights = [copy.deepcopy(param.grad) for param in policy.parameters()]
batch_obs = info['replay_buffer'].get(self.cfg.n_samples).to(self.cfg.device)["env/env_obs"][0]
#We do it sample by sample
for obs in batch_obs:
#gathering mus grad
grads_mu = []
for i in range(output_dim // 2):
mu_i = policy(obs)[i]
mu_i.backward()
grads_mu.append([copy.deepcopy(param.grad) for param in policy.parameters()])
policy.zero_grad()
#gathering std grad
grads_std = []
stds = []
for i in range(output_dim // 2,output_dim):
std_i = policy(obs)[i]
std_i = torch.clip(std_i, min=-20., max=2.)
std_i = std_i.exp()
std_i.backward()
grads_std.append([copy.deepcopy(param.grad) for param in policy.parameters()])
stds.append(std_i)
policy.zero_grad()
#calculating fisher matrix
fisher = [copy.deepcopy(param.grad) for param in policy.parameters()]
for grad_mu, grad_std, std in zip(grads_mu,grads_std, stds): #for each output scalar
for i in range(len(fisher)): #for each policy parameter
fisher[i] += (grad_mu[i] ** 2 + 2 * grad_std[i] ** 2) / (std ** 2 + 1e-6) #closed form, see page 21 in https://arxiv.org/pdf/2105.10919.pdf
del grads_mu,grads_std,stds
#averaging over batch dimension
for i in range(len(reg_weights)):
fisher[i] = torch.clamp(fisher[i],min=1e-5) #clipping from below, see https://github.com/awarelab/continual_world/blob/main/continualworld/methods/ewc.py#L66
reg_weights[i] += fisher[i] / self.cfg.n_samples
#register new regluarisation weights for next task
action_agent[-1].register_and_consolidate(reg_weights)
r={"n_epochs":self.cfg.n_samples,"training_time":time.time() - _training_start_time,"n_interactions":0}
return r, action_agent.to('cpu'), q_agent.to('cpu'), info
class L2:
def __init__(self,params):
self.cfg = params
def run(self,action_agent, q_agent, logger, info):
logger = logger.get_logger(type(self).__name__+str("/"))
logger.message("Starting L2 regularization")
#register new regluarisation weights for next task
action_agent = action_agent.to(self.cfg.device)
action_agent[-1].register_and_consolidate()
r={"n_epochs":0,"training_time":0,"n_interactions":0}
return r, action_agent.to('cpu'), q_agent.to('cpu'), info