Skip to content

Commit

Permalink
Fix mask in _chunk_scan_chunk_state_bwd_dx that could cause NaN
Browse files Browse the repository at this point in the history
Only affect cases where sequence length is not a multiple of 256
  • Loading branch information
tridao committed Jun 12, 2024
1 parent 8f42a5e commit bb3a82a
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion mamba_ssm/ops/triton/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ def _chunk_scan_chunk_state_bwd_dx_kernel(
dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32)
cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None])
mask = k + offs_k[None, :] >= offs_m[:, None]
# If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range,
# we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf.
# Multiplying with cb, which is 0.0 outside the range, will make the result NaN.
# This will cause NaN in acc, and hence NaN in dx and ddt.
mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX)
cb = tl.where(mask, cb, 0.0)
cb = cb.to(dout_ptr.dtype.element_ty)
acc += tl.dot(cb, dout)
Expand Down

0 comments on commit bb3a82a

Please sign in to comment.