Skip to content

Commit

Permalink
Triton PR#4179 (pytorch#2435)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#124

Pull Request resolved: pytorch#2435

X-link: pytorch/FBGEMM#3027

This PR is a dependency of the grid_constant PR. The API for TMA descriptor fill methods was changed, so I fixed up all usages in fbcode.

triton-lang/triton#4179

Reviewed By: minjang

Differential Revision: D61729239
  • Loading branch information
Elliot Gorokhovsky authored and facebook-github-bot committed Aug 23, 2024
1 parent 0968f5e commit ec1d4da
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchbenchmark/util/kernels/triton_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size):
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.numpy())
self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)


Expand All @@ -75,7 +75,7 @@ def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0,
else:
desc_x = self.cuda_descriptors[name]
buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True)
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.numpy())
self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr())
desc_x.copy_(buf_x, non_blocking=True)


Expand Down

0 comments on commit ec1d4da

Please sign in to comment.