diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index b640ea6b..df218a73 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -170,7 +170,11 @@ def _chunk_scan_chunk_state_bwd_dx_kernel( dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) - mask = k + offs_k[None, :] >= offs_m[:, None] + # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, + # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. + # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. + # This will cause NaN in acc, and hence NaN in dx and ddt. + mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) cb = tl.where(mask, cb, 0.0) cb = cb.to(dout_ptr.dtype.element_ty) acc += tl.dot(cb, dout)