Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUBLAS][FP8] Enable R.matmul + R.multiply offloading #16974

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/tvm/relax/backend/contrib/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from tvm.relax.transform import PatternCheckContext

from ..pattern_registry import get_patterns_with_prefix, register_patterns
from ..patterns import make_matmul_pattern, make_matmul_dequantize_pattern
from ..patterns import (
make_matmul_pattern,
make_matmul_dequantize_pattern,
make_matmul_multiply_pattern,
)
from ..utils import has_leaking_intermediate_variables


Expand Down Expand Up @@ -202,6 +206,11 @@ def _check_matmul(context: PatternCheckContext) -> bool:
*make_matmul_dequantize_pattern(transposed_rhs=True),
_check_matmul,
),
(
"cublas.matmul_transposed_multiply",
*make_matmul_multiply_pattern(transposed_rhs=True),
_check_matmul,
),
]
)

Expand Down
38 changes: 38 additions & 0 deletions python/tvm/relax/backend/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,44 @@ def make_matmul_dequantize_pattern(
return out, annotations


def make_matmul_multiply_pattern(
transposed_rhs: bool = False,
) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
"""
Create pattern for matrix multiplication and multiply operation.

Parameters
----------
transposed_rhs: bool
Whether the right hand side of multiplication is transposed.

Returns
-------
pattern: DFPattern
The resulting pattern describing a matrix multiplication.

annotations: Mapping[str, DFPattern]
A mapping from name to sub pattern. It can be used to extract important expressions from
match result, to power the partition check function and codegen.
"""

lhs = wildcard()
rhs = wildcard()
scaleA = wildcard()
scaleB = wildcard()
annotations = {"lhs": lhs, "rhs": rhs, "scaleA": scaleA, "scaleB": scaleB}

if transposed_rhs:
rhs = is_op("relax.permute_dims")(rhs)
out = is_op("relax.matmul")(lhs, rhs)
annotations["root"] = out
scale = is_op("relax.multiply")(scaleA.has_shape((1,)), scaleB.has_shape((1,)))
out = is_op("relax.multiply")(out, scale)
out = is_op("relax.astype")(out)

return out, annotations


def make_attention_rewrite_pattern(
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
):
Expand Down
5 changes: 4 additions & 1 deletion src/relax/backend/contrib/cublas/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,17 @@ class CublasJSONSerializer : public JSONSerializer {
inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end());
}

ICHECK(inputs_tmp.size() <= 3);
ICHECK(inputs_tmp.size() <= 4);
NodeEntries inputs(inputs_tmp.size());

auto arg_idx = backend::ExtractArgIdx(composite_name, fn);
inputs[0] = inputs_tmp[arg_idx["lhs"]->value];
inputs[1] = inputs_tmp[arg_idx["rhs"]->value];
if (inputs_tmp.size() == 3) {
inputs[2] = inputs_tmp[arg_idx["bias"]->value];
} else if (inputs_tmp.size() == 4) {
inputs[2] = inputs_tmp[arg_idx["scaleA"]->value];
inputs[3] = inputs_tmp[arg_idx["scaleB"]->value];
}

auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */
Expand Down
14 changes: 12 additions & 2 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,9 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; }

void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue,
const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
size_t workspace_size, cublasLtEpilogue_t epilogue,
std::optional<float> dq_scale) {
ICHECK(TypeEqual(A->dtype, B->dtype));
// Reversed strides indicates an in-place transpose operation.
Expand Down Expand Up @@ -193,6 +194,15 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
&bias->data, sizeof(float*)));
}

if (scaleA != nullptr && scaleB != nullptr) {
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&scaleA_data, sizeof(float*)));
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&scaleB_data, sizeof(float*)));
}

if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue, sizeof(epilogue)));
Expand Down
15 changes: 10 additions & 5 deletions src/runtime/contrib/cublas/cublas_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,15 @@ class CublasJSONRuntime : public JSONRuntimeBase {
return dl_tensors[eid];
};

auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
const DLTensor* bias = nullptr;
auto get_inputs = [=](const JSONGraphNode& node, bool has_bias, bool has_scale) {
const DLTensor *bias = nullptr, *scaleA = nullptr, *scaleB = nullptr;
if (has_bias) {
bias = get_input(node, 2);
} else if (has_scale) {
scaleA = get_input(node, 2);
scaleB = get_input(node, 3);
}
return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
return std::make_tuple(get_input(node, 0), get_input(node, 1), bias, scaleA, scaleB);
};

for (size_t i = 0; i < nodes_.size(); ++i) {
Expand All @@ -127,15 +130,17 @@ class CublasJSONRuntime : public JSONRuntimeBase {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}

auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT);
bool has_scale = op_name.find("multiply") != std::string::npos;
auto [a_ptr, b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr] =
get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT, has_scale);

std::optional<float> dq_scale = std::nullopt;
if (op_name.find("dequantize") != std::string::npos) {
dq_scale = std::stof(node.GetAttr<std::vector<std::string>>("dq_scale")[0]);
}

tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr,
b_ptr, bias_ptr, out_ptr, transa, transb,
b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr, out_ptr, transa, transb,
entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue,
dq_scale);
}
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/contrib/cublas/cublas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
/*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
void* workspace_ptr, size_t workspace_size,
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
size_t workspace_size, cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
std::optional<float> dq_scale = std::nullopt);

} // namespace contrib
Expand Down
79 changes: 79 additions & 0 deletions tests/python/relax/test_codegen_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,40 @@ def get_relax_matmul_dequantize_module(
return tvm.IRModule({"main": func})


def get_relax_matmul_multiply_module(
x_shape,
y_shape,
z_shape,
in_dtype,
acc_dtype,
out_dtype,
transposed_y=False,
):
"""Create a matmul op followd by multiply operations."""
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
x = R.arg("x", R.Tensor(x_shape, in_dtype))
y = R.arg("y", R.Tensor(y_shape, in_dtype))
scaleA = R.arg("scaleA", R.Tensor(z_shape, acc_dtype))
scaleB = R.arg("scaleB", R.Tensor(z_shape, acc_dtype))

with R.dataflow() as frame:
if transposed_y:
axes = list(range(len(y_shape) - 2)) + [-1, -2]
y = R.emit(R.permute_dims(y, axes=axes))
result = R.emit(R.matmul(x, y, out_dtype=acc_dtype))
z = R.emit(R.multiply(scaleA, scaleB))
result = R.emit(R.multiply(result, z))
if acc_dtype != out_dtype:
result = R.emit(R.astype(result, out_dtype))
R.output(result)
R.func_ret_value(frame.output_vars[0])

func = builder.get()
return tvm.IRModule({"main": func})


@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, epilogue",
[
Expand Down Expand Up @@ -327,6 +361,36 @@ def test_matmul_fp8_dequantize_offload():
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)


@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
def test_matmul_fp8_multiply_offload():
x_shape = (10, 32)
y_shape = (64, 32)
z_shape = (1,)
in_dtype, acc_dtype = ("e4m3_float8", "float32")

mod = get_relax_matmul_multiply_module(
x_shape,
y_shape,
z_shape,
in_dtype,
acc_dtype,
"float16",
transposed_y=True,
)

numpytype = "float8_e4m3fn"
x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
scaleA = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
scaleB = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
args = (x, y, scaleA, scaleB)

out = get_result_with_relax_cublas_offload(mod, args)
ref = build_and_run(mod, args, "llvm", legalize=True)
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)


@pytest.mark.parametrize(
"M, N, K, out_dtype, transposed_y, partition_done",
[
Expand Down Expand Up @@ -371,6 +435,21 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings
assert len(mod["main"].body.blocks[0].bindings) == num_bindings


def test_cublas_partition_fp8_matmul_multiply():
M, N, K = (32, 64, 128)
mod = get_relax_matmul_multiply_module(
(M, K),
(N, K),
(1,),
"e4m3_float8",
"float32",
"float16",
transposed_y=True,
)
mod = partition_for_cublas(mod)
assert len(mod["main"].body.blocks[0].bindings) == 1


def test_cublas_partition_matmul_without_bias():
# cuBLAS does not handle 2D bias (residual input)
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))
Expand Down
Loading