Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: correct casts in RMSNorm to match references #92

Merged
merged 7 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 118 additions & 18 deletions src/liger_kernel/ops/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
from triton.language.math import rsqrt


_CASTING_MODE_NONE = tl.constexpr(-1)
_CASTING_MODE_LLAMA = tl.constexpr(0)
_CASTING_MODE_GEMMA = tl.constexpr(1)


@triton.jit
def _rms_norm_forward(
Y_ptr,
Expand All @@ -34,10 +39,11 @@ def _rms_norm_forward(
n_cols,
eps,
offset,
casting_mode: tl.constexpr, # constexpr so the `if` blocks can be optimized out
BLOCK_SIZE: tl.constexpr,
):
"""
y_i = (x_i / (RMS)) * wi, RMS = sqrt(sum(x_i^2) / N)
y_i = (x_i / (RMS)) * (offset + wi), RMS = sqrt(sum(x_i^2) / N)

Reference:
1. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
Expand All @@ -54,8 +60,18 @@ def _rms_norm_forward(
r_ptr += row_idx * r_row_stride

X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
X_row_dtype = X_row.dtype
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)

# On Llama, only inv_rms is computed on fp32
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)

# Gemma computes everything on fp32, and then casts back the output to the original dtype
if casting_mode == _CASTING_MODE_GEMMA:
W_row = W_row.to(tl.float32)
X_row = X_row.to(tl.float32)

mean_square = tl.sum(X_row * X_row, axis=0) / n_cols
inv_rms = rsqrt(mean_square + eps)

Expand All @@ -64,7 +80,13 @@ def _rms_norm_forward(
# However, on the computation side, it can save 4 operations (*, sum, /, sqrt).
tl.store(r_ptr, inv_rms)

Y_row = X_row * inv_rms * (offset + W_row)
X_row = X_row * inv_rms

# On Llama, the multiplication with the weight is done on the original dtype
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(X_row_dtype)

Y_row = X_row * (offset + W_row)

tl.store(Y_ptr + col_offsets, Y_row, mask=mask)

Expand All @@ -84,10 +106,11 @@ def _rms_norm_backward(
n_cols,
eps,
offset,
casting_mode: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
dx = (1 / RMS) * [dy * w - (1 / N) * (1 / RMS^2) * ((dy * w) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dx = (1 / RMS) * [dy * (w + offset - (1 / N) * (1 / RMS^2) * ((dy * (w + offset)) dot x) * x]. * means element-wise multiplication, whileas dot means dot product
dw = sum(dy * (x / RMS)). summation over BxT dimension
"""

Expand All @@ -103,34 +126,95 @@ def _rms_norm_backward(
dY_row = tl.load(dY_ptr + col_offsets, mask=mask, other=0)
X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)
W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)
original_x_dtype = X_row.dtype

# Get cached rms
inv_rms_row = tl.load(r_ptr)

W_row = W_row + offset
dX_row = (inv_rms_row) * (
dY_row * W_row
- (1 / n_cols)
* inv_rms_row
* inv_rms_row
* tl.sum(dY_row * W_row * X_row, axis=0)
* X_row
)
tl.store(dY_ptr + col_offsets, dX_row, mask=mask)

# Different bacward graphs for different casting modes
if casting_mode == _CASTING_MODE_LLAMA:
X_row = X_row.to(tl.float32)
m = (dY_row * W_row).to(tl.float32)
dX_row = inv_rms_row * m

dX_row += (inv_rms_row) * (
-(1 / n_cols)
* inv_rms_row
* inv_rms_row
* tl.sum(m * X_row, axis=0)
* X_row
)

if casting_mode == _CASTING_MODE_GEMMA:
dY_row, W_row, X_row = (
dY_row.to(tl.float32),
W_row.to(tl.float32),
X_row.to(tl.float32),
)
dX_row = inv_rms_row * dY_row * W_row

dX_row += (inv_rms_row) * (
-(1 / n_cols)
* inv_rms_row
* inv_rms_row
* tl.sum(dY_row * W_row * X_row, axis=0)
* X_row
)

# calculate the gradient of W
dW_row = dY_row * X_row * inv_rms_row
if casting_mode == _CASTING_MODE_LLAMA:
dW_row = dY_row * (X_row * inv_rms_row).to(original_x_dtype)
else:
# here X_row is already in fp32 (see previous if block)
dW_row = dY_row * (X_row * inv_rms_row)

tl.store(dY_ptr + col_offsets, dX_row, mask=mask)
tl.store(dW_ptr + col_offsets, dW_row, mask=mask)


