-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
105 lines (85 loc) · 2.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
def sample_top(a=[], top_k=10):
idx = np.argsort(a)[::-1]
idx = idx[:top_k]
probs = a[idx]
probs = probs / np.sum(probs)
choice = np.random.choice(idx, p=probs)
return choice
# fajie
def sample_top_k(a=[], top_k=10):
idx = np.argsort(a)[::-1]
idx = idx[:top_k]
# probs = a[idx]
# probs = probs / np.sum(probs)
# choice = np.random.choice(idx, p=probs)
return idx
# print(sample_top_k(np.array([0.02,0.01,0.01,0.16,0.8]),3))
def to_var(x, requires_grad=False, volatile=False):
"""
Varialbe type that automatically choose cpu or cuda
"""
if torch.cuda.is_available():
x = x.cuda()
return Variable(x, requires_grad=requires_grad, volatile=volatile)
def prune_rate(model, verbose=True):
"""
Print out prune rate for each layer and the whole network
"""
total_nb_param = 0
nb_zero_param = 0
layer_id = 0
for parameter in model.parameters():
param_this_layer = 1
for dim in parameter.data.size():
param_this_layer *= dim
total_nb_param += param_this_layer
# only pruning linear and conv layers
if len(parameter.data.size()) != 1:
layer_id += 1
zero_param_this_layer = \
np.count_nonzero(parameter.cpu().data.numpy() == 0)
nb_zero_param += zero_param_this_layer
if verbose:
print("Layer {} | {} layer | {:.2f}% parameters pruned" \
.format(
layer_id,
'Conv' if len(parameter.data.size()) == 4 \
else 'Linear',
100. * zero_param_this_layer / param_this_layer,
))
pruning_perc = 100. * nb_zero_param / total_nb_param
if verbose:
print("Final pruning rate: {:.2f}%".format(pruning_perc))
return pruning_perc
def arg_nonzero_min(a):
"""
nonzero argmin of a non-negative array
"""
if not a:
return
min_ix, min_v = None, None
# find the starting value (should be nonzero)
for i, e in enumerate(a):
if e != 0:
min_ix = i
min_v = e
if not min_ix:
print('Warning: all zero')
return np.inf, np.inf
# search for the smallest nonzero
for i, e in enumerate(a):
if e < min_v and e != 0:
min_v = e
min_ix = i
return min_v, min_ix
def getOneHot(y): #[batch_size, class_num]
shape = y.size()
_, ind = y.max(dim=-1)
y_hard = torch.zeros_like(y).view(-1, shape[-1])
y_hard.scatter_(1, ind.view(-1, 1), 1)
y_hard = y_hard.view(*shape)
return y_hard