Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[draft]Fp8 fast accumulation #6388

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -901,12 +901,16 @@ StatusOr<se::gpu::BlasLt::Epilogue> AsBlasLtEpilogue(
se::blas::ComputationType computation_type,
GetBlasComputationType(lhs_layout.dtype, output_layout.dtype,
config.compute_precision));

// For FP8 matmuls, fast acumulation is only turned on when both operands's
// precision are DEFAULT.
bool fast_accum = (primitive_util::IsF8Type(lhs_layout.dtype) ||
primitive_util::IsF8Type(rhs_layout.dtype)) &&
config.compute_precision == 0;
TF_ASSIGN_OR_RETURN(
se::gpu::BlasLt::MatmulDesc op_desc,
se::gpu::BlasLt::MatmulDesc::Create(
computation_type, GetScaleType(output_dtype, computation_type),
trans_a, trans_b, epilogue));
trans_a, trans_b, epilogue, fast_accum));

TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::MatrixLayout a_desc,
AsBlasLtMatrixLayout(lhs_layout));
Expand Down
2 changes: 2 additions & 0 deletions xla/stream_executor/cuda/cuda_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
AsCublasOperation(trans_b)));
TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
TF_RETURN_IF_ERROR(
SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, int8_t(fast_accum)));
return std::move(desc);
}

Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/cuda_blas_lt.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class BlasLt {
blas::ComputationType compute_type, blas::DataType scale_type,
blas::Transpose trans_a = blas::Transpose::kNoTranspose,
blas::Transpose trans_b = blas::Transpose::kNoTranspose,
Epilogue epilogue = Epilogue::kDefault,
Epilogue epilogue = Epilogue::kDefault, bool fast_accum = false,
PointerMode pointer_mode = PointerMode::kHost);

cublasComputeType_t compute_type() const;
Expand Down