Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTERPRETER] Fix lower bound check for block pointers #5201

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 25 additions & 17 deletions python/test/unit/language/test_block_pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,36 @@


@triton.jit
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr):
def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, PADDING_OPTION: tl.constexpr,
TEST_LOWER_BOUND: tl.constexpr, TEST_UPPER_BOUND: tl.constexpr):
pid = tl.program_id(0)
offset = pid * BLOCK_SIZE
if TEST_LOWER_BOUND:
offset = -N
elif TEST_UPPER_BOUND:
offset = N
# We only copy half of the data to see if the padding works
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(offset, ),
block_shape=(BLOCK_SIZE, ), order=(0, ))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ),
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(offset, ),
block_shape=(BLOCK_SIZE, ), order=(0, ))
if padding_option is None:
if PADDING_OPTION is None:
a = tl.load(a_block_ptr, boundary_check=(0, ))
else:
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option)
a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=PADDING_OPTION)
tl.store(b_block_ptr, a, boundary_check=(0, ))


@pytest.mark.interpreter
@pytest.mark.parametrize("dtypes_str, n, padding_option", [ #
(dtypes_str, n, padding)
@pytest.mark.parametrize("dtypes_str, n, padding_option, boundary_check", [ #
(dtypes_str, n, padding, boundary_check) #
for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"),
("float32", "float32"), ("bfloat16", "bfloat16"))
for n in (64, 128, 256, 512, 1024)
for padding in (None, "zero", "nan") #
for boundary_check in (None, "lower", "upper")
])
def test_block_copy(dtypes_str, n, padding_option, device):
def test_block_copy(dtypes_str, n, padding_option, boundary_check, device):
src_dtype_str = dtypes_str[0]
dst_dtype_str = dtypes_str[1]
src_dtype = getattr(torch, src_dtype_str)
Expand All @@ -45,13 +52,17 @@ def test_block_copy(dtypes_str, n, padding_option, device):
b = torch.zeros((n, ), device=device, dtype=dst_dtype)

grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), )
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option)
block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, PADDING_OPTION=padding_option,
TEST_LOWER_BOUND=boundary_check == "lower", TEST_UPPER_BOUND=boundary_check == "upper")
a.to(dst_dtype)
assert torch.all(a[0:n // 2] == b[0:n // 2])
if padding_option == "zero":
assert torch.all(b[n // 2:n] == 0)
elif padding_option == "nan":
assert torch.all(torch.isnan(b[n // 2:n]))
if (boundary_check == "lower") or (boundary_check == "upper"):
assert torch.all(b == 0)
else:
assert torch.all(a[0:n // 2] == b[0:n // 2])
if padding_option == "zero":
assert torch.all(b[n // 2:n] == 0)
elif padding_option == "nan":
assert torch.all(torch.isnan(b[n // 2:n]))


@triton.jit
Expand All @@ -69,9 +80,6 @@ def matmul_no_scf_with_advance_kernel( #
block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))
b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),
block_shape=(BLOCK_K, BLOCK_N), order=(1, 0))
# Below two lines are just for testing negative offsets for the `advance` API, which could be removed
a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K))
a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K))
peterbell10 marked this conversation as resolved.
Show resolved Hide resolved
a = tl.load(a_block_ptr, boundary_check=(1, ), padding_option="zero")
b = tl.load(b_block_ptr, boundary_check=(0, ), padding_option="zero")

Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def materialize_pointers(self, boundary_check):
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
if dim in boundary_check:
masks = np.logical_and(masks, off < self.shape[dim].data)
masks = masks & (off < self.shape[dim].data) & (off >= 0)
ptrs = TensorHandle(ptrs, self.base.dtype.scalar)
return ptrs, masks

Expand Down
Loading