forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DistanceOpsKernel.cpp
447 lines (379 loc) · 18.1 KB
/
DistanceOpsKernel.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
#include <ATen/native/Distance.h>
#include <numeric>
#include <iterator>
#include <algorithm>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/cpu/vml.h>
namespace at { namespace native { namespace {
template<typename scalar_t>
struct Dist {
using Vec = vec256::Vec256<scalar_t>;
// Depending on the value of the pnorm, there are specific implementations
// that are much faster than std::pow(std::abs(a - b), p), but have the same
// standard loop code for how to process the input vector. To reuse the main
// outside loop while still guaranteeing that the compiler inlines every
// different function on p, we break the inner norm logic into structs with
// static functions that represent what's done differently, and template the
// outer loop on those structs.
//
// The four functions are:
// map : This tells how to modify (a - b) to form the component that
// gets summed.
// red : This tells how to sum the result of map up. This is
// separate because the inf norm actually uses max instead of
// sum.
// finish : This tells what to do with the aggregated value to compute
// the norm. Generally this is the result of val ^ (1 / p).
// backward : This is the gradient for that norm. Arguments are pretty
// self explanitory.
//
// There are a few cases where these aren't used. The 0 norm has no backward,
// because it's always 0, so that's shortcircuited earlier. There's a special
// implementation of the general backward pass when p is less than two, so
// there's a struct with only a backward pass for this case.
// TODO This is an inefficient way to compite sign, and can be much faster
// using native SSE instructions that should be added to Vec256.
static inline Vec sign(Vec val) {
return vec256::minimum(vec256::maximum(Vec(0), val.ceil()), Vec(1)) +
vec256::minimum(vec256::maximum(Vec(-1), val.floor()), Vec(0));
}
static inline Vec abs(Vec val) {
return val.abs();
}
static inline scalar_t abs(scalar_t val) {
return std::abs(val);
}
static inline Vec ceil(Vec val) {
return val.ceil();
}
static inline scalar_t ceil(scalar_t val) {
return std::ceil(val);
}
static inline Vec min(Vec val, scalar_t other) {
return vec256::minimum(val, Vec(other));
}
static inline scalar_t min(scalar_t val, scalar_t other) {
return std::min(val, other);
}
static inline Vec max(Vec val, Vec other) {
return vec256::maximum(val, other);
}
static inline scalar_t max(scalar_t val, scalar_t other) {
return std::max(val, other);
}
static inline Vec pow(Vec val, Vec p) {
return val.pow(p);
}
static inline scalar_t pow(scalar_t val, scalar_t p) {
return std::pow(val, p);
}
// Zero norm
template<typename data_t>
struct zdist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return min(ceil(abs(diff)), 1); }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
};
// One norm
template<typename data_t>
struct odist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return Vec(grad) * sign(diff); }
};
// Special general pnorm derivative if p is less than two
struct lttdist_calc {
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) {
Vec result = (dist == 0.0) ? Vec(0) : (sign(diff) * diff.abs().pow(p - Vec(1)) * Vec(grad) / Vec(dist).pow(p - Vec(1)));
result = Vec::blendv(result, Vec(0), (diff == Vec(0)) & (p < Vec(1)));
return result;
}
};
// Two norm
template<typename data_t>
struct tdist_calc {
// TODO This can probably use fused add multiply to get better perf
static inline data_t map(const data_t& diff, const data_t& p) { return diff * diff; }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::sqrt(agg); }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : Vec(grad) * diff / Vec(dist); }
};
// General p norm
template<typename data_t>
struct pdist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return pow(diff, p); }
static inline data_t red(const data_t& agg, const data_t& up) { return agg + up; }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return std::pow(agg, 1.0 / p); }
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return dist == 0.0 ? Vec(0) : diff * diff.abs().pow(p - Vec(2)) * Vec(grad) / Vec(dist).pow(p - Vec(1)); }
};
// Inf norm
template<typename data_t>
struct idist_calc {
static inline data_t map(const data_t& diff, const data_t& p) { return diff; }
static inline data_t red(const data_t& agg, const data_t& up) { return max(agg, up); }
static inline scalar_t finish(const scalar_t agg, const scalar_t p) { return agg; }
// TODO This backward pass uses a very complext expression to compute (diff
// == dist) that could be much faster if using SSE instructions.
static inline Vec backward(const Vec& diff, const scalar_t grad, const scalar_t dist, const Vec& p) { return Vec(grad) * sign(diff) * (Vec(1) - vec256::minimum(Vec(1), (diff.abs() - Vec(dist)).abs().ceil())); }
};
template <typename F>
static void run_parallel_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
const scalar_t * const self_start = self.data_ptr<scalar_t>();
const scalar_t * const self_end = self_start + self.numel();
int64_t n = self.size(0);
int64_t m = self.size(1);
scalar_t * const res_start = result.data_ptr<scalar_t>();
int64_t combs = result.numel(); // n * (n - 1) / 2
// We conceptually iterate over tuples of (i, j, k) where i is the first
// vector from the input, j is the second, and k is the result index. This
// parallelizes over the range of k and infers what i and j are from the
// value of k.
parallel_for(0, combs, internal::GRAIN_SIZE / (16 * m), [p, self_start, self_end, n, m, res_start](int64_t k, int64_t end) {
const Vec pvec(p);
double n2 = n - .5;
// The -1 accounts for floating point truncation issues
int64_t i = static_cast<int64_t>((n2 - std::sqrt(n2 * n2 - 2 * k - 1)));
int64_t j = k - n * i + i * (i + 1) / 2 + i + 1;
const scalar_t * self_i = self_start + i * m;
const scalar_t * self_j = self_start + j * m;
scalar_t * res = res_start + k;
const scalar_t * const res_end = res_start + end;
while (res != res_end) {
*res = F::finish(vec256::map2_reduce_all<scalar_t>(
[&pvec](Vec a, Vec b) { return F::map((a - b).abs(), pvec); },
F::red, self_i, self_j, m), p);
res += 1;
self_j += m;
if (self_j == self_end) {
self_i += m;
self_j = self_i + m;
}
}
});
}
// Assumes self is nonempty, contiguous, and 2D
static void apply_pdist(Tensor& result, const Tensor& self, const scalar_t p) {
if (p == 0.0) {
run_parallel_pdist<zdist_calc<Vec>>(result, self, p);
} else if (p == 1.0) {
run_parallel_pdist<odist_calc<Vec>>(result, self, p);
} else if (p == 2.0) {
run_parallel_pdist<tdist_calc<Vec>>(result, self, p);
} else if (std::isinf(p)) {
run_parallel_pdist<idist_calc<Vec>>(result, self, p);
} else {
run_parallel_pdist<pdist_calc<Vec>>(result, self, p);
}
}
template <typename F>
static void run_parallel_cdist(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) {
const scalar_t * const t1_start = t1.data_ptr<scalar_t>();
const scalar_t * const t2_start = t2.data_ptr<scalar_t>();
int64_t d = t1.size(0);
int64_t r1 = t1.size(-2);
int64_t r2 = t2.size(-2);
int64_t m = t1.size(-1);
scalar_t * const res_start = result.data_ptr<scalar_t>();
int64_t combs = r1 * r2;
int64_t size1 = r1 * m;
int64_t size2 = r2 * m;
parallel_for(0, combs * d, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
scalar_t * res = res_start + start;
const scalar_t * const res_end = res_start + end;
int64_t l = start / combs;
int64_t k = start % combs;
int64_t i = k / r2;
int64_t j = k % r2;
i = i * m;
j = j * m;
while (res != res_end) {
const scalar_t * self_i = t1_start + size1 * l + i;
const scalar_t * self_j = t2_start + size2 * l + j;
scalar_t agg = 0;
for (int x = 0; x < m; x++) {
scalar_t a = *(self_i + x);
scalar_t b = *(self_j + x);
agg = F::red(agg, F::map(std::abs(a-b), p));
}
*res = F::finish(agg, p);
res += 1;
j += m;
if (j == size2) {
j = 0;
i += m;
if (i == size1) {
i = 0;
l += 1;
}
}
}
});
}
static void apply_cdist(Tensor& result, const Tensor& x1, const Tensor& x2, const scalar_t p) {
if (p == 0.0) {
run_parallel_cdist<zdist_calc<scalar_t>>(result, x1, x2, p);
} else if (p == 1.0) {
run_parallel_cdist<odist_calc<scalar_t>>(result, x1, x2, p);
} else if (p == 2.0) {
run_parallel_cdist<tdist_calc<scalar_t>>(result, x1, x2, p);
} else if (std::isinf(p)) {
run_parallel_cdist<idist_calc<scalar_t>>(result, x1, x2, p);
} else {
run_parallel_cdist<pdist_calc<scalar_t>>(result, x1, x2, p);
}
}
// This does a backward pass down a Vec column of the input
template <typename F>
inline static void backward_down_column_pdist(const scalar_t * self_i, scalar_t * res_i, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t n, int64_t m, int64_t gs, int64_t count = Vec::size()) {
for (const scalar_t * const self_end = self_i + m * n; self_i != self_end - m; self_i += m, res_i += m) {
const Vec self_vec_i = Vec::loadu(self_i, count);
Vec res_vec_i = Vec::loadu(res_i, count);
const scalar_t * self_j = self_i + m;
scalar_t * res_j = res_i + m;
for (; self_j != self_end; self_j += m, res_j += m, grad_k += gs, dist_k += 1) {
const Vec self_vec_j = Vec::loadu(self_j, count);
Vec res_vec_j = Vec::loadu(res_j, count);
Vec res = F::backward(self_vec_i - self_vec_j, *grad_k, *dist_k, pvec);
res_vec_i = res_vec_i + res;
res_vec_j = res_vec_j - res;
res_vec_j.store(res_j, count);
}
res_vec_i.store(res_i, count);
}
}
template <typename F>
static void run_backward_parallel_pdist(Tensor& result, const Tensor & grad, const Tensor & self, const scalar_t p, const Tensor& dist) {
const int64_t n = self.size(0);
const int64_t m = self.size(1);
const int64_t gs = grad.stride(0);
const scalar_t * const grad_start = grad.data_ptr<scalar_t>();
const scalar_t * const dist_start = dist.data_ptr<scalar_t>();
const scalar_t * const self_start = self.data_ptr<scalar_t>();
scalar_t * const res_start = result.data_ptr<scalar_t>();
// The only way to parallelize and avoid locking requires parallelizing
// over the columns of the input, i.e. we compute the gradient for the
// first section of each vector independentaly of the second section, etc.
at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (8 * n * n), [p, n, m, gs, grad_start, dist_start, self_start, res_start](int64_t l, int64_t end) {
const Vec pvec(p);
const scalar_t * self_l = self_start + l * Vec::size();
scalar_t * res_l = res_start + l * Vec::size();
for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; self_l += Vec::size(), res_l += Vec::size()) {
backward_down_column_pdist<F>(self_l, res_l, grad_start, dist_start, pvec, n, m, gs);
}
});
const int64_t remainder = m % Vec::size();
if (remainder) {
backward_down_column_pdist<F>(self_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), n, m, gs, remainder);
}
}
// Assumes self is nonempty, contiguous, and 2D and dist is also contiguous
static void apply_backward_pdist(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
result.fill_(0);
if (p == 0.0) {
} else if (p == 1.0) {
run_backward_parallel_pdist<odist_calc<Vec>>(result, grad, self, p, dist);
} else if (p < 2.0) {
run_backward_parallel_pdist<lttdist_calc>(result, grad, self, p, dist);
} else if (p == 2.0) {
run_backward_parallel_pdist<tdist_calc<Vec>>(result, grad, self, p, dist);
} else if (std::isinf(p)) {
run_backward_parallel_pdist<idist_calc<Vec>>(result, grad, self, p, dist);
} else {
run_backward_parallel_pdist<pdist_calc<Vec>>(result, grad, self, p, dist);
}
}
static void apply_backward_cdist(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
result.fill_(0);
if (p == 0.0) {
} else if (p == 1.0) {
run_backward_parallel_cdist<odist_calc<Vec>>(result, grad, x1, x2, p, dist);
} else if (p < 2.0) {
run_backward_parallel_cdist<lttdist_calc>(result, grad, x1, x2, p, dist);
} else if (p == 2.0) {
run_backward_parallel_cdist<tdist_calc<Vec>>(result, grad, x1, x2, p, dist);
} else if (std::isinf(p)) {
run_backward_parallel_cdist<idist_calc<Vec>>(result, grad, x1, x2, p, dist);
} else {
run_backward_parallel_cdist<pdist_calc<Vec>>(result, grad, x1, x2, p, dist);
}
}
template <typename F>
static void run_backward_parallel_cdist(Tensor& result, const Tensor & grad, const Tensor & t1, const Tensor & t2, const scalar_t p, const Tensor& dist) {
const int64_t r1 = t1.size(-2);
const int64_t r2 = t2.size(-2);
const int64_t m = t1.size(-1);
const int64_t d = result.size(0);
const int64_t l1_size = r1 * m;
const int64_t l2_size = r2 * m;
//current implementation supports only tensor that can be collapsed to 1D. However, to avoid checking if grad satisfies this assumption,
//we call .contiguous() on grad before backward, thus stride is guaranteed to be 1
//don't use grad.stride(-1), because if last dimension is 1, stride can be bogus.
const int64_t gs = 1;
const scalar_t * const grad_start = grad.data_ptr<scalar_t>();
const scalar_t * const dist_start = dist.data_ptr<scalar_t>();
const scalar_t * const t1_start = t1.data_ptr<scalar_t>();
const scalar_t * const t2_start = t2.data_ptr<scalar_t>();
scalar_t * const res_start = result.data_ptr<scalar_t>();
at::parallel_for(0, m / Vec::size(), internal::GRAIN_SIZE / (16 * r1), [=](int64_t l, int64_t end) {
const Vec pvec(p);
const scalar_t * i = t1_start + l * Vec::size();
const scalar_t * j = t2_start + l * Vec::size();
scalar_t * res_l = res_start + l * Vec::size();
for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; i += Vec::size(), j += Vec::size(), res_l += Vec::size()) {
backward_down_column_cdist<F>(i, j, res_l, grad_start, dist_start, pvec, r1, r2, m, d, gs, l1_size, l2_size);
}
});
const int64_t remainder = m % Vec::size();
if (remainder) {
backward_down_column_cdist<F>(t1_start + (m - remainder), t2_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), r1, r2, m, d, gs, l1_size, l2_size, remainder);
}
}
template <typename F>
inline static void backward_down_column_cdist(const scalar_t * t1, const scalar_t * t2, scalar_t * res, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t r1, int64_t r2, int64_t m, int64_t d, int64_t gs, int64_t l1_size, int64_t l2_size, int64_t count = Vec::size()) {
const scalar_t * t1_end = t1 + l1_size;
const scalar_t * t2_end = t2 + l2_size;
for (int64_t l = 0; l < d; l++) {
for (; t1 != t1_end; t1 += m, res += m) {
const Vec vec_t1 = Vec::loadu(t1, count);
Vec res_vec = Vec::loadu(res, count);
for (const scalar_t * t2_curr = t2; t2_curr != t2_end; t2_curr += m, grad_k += gs, dist_k += 1) {
const Vec vec_t2 = Vec::loadu(t2_curr, count);
Vec res = F::backward(vec_t1 - vec_t2, *grad_k, *dist_k, pvec);
res_vec = res_vec + res;
}
res_vec.store(res, count);
}
t1_end += l1_size;
t2_end += l2_size;
t2 += l2_size;
}
}
};
void pdist_forward_kernel_impl(Tensor& result, const Tensor& self, const double p) {
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist", [&] {
Dist<scalar_t>::apply_pdist(result, self, p);
});
}
static void pdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& self, const double p, const Tensor& dist) {
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "pdist_backward", [&] {
Dist<scalar_t>::apply_backward_pdist(result, grad, self, p, dist);
});
}
static void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, const double p) {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist", [&] {
Dist<scalar_t>::apply_cdist(result, x1, x2, p);
});
}
static void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor& x1, const Tensor& x2, const double p, const Tensor& dist) {
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_backward", [&] {
Dist<scalar_t>::apply_backward_cdist(result, grad, x1, x2, p, dist);
});
}
} // anonymous namespace
REGISTER_DISPATCH(pdist_forward_stub, &pdist_forward_kernel_impl);
REGISTER_DISPATCH(pdist_backward_stub, &pdist_backward_kernel_impl);
REGISTER_DISPATCH(cdist_stub, &cdist_kernel_impl);
REGISTER_DISPATCH(cdist_backward_stub, &cdist_backward_kernel_impl);
}} // namespace at::native