Skip to content

Commit

Permalink
[CrossEntropy] Fix where labels address not aligned to 16 bytes
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Oct 5, 2024
1 parent 53a4f34 commit bedf877
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flash_attn/ops/triton/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Tuple, Optional, Union

import torch
import torch.nn.functional as F

import triton
import triton.language as tl
Expand Down Expand Up @@ -160,6 +161,11 @@ def forward(
inplace_backward=False,
process_group=None,
):
# For some reason Triton generates wrong code when labels has dtype long and its address
# is not aligned to 16 bytes. The ld.global.b64 seems to load the wrong label index.
if labels.dtype == torch.long and labels.data_ptr() % 16 != 0:
labels = F.pad(labels, (0, 1))[..., :-1]
assert labels.data_ptr() % 16 == 0
n_rows, n_cols = logits.shape
assert labels.shape == (n_rows,)
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)
Expand Down

0 comments on commit bedf877

Please sign in to comment.