From 1a6b17e2b78fc811d50030b9326a4d01f1ff956f Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 28 Aug 2024 02:06:52 +1000 Subject: [PATCH] feat: add gemma_rmsnorm and gemma_fused_add_rmsnorm (#477) for gemma2 cc @yzh119 --- include/flashinfer/norm.cuh | 184 ++++++++++++++++++++++++++++++++++ python/csrc/flashinfer_ops.cu | 3 + python/csrc/flashinfer_ops.h | 5 + python/csrc/norm.cu | 54 ++++++++++ python/flashinfer/__init__.py | 2 +- python/flashinfer/norm.py | 39 +++++++ python/tests/test_norm.py | 59 +++++++++++ 7 files changed, 345 insertions(+), 1 deletion(-) diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 6fa7b317..82d2513d 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -212,6 +212,190 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz return cudaSuccess; } +template +__global__ void GemmaRMSNormKernel(T* __restrict__ input, T* __restrict__ weight, + T* __restrict__ output, const uint32_t d, float eps) { + const uint32_t bx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t warp_size = 32; + const uint32_t num_warps = blockDim.y; + const uint32_t thread_id = tx + ty * warp_size; + const uint32_t num_threads = num_warps * warp_size; + const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); + extern __shared__ float smem[]; + + float sum_sq = 0.f; + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + input_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + sum_sq += float(input_vec[j]) * float(input_vec[j]); + } + } + + // first, warp reduce sum +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + + smem[ty] = sum_sq; + __syncthreads(); + // then, cross warp reduce sum using only the first warp + if (ty == 0) { + sum_sq = (tx < num_warps) ? smem[tx] : 0.f; +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + smem[0] = sum_sq; + } + __syncthreads(); + + float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + vec_t weight_vec; + vec_t output_vec; + input_vec.fill(0.f); + weight_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + output_vec[j] = float(input_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j])); + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } +} + +template +cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d, + float eps = 1e-5, cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = num_warps * sizeof(float); + void* args[] = {&input, &weight, &output, &d, &eps}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = GemmaRMSNormKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +template +__global__ void GemmaFusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, + T* __restrict__ weight, const uint32_t d, float eps) { + const uint32_t bx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t warp_size = 32; + const uint32_t num_warps = blockDim.y; + const uint32_t thread_id = tx + ty * warp_size; + const uint32_t num_threads = num_warps * warp_size; + const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); + extern __shared__ float smem[]; + + float sum_sq = 0.f; + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + input_vec.fill(0.f); + vec_t residual_vec; + residual_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + float x = float(input_vec[j]); + x += float(residual_vec[j]); + sum_sq += x * x; + residual_vec[j] = (T)x; + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } + + // first, warp reduce sum +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + + smem[ty] = sum_sq; + __syncthreads(); + // then, cross warp reduce sum using only the first warp + if (ty == 0) { + sum_sq = (tx < num_warps) ? smem[tx] : 0.f; +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + smem[0] = sum_sq; + } + __syncthreads(); + + float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + vec_t weight_vec; + vec_t residual_vec; + input_vec.fill(0.f); + weight_vec.fill(0.f); + residual_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j++) { + input_vec[j] = float(residual_vec[j]) * rms_rcp * (1.0f + float(weight_vec[j])); + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } + } +} + +template +cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, + float eps = 1e-5, cudaStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = num_warps * sizeof(float); + void* args[] = {&input, &residual, &weight, &d, &eps}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = GemmaFusedAddRMSNormKernel; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + + return cudaSuccess; +} + } // namespace norm } // namespace flashinfer diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index aa28b2b0..ff72fd65 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -39,6 +39,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Speculative sampling from sequence of probabilities"); m.def("rmsnorm", &rmsnorm, "Root mean square normalization"); m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization"); + m.def("gemma_rmsnorm", &gemma_rmsnorm, "Gemma Root mean square normalization"); + m.def("gemma_fused_add_rmsnorm", &gemma_fused_add_rmsnorm, + "Gemma Fused add root mean square normalization"); m.def("silu_and_mul", &silu_and_mul, "Fused SiLU and Mul"); m.def("gelu_tanh_and_mul", &gelu_tanh_and_mul, "Fused GeLU Tanh and Mul"); m.def("gelu_and_mul", &gelu_and_mul, "Fused GeLU and Mul"); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 86c921f8..9bd8389d 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -77,6 +77,11 @@ torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps); void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); +torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps); + +void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, + double eps); + void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/python/csrc/norm.cu b/python/csrc/norm.cu index 64be041a..6d2b6dd6 100644 --- a/python/csrc/norm.cu +++ b/python/csrc/norm.cu @@ -73,3 +73,57 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso return true; }); } + +torch::Tensor gemma_rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) { + CHECK_INPUT(input); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + auto output = torch::empty_like(input); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = norm::GemmaRMSNorm(static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, + hidden_size, eps, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "GemmaRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); + return true; + }); + return output; +} + +void gemma_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, + double eps) { + CHECK_INPUT(input); + CHECK_INPUT(residual); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(residual.device(), device); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(0), residual.size(0)); + CHECK_EQ(input.size(1), residual.size(1)); + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = norm::GemmaFusedAddRMSNorm( + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, "GemmaFusedAddRMSNorm failed with error code " + + std::string(cudaGetErrorString(status))); + return true; + }); +} diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 1871cfb0..33473c10 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -29,7 +29,7 @@ single_decode_with_kv_cache, ) from .gemm import SegmentGEMMWrapper, bmm_fp8 -from .norm import fused_add_rmsnorm, rmsnorm +from .norm import fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm from .page import append_paged_kv_cache from .prefill import ( BatchPrefillWithPagedKVCacheWrapper, diff --git a/python/flashinfer/norm.py b/python/flashinfer/norm.py index 63a078ff..67ca7e8b 100644 --- a/python/flashinfer/norm.py +++ b/python/flashinfer/norm.py @@ -69,3 +69,42 @@ def fused_add_rmsnorm( Epsilon for numerical stability. """ _kernels.fused_add_rmsnorm(input, residual, weight, eps) + + +def gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6): + r"""Gemma Root mean square normalization. + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (batch_size, hidden_size). + weight: torch.Tensor + Weight tensor, shape (hidden_size,). + eps: float + Epsilon for numerical stability. + + Returns + ------- + output: torch.Tensor + Gemma Normalized tensor, shape (batch_size, hidden_size). + """ + return _kernels.gemma_rmsnorm(input, weight, eps) + + +def gemma_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +): + r"""Gemma Fused add root mean square normalization. + + Parameters + ---------- + input: torch.Tensor + Input tensor, shape (batch_size, hidden_size). + residual: torch.Tensor + Residual tensor, shape (batch_size, hidden_size). + weight: torch.Tensor + Weight tensor, shape (hidden_size,). + eps: float + Epsilon for numerical stability. + """ + _kernels.gemma_fused_add_rmsnorm(input, residual, weight, eps) diff --git a/python/tests/test_norm.py b/python/tests/test_norm.py index 79877cb0..1e8453b4 100644 --- a/python/tests/test_norm.py +++ b/python/tests/test_norm.py @@ -29,6 +29,28 @@ def _norm(x): return output * w +def gemma_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w) + x = x.to(orig_dtype) + return x + + +def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6): + orig_dtype = x.dtype + x = x + residual + residual = x + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w) + x = x.to(orig_dtype) + return x, residual + + def fused_add_rms_norm(x, residual, weight, eps): orig_dtype = x.dtype x = x.to(torch.float32) @@ -76,3 +98,40 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): torch.testing.assert_close(x_fused, x_native, rtol=1e-2, atol=1e-2) torch.testing.assert_close(residual_fused, residual_native, rtol=1e-2, atol=1e-2) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_norm(batch_size, hidden_size, dtype): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = gemma_rms_norm(x, w) + y = flashinfer.norm.gemma_rmsnorm(x, w) + + numpy.testing.assert_allclose( + y_ref.cpu().numpy(), y.cpu().numpy(), rtol=1e-3, atol=1e-3 + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 4096, 8192]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = gemma_fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + flashinfer.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)