From 68bcdf81a47fb0f753d837c034931094c5cd8017 Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Tue, 22 Oct 2024 12:02:40 -0700 Subject: [PATCH 1/2] Refactor Add Handler --- .../cpu/onednn_contraction_rewriter.cc | 268 ++++++++++-------- xla/service/cpu/onednn_contraction_rewriter.h | 32 ++- xla/service/cpu/onednn_convolution.h | 1 + xla/service/cpu/onednn_matmul.h | 1 + xla/service/cpu/onednn_util.h | 2 +- 5 files changed, 166 insertions(+), 138 deletions(-) diff --git a/xla/service/cpu/onednn_contraction_rewriter.cc b/xla/service/cpu/onednn_contraction_rewriter.cc index 01ffb340e07c1..e351c1d0ce656 100644 --- a/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/xla/service/cpu/onednn_contraction_rewriter.cc @@ -196,6 +196,38 @@ std::optional GetConstantValueAsFloat32(const HloInstruction* inst) { } } +ContractionVariant GetContractionVariant( + absl::StatusOr* backend_config) { + return ((*backend_config)->backend_config_oneof_case() == kOnednnConvConfig) + ? ContractionVariant(PrimitiveTrait{}) + : ContractionVariant(PrimitiveTrait{}); +} + +// Return the correct mutable config instance for the given contraction variant +// based on the template parameter +template +TransformationType GetTransformationConfig( + absl::StatusOr* backend_config) { + return std::visit( + [&](auto&& config) -> TransformationType { + using T = std::decay_t; + return PrimitiveTrait:: + GetTransformationConfig( + GetKernelConfig(backend_config)); + }, + GetContractionVariant(backend_config)); +} + +FusionsConfigPointer GetFusionsConfig( + absl::StatusOr* backend_config) { + return GetTransformationConfig(backend_config); +} + +OptimizationsConfigPointer GetOptimizationsConfig( + absl::StatusOr* backend_config) { + return GetTransformationConfig(backend_config); +} + inline auto BcastConstScalarNear(double value) { return m::Broadcast(ConstScalarNear(value)); } @@ -285,7 +317,7 @@ auto GELUActivation(HloInstruction* instr, HloInstruction** src) { return OneDnnFusionConfig::UNDEFINED; } -// OneDNN matmul / convolution can fuse add operation with automatic +// OneDNN matmul and convolution can fuse add operation with automatic // broadcasting along the addend's dimensions that are 1s. When compatible, // Broadcast can be replaced by Bitcast, which is much cheaper. Compute new // shape for the Bitcast. @@ -612,141 +644,125 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { m::Op(&addend_intermediate)); if (Match(instr, pattern)) { - HANDLE_OP_INTERNAL(HandleAddInternal, contraction, instr, - addend_intermediate, optional_contraction_convert, - optional_contraction_bitcast); - } - - return absl::OkStatus(); - } + if (!IsSupportedType(contraction->shape().element_type())) + return absl::OkStatus(); + // TODO(intel-tf): Remove the condition below when the fusion Contraction + // + Add(bias) + Add(e.g., residual) is enabled. + auto contraction_config = contraction->backend_config(); + auto orig_fusion_config = GetFusionsConfig(&contraction_config); + if (!orig_fusion_config->ops().empty() && + orig_fusion_config->ops(0) == OneDnnFusionConfig::BIAS) { + return absl::OkStatus(); + } + std::vector new_operands; + for (auto operand : contraction->operands()) { + new_operands.push_back(operand); + } - template - absl::Status HandleAddInternal(HloInstruction* contraction, - HloInstruction* instr, - HloInstruction* addend_intermediate, - HloInstruction* optional_contraction_convert, - HloInstruction* optional_contraction_bitcast) { - if (!IsSupportedType(contraction->shape().element_type())) - return absl::OkStatus(); - // TODO(intel-tf): Remove the condition below when the fusion Contraction + - // Add(bias) + Add(e.g., residual) is enabled. - auto contraction_config = contraction->backend_config(); - if (!GetKernelConfig(&contraction_config) - ->mutable_fusions() - ->ops() - .empty() && - GetKernelConfig(&contraction_config) - ->mutable_fusions() - ->ops(0) == OneDnnFusionConfig::BIAS) { - return absl::OkStatus(); - } - std::vector new_operands; - for (auto operand : contraction->operands()) { - new_operands.push_back(operand); - } + // At this point, the addend could have one of the following + // possiblities that the current fusion can handle: + // + // - addend -> Convert -> Broadcast -> Add + // - addend -> Broadcast -> Convert -> Add + // - addend -> Convert + // - addend -> Broadcast + // - addend + // + // Hunt for addend through possible sequences above and check the addend + // is compatible for onednn fusion. + HloInstruction* addend = nullptr; + HloInstruction* optional_addend_broadcast = nullptr; + auto addend_pattern = m::AnyOf( + m::Broadcast(&optional_addend_broadcast, + m::Convert(&addend, m::Op())), + m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))), + m::Convert(&addend, m::Op()), + m::Broadcast(&optional_addend_broadcast, m::Op(&addend)), + m::Op(&addend)); + if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus(); + + if (optional_addend_broadcast && addend->shape().rank() != 1) { + auto new_shape = + AdjustBiasShape(optional_addend_broadcast, contraction->shape()); + if (new_shape.ok()) { + addend = addend->AddInstruction( + HloInstruction::CreateBitcast(new_shape.value(), addend)); + } else { + VLOG(2) << new_shape.status(); + return absl::OkStatus(); + } + } - // At this point, the addend could have one of the following - // possiblities that the current fusion can handle: - // - // - addend -> Convert -> Broadcast -> Add - // - addend -> Broadcast -> Convert -> Add - // - addend -> Convert - // - addend -> Broadcast - // - addend - // - // Hunt for addend through possible sequences above and check the addend - // is compatible for onednn fusion. - HloInstruction* addend = nullptr; - HloInstruction* optional_addend_broadcast = nullptr; - auto addend_pattern = m::AnyOf( - m::Broadcast(&optional_addend_broadcast, m::Convert(&addend, m::Op())), - m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))), - m::Convert(&addend, m::Op()), - m::Broadcast(&optional_addend_broadcast, m::Op(&addend)), - m::Op(&addend)); - if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus(); - - if (optional_addend_broadcast && addend->shape().rank() != 1) { - auto new_shape = - AdjustBiasShape(optional_addend_broadcast, contraction->shape()); - if (new_shape.ok()) { - addend = addend->AddInstruction( - HloInstruction::CreateBitcast(new_shape.value(), addend)); + // Validate addend for fusion. + if (IsSupportedType(addend->shape().element_type()) && + IsOperandFusible(addend, contraction)) { + new_operands.push_back(addend); } else { - VLOG(2) << new_shape.status(); return absl::OkStatus(); } - } - - // Validate addend for fusion. - if (IsSupportedType(addend->shape().element_type()) && - IsOperandFusible(addend, contraction)) { - new_operands.push_back(addend); - } else { - return absl::OkStatus(); - } - auto custom_call = Cast(instr->AddInstruction( - contraction->CloneWithNewOperands(contraction->shape(), new_operands))); - - auto backend_config = custom_call->backend_config(); - - // TODO(intel-tf): Remove this restriction once oneDNN has an optimized - // implementation for broadcasted add across all dimensions. - OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; - kind = - (addend->shape().rank() == 1) - ? (GetKernelConfig(&backend_config)->fusions().ops().empty() - ? OneDnnFusionConfig::BIAS - : OneDnnFusionConfig::UNDEFINED) - : OneDnnFusionConfig::BINARY_ADD; - if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); - - GetKernelConfig(&backend_config)->mutable_fusions()->add_ops(kind); - - if (optional_addend_broadcast) { - GetKernelConfig(&backend_config) - ->mutable_optimization_config() - ->set_bias_broadcast(true); - } - TF_RETURN_IF_ERROR(custom_call->set_backend_config(*backend_config)); - - HloInstruction* new_instr; - // If matched pattern has custom-call -> bitcast -> add, then we need to - // insert bitcast after the new fusion to maintain the correct shape - // (new-custom-call -> bitcast). Also, this will optionally be followed - // by -> convert for bf16 case to avoid datatype mismatch. - if (optional_contraction_bitcast != nullptr && - optional_contraction_bitcast->opcode() == HloOpcode::kBitcast) { - if (optional_contraction_convert != nullptr && - optional_contraction_convert->opcode() == HloOpcode::kConvert) { - auto bitcast_call = - custom_call->AddInstruction(HloInstruction::CreateBitcast( - ShapeUtil::ChangeElementType( - instr->shape(), custom_call->shape().element_type()), - custom_call)); - new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType( - bitcast_call->shape(), - optional_contraction_convert->shape().element_type()), - bitcast_call)); - } else { - new_instr = custom_call->AddInstruction( - HloInstruction::CreateBitcast(instr->shape(), custom_call)); + auto custom_call = Cast( + instr->AddInstruction(contraction->CloneWithNewOperands( + contraction->shape(), new_operands))); + + auto backend_config = custom_call->backend_config(); + auto fusions_config = GetFusionsConfig(&backend_config); + auto optimization_config = GetOptimizationsConfig(&backend_config); + // TODO(intel-tf): Remove this restriction once oneDNN has an optimized + // implementation for broadcasted add across all dimensions. + OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; + kind = + (addend->shape().rank() == 1) + ? (fusions_config->ops().empty() ? OneDnnFusionConfig::BIAS + : OneDnnFusionConfig::UNDEFINED) + : OneDnnFusionConfig::BINARY_ADD; + if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus(); + + fusions_config->add_ops(kind); + + if (optional_addend_broadcast) { + optimization_config->set_bias_broadcast(true); } - } else { - if (optional_contraction_convert != nullptr && - optional_contraction_convert->opcode() == HloOpcode::kConvert) { - new_instr = custom_call->AddInstruction(HloInstruction::CreateConvert( - ShapeUtil::ChangeElementType( - custom_call->shape(), - optional_contraction_convert->shape().element_type()), - custom_call)); + TF_RETURN_IF_ERROR(custom_call->set_backend_config(*backend_config)); + + HloInstruction* new_instr; + // If matched pattern has custom-call -> bitcast -> add, then we need to + // insert bitcast after the new fusion to maintain the correct shape + // (new-custom-call -> bitcast). Also, this will optionally be followed + // by -> convert for bf16 case to avoid datatype mismatch. + if (optional_contraction_bitcast != nullptr && + optional_contraction_bitcast->opcode() == HloOpcode::kBitcast) { + if (optional_contraction_convert != nullptr && + optional_contraction_convert->opcode() == HloOpcode::kConvert) { + auto bitcast_call = + custom_call->AddInstruction(HloInstruction::CreateBitcast( + ShapeUtil::ChangeElementType( + instr->shape(), custom_call->shape().element_type()), + custom_call)); + new_instr = + bitcast_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + bitcast_call->shape(), + optional_contraction_convert->shape().element_type()), + bitcast_call)); + } else { + new_instr = custom_call->AddInstruction( + HloInstruction::CreateBitcast(instr->shape(), custom_call)); + } } else { - new_instr = custom_call; + if (optional_contraction_convert != nullptr && + optional_contraction_convert->opcode() == HloOpcode::kConvert) { + new_instr = custom_call->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType( + custom_call->shape(), + optional_contraction_convert->shape().element_type()), + custom_call)); + } else { + new_instr = custom_call; + } } + TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); } - TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr)); return absl::OkStatus(); } diff --git a/xla/service/cpu/onednn_contraction_rewriter.h b/xla/service/cpu/onednn_contraction_rewriter.h index 2706d05d1ef92..f96a500c4ff81 100644 --- a/xla/service/cpu/onednn_contraction_rewriter.h +++ b/xla/service/cpu/onednn_contraction_rewriter.h @@ -24,6 +24,8 @@ limitations under the License. #include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/service/cpu/onednn_matmul.h" +#include "xla/service/cpu/onednn_convolution.h" #include "tsl/platform/threadpool.h" namespace xla { @@ -61,18 +63,26 @@ class OneDnnContractionRewriter : public HloModulePass { const tsl::thread::ThreadPool* compile_threadpool_; }; -#define HANDLE_OP_INTERNAL(internal_callee, contraction, ...) \ - switch (contraction->backend_config() \ - ->backend_config_oneof_case()) { \ - case BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig: \ - return internal_callee< \ - BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig>( \ - contraction, __VA_ARGS__); \ - default: \ - return internal_callee< \ - BackendConfig::BackendConfigOneofCase::kOnednnConvConfig>( \ - contraction, __VA_ARGS__); \ +using ContractionVariant = std::variant, + PrimitiveTrait>; +using FusionsConfigPointer = xla::cpu::OneDnnFusionConfig*; +using OptimizationsConfigPointer = xla::cpu::OneDnnOptimizationConfig*; + +template +struct PrimitiveTrait { + static FusionsConfigPointer GetTransformationConfig( + typename PrimitiveTrait::pointer_type kernel_config) { + return kernel_config->mutable_fusions(); } +}; + +template +struct PrimitiveTrait { + static OptimizationsConfigPointer GetTransformationConfig( + typename PrimitiveTrait::pointer_type kernel_config) { + return kernel_config->mutable_optimization_config(); + } +}; } // namespace cpu } // namespace xla diff --git a/xla/service/cpu/onednn_convolution.h b/xla/service/cpu/onednn_convolution.h index 657cddffb21af..7d61193b9e3b3 100644 --- a/xla/service/cpu/onednn_convolution.h +++ b/xla/service/cpu/onednn_convolution.h @@ -31,6 +31,7 @@ extern void __xla_cpu_runtime_OneDnnConvolution(void* result, void** args); template <> struct PrimitiveTrait { using pointer_type = xla::cpu::OneDnnConvolutionConfig*; + static const BackendConfigOneofCase kConfigVal = kOnednnConvConfig; }; } // namespace cpu diff --git a/xla/service/cpu/onednn_matmul.h b/xla/service/cpu/onednn_matmul.h index bf452e9d9f051..8429e8a2d3922 100644 --- a/xla/service/cpu/onednn_matmul.h +++ b/xla/service/cpu/onednn_matmul.h @@ -43,6 +43,7 @@ extern void __xla_cpu_runtime_OneDnnMatMulReorder(void* result, void** args); template <> struct PrimitiveTrait { using pointer_type = xla::cpu::OneDnnMatMulConfig*; + static const BackendConfigOneofCase kConfigVal = kOnednnMatmulConfig; }; } // namespace cpu diff --git a/xla/service/cpu/onednn_util.h b/xla/service/cpu/onednn_util.h index 09bc5efb4b757..9de597aef9435 100644 --- a/xla/service/cpu/onednn_util.h +++ b/xla/service/cpu/onednn_util.h @@ -66,7 +66,7 @@ typedef BackendConfig::BackendConfigOneofCase BackendConfigOneofCase; template std::unique_ptr CreateOneDnnPrimDesc(HloInstruction*); -template +template struct PrimitiveTrait; template From 462890bb75f2fcea3fdc5966bfa7a2b8f94b255a Mon Sep 17 00:00:00 2001 From: Akhil Goel Date: Tue, 5 Nov 2024 18:10:14 -0800 Subject: [PATCH 2/2] Address review comments --- .../cpu/onednn_contraction_rewriter.cc | 24 +++++++++---------- xla/service/cpu/onednn_contraction_rewriter.h | 15 ++++++------ 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/xla/service/cpu/onednn_contraction_rewriter.cc b/xla/service/cpu/onednn_contraction_rewriter.cc index e351c1d0ce656..29a07a4d062d9 100644 --- a/xla/service/cpu/onednn_contraction_rewriter.cc +++ b/xla/service/cpu/onednn_contraction_rewriter.cc @@ -196,11 +196,11 @@ std::optional GetConstantValueAsFloat32(const HloInstruction* inst) { } } -ContractionVariant GetContractionVariant( +auto GetOneDnnContractionVariant( absl::StatusOr* backend_config) { return ((*backend_config)->backend_config_oneof_case() == kOnednnConvConfig) - ? ContractionVariant(PrimitiveTrait{}) - : ContractionVariant(PrimitiveTrait{}); + ? OneDnnContractionVariant(PrimitiveTrait{}) + : OneDnnContractionVariant(PrimitiveTrait{}); } // Return the correct mutable config instance for the given contraction variant @@ -215,17 +215,15 @@ TransformationType GetTransformationConfig( GetTransformationConfig( GetKernelConfig(backend_config)); }, - GetContractionVariant(backend_config)); + GetOneDnnContractionVariant(backend_config)); } -FusionsConfigPointer GetFusionsConfig( - absl::StatusOr* backend_config) { - return GetTransformationConfig(backend_config); +auto GetFusionsConfig(absl::StatusOr* backend_config) { + return GetTransformationConfig(backend_config); } -OptimizationsConfigPointer GetOptimizationsConfig( - absl::StatusOr* backend_config) { - return GetTransformationConfig(backend_config); +auto GetOptimizationsConfig(absl::StatusOr* backend_config) { + return GetTransformationConfig(backend_config); } inline auto BcastConstScalarNear(double value) { @@ -708,10 +706,10 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor { auto backend_config = custom_call->backend_config(); auto fusions_config = GetFusionsConfig(&backend_config); auto optimization_config = GetOptimizationsConfig(&backend_config); - // TODO(intel-tf): Remove this restriction once oneDNN has an optimized + // TODO(intel-tf): Here, we allow 1D addends only when they are the first + // fused op. Remove this restriction once oneDNN has an optimized // implementation for broadcasted add across all dimensions. - OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED; - kind = + OneDnnFusionConfig_FusionKind kind = (addend->shape().rank() == 1) ? (fusions_config->ops().empty() ? OneDnnFusionConfig::BIAS : OneDnnFusionConfig::UNDEFINED) diff --git a/xla/service/cpu/onednn_contraction_rewriter.h b/xla/service/cpu/onednn_contraction_rewriter.h index f96a500c4ff81..3d641891eac44 100644 --- a/xla/service/cpu/onednn_contraction_rewriter.h +++ b/xla/service/cpu/onednn_contraction_rewriter.h @@ -63,22 +63,21 @@ class OneDnnContractionRewriter : public HloModulePass { const tsl::thread::ThreadPool* compile_threadpool_; }; -using ContractionVariant = std::variant, - PrimitiveTrait>; -using FusionsConfigPointer = xla::cpu::OneDnnFusionConfig*; -using OptimizationsConfigPointer = xla::cpu::OneDnnOptimizationConfig*; +using OneDnnContractionVariant = + std::variant, + PrimitiveTrait>; template -struct PrimitiveTrait { - static FusionsConfigPointer GetTransformationConfig( +struct PrimitiveTrait { + static OneDnnFusionConfig* GetTransformationConfig( typename PrimitiveTrait::pointer_type kernel_config) { return kernel_config->mutable_fusions(); } }; template -struct PrimitiveTrait { - static OptimizationsConfigPointer GetTransformationConfig( +struct PrimitiveTrait { + static OneDnnOptimizationConfig* GetTransformationConfig( typename PrimitiveTrait::pointer_type kernel_config) { return kernel_config->mutable_optimization_config(); }