-
Notifications
You must be signed in to change notification settings - Fork 32
/
autograd_hacks.py
285 lines (204 loc) · 9.51 KB
/
autograd_hacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
"""
Library for extracting interesting quantites from autograd, see README.md
Not thread-safe because of module-level variables
Notation:
o: number of output classes (exact Hessian), number of Hessian samples (sampled Hessian)
n: batch-size
do: output dimension (output channels for convolution)
di: input dimension (input channels for convolution)
Hi: per-example Hessian of matmul, shaped as matrix of [dim, dim], indices have been row-vectorized
Hi_bias: per-example Hessian of bias
Oh, Ow: output height, output width (convolution)
Kh, Kw: kernel height, kernel width (convolution)
Jb: batch output Jacobian of matmul, output sensitivity for example,class pair, [o, n, ....]
Jb_bias: as above, but for bias
A, activations: inputs into current layer
B, backprops: backprop values (aka Lop aka Jacobian-vector product) observed at current layer
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
_supported_layers = ['Linear', 'Conv2d'] # Supported layer class types
_hooks_disabled: bool = False # work-around for https://github.com/pytorch/pytorch/issues/25723
_enforce_fresh_backprop: bool = False # global switch to catch double backprop errors on Hessian computation
def add_hooks(model: nn.Module) -> None:
"""
Adds hooks to model to save activations and backprop values.
The hooks will
1. save activations into param.activations during forward pass
2. append backprops to params.backprops_list during backward pass.
Call "remove_hooks(model)" to disable this.
Args:
model:
"""
global _hooks_disabled
_hooks_disabled = False
handles = []
for layer in model.modules():
if _layer_type(layer) in _supported_layers:
handles.append(layer.register_forward_hook(_capture_activations))
handles.append(layer.register_backward_hook(_capture_backprops))
model.__dict__.setdefault('autograd_hacks_hooks', []).extend(handles)
def remove_hooks(model: nn.Module) -> None:
"""
Remove hooks added by add_hooks(model)
"""
assert model == 0, "not working, remove this after fix to https://github.com/pytorch/pytorch/issues/25723"
if not hasattr(model, 'autograd_hacks_hooks'):
print("Warning, asked to remove hooks, but no hooks found")
else:
for handle in model.autograd_hacks_hooks:
handle.remove()
del model.autograd_hacks_hooks
def disable_hooks() -> None:
"""
Globally disable all hooks installed by this library.
"""
global _hooks_disabled
_hooks_disabled = True
def enable_hooks() -> None:
"""the opposite of disable_hooks()"""
global _hooks_disabled
_hooks_disabled = False
def is_supported(layer: nn.Module) -> bool:
"""Check if this layer is supported"""
return _layer_type(layer) in _supported_layers
def _layer_type(layer: nn.Module) -> str:
return layer.__class__.__name__
def _capture_activations(layer: nn.Module, input: List[torch.Tensor], output: torch.Tensor):
"""Save activations into layer.activations in forward pass"""
if _hooks_disabled:
return
assert _layer_type(layer) in _supported_layers, "Hook installed on unsupported layer, this shouldn't happen"
setattr(layer, "activations", input[0].detach())
def _capture_backprops(layer: nn.Module, _input, output):
"""Append backprop to layer.backprops_list in backward pass."""
global _enforce_fresh_backprop
if _hooks_disabled:
return
if _enforce_fresh_backprop:
assert not hasattr(layer, 'backprops_list'), "Seeing result of previous backprop, use clear_backprops(model) to clear"
_enforce_fresh_backprop = False
if not hasattr(layer, 'backprops_list'):
setattr(layer, 'backprops_list', [])
layer.backprops_list.append(output[0].detach())
def clear_backprops(model: nn.Module) -> None:
"""Delete layer.backprops_list in every layer."""
for layer in model.modules():
if hasattr(layer, 'backprops_list'):
del layer.backprops_list
def compute_grad1(model: nn.Module, loss_type: str = 'mean') -> None:
"""
Compute per-example gradients and save them under 'param.grad1'. Must be called after loss.backprop()
Args:
model:
loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch
"""
assert loss_type in ('sum', 'mean')
for layer in model.modules():
layer_type = _layer_type(layer)
if layer_type not in _supported_layers:
continue
assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"
assert len(layer.backprops_list) == 1, "Multiple backprops detected, make sure to call clear_backprops(model)"
A = layer.activations
n = A.shape[0]
if loss_type == 'mean':
B = layer.backprops_list[0] * n
else: # loss_type == 'sum':
B = layer.backprops_list[0]
if layer_type == 'Linear':
setattr(layer.weight, 'grad1', torch.einsum('ni,nj->nij', B, A))
if layer.bias is not None:
setattr(layer.bias, 'grad1', B)
elif layer_type == 'Conv2d':
A = torch.nn.functional.unfold(A, layer.kernel_size)
B = B.reshape(n, -1, A.shape[-1])
grad1 = torch.einsum('ijk,ilk->ijl', B, A)
shape = [n] + list(layer.weight.shape)
setattr(layer.weight, 'grad1', grad1.reshape(shape))
if layer.bias is not None:
setattr(layer.bias, 'grad1', torch.sum(B, dim=2))
def compute_hess(model: nn.Module,) -> None:
"""Save Hessian under param.hess for each param in the model"""
for layer in model.modules():
layer_type = _layer_type(layer)
if layer_type not in _supported_layers:
continue
assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"
if layer_type == 'Linear':
A = layer.activations
B = torch.stack(layer.backprops_list)
n = A.shape[0]
o = B.shape[0]
A = torch.stack([A] * o)
Jb = torch.einsum("oni,onj->onij", B, A).reshape(n*o, -1)
H = torch.einsum('ni,nj->ij', Jb, Jb) / n
setattr(layer.weight, 'hess', H)
if layer.bias is not None:
setattr(layer.bias, 'hess', torch.einsum('oni,onj->ij', B, B)/n)
elif layer_type == 'Conv2d':
Kh, Kw = layer.kernel_size
di, do = layer.in_channels, layer.out_channels
A = layer.activations.detach()
A = torch.nn.functional.unfold(A, (Kh, Kw)) # n, di * Kh * Kw, Oh * Ow
n = A.shape[0]
B = torch.stack([Bt.reshape(n, do, -1) for Bt in layer.backprops_list]) # o, n, do, Oh*Ow
o = B.shape[0]
A = torch.stack([A] * o) # o, n, di * Kh * Kw, Oh*Ow
Jb = torch.einsum('onij,onkj->onik', B, A) # o, n, do, di * Kh * Kw
Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb) # n, do, di*Kh*Kw, do, di*Kh*Kw
Jb_bias = torch.einsum('onij->oni', B)
Hi_bias = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias)
setattr(layer.weight, 'hess', Hi.mean(dim=0))
if layer.bias is not None:
setattr(layer.bias, 'hess', Hi_bias.mean(dim=0))
def backprop_hess(output: torch.Tensor, hess_type: str) -> None:
"""
Call backprop 1 or more times to get values needed for Hessian computation.
Args:
output: prediction of neural network (ie, input of nn.CrossEntropyLoss())
hess_type: type of Hessian propagation, "CrossEntropy" results in exact Hessian for CrossEntropy
Returns:
"""
assert hess_type in ('LeastSquares', 'CrossEntropy')
global _enforce_fresh_backprop
n, o = output.shape
_enforce_fresh_backprop = True
if hess_type == 'CrossEntropy':
batch = F.softmax(output, dim=1)
mask = torch.eye(o).expand(n, o, o)
diag_part = batch.unsqueeze(2).expand(n, o, o) * mask
outer_prod_part = torch.einsum('ij,ik->ijk', batch, batch)
hess = diag_part - outer_prod_part
assert hess.shape == (n, o, o)
for i in range(n):
hess[i, :, :] = symsqrt(hess[i, :, :])
hess = hess.transpose(0, 1)
elif hess_type == 'LeastSquares':
hess = []
assert len(output.shape) == 2
batch_size, output_size = output.shape
id_mat = torch.eye(output_size)
for out_idx in range(output_size):
hess.append(torch.stack([id_mat[out_idx]] * batch_size))
for o in range(o):
output.backward(hess[o], retain_graph=True)
def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32):
"""Symmetric square root of a positive semi-definite matrix.
See https://github.com/pytorch/pytorch/issues/25481"""
s, u = torch.symeig(a, eigenvectors=True)
cond_dict = {torch.float32: 1e3 * 1.1920929e-07, torch.float64: 1E6 * 2.220446049250313e-16}
if cond in [None, -1]:
cond = cond_dict[dtype]
above_cutoff = (abs(s) > cond * torch.max(abs(s)))
psigma_diag = torch.sqrt(s[above_cutoff])
u = u[:, above_cutoff]
B = u @ torch.diag(psigma_diag) @ u.t()
if return_rank:
return B, len(psigma_diag)
else:
return B