_str_to_casting_mode = {
"llama": _CASTING_MODE_LLAMA.value,
"gemma": _CASTING_MODE_GEMMA.value,
"none": _CASTING_MODE_NONE.value,
}


class LigerRMSNormFunction(torch.autograd.Function):
"""
Performs RMSNorm (Root Mean Square Normalization), which normalizes the input tensor `X` using the
weight tensor `W`, with an optional offset and casting mode.

Some models use an 'offset' to shift the weight tensor `W` by a constant value. For example, Gemma
uses an offset of 1.0, so the computation becomes `(X / RMS(X)) * (W + 1.0)` instead of the usual
`(X / RMS(X)) * W`. You can pass the offset value as an argument to the forward function.

In addition, different models cast their inputs at different places during RMSNorm computation. For
example, Gemma casts everything to fp32 nefore starting the computation, while Llama casts only the
inverse RMS to fp32. You can specify the casting mode using the `casting_mode` argument. We currently
support the following casting modes (they match HuggingFace Transformers' implementations):
- 'llama': matches the Llama implementation, where only the inverse RMS is computed on fp32.
- 'gemma': matches the Gemma implementation, where everything is cast to fp32, then computed, then cast back to the original dtype.
- 'none': no casting is done. The computation is done in the original dtype. This saves memory and is slightly faster, but has more error w.r.t. the original implementation.
"""

@staticmethod
@ensure_contiguous
def forward(ctx, X, W, eps, offset=0.0):
def forward(ctx, X, W, eps, offset=0.0, casting_mode="llama"):
"""
yundai424 marked this conversation as resolved.
Show resolved Hide resolved
X: (B, T, H) or (BxT, H)
W: (H,)
"""
if not isinstance(casting_mode, int):
assert (
casting_mode in _str_to_casting_mode
), f"Invalid casting mode: {casting_mode}"
casting_mode = _str_to_casting_mode[casting_mode]
else:
assert (
casting_mode in _str_to_casting_mode.values()
), f"Invalid casting mode: {casting_mode}"

shape = X.shape
dim = shape[-1]
Expand All @@ -140,7 +224,13 @@ def forward(ctx, X, W, eps, offset=0.0):

Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)
# r is to cache (1/rms) for each row
r = torch.empty(n_rows, dtype=X.dtype, device=X.device)
# r is always computed/stored in fp32 if we are using Llama or Gemma casting mode
r_dtype = (
torch.float32
if casting_mode in (_CASTING_MODE_LLAMA.value, _CASTING_MODE_GEMMA.value)
else X.dtype
)
r = torch.empty(n_rows, dtype=r_dtype, device=X.device)

