Skip to content

Commit

Permalink
fix bug in fused softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
hyunwoongko committed Feb 23, 2022
1 parent 0885768 commit 80dcfb0
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def is_kernel_available(self, mask, b, np, sq, sk):
self.fusion # user wants to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sq <= 2048 # sq must be 16 ~ 2048
and sk % 4 == 0 # sk must be divisor of 4
and 16 < sk <= 2048 # sq must be 16 ~ 2048
and sq % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
Expand Down Expand Up @@ -197,7 +197,7 @@ def forward_torch_softmax(self, input, mask):
return probs

@staticmethod
def get_batch_per_block(b, np, sq, sk):
def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda

return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)

0 comments on commit 80dcfb0

Please sign in to comment.