-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathroot_finding.py
277 lines (204 loc) · 8.25 KB
/
root_finding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""
Bisection implementation of alpha-entmax (Peters et al., 2019).
Backward pass wrt alpha per (Correia et al., 2019). See
https://arxiv.org/pdf/1905.05702 for detailed description.
"""
# Author: Goncalo M Correia
# Author: Ben Peters
# Author: Vlad Niculae <[email protected]>
import torch
import torch.nn as nn
from torch.autograd import Function
class EntmaxBisectFunction(Function):
@classmethod
def _gp(cls, x, alpha):
return x ** (alpha - 1)
@classmethod
def _gp_inv(cls, y, alpha):
return y ** (1 / (alpha - 1))
@classmethod
def _p(cls, X, alpha):
return cls._gp_inv(torch.clamp(X, min=0), alpha)
@classmethod
def forward(cls, ctx, X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True):
if not isinstance(alpha, torch.Tensor):
alpha = torch.tensor(alpha, dtype=X.dtype, device=X.device)
alpha_shape = list(X.shape)
alpha_shape[dim] = 1
alpha = alpha.expand(*alpha_shape)
ctx.alpha = alpha
ctx.dim = dim
d = X.shape[dim]
max_val, _ = X.max(dim=dim, keepdim=True)
X = X * (alpha - 1)
max_val = max_val * (alpha - 1)
# Note: when alpha < 1, tau_lo > tau_hi. This still works since dm < 0.
tau_lo = max_val - cls._gp(1, alpha)
tau_hi = max_val - cls._gp(1 / d, alpha)
f_lo = cls._p(X - tau_lo, alpha).sum(dim) - 1
dm = tau_hi - tau_lo
for it in range(n_iter):
dm /= 2
tau_m = tau_lo + dm
p_m = cls._p(X - tau_m, alpha)
f_m = p_m.sum(dim) - 1
mask = (f_m * f_lo >= 0).unsqueeze(dim)
tau_lo = torch.where(mask, tau_m, tau_lo)
if ensure_sum_one:
p_m /= p_m.sum(dim=dim).unsqueeze(dim=dim)
ctx.save_for_backward(p_m)
return p_m
@classmethod
def backward(cls, ctx, dY):
(Y,) = ctx.saved_tensors
gppr = torch.where(Y > 0, Y ** (2 - ctx.alpha), Y.new_zeros(1))
dX = dY * gppr
q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
q = q.unsqueeze(ctx.dim)
dX -= q * gppr
d_alpha = None
if ctx.needs_input_grad[1]:
# alpha gradient computation
# d_alpha = (partial_y / partial_alpha) * dY
# NOTE: ensure alpha is not close to 1
# since there is an indetermination
# batch_size, _ = dY.shape
# shannon terms
S = torch.where(Y > 0, Y * torch.log(Y), Y.new_zeros(1))
# shannon entropy
ent = S.sum(ctx.dim).unsqueeze(ctx.dim)
Y_skewed = gppr / gppr.sum(ctx.dim).unsqueeze(ctx.dim)
d_alpha = dY * (Y - Y_skewed) / ((ctx.alpha - 1) ** 2)
d_alpha -= dY * (S - Y_skewed * ent) / (ctx.alpha - 1)
d_alpha = d_alpha.sum(ctx.dim).unsqueeze(ctx.dim)
return dX, d_alpha, None, None, None
# slightly more efficient special case for sparsemax
class SparsemaxBisectFunction(EntmaxBisectFunction):
@classmethod
def _gp(cls, x, alpha):
return x
@classmethod
def _gp_inv(cls, y, alpha):
return y
@classmethod
def _p(cls, x, alpha):
return torch.clamp(x, min=0)
@classmethod
def forward(cls, ctx, X, dim=-1, n_iter=50, ensure_sum_one=True):
return super().forward(
ctx, X, alpha=2, dim=dim, n_iter=50, ensure_sum_one=True
)
@classmethod
def backward(cls, ctx, dY):
Y, = ctx.saved_tensors
gppr = (Y > 0).to(dtype=dY.dtype)
dX = dY * gppr
q = dX.sum(ctx.dim) / gppr.sum(ctx.dim)
q = q.unsqueeze(ctx.dim)
dX -= q * gppr
return dX, None, None, None
def entmax_bisect(X, alpha=1.5, dim=-1, n_iter=50, ensure_sum_one=True):
"""alpha-entmax: normalizing sparse transform (a la softmax).
Solves the optimization problem:
max_p <x, p> - H_a(p) s.t. p >= 0, sum(p) == 1.
where H_a(p) is the Tsallis alpha-entropy with custom alpha >= 1,
using a bisection (root finding, binary search) algorithm.
This function is differentiable with respect to both X and alpha.
Parameters
----------
X : torch.Tensor
The input tensor.
alpha : float or torch.Tensor
Tensor of alpha parameters (> 1) to use. If scalar
or python float, the same value is used for all rows, otherwise,
it must have shape (or be expandable to)
alpha.shape[j] == (X.shape[j] if j != dim else 1)
A value of alpha=2 corresponds to sparsemax, and alpha=1 corresponds to
softmax (but computing it this way is likely unstable).
dim : int
The dimension along which to apply alpha-entmax.
n_iter : int
Number of bisection iterations. For float32, 24 iterations should
suffice for machine precision.
ensure_sum_one : bool,
Whether to divide the result by its sum. If false, the result might
sum to close but not exactly 1, which might cause downstream problems.
Returns
-------
P : torch tensor, same shape as X
The projection result, such that P.sum(dim=dim) == 1 elementwise.
"""
return EntmaxBisectFunction.apply(X, alpha, dim, n_iter, ensure_sum_one)
def sparsemax_bisect(X, dim=-1, n_iter=50, ensure_sum_one=True):
"""sparsemax: normalizing sparse transform (a la softmax), via bisection.
Solves the projection:
min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1.
Parameters
----------
X : torch.Tensor
The input tensor.
dim : int
The dimension along which to apply sparsemax.
n_iter : int
Number of bisection iterations. For float32, 24 iterations should
suffice for machine precision.
ensure_sum_one : bool,
Whether to divide the result by its sum. If false, the result might
sum to close but not exactly 1, which might cause downstream problems.
Note: This function does not yet support normalizing along anything except
the last dimension. Please use transposing and views to achieve more
general behavior.
Returns
-------
P : torch tensor, same shape as X
The projection result, such that P.sum(dim=dim) == 1 elementwise.
"""
return SparsemaxBisectFunction.apply(X, dim, n_iter, ensure_sum_one)
class SparsemaxBisect(nn.Module):
def __init__(self, dim=-1, n_iter=None):
"""sparsemax: normalizing sparse transform (a la softmax) via bisection
Solves the projection:
min_p ||x - p||_2 s.t. p >= 0, sum(p) == 1.
Parameters
----------
dim : int
The dimension along which to apply sparsemax.
n_iter : int
Number of bisection iterations. For float32, 24 iterations should
suffice for machine precision.
"""
self.dim = dim
self.n_iter = n_iter
super().__init__()
def forward(self, X):
return sparsemax_bisect(X, dim=self.dim, n_iter=self.n_iter)
class EntmaxBisect(nn.Module):
def __init__(self, alpha=1.5, dim=-1, n_iter=50):
"""alpha-entmax: normalizing sparse map (a la softmax) via bisection.
Solves the optimization problem:
max_p <x, p> - H_a(p) s.t. p >= 0, sum(p) == 1.
where H_a(p) is the Tsallis alpha-entropy with custom alpha >= 1,
using a bisection (root finding, binary search) algorithm.
Parameters
----------
alpha : float or torch.Tensor
Tensor of alpha parameters (> 1) to use. If scalar
or python float, the same value is used for all rows, otherwise,
it must have shape (or be expandable to)
alpha.shape[j] == (X.shape[j] if j != dim else 1)
A value of alpha=2 corresponds to sparsemax; alpha=1 corresponds
to softmax (but computing it this way is likely unstable).
dim : int
The dimension along which to apply alpha-entmax.
n_iter : int
Number of bisection iterations. For float32, 24 iterations should
suffice for machine precision.
"""
super(EntmaxBisect).__init__()
self.dim = dim
self.n_iter = n_iter
self.alpha = alpha
def forward(self, X):
return entmax_bisect(
X, alpha=self.alpha, dim=self.dim, n_iter=self.n_iter
)