Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTERPRETER][NFC] Rename tensor_shape -> block_shape in interpreter #5195

Merged
merged 3 commits into from
Nov 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, data, dtype):
'''
data: numpy array
dtype: triton type, either pointer_type or scalar_type.
we don't store block_type here because the shape information is already availale in the data field
we don't store block_type here because the shape information is already available in the data field
attr: a dictionary of attributes
'''
self.data = data
Expand All @@ -46,24 +46,23 @@ def set_attr(self, key, value):

class BlockPointerHandle:

def __init__(self, base, shape, strides, offsets, tensor_shape, order):
def __init__(self, base, shape, strides, offsets, block_shape, order):
self.base = base
self.shape = shape
self.strides = strides
self.offsets = offsets
self.tensor_shape = tensor_shape
self.block_shape = block_shape
self.order = order

def materialize_pointers(self, boundary_check):
dtype_tt = self.base.get_element_ty()
n_bytes = dtype_tt.primitive_bitwidth // 8
tensor_shape = self.tensor_shape
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
masks = np.ones(self.tensor_shape, dtype=bool)
for dim in range(len(tensor_shape)):
bcast_dims = [1] * len(tensor_shape)
bcast_dims[dim] = tensor_shape[dim]
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
ptrs = np.broadcast_to(self.base.data, self.block_shape)
masks = np.ones(self.block_shape, dtype=bool)
for dim in range(len(self.block_shape)):
bcast_dims = [1] * len(self.block_shape)
bcast_dims[dim] = self.block_shape[dim]
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
if dim in boundary_check:
masks = np.logical_and(masks, off < self.shape[dim].data)
Expand Down Expand Up @@ -655,17 +654,17 @@ def create_barrier(self):
# Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
pass

def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
# Create new offsets to avoid modifying the original
new_offsets = [offset.clone() for offset in offsets]
return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order)
return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)

def create_advance(self, ptr, offsets):
if len(ptr.offsets) != len(offsets):
raise ValueError("len(ptr.offsets) != len(offsets)")
# Create new offsets to avoid modifying the original
new_offsets = [offset.clone() for offset in ptr.offsets]
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order)
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
for i in range(len(offsets)):
ret.offsets[i].data += offsets[i].data
return ret
Expand Down
Loading