# Check constraints.
assert (
Expand All @@ -159,11 +249,13 @@ def forward(ctx, X, W, eps, offset=0.0):
n_cols,
eps,
offset,
casting_mode,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
ctx.eps = eps
ctx.offset = offset
ctx.casting_mode = casting_mode
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps

Expand All @@ -182,7 +274,14 @@ def backward(ctx, dY):
dY = dY.view(-1, dim)
X, W, r = ctx.saved_tensors
n_rows, n_cols = dY.shape
dW = torch.zeros_like(X)
dW = torch.empty_like(
X,
dtype=(
torch.float32
if ctx.casting_mode == _CASTING_MODE_GEMMA.value
else W.dtype
),
)

# Here we use dY to store the value of dX to save memory
_rms_norm_backward[(n_rows,)](
Expand All @@ -199,9 +298,10 @@ def backward(ctx, dY):
n_cols,
ctx.eps,
ctx.offset,
ctx.casting_mode,
BLOCK_SIZE=ctx.BLOCK_SIZE,
num_warps=ctx.num_warps,
)
dX = dY.view(*shape)
dW = torch.sum(dW, dim=0)
return dX, dW, None, None
dW = torch.sum(dW, dim=0).to(W.dtype)
return dX, dW, None, None, None
4 changes: 3 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def apply_liger_kernel_to_gemma(
modeling_gemma.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
# https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L109
modeling_gemma.GemmaRMSNorm = partial(LigerRMSNorm, offset=1.0, init_fn="zeros")
modeling_gemma.GemmaRMSNorm = partial(
LigerRMSNorm, offset=1.0, init_fn="zeros", casting_mode="gemma"
)
if cross_entropy:
modeling_gemma.CrossEntropyLoss = LigerCrossEntropyLoss
if geglu:
Expand Down
17 changes: 13 additions & 4 deletions src/liger_kernel/transformers/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@


class LigerRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6, offset=0.0, init_fn="ones"):
def __init__(
self, hidden_size, eps=1e-6, offset=0.0, casting_mode="llama", init_fn="ones"
):
super().__init__()
assert init_fn in [
"ones",
Expand All @@ -14,10 +16,17 @@ def __init__(self, hidden_size, eps=1e-6, offset=0.0, init_fn="ones"):
self.weight = nn.Parameter(
torch.ones(hidden_size) if init_fn == "ones" else torch.zeros(hidden_size)
)
self.variance_epsilon = eps
self.offset = offset
self.variance_epsilon, self.offset, self.casting_mode = (
eps,
offset,
casting_mode,
)

def forward(self, hidden_states):
return LigerRMSNormFunction.apply(
hidden_states, self.weight, self.variance_epsilon, self.offset
hidden_states,
self.weight,
self.variance_epsilon,
self.offset,
self.casting_mode,
)
25 changes: 18 additions & 7 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import os
from test.utils import (
DEFAULT_DATASET_PATH,
MiniModelConfig,
Expand Down Expand Up @@ -29,6 +30,16 @@
apply_liger_kernel_to_qwen2,
)

torch.use_deterministic_algorithms(True)

# Only setting torch.use_deterministic_algorithms(True) throws the following error:
# RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`,
# but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an
# environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information,
# go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

MINI_MODEL_SETUPS = {
"mini_llama3": MiniModelConfig(
liger_kernel_patch_func=functools.partial(
Expand Down Expand Up @@ -332,13 +343,13 @@ def run_mini_model(
"model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
[
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
("mini_gemma1", 32, 1e-4, torch.float32, 1e-6, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-6, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-2, 1e-4, 2e-1, 1e-5, 1e-2, 1e-5),
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 2e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
# TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine
# TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass.
Expand Down
42 changes: 28 additions & 14 deletions test/transformers/test_rms_norm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
import os
from test.utils import assert_verbose_allclose

import pytest
import torch
import torch.nn as nn

from liger_kernel.transformers.rms_norm import LigerRMSNorm

torch.use_deterministic_algorithms(True)
yundai424 marked this conversation as resolved.
Show resolved Hide resolved

# Only setting torch.use_deterministic_algorithms(True) might throw the following error:
# RuntimeError: Deterministic behavior was enabled with either `torch.use_deterministic_algorithms(True)` or `at::Context::setDeterministicAlgorithms(true)`,
# but this operation is not deterministic because it uses CuBLAS and you have CUDA >= 10.2. To enable deterministic behavior in this case, you must set an
# environment variable before running your PyTorch application: CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information,
# go to https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility

os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

SLEEP_SECONDS = 0.1


Expand Down Expand Up @@ -58,17 +71,18 @@ def forward(self, x):
"dtype, atol, rtol",
[
(torch.float32, 1e-4, 1e-6),
(torch.bfloat16, 5.0, 1e-5),
(torch.bfloat16, 2e-1, 2e-2),
(torch.float16, 2e-1, 2e-2),
],
)
@pytest.mark.parametrize(
"reference, offset",
"reference, offset, casting_mode",
[
(LlamaRMSNorm, 0.0),
(GemmaRMSNorm, 1.0),
(LlamaRMSNorm, 0.0, "llama"),
(GemmaRMSNorm, 1.0, "gemma"),
],
)
def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset):
def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset, casting_mode):
# h
_tensor = torch.randn(bs, sl, hd, device="cuda", dtype=dtype)

Expand All @@ -84,16 +98,16 @@ def test_correctness(bs, sl, hd, dtype, atol, rtol, reference, offset):
ref_o.backward(do.clone(), retain_graph=True)

# triton
triton_rms = LigerRMSNorm(hidden_size=hd, offset=offset).to("cuda").to(dtype)
triton_rms = (
LigerRMSNorm(hidden_size=hd, offset=offset, casting_mode=casting_mode)
.to("cuda")
.to(dtype)
)
triton_o = triton_rms(h2)
triton_o.backward(do.clone(), retain_graph=True)

assert torch.allclose(ref_o, triton_o, atol=atol, rtol=rtol) is True
assert (
torch.allclose(
ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol
)
is True
assert_verbose_allclose(ref_o, triton_o, atol=atol, rtol=rtol)
assert_verbose_allclose(
ref_rms.weight.grad, triton_rms.weight.grad, atol=atol, rtol=rtol
)

assert torch.allclose(h1.grad, h2.grad, atol=atol, rtol=rtol) is True
assert_verbose_allclose(h1.grad, h2.grad, atol=atol, rtol=rtol)
Loading