From 61328aeecda81a3418c6c742502e9c3c82e36c00 Mon Sep 17 00:00:00 2001 From: Ivan Sidorenko Date: Tue, 7 May 2024 11:45:34 +0000 Subject: [PATCH] [CUBLAS][FP8] Enable R.matmul + R.multiply offloading This commit enables offloading of the next pattern to cuBLAS: mm = R.linear(data, weights) scale = R.multiply(a_scale, w_scale) out = R.multiply(mm, scale) out = R.cast(out, dtype) --- python/tvm/relax/backend/contrib/cublas.py | 11 ++- python/tvm/relax/backend/patterns.py | 38 +++++++++ src/relax/backend/contrib/cublas/codegen.cc | 5 +- src/runtime/contrib/cublas/cublas.cc | 14 +++- .../contrib/cublas/cublas_json_runtime.cc | 15 ++-- src/runtime/contrib/cublas/cublas_utils.h | 6 +- tests/python/relax/test_codegen_cublas.py | 79 +++++++++++++++++++ 7 files changed, 156 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cublas.py b/python/tvm/relax/backend/contrib/cublas.py index e5bc55c32751d..db4bd332c5bad 100644 --- a/python/tvm/relax/backend/contrib/cublas.py +++ b/python/tvm/relax/backend/contrib/cublas.py @@ -25,7 +25,11 @@ from tvm.relax.transform import PatternCheckContext from ..pattern_registry import get_patterns_with_prefix, register_patterns -from ..patterns import make_matmul_pattern, make_matmul_dequantize_pattern +from ..patterns import ( + make_matmul_pattern, + make_matmul_dequantize_pattern, + make_matmul_multiply_pattern, +) from ..utils import has_leaking_intermediate_variables @@ -202,6 +206,11 @@ def _check_matmul(context: PatternCheckContext) -> bool: *make_matmul_dequantize_pattern(transposed_rhs=True), _check_matmul, ), + ( + "cublas.matmul_transposed_multiply", + *make_matmul_multiply_pattern(transposed_rhs=True), + _check_matmul, + ), ] ) diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 404f7dc97526d..26fde7d9c1ce3 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -376,6 +376,44 @@ def make_matmul_dequantize_pattern( return out, annotations +def make_matmul_multiply_pattern( + transposed_rhs: bool = False, +) -> Tuple[DFPattern, Mapping[str, DFPattern]]: + """ + Create pattern for matrix multiplication and dequantize operation. + + Parameters + ---------- + transposed_rhs: bool + Whether the right hand side of multiplication is transposed. + + Returns + ------- + pattern: DFPattern + The resulting pattern describing a matrix multiplication. + + annotations: Mapping[str, DFPattern] + A mapping from name to sub pattern. It can be used to extract important expressions from + match result, to power the partition check function and codegen. + """ + + lhs = wildcard() + rhs = wildcard() + scaleA = wildcard() + scaleB = wildcard() + annotations = {"lhs": lhs, "rhs": rhs, "scaleA": scaleA, "scaleB": scaleB} + + if transposed_rhs: + rhs = is_op("relax.permute_dims")(rhs) + out = is_op("relax.matmul")(lhs, rhs) + annotations["root"] = out + scale = is_op("relax.multiply")(scaleA.has_shape((1,)), scaleB.has_shape((1,))) + out = is_op("relax.multiply")(out, scale) + out = is_op("relax.astype")(out) + + return out, annotations + + def make_attention_rewrite_pattern( qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False ): diff --git a/src/relax/backend/contrib/cublas/codegen.cc b/src/relax/backend/contrib/cublas/codegen.cc index 9f29d21aaa3d2..e92ee57a5a02b 100644 --- a/src/relax/backend/contrib/cublas/codegen.cc +++ b/src/relax/backend/contrib/cublas/codegen.cc @@ -62,7 +62,7 @@ class CublasJSONSerializer : public JSONSerializer { inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end()); } - ICHECK(inputs_tmp.size() <= 3); + ICHECK(inputs_tmp.size() <= 4); NodeEntries inputs(inputs_tmp.size()); auto arg_idx = backend::ExtractArgIdx(composite_name, fn); @@ -70,6 +70,9 @@ class CublasJSONSerializer : public JSONSerializer { inputs[1] = inputs_tmp[arg_idx["rhs"]->value]; if (inputs_tmp.size() == 3) { inputs[2] = inputs_tmp[arg_idx["bias"]->value]; + } else if (inputs_tmp.size() == 4) { + inputs[2] = inputs_tmp[arg_idx["scaleA"]->value]; + inputs[3] = inputs_tmp[arg_idx["scaleB"]->value]; } auto node = std::make_shared(composite_name, /* name_ */ diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 1edb6b95c962b..8925080abfbcb 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -137,8 +137,9 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; } void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B, - const DLTensor* bias, const DLTensor* C, bool transa, bool transb, - void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue, + const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB, + const DLTensor* C, bool transa, bool transb, void* workspace_ptr, + size_t workspace_size, cublasLtEpilogue_t epilogue, std::optional dq_scale) { ICHECK(TypeEqual(A->dtype, B->dtype)); // Reversed strides indicates an in-place transpose operation. @@ -193,6 +194,15 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, &bias->data, sizeof(float*))); } + if (scaleA != nullptr && scaleB != nullptr) { + auto scaleA_data = static_cast(scaleA->data) + scaleA->byte_offset; + auto scaleB_data = static_cast(scaleB->data) + scaleB->byte_offset; + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &scaleA_data, sizeof(float*))); + CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &scaleB_data, sizeof(float*))); + } + if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) { CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index 8578d86789b8b..49ff061da5df8 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -97,12 +97,15 @@ class CublasJSONRuntime : public JSONRuntimeBase { return dl_tensors[eid]; }; - auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) { - const DLTensor* bias = nullptr; + auto get_inputs = [=](const JSONGraphNode& node, bool has_bias, bool has_scale) { + const DLTensor *bias = nullptr, *scaleA = nullptr, *scaleB = nullptr; if (has_bias) { bias = get_input(node, 2); + } else if (has_scale) { + scaleA = get_input(node, 2); + scaleB = get_input(node, 3); } - return std::make_tuple(get_input(node, 0), get_input(node, 1), bias); + return std::make_tuple(get_input(node, 0), get_input(node, 1), bias, scaleA, scaleB); }; for (size_t i = 0; i < nodes_.size(); ++i) { @@ -127,7 +130,9 @@ class CublasJSONRuntime : public JSONRuntimeBase { epilogue = CUBLASLT_EPILOGUE_BIAS; } - auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT); + bool has_scale = op_name.find("multiply") != std::string::npos; + auto [a_ptr, b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr] = + get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT, has_scale); std::optional dq_scale = std::nullopt; if (op_name.find("dequantize") != std::string::npos) { @@ -135,7 +140,7 @@ class CublasJSONRuntime : public JSONRuntimeBase { } tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr, - b_ptr, bias_ptr, out_ptr, transa, transb, + b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr, out_ptr, transa, transb, entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue, dq_scale); } diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 2906279f904a0..387065093eaa5 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -123,9 +123,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { /*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B, - const DLTensor* bias, const DLTensor* C, bool transa, bool transb, - void* workspace_ptr, size_t workspace_size, - cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT, + const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB, + const DLTensor* C, bool transa, bool transb, void* workspace_ptr, + size_t workspace_size, cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT, std::optional dq_scale = std::nullopt); } // namespace contrib diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 4ff498ae2b939..913f203d1965a 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -134,6 +134,40 @@ def get_relax_matmul_dequantize_module( return tvm.IRModule({"main": func}) +def get_relax_matmul_multiply_module( + x_shape, + y_shape, + z_shape, + in_dtype, + acc_dtype, + out_dtype, + transposed_y=False, +): + """Create a matmul op followd by multiply operations.""" + with IRBuilder() as builder: + with relax_builder.function(): + R.func_name("main") + x = R.arg("x", R.Tensor(x_shape, in_dtype)) + y = R.arg("y", R.Tensor(y_shape, in_dtype)) + scaleA = R.arg("scaleA", R.Tensor(z_shape, acc_dtype)) + scaleB = R.arg("scaleB", R.Tensor(z_shape, acc_dtype)) + + with R.dataflow() as frame: + if transposed_y: + axes = list(range(len(y_shape) - 2)) + [-1, -2] + y = R.emit(R.permute_dims(y, axes=axes)) + result = R.emit(R.matmul(x, y, out_dtype=acc_dtype)) + z = R.emit(R.multiply(scaleA, scaleB)) + result = R.emit(R.multiply(result, z)) + if acc_dtype != out_dtype: + result = R.emit(R.astype(result, out_dtype)) + R.output(result) + R.func_ret_value(frame.output_vars[0]) + + func = builder.get() + return tvm.IRModule({"main": func}) + + @pytest.mark.parametrize( "x_shape, y_shape, transpose_y, epilogue", [ @@ -327,6 +361,36 @@ def test_matmul_fp8_dequantize_offload(): tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) +@tvm.testing.requires_cuda_compute_version(9) +@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") +def test_matmul_fp8_multiply_offload(): + x_shape = (10, 32) + y_shape = (64, 32) + z_shape = (1,) + in_dtype, acc_dtype = ("e4m3_float8", "float32") + + mod = get_relax_matmul_multiply_module( + x_shape, + y_shape, + z_shape, + in_dtype, + acc_dtype, + "float16", + transposed_y=True, + ) + + numpytype = "float8_e4m3fn" + x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype) + y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype) + scaleA = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype) + scaleB = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype) + args = (x, y, scaleA, scaleB) + + out = get_result_with_relax_cublas_offload(mod, args) + ref = build_and_run(mod, args, "llvm", legalize=True) + tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) + + @pytest.mark.parametrize( "M, N, K, out_dtype, transposed_y, partition_done", [ @@ -371,6 +435,21 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings assert len(mod["main"].body.blocks[0].bindings) == num_bindings +def test_cublas_partition_fp8_matmul_multiply(): + M, N, K = (32, 64, 128) + mod = get_relax_matmul_multiply_module( + (M, K), + (N, K), + (1,), + "e4m3_float8", + "float32", + "float16", + transposed_y=True, + ) + mod = partition_for_cublas(mod) + assert len(mod["main"].body.blocks[0].bindings) == 1 + + def test_cublas_partition_matmul_without_bias(): # cuBLAS does not handle 2D bias (residual input) mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))