diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index 8e84a9f82a08..aff7a29d8781 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -7,29 +7,36 @@ @triton.jit -def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.constexpr): +def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, PADDING_OPTION: tl.constexpr, + TEST_LOWER_BOUND: tl.constexpr, TEST_UPPER_BOUND: tl.constexpr): pid = tl.program_id(0) + offset = pid * BLOCK_SIZE + if TEST_LOWER_BOUND: + offset = -N + elif TEST_UPPER_BOUND: + offset = N # We only copy half of the data to see if the padding works - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(N // 2, ), strides=(1, ), offsets=(offset, ), block_shape=(BLOCK_SIZE, ), order=(0, )) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(N, ), strides=(1, ), offsets=(offset, ), block_shape=(BLOCK_SIZE, ), order=(0, )) - if padding_option is None: + if PADDING_OPTION is None: a = tl.load(a_block_ptr, boundary_check=(0, )) else: - a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=padding_option) + a = tl.load(a_block_ptr, boundary_check=(0, ), padding_option=PADDING_OPTION) tl.store(b_block_ptr, a, boundary_check=(0, )) @pytest.mark.interpreter -@pytest.mark.parametrize("dtypes_str, n, padding_option", [ # - (dtypes_str, n, padding) +@pytest.mark.parametrize("dtypes_str, n, padding_option, boundary_check", [ # + (dtypes_str, n, padding, boundary_check) # for dtypes_str in (("bool", "bool"), ("int16", "int16"), ("int32", "int32"), ("float16", "float16"), ("float32", "float32"), ("bfloat16", "bfloat16")) for n in (64, 128, 256, 512, 1024) for padding in (None, "zero", "nan") # + for boundary_check in (None, "lower", "upper") ]) -def test_block_copy(dtypes_str, n, padding_option, device): +def test_block_copy(dtypes_str, n, padding_option, boundary_check, device): src_dtype_str = dtypes_str[0] dst_dtype_str = dtypes_str[1] src_dtype = getattr(torch, src_dtype_str) @@ -45,13 +52,17 @@ def test_block_copy(dtypes_str, n, padding_option, device): b = torch.zeros((n, ), device=device, dtype=dst_dtype) grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) - block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) + block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, PADDING_OPTION=padding_option, + TEST_LOWER_BOUND=boundary_check == "lower", TEST_UPPER_BOUND=boundary_check == "upper") a.to(dst_dtype) - assert torch.all(a[0:n // 2] == b[0:n // 2]) - if padding_option == "zero": - assert torch.all(b[n // 2:n] == 0) - elif padding_option == "nan": - assert torch.all(torch.isnan(b[n // 2:n])) + if (boundary_check == "lower") or (boundary_check == "upper"): + assert torch.all(b == 0) + else: + assert torch.all(a[0:n // 2] == b[0:n // 2]) + if padding_option == "zero": + assert torch.all(b[n // 2:n] == 0) + elif padding_option == "nan": + assert torch.all(torch.isnan(b[n // 2:n])) @triton.jit diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index b85cc2e9aec7..3b94f55ea3c0 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -65,7 +65,7 @@ def materialize_pointers(self, boundary_check): off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims) ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) if dim in boundary_check: - masks = np.logical_and(masks, off < self.shape[dim].data) + masks = masks & (off < self.shape[dim].data) & (off >= 0) ptrs = TensorHandle(ptrs, self.base.dtype.scalar) return ptrs, masks