forked from GeorgeCazenavette/glad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
reparam_module.py
159 lines (136 loc) · 6.59 KB
/
reparam_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import torch
import torch.nn as nn
import warnings
import types
from collections import namedtuple
from contextlib import contextmanager
class ReparamModule(nn.Module):
def _get_module_from_name(self, mn):
if mn == '':
return self
m = self
for p in mn.split('.'):
m = getattr(m, p)
return m
def __init__(self, module):
super(ReparamModule, self).__init__()
self.module = module
param_infos = [] # (module name/path, param name)
shared_param_memo = {}
shared_param_infos = [] # (module name/path, param name, src module name/path, src param_name)
params = []
param_numels = []
param_shapes = []
for mn, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p is not None:
if p in shared_param_memo:
shared_mn, shared_n = shared_param_memo[p]
shared_param_infos.append((mn, n, shared_mn, shared_n))
else:
shared_param_memo[p] = (mn, n)
param_infos.append((mn, n))
params.append(p.detach())
param_numels.append(p.numel())
param_shapes.append(p.size())
assert len(set(p.dtype for p in params)) <= 1, \
"expects all parameters in module to have same dtype"
# store the info for unflatten
self._param_infos = tuple(param_infos)
self._shared_param_infos = tuple(shared_param_infos)
self._param_numels = tuple(param_numels)
self._param_shapes = tuple(param_shapes)
# flatten
flat_param = nn.Parameter(torch.cat([p.reshape(-1) for p in params], 0))
self.register_parameter('flat_param', flat_param)
self.param_numel = flat_param.numel()
del params
del shared_param_memo
# deregister the names as parameters
for mn, n in self._param_infos:
delattr(self._get_module_from_name(mn), n)
for mn, n, _, _ in self._shared_param_infos:
delattr(self._get_module_from_name(mn), n)
# register the views as plain attributes
self._unflatten_param(self.flat_param)
# now buffers
# they are not reparametrized. just store info as (module, name, buffer)
buffer_infos = []
for mn, m in self.named_modules():
for n, b in m.named_buffers(recurse=False):
if b is not None:
buffer_infos.append((mn, n, b))
self._buffer_infos = tuple(buffer_infos)
self._traced_self = None
def trace(self, example_input, **trace_kwargs):
assert self._traced_self is None, 'This ReparamModule is already traced'
if isinstance(example_input, torch.Tensor):
example_input = (example_input,)
example_input = tuple(example_input)
example_param = (self.flat_param.detach().clone(),)
example_buffers = (tuple(b.detach().clone() for _, _, b in self._buffer_infos),)
self._traced_self = torch.jit.trace_module(
self,
inputs=dict(
_forward_with_param=example_param + example_input,
_forward_with_param_and_buffers=example_param + example_buffers + example_input,
),
**trace_kwargs,
)
# replace forwards with traced versions
self._forward_with_param = self._traced_self._forward_with_param
self._forward_with_param_and_buffers = self._traced_self._forward_with_param_and_buffers
return self
def clear_views(self):
for mn, n in self._param_infos:
setattr(self._get_module_from_name(mn), n, None) # This will set as plain attr
def _apply(self, *args, **kwargs):
if self._traced_self is not None:
self._traced_self._apply(*args, **kwargs)
return self
return super(ReparamModule, self)._apply(*args, **kwargs)
def _unflatten_param(self, flat_param):
ps = (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))
for (mn, n), p in zip(self._param_infos, ps):
setattr(self._get_module_from_name(mn), n, p) # This will set as plain attr
for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
@contextmanager
def unflattened_param(self, flat_param):
saved_views = [getattr(self._get_module_from_name(mn), n) for mn, n in self._param_infos]
self._unflatten_param(flat_param)
yield
# Why not just `self._unflatten_param(self.flat_param)`?
# 1. because of https://github.com/pytorch/pytorch/issues/17583
# 2. slightly faster since it does not require reconstruct the split+view
# graph
for (mn, n), p in zip(self._param_infos, saved_views):
setattr(self._get_module_from_name(mn), n, p)
for (mn, n, shared_mn, shared_n) in self._shared_param_infos:
setattr(self._get_module_from_name(mn), n, getattr(self._get_module_from_name(shared_mn), shared_n))
@contextmanager
def replaced_buffers(self, buffers):
for (mn, n, _), new_b in zip(self._buffer_infos, buffers):
setattr(self._get_module_from_name(mn), n, new_b)
yield
for mn, n, old_b in self._buffer_infos:
setattr(self._get_module_from_name(mn), n, old_b)
def _forward_with_param_and_buffers(self, flat_param, buffers, *inputs, **kwinputs):
with self.unflattened_param(flat_param):
with self.replaced_buffers(buffers):
return self.module(*inputs, **kwinputs)
def _forward_with_param(self, flat_param, *inputs, **kwinputs):
with self.unflattened_param(flat_param):
return self.module(*inputs, **kwinputs)
def forward(self, *inputs, flat_param=None, buffers=None, **kwinputs):
flat_param = torch.squeeze(flat_param)
# print("PARAMS ON DEVICE: ", flat_param.get_device(), flat_param.shape)
# print("DATA ON DEVICE: ", inputs[0].get_device(), inputs[0].shape)
# flat_param.to("cuda:{}".format(inputs[0].get_device()))
# self.module.to("cuda:{}".format(inputs[0].get_device()))
if flat_param is None:
flat_param = self.flat_param
if buffers is None:
return self._forward_with_param(flat_param, *inputs, **kwinputs)
else:
return self._forward_with_param_and_buffers(flat_param, tuple(buffers), *inputs, **kwinputs)