Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
vgokhale committed Dec 12, 2024
1 parent ed152a2 commit faa9efe
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions python/perf-kernels/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def layernorm_kernel_blocked_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_ro

@triton.autotune(configs=get_autotune_config(), key=['n_rows', 'n_cols'], use_cuda_graph=True)
@triton.jit
def layernorm_kernel_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, eps, N_COLS: tl.constexpr):
def layernorm_kernel_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride, n_rows, n_cols, eps, BLOCK_SIZE: tl.constexpr):

tl.assume(x_row_stride > 0)
tl.assume(y_row_stride > 0)
Expand All @@ -99,28 +99,29 @@ def layernorm_kernel_impl(x_ptr, y_ptr, w_ptr, b_ptr, x_row_stride, y_row_stride
tl.assume(row > 0)
x_ptr_start = x_ptr + (row * x_row_stride)
y_ptr_start = y_ptr + (row * y_row_stride)
col_offs = tl.arange(0, N_COLS)
col_offs = tl.arange(0, BLOCK_SIZE)

#calculate mean
x_ptrs = x_ptr_start + col_offs
x_block = tl.load(x_ptrs, cache_modifier=".cg").to(tl.float32) #Unmasked loads
mean = tl.sum(x_block, axis=0) / N_COLS
_x_block = x_block - mean
var = tl.sum(_x_block * _x_block, axis=0) / N_COLS
mask = col_offs < n_cols
x_block = tl.load(x_ptrs, cache_modifier=".cg", mask=mask, other=0.0).to(tl.float32) #Unmasked loads
mean = tl.sum(x_block, axis=0) / n_cols
_x_block = tl.where(mask, x_block - mean, 0.0)
var = tl.sum(_x_block * _x_block, axis=0) / n_cols
rstd = tl.rsqrt(var + eps)

w_block = tl.load(w_ptr + col_offs)
b_block = tl.load(b_ptr + col_offs)
w_block = tl.load(w_ptr + col_offs, mask=mask, other=0.0)
b_block = tl.load(b_ptr + col_offs, mask=mask, other=0.0)
y_block = (x_block - mean) * rstd
y_block = y_block * w_block + b_block
tl.store(y_ptr_start + col_offs, y_block)
tl.store(y_ptr_start + col_offs, y_block, mask=mask)

def layernorm(x, y, w, b, eps=1e-5):
n_rows, n_cols = x.shape

grid = lambda meta: (n_rows, )
if n_cols <= 8192:
layernorm_kernel_impl[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, eps, N_COLS=triton.next_power_of_2(n_cols))
layernorm_kernel_impl[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps, BLOCK_SIZE=triton.next_power_of_2(n_cols))
else:
layernorm_kernel_blocked_impl[grid](x, y, w, b, x.stride(0), y.stride(0), n_rows, n_cols, eps, BLOCK_SIZE=2048)

Expand Down

0 comments on commit faa9efe

Please sign in to comment.