-
Notifications
You must be signed in to change notification settings - Fork 22
/
funcs.py
392 lines (313 loc) · 14.2 KB
/
funcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import torch
import math
'''
REDISTRIBUTION
'''
def momentum_redistribution(masking, name, weight, mask):
"""Calculates momentum redistribution statistics.
Args:
masking Masking class with state about current
layers and the entire sparse network.
name The name of the layer. This can be used to
access layer-specific statistics in the
masking class.
weight The weight of the respective sparse layer.
This is a torch parameter.
mask The binary mask. 1s indicated active weights.
Returns:
Layer Statistic The unnormalized layer statistics
for the layer "name". A higher value indicates
that more pruned parameters are redistributed
to this layer compared to layers with lower value.
The values will be automatically sum-normalized
after this step.
The calculation of redistribution statistics is the first
step in this sparse learning library.
"""
grad = masking.get_momentum_for_weight(weight)
mean_magnitude = torch.abs(grad[mask.bool()]).mean().item()
return mean_magnitude
def magnitude_redistribution(masking, name, weight, mask):
mean_magnitude = torch.abs(weight)[mask.bool()].mean().item()
return mean_magnitude
def nonzero_redistribution(masking, name, weight, mask):
nonzero = (weight !=0.0).sum().item()
return nonzero
def no_redistribution(masking, name, weight, mask):
num_params = masking.baseline_nonzero
n = weight.numel()
return n/float(num_params)
'''
PRUNE
'''
def magnitude_prune(masking, mask, weight, name):
"""Prunes the weights with smallest magnitude.
The pruning functions in this sparse learning library
work by constructing a binary mask variable "mask"
which prevents gradient flow to weights and also
sets the weights to zero where the binary mask is 0.
Thus 1s in the "mask" variable indicate where the sparse
network has active weights. In this function name
and masking can be used to access global statistics
about the specific layer (name) and the sparse network
as a whole.
Args:
masking Masking class with state about current
layers and the entire sparse network.
mask The binary mask. 1s indicated active weights.
weight The weight of the respective sparse layer.
This is a torch parameter.
name The name of the layer. This can be used to
access layer-specific statistics in the
masking class.
Returns:
mask Pruned Binary mask where 1s indicated active
weights. Can be modified in-place or newly
constructed
Accessable global statistics:
Layer statistics:
Non-zero count of layer:
masking.name2nonzeros[name]
Zero count of layer:
masking.name2zeros[name]
Redistribution proportion:
masking.name2variance[name]
Number of items removed through pruning:
masking.name2removed[name]
Network statistics:
Total number of nonzero parameter in the network:
masking.total_nonzero = 0
Total number of zero-valued parameter in the network:
masking.total_zero = 0
Total number of parameters removed in pruning:
masking.total_removed = 0
"""
num_remove = math.ceil(masking.prune_rate*masking.name2nonzeros[name])
num_zeros = masking.name2zeros[name]
k = math.ceil(num_zeros + num_remove)
if num_remove == 0.0: return weight.data != 0.0
x, idx = torch.sort(torch.abs(weight.data.view(-1)))
mask.data.view(-1)[idx[:k]] = 0.0
return mask
def global_magnitude_prune(masking):
prune_rate = 0.0
for name in masking.name2prune_rate:
if name in masking.masks:
prune_rate = masking.name2prune_rate[name]
tokill = math.ceil(prune_rate*masking.baseline_nonzero)
total_removed = 0
prev_removed = 0
while total_removed < tokill*(1.0-masking.tolerance) or (total_removed > tokill*(1.0+masking.tolerance)):
total_removed = 0
for module in masking.modules:
for name, weight in module.named_parameters():
if name not in masking.masks: continue
remain = (torch.abs(weight.data) > masking.prune_threshold).sum().item()
total_removed += masking.name2nonzeros[name] - remain
if prev_removed == total_removed: break
prev_removed = total_removed
if total_removed > tokill*(1.0+masking.tolerance):
masking.prune_threshold *= 1.0-masking.increment
masking.increment *= 0.99
elif total_removed < tokill*(1.0-masking.tolerance):
masking.prune_threshold *= 1.0+masking.increment
masking.increment *= 0.99
for module in masking.modules:
for name, weight in module.named_parameters():
if name not in masking.masks: continue
masking.masks[name][:] = torch.abs(weight.data) > masking.prune_threshold
return int(total_removed)
def magnitude_and_negativity_prune(masking, mask, weight, name):
num_remove = math.ceil(masking.name2prune_rate[name]*masking.name2nonzeros[name])
if num_remove == 0.0: return weight.data != 0.0
num_zeros = masking.name2zeros[name]
k = math.ceil(num_zeros + (num_remove/2.0))
# remove all weights which absolute value is smaller than threshold
x, idx = torch.sort(torch.abs(weight.data.view(-1)))
mask.data.view(-1)[idx[:k]] = 0.0
# remove the most negative weights
x, idx = torch.sort(weight.data.view(-1))
mask.data.view(-1)[idx[:math.ceil(num_remove/2.0)]] = 0.0
return mask
'''
GROWTH
'''
def random_growth(masking, name, new_mask, total_regrowth, weight):
n = (new_mask==0).sum().item()
if n == 0: return new_mask
expeced_growth_probability = (total_regrowth/n)
new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability
return new_mask.bool() | new_weights
def random_unfired_growth(masking, name, new_mask, total_regrowth, weight):
n = (new_mask == 0).sum().item()
if n == 0: return new_mask
num_nonfired_weights = (masking.fired_masks[name] == 0).sum().item()
if total_regrowth <= num_nonfired_weights:
idx = (masking.fired_masks[name].flatten() == 0).nonzero()
indices = torch.randperm(len(idx))[:total_regrowth]
# idx = torch.nonzero(self.fired_masks[name].flatten())
new_mask.data.view(-1)[idx[indices]] = 1.0
else:
new_mask[masking.fired_masks[name] == 0] = 1.0
n = (new_mask == 0).sum().item()
expeced_growth_probability = ((total_regrowth - num_nonfired_weights) / n)
new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability
new_mask = new_mask.byte() | new_weights
return new_mask
def gradient_growth(masking, name, new_mask, total_regrowth, weight):
grad = masking.get_gradient_for_weights(weight)
if grad.dtype == torch.float16:
grad = grad * (new_mask == 0).half()
else:
grad = grad * (new_mask == 0).float()
y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0
return new_mask
def mix_growth(masking, name, new_mask, total_regrowth, weight):
gradient_grow = int(total_regrowth * masking.mix)
random_grow = total_regrowth - gradient_grow
grad = masking.get_gradient_for_weights(weight)
if grad.dtype == torch.float16:
grad = grad * (new_mask == 0).half()
else:
grad = grad * (new_mask == 0).float()
y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
new_mask.data.view(-1)[idx[:gradient_grow]] = 1.0
n = (new_mask == 0).sum().item()
expeced_growth_probability = (random_grow / n)
new_weights = torch.rand(new_mask.shape).cuda() < expeced_growth_probability
new_mask = new_mask.bool() | new_weights
return new_mask
def momentum_growth(masking, name, new_mask, total_regrowth, weight):
"""Grows weights in places where the momentum is largest.
Growth function in the sparse learning library work by
changing 0s to 1s in a binary mask which will enable
gradient flow. Weights default value are 0 and it can
be changed in this function. The number of parameters
to be regrown is determined by the total_regrowth
parameter. The masking object in conjunction with the name
of the layer enables the access to further statistics
and objects that allow more flexibility to implement
custom growth functions.
Args:
masking Masking class with state about current
layers and the entire sparse network.
name The name of the layer. This can be used to
access layer-specific statistics in the
masking class.
new_mask The binary mask. 1s indicated active weights.
This binary mask has already been pruned in the
pruning step that preceeds the growth step.
total_regrowth This variable determines the number of
parameters to regrowtn in this function.
It is automatically determined by the
redistribution function and algorithms
internal to the sparselearning library.
weight The weight of the respective sparse layer.
This is a torch parameter.
Returns:
mask Binary mask with newly grown weights.
1s indicated active weights in the binary mask.
Access to optimizer:
masking.optimizer
Access to momentum/Adam update:
masking.get_momentum_for_weight(weight)
Accessable global statistics:
Layer statistics:
Non-zero count of layer:
masking.name2nonzeros[name]
Zero count of layer:
masking.name2zeros[name]
Redistribution proportion:
masking.name2variance[name]
Number of items removed through pruning:
masking.name2removed[name]
Network statistics:
Total number of nonzero parameter in the network:
masking.total_nonzero = 0
Total number of zero-valued parameter in the network:
masking.total_zero = 0
Total number of parameters removed in pruning:
masking.total_removed = 0
"""
grad = masking.get_momentum_for_weight(weight)
if grad.dtype == torch.float16:
grad = grad*(new_mask==0).half()
else:
grad = grad*(new_mask==0).float()
y, idx = torch.sort(torch.abs(grad).flatten(), descending=True)
new_mask.data.view(-1)[idx[:total_regrowth]] = 1.0
return new_mask
def momentum_neuron_growth(masking, name, new_mask, total_regrowth, weight):
grad = masking.get_momentum_for_weight(weight)
M = torch.abs(grad)
if len(M.shape) == 2: sum_dim = [1]
elif len(M.shape) == 4: sum_dim = [1, 2, 3]
v = M.mean(sum_dim).data
v /= v.sum()
slots_per_neuron = (new_mask==0).sum(sum_dim)
M = M*(new_mask==0).float()
for i, fraction in enumerate(v):
neuron_regrowth = math.floor(fraction.item()*total_regrowth)
available = slots_per_neuron[i].item()
y, idx = torch.sort(M[i].flatten())
if neuron_regrowth > available:
neuron_regrowth = available
# TODO: Work into more stable growth method
threshold = y[-(neuron_regrowth)].item()
if threshold == 0.0: continue
if neuron_regrowth < 10: continue
new_mask[i] = new_mask[i] | (M[i] > threshold)
return new_mask
def global_momentum_growth(masking, total_regrowth):
togrow = total_regrowth
total_grown = 0
last_grown = 0
while total_grown < togrow*(1.0-masking.tolerance) or (total_grown > togrow*(1.0+masking.tolerance)):
total_grown = 0
total_possible = 0
for module in masking.modules:
for name, weight in module.named_parameters():
if name not in masking.masks: continue
new_mask = masking.masks[name]
grad = masking.get_momentum_for_weight(weight)
grad = grad*(new_mask==0).float()
possible = (grad !=0.0).sum().item()
total_possible += possible
grown = (torch.abs(grad.data) > masking.growth_threshold).sum().item()
total_grown += grown
if total_grown == last_grown: break
last_grown = total_grown
if total_grown > togrow*(1.0+masking.tolerance):
masking.growth_threshold *= 1.02
#masking.growth_increment *= 0.95
elif total_grown < togrow*(1.0-masking.tolerance):
masking.growth_threshold *= 0.98
#masking.growth_increment *= 0.95
total_new_nonzeros = 0
for module in masking.modules:
for name, weight in module.named_parameters():
if name not in masking.masks: continue
new_mask = masking.masks[name]
grad = masking.get_momentum_for_weight(weight)
grad = grad*(new_mask==0).float()
masking.masks[name][:] = (new_mask.bool() | (torch.abs(grad.data) > masking.growth_threshold)).float()
total_new_nonzeros += new_mask.sum().item()
return total_new_nonzeros
prune_funcs = {}
prune_funcs['magnitude'] = magnitude_prune
prune_funcs['SET'] = magnitude_and_negativity_prune
prune_funcs['global_magnitude'] = global_magnitude_prune
growth_funcs = {}
growth_funcs['random'] = random_growth
growth_funcs['random_unfired'] = random_unfired_growth
growth_funcs['momentum'] = momentum_growth
growth_funcs['gradient'] = gradient_growth
growth_funcs['mix'] = mix_growth
growth_funcs['momentum_neuron'] = momentum_neuron_growth
growth_funcs['global_momentum_growth'] = global_momentum_growth
redistribution_funcs = {}
redistribution_funcs['momentum'] = momentum_redistribution
redistribution_funcs['nonzero'] = nonzero_redistribution
redistribution_funcs['magnitude'] = magnitude_redistribution
redistribution_funcs['none'] = no_redistribution