-
Notifications
You must be signed in to change notification settings - Fork 334
/
softmax.py
419 lines (346 loc) · 16.4 KB
/
softmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused scaled masked softmax functions"""
import os
from typing import Callable, Tuple, Union, Optional
import torch
from torch import nn
import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32
THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
_default_causal_mask = {}
def _get_default_causal_mask(mask_type: str, sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
matrix_identifiers = (mask_type, sq, sk)
if matrix_identifiers not in _default_causal_mask:
diagonal_offset = sk - sq + 1 if "bottom_right" in mask_type else 1
_default_causal_mask[matrix_identifiers] = torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"), diagonal=diagonal_offset
)
return _default_causal_mask[matrix_identifiers]
def _get_onnx_export_causal_mask(
seq_q: int, seq_k: int, onnx_causal_mask: torch.Tensor
) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input, for ONNX export.
ONNX does not support dynamic control-flow and requires non-square masks when
using a KV-cache (seq_k's length len(context)+len(generative) while seq_q's length is 1).
Argument `onnx_causal_mask` is a square triu (k=1) mask that is sliced to the correct
shape for GPT context and generation phases.
In the context phase the derived mask is a square triu of shape (seq_k, seq_k), and in
the generation phase the mask is rectangular with shape (1, seq_k).
"""
assert len(onnx_causal_mask.size()) == 2
assert onnx_causal_mask.size(0) == onnx_causal_mask.size(1)
assert onnx_causal_mask.size(0) >= (seq_k - seq_q) >= 0
derived_mask = onnx_causal_mask[seq_k - seq_q : seq_k, :seq_k]
return derived_mask
def fp32_compute(onnx_symbolic_fn):
"""A decorator that wraps an ONNX symoblic function with FP32 compute operators."""
def wrapper(g: torch.Graph, inp: torch._C.Value, scale: float, *args, **kwargs):
return compute_in_fp32(g, inp, onnx_symbolic_fn, scale, *args, **kwargs)
return wrapper
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledUpperTriangMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledUpperTriangMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_upper_triang_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
@staticmethod
@fp32_compute
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledUpperTriangMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
mask = g.op("Trilu", ones, k, upper_i=1)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_upper_triang_masked_softmax_warp_forward
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
return out
class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply causal mask aligned to the bottom right corner of the input matrix
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledAlignedCausalMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledAlignedCausalMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_aligned_causal_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
@staticmethod
@fp32_compute
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledAlignedCausalMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
# rectangular causal mask aligned to the bottom right corner of Attention matrix
rows = inputs.size(dim=-2)
cols = inputs.size(dim=-1)
diag_shift = cols - rows + 1
mask = g.op("Trilu", ones, k, upper_i=diag_shift)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_aligned_masked_softmax_warp_forward
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
return out
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
@staticmethod
@fp32_compute
def symbolic(
g: torch.Graph, inputs: torch._C.Value, mask: torch._C.Value, scale: float
) -> torch._C.Value:
"""ScaledMaskedSoftmax symbolic method"""
# Captures the logic of function scaled_masked_softmax_warp_forward.
# output = softmax(mask(input*scale)
# Computed as:
# masked_scaled = (1 - mask)*(input*scale)
# softmax_mask = mask * -10000
# output = softmax(masked_scaled + softmax_mask)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
# Note: type is hard coded because softmax uses FP16 or BF16
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000.0, dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
return out
class ScaledSoftmax(torch.autograd.Function):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_softmax_backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None
@staticmethod
@fp32_compute
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledSoftmax symbolic method"""
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
out = g.op("Softmax", scaled)
return out
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Arguments:
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
"""
def __init__(
self,
mask_func: Callable,
softmax_in_fp32: bool = True,
) -> None:
super().__init__()
self.scaled_masked_softmax_fusion = bool(int(os.getenv("NVTE_MASKED_SOFTMAX_FUSION", "1")))
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
# Users exporting to ONNX can optimize the attention mask for GPT text generation.
self.kvcache_max_seq = int(os.getenv("NVTE_ONNX_KVCACHE_MAX_SEQ_LEN", "-1"))
if self.kvcache_max_seq > 0:
self.register_buffer(
"onnx_causal_mask",
torch.triu(
torch.ones(self.kvcache_max_seq, self.kvcache_max_seq, device="cuda"),
diagonal=1,
).bool(),
persistent=False,
)
def forward(
self,
inp: torch.Tensor,
mask: torch.Tensor,
attn_mask_type: str,
scale: Optional[float] = None,
) -> torch.Tensor:
"""FusedScaleMaskSoftmax fprop"""
# [b, np, sq, sk]
assert inp.dim() == 4
self.input_in_fp16 = inp.dtype == torch.float16
self.input_in_bf16 = inp.dtype == torch.bfloat16
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
assert scale is None or self.softmax_in_fp32, "softmax should be in fp32 when scaled"
if self.is_kernel_available(mask, *inp.size()) and not is_in_onnx_export_mode():
return self.forward_fused_softmax(inp, mask, scale)
return self.forward_torch_softmax(inp, mask, scale)
def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk: int) -> bool:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches = b * np
if not self.scaled_masked_softmax_fusion:
return False # user doesn't want to fuse
if not self.input_in_float16:
return False # input must be fp16
if not 16 < sk < 16384:
return False # sk must be 16 ~ 16384
if sk % 8 != 0:
return False # sk must be divisor of 8
if sq == 1:
return False # sq must be > 1
if self.attn_mask_type == "causal" and sq != sk:
return False # Fused causal kernel only support causal_bottom_right
if (
sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
batch_per_block = self.get_batch_per_block(int(sk))
if "padding" in self.attn_mask_type or self.attn_mask_type == "arbitrary":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[0] in [1, b]
and mask.shape[1:] == (1, sq, sk)
):
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
"""
Fused masked softmax path.
attn_mask_type | module
-----------------------------------------------------------------------------------------
no_mask | ScaledSoftmax
causal (self-attention), causal_bottom_right | ScaledAlignedCausalMaskedSoftmax
padding, padding_causal, padding_causal_bottom_right | ScaledMaskedSoftmax
arbitrary ([1, 1, sq, sk] or [b, 1, sq, sk]) | ScaledMaskedSoftmax
"""
scale = 1.0 if scale is None else scale
if self.attn_mask_type in ["causal", "causal_bottom_right"]:
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)
# input is 4D tensor (1, 1, sq, sk) or (b, 1, sq, sk)
if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale)
return ScaledSoftmax.apply(inp, scale)
def forward_torch_softmax(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
"""Framework softmax"""
if self.input_in_float16 and self.softmax_in_fp32:
inp = inp.float()
if scale is not None:
inp = inp * scale
if self.attn_mask_type in ["causal", "causal_bottom_right"]:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
assert self.kvcache_max_seq >= seq_len_k
causal_mask = _get_onnx_export_causal_mask(
seq_len_q, seq_len_k, self.onnx_causal_mask
)
else:
causal_mask = _get_default_causal_mask(self.attn_mask_type, seq_len_q, seq_len_k)
if mask is None:
mask = causal_mask
else:
mask = torch.logical_or(mask, causal_mask)
mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
mask_output = self.mask_func(inp, mask)
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
@staticmethod
def get_batch_per_block(key_seq_len: int) -> int:
"""Softmax utility"""
pow2 = 1 << (key_seq_len - 1).bit_length()
warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = THREADS_PER_BLOCK // warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block