Skip to content

Commit

Permalink
PR#4179 (#2435)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#124

Pull Request resolved: #2435

X-link: pytorch/FBGEMM#3027

This PR is a dependency of the grid_constant PR. The API for TMA descriptor fill methods was changed, so I fixed up all usages in fbcode.

triton-lang/triton#4179

Reviewed By: minjang

Differential Revision: D61729239

fbshipit-source-id: 8ce25b7c230c3f4ad960f76aa0dd29626c8ee4d2
  • Loading branch information
Elliot Gorokhovsky authored and facebook-github-bot committed Aug 26, 2024
1 parent 52103b5 commit babb128
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchbenchmark/util/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.numpy())
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)


Expand All @@ -75,7 +75,7 @@ def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0,
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.numpy())
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)


Expand Down

0 comments on commit babb128

Please sign in to comment.