forked from torch/cunn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LogSoftMax.cu
314 lines (252 loc) · 8.22 KB
/
LogSoftMax.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
#include "utils.h"
struct MaxFloat {
__device__ __forceinline__ float operator()(float max, float v) const {
return fmaxf(max, v);
}
};
struct SumFloat {
__device__ __forceinline__ float operator()(float sum, float v) const {
return sum + v;
}
};
struct SumExpFloat {
__device__ __forceinline__ SumExpFloat(float v) : max_k(v) {}
__device__ __forceinline__ float operator()(float sum, float v) const {
return sum + expf(v - max_k);
}
const float max_k;
};
struct NoFinal {
__device__ __forceinline__ float operator()(float v) const {
return v;
}
};
struct LSMFinal {
__device__ __forceinline__ LSMFinal(float m) : max_k(m) {}
__device__ __forceinline__ float operator()(float v) const {
return max_k + logf(v);
}
const float max_k;
};
template <typename Reduction, typename Finalize>
__device__ __forceinline__ float
blockReduce(float* smem, float val,
const Reduction& r,
float defaultVal,
const Finalize& f) {
// To avoid RaW races from chaning blockReduce calls together, we
// need a sync here
__syncthreads();
smem[threadIdx.x] = val;
__syncthreads();
float warpVal = defaultVal;
// First warp will perform per-warp reductions for the remaining warps
if ((threadIdx.x / 32) == 0) {
int lane = threadIdx.x % 32;
if (lane < blockDim.x / 32) {
#pragma unroll
for (int i = 0; i < 32; ++i) {
warpVal = r(warpVal, smem[lane * 32 + i]);
}
smem[lane] = warpVal;
}
}
__syncthreads();
// First thread will perform a reduction of the above per-warp reductions
float blockVal = defaultVal;
if (threadIdx.x == 0) {
for (int i = 0; i < blockDim.x / 32; ++i) {
blockVal = r(blockVal, smem[i]);
}
smem[0] = f(blockVal);
}
// Sync and broadcast
__syncthreads();
return smem[0];
}
template <typename Reduction>
__device__ __forceinline__ float
blockReduce(float* smem, float val,
const Reduction& r,
float defaultVal) {
return blockReduce<Reduction, NoFinal>(smem, val, r, defaultVal, NoFinal());
}
template <typename Reduction, int ILP>
__device__ __forceinline__ float
ilpReduce(float* data,
int size,
const Reduction& r,
float defaultVal) {
float threadVal = defaultVal;
int offset = threadIdx.x;
int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times)
for (; offset < size - last;
offset += blockDim.x * ILP) {
float tmp[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
tmp[j] = data[offset + j * blockDim.x];
}
#pragma unroll
for (int j = 0; j < ILP; ++j) {
threadVal = r(threadVal, tmp[j]);
}
}
// Epilogue
for (; offset < size; offset += blockDim.x) {
threadVal = r(threadVal, data[offset]);
}
return threadVal;
}
template <int ILP>
__global__ void
cunn_LogSoftMax_updateOutput_kernel(float *output, float *input, int classes) {
extern __shared__ float buffer[];
input += blockIdx.x * classes;
output += blockIdx.x * classes;
float threadMax =
ilpReduce<MaxFloat, ILP>(input, classes, MaxFloat(), -FLT_MAX);
float max_k =
blockReduce<MaxFloat>(buffer, threadMax, MaxFloat(), -FLT_MAX);
float threadExp =
ilpReduce<SumExpFloat, ILP>(input, classes, SumExpFloat(max_k), 0.0f);
float logsum_k =
blockReduce<SumFloat, LSMFinal>(
buffer, threadExp, SumFloat(), 0.0f, LSMFinal(max_k));
// Output LSM (hand ILP)
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for ( ; offset < classes - last; offset += blockDim.x * ILP) {
float tmp[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
tmp[j] = input[offset + j * blockDim.x];
}
#pragma unroll
for (int j = 0; j < ILP; ++j) {
output[offset + j * blockDim.x] = tmp[j] - logsum_k;
}
}
for (; offset < classes; offset += blockDim.x) {
output[offset] = input[offset] - logsum_k;
}
}
template <int ILP>
__global__ void
cunn_LogSoftMax_updateGradInput_kernel(float *gradInput,
float *output,
float *gradOutput,
int classes) {
extern __shared__ float buffer[];
gradInput += blockIdx.x * classes;
output += blockIdx.x * classes;
gradOutput += blockIdx.x * classes;
float threadSum =
ilpReduce<SumFloat, 4>(gradOutput, classes, SumFloat(), 0.0f);
float sum_k =
blockReduce<SumFloat>(buffer, threadSum, SumFloat(), 0.0f);
// Update gradInput (hand ILP)
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for ( ; offset < classes - last; offset += blockDim.x * ILP) {
float tmpGradOutput[ILP];
float tmpOutput[ILP];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
tmpGradOutput[j] = gradOutput[offset + j * blockDim.x];
tmpOutput[j] = output[offset + j * blockDim.x];
}
#pragma unroll
for (int j = 0; j < ILP; ++j) {
gradInput[offset + j * blockDim.x] =
tmpGradOutput[j] - __expf(tmpOutput[j]) * sum_k;
}
}
for (; offset < classes; offset += blockDim.x) {
gradInput[offset] =
gradOutput[offset] - __expf(output[offset]) * sum_k;
}
}
static int cunn_LogSoftMax_updateOutput(lua_State *L) {
THCState *state = getCutorchState(L);
THCudaTensor *input =
(THCudaTensor*)luaT_checkudata(L, 2, "torch.CudaTensor");
THCudaTensor *output =
(THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 2, input, output));
input = THCudaTensor_newContiguous(state, input);
THCudaTensor_resizeAs(state, output, input);
int batchSize = 1;
int classSize = 0;
if (THCudaTensor_nDimension(state, input) == 1) {
classSize = THCudaTensor_size(state, input, 0);
} else if (THCudaTensor_nDimension(state, input) == 2) {
batchSize = THCudaTensor_size(state, input, 0);
classSize = THCudaTensor_size(state, input, 1);
} else {
THError("vector or matrix expected");
}
dim3 grid(batchSize);
dim3 block(1024);
cunn_LogSoftMax_updateOutput_kernel<2>
<<<grid, block, block.x * sizeof(float), THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, output),
THCudaTensor_data(state, input),
classSize);
cudaError errcode = cudaGetLastError();
if (errcode != cudaSuccess) {
THError(cudaGetErrorString(errcode));
}
THCudaTensor_free(state, input);
return 1;
}
static int cunn_LogSoftMax_updateGradInput(lua_State *L) {
THCState *state = getCutorchState(L);
THCudaTensor *gradOutput =
(THCudaTensor*)luaT_checkudata(L, 3, "torch.CudaTensor");
THCudaTensor *output =
(THCudaTensor*)luaT_getfieldcheckudata(L, 1, "output", "torch.CudaTensor");
THCudaTensor *gradInput =
(THCudaTensor*)luaT_getfieldcheckudata(L, 1, "gradInput", "torch.CudaTensor");
THAssert(THCudaTensor_checkGPU(state, 3, output, gradOutput, gradInput));
output = THCudaTensor_newContiguous(state, output);
gradOutput = THCudaTensor_newContiguous(state, gradOutput);
THCudaTensor_resizeAs(state, gradInput, output);
int batchSize = 1;
int classSize = 0;
if (THCudaTensor_nDimension(state, gradInput) == 1) {
classSize = THCudaTensor_size(state, gradInput, 0);
} else if (THCudaTensor_nDimension(state, gradInput) == 2) {
batchSize = THCudaTensor_size(state, gradInput, 0);
classSize = THCudaTensor_size(state, gradInput, 1);
} else {
THError("vector or matrix expected");
}
dim3 grid(batchSize);
dim3 block(1024);
cunn_LogSoftMax_updateGradInput_kernel<2>
<<<grid, block, block.x * sizeof(float), THCState_getCurrentStream(state)>>>(
THCudaTensor_data(state, gradInput),
THCudaTensor_data(state, output),
THCudaTensor_data(state, gradOutput),
classSize);
cudaError errcode = cudaGetLastError();
if (errcode != cudaSuccess) {
THError(cudaGetErrorString(errcode));
}
THCudaTensor_free(state, gradOutput);
THCudaTensor_free(state, output);
return 1;
}
static const struct luaL_Reg cunn_LogSoftMax__ [] = {
{"LogSoftMax_updateOutput", cunn_LogSoftMax_updateOutput},
{"LogSoftMax_updateGradInput", cunn_LogSoftMax_updateGradInput},
{NULL, NULL}
};
void cunn_LogSoftMax_init(lua_State *L)
{
luaT_pushmetatable(L, "torch.CudaTensor");
luaT_registeratname(L, cunn_LogSoftMax__, "nn");
lua_pop(L,1);
}