Skip to content

Commit

Permalink
feat: add gemma_rmsnorm and gemma_fused_add_rmsnorm (#477)
Browse files Browse the repository at this point in the history
for gemma2 cc @yzh119
  • Loading branch information
zhyncs authored Aug 27, 2024
1 parent 9ee26e7 commit 1a6b17e
Show file tree
Hide file tree
Showing 7 changed files with 345 additions and 1 deletion.
184 changes: 184 additions & 0 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,190 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void GemmaRMSNormKernel(T* __restrict__ input, T* __restrict__ weight,
T* __restrict__ output, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
sum_sq += float(input_vec[j]) * float(input_vec[j]);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
output_vec[j] = float(input_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &weight, &output, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = GemmaRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void GemmaFusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
T* __restrict__ weight, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0.f);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = float(input_vec[j]);
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0.f);
weight_vec.fill(0.f);
residual_vec.fill(0.f);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j]));
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = GemmaFusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

return cudaSuccess;
}

} // namespace norm

} // namespace flashinfer
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization");
m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm,
"Gemma Fused add root mean square normalization");
m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul");
m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul");
m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul");
Expand Down
5 changes: 5 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);
void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);

void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
54 changes: 54 additions & 0 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,57 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso
return true;
});
}

torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) {
CHECK_INPUT(input);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto output = torch::empty_like(input);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaRMSNorm(static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
static_cast<c_type*>(output.data_ptr()), batch_size,
hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"GemmaRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
return output;
}

void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(residual.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::GemmaFusedAddRMSNorm(
static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps,
torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "GemmaFusedAddRMSNorm failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}
2 changes: 1 addition & 1 deletion python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
single_decode_with_kv_cache,
)
from .gemm import SegmentGEMMWrapper, bmm_fp8
from .norm import fused_add_rmsnorm, rmsnorm
from .norm import fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (
BatchPrefillWithPagedKVCacheWrapper,
Expand Down
39 changes: 39 additions & 0 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,42 @@ def fused_add_rmsnorm(
Epsilon for numerical stability.
"""
_kernels.fused_add_rmsnorm(input, residual, weight, eps)


def gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
r"""Gemma Root mean square normalization.
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
Returns
-------
output: torch.Tensor
Gemma Normalized tensor, shape (batch_size, hidden_size).
"""
return _kernels.gemma_rmsnorm(input, weight, eps)


def gemma_fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
):
r"""Gemma Fused add root mean square normalization.
Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
"""
_kernels.gemma_fused_add_rmsnorm(input, residual, weight, eps)
59 changes: 59 additions & 0 deletions python/tests/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ def _norm(x):
return output * w


def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w)
x = x.to(orig_dtype)
return x


def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
orig_dtype = x.dtype
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w)
x = x.to(orig_dtype)
return x, residual


def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
Expand Down Expand Up @@ -76,3 +98,40 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):

torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_norm(batch_size, hidden_size, dtype):
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
w = torch.randn(hidden_size).to(0).to(dtype)

y_ref = gemma_rms_norm(x, w)
y = flashinfer.norm.gemma_rmsnorm(x, w)

numpy.testing.assert_allclose(
y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3
)


@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
eps = 1e-6

x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")

x_native, residual_native = gemma_fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)

x_fused = x.clone()
residual_fused = residual.clone()
flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)

torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)

0 comments on commit 1a6b17e

Please sign in to comment.