Skip to content

Commit

Permalink
[SW-187215] Add valid_seq_len feature to patched SDPA module
Browse files Browse the repository at this point in the history
Change-Id: Ia627fe8134470d68a7e55fc978a972bb7f7b3d5b
  • Loading branch information
wszczurekhabana authored and Eran Geva committed Jul 25, 2024
1 parent 039af39 commit 3f61954
Showing 1 changed file with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,9 @@ def forward(
is_causal=False,
scale=None,
softmax_mode="None",
recompute=None,
valid_seq_len=None,
seq_padding_type="None",
):
qinput = self.quant_q(q).detach()
kinput = self.quant_k(k).detach()
Expand All @@ -746,6 +749,8 @@ def forward(
q_scale_o=self.scale_output,
d_scale_s=self.descale_amax,
is_amax_s=False,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type
)
output = results[0]
d_out = self.dequant_output(output)
Expand All @@ -761,6 +766,9 @@ def forward_measure(
is_causal=False,
scale=None,
softmax_mode="fast",
recompute=None,
valid_seq_len=None,
seq_padding_type="None",
):
dq = q.detach()
dk = k.detach()
Expand All @@ -777,6 +785,8 @@ def forward_measure(
# fp8_fused_sdpa in bf16 can use either FastSoftmax or regular
softmax_mode="fast",
is_amax_s=True,
valid_seq_len=valid_seq_len,
seq_padding_type=seq_padding_type
)
output = results[0]
amax = results[1]
Expand Down

0 comments on commit 3f61954

Please sign in to comment.