Skip to content

Commit

Permalink
PR#4179 (pytorch#3027)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#124

X-link: pytorch/benchmark#2435

Pull Request resolved: pytorch#3027

This PR is a dependency of the grid_constant PR. The API for TMA descriptor fill methods was changed, so I fixed up all usages in fbcode.

triton-lang/triton#4179

Reviewed By: minjang

Differential Revision: D61729239
  • Loading branch information
Elliot Gorokhovsky authored and facebook-github-bot committed Aug 23, 2024
1 parent a4a6661 commit 36866e9
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,6 @@ def persistent_grid(META):

# used by TMA persistent kernel
TMA_SIZE = 128
import numpy as np

# autotune doesn't work with TMA
# https://github.com/triton-lang/triton/blob/main/python/tutorials/09-persistent-matmul.py#L312
Expand All @@ -897,11 +896,11 @@ def persistent_grid(META):
num_stages = 3
num_warps = 8

desc_a = np.empty(TMA_SIZE, dtype=np.int8)
desc_b = np.empty(TMA_SIZE, dtype=np.int8)
desc_c = np.empty(TMA_SIZE, dtype=np.int8)
desc_a_scale = np.empty(TMA_SIZE, dtype=np.int8)
desc_b_scale = np.empty(TMA_SIZE, dtype=np.int8)
desc_a = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_b = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_c = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_a_scale = torch.empty(TMA_SIZE, dtype=torch.int8)
desc_b_scale = torch.empty(TMA_SIZE, dtype=torch.int8)

triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
a_tl.data_ptr(),
Expand All @@ -910,7 +909,7 @@ def persistent_grid(META):
BLOCK_M,
BLOCK_K,
a_tl.element_size(),
desc_a,
desc_a.data_ptr(),
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
b_tl.data_ptr(),
Expand All @@ -919,7 +918,7 @@ def persistent_grid(META):
BLOCK_N,
BLOCK_K,
b_tl.element_size(),
desc_b,
desc_b.data_ptr(),
)
triton.runtime.driver.active.utils.fill_2d_tma_descriptor(
c.data_ptr(),
Expand All @@ -928,21 +927,21 @@ def persistent_grid(META):
BLOCK_M,
BLOCK_N,
c.element_size(),
desc_c,
desc_c.data_ptr(),
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
a_scale.data_ptr(),
M,
BLOCK_M,
a_scale.element_size(),
desc_a_scale,
desc_a_scale.data_ptr(),
)
triton.runtime.driver.active.utils.fill_1d_tma_descriptor(
b_scale.data_ptr(),
N,
BLOCK_N,
b_scale.element_size(),
desc_b_scale,
desc_b_scale.data_ptr(),
)
desc_a = torch.tensor(desc_a, device="cuda")
desc_b = torch.tensor(desc_b, device="cuda")
Expand Down

0 comments on commit 36866e9

Please sign in to comment.