Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:CPU][oneDNN] Refactor code that fuses Add operation with oneDNN primitives #18616

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 141 additions & 127 deletions xla/service/cpu/onednn_contraction_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,36 @@ std::optional<float> GetConstantValueAsFloat32(const HloInstruction* inst) {
}
}

auto GetOneDnnContractionVariant(
absl::StatusOr<BackendConfig>* backend_config) {
return ((*backend_config)->backend_config_oneof_case() == kOnednnConvConfig)
? OneDnnContractionVariant(PrimitiveTrait<kOnednnConvConfig>{})
: OneDnnContractionVariant(PrimitiveTrait<kOnednnMatmulConfig>{});
}

// Return the correct mutable config instance for the given contraction variant
// based on the template parameter
template <typename TransformationType>
TransformationType GetTransformationConfig(
absl::StatusOr<BackendConfig>* backend_config) {
return std::visit(
[&](auto&& config) -> TransformationType {
using T = std::decay_t<decltype(config)>;
return PrimitiveTrait<T::kConfigVal, TransformationType>::
GetTransformationConfig(
GetKernelConfig<T::kConfigVal>(backend_config));
},
GetOneDnnContractionVariant(backend_config));
}

auto GetFusionsConfig(absl::StatusOr<BackendConfig>* backend_config) {
return GetTransformationConfig<OneDnnFusionConfig*>(backend_config);
}

auto GetOptimizationsConfig(absl::StatusOr<BackendConfig>* backend_config) {
return GetTransformationConfig<OneDnnOptimizationConfig*>(backend_config);
}

inline auto BcastConstScalarNear(double value) {
return m::Broadcast(ConstScalarNear(value));
}
Expand Down Expand Up @@ -289,7 +319,7 @@ auto GELUActivation(HloInstruction* instr, HloInstruction** src) {
return OneDnnFusionConfig::UNDEFINED;
}

// OneDNN matmul / convolution can fuse add operation with automatic
// OneDNN matmul and convolution can fuse add operation with automatic
// broadcasting along the addend's dimensions that are 1s. When compatible,
// Broadcast can be replaced by Bitcast, which is much cheaper. Compute new
// shape for the Bitcast.
Expand Down Expand Up @@ -616,142 +646,126 @@ class OneDnnContractionRewriteVisitor : public DfsHloRewriteVisitor {
m::Op(&addend_intermediate));

if (Match(instr, pattern)) {
HANDLE_OP_INTERNAL(HandleAddInternal, contraction, instr,
addend_intermediate, optional_contraction_convert,
optional_contraction_bitcast);
}

return absl::OkStatus();
}

template <typename BackendConfig::BackendConfigOneofCase config>
absl::Status HandleAddInternal(HloInstruction* contraction,
HloInstruction* instr,
HloInstruction* addend_intermediate,
HloInstruction* optional_contraction_convert,
HloInstruction* optional_contraction_bitcast) {
if (!IsSupportedType(contraction->shape().element_type()))
return absl::OkStatus();
// TODO(intel-tf): Remove the condition below when the fusion Contraction +
// Add(bias) + Add(e.g., residual) is enabled.
auto contraction_config = contraction->backend_config<BackendConfig>();
if (!GetKernelConfig<config>(&contraction_config)
->mutable_fusions()
->ops()
.empty() &&
GetKernelConfig<config>(&contraction_config)
->mutable_fusions()
->ops(0) == OneDnnFusionConfig::BIAS) {
return absl::OkStatus();
}
std::vector<HloInstruction*> new_operands;
for (auto operand : contraction->operands()) {
new_operands.push_back(operand);
}
if (!IsSupportedType(contraction->shape().element_type()))
return absl::OkStatus();
// TODO(intel-tf): Remove the condition below when the fusion Contraction
// + Add(bias) + Add(e.g., residual) is enabled.
auto contraction_config = contraction->backend_config<BackendConfig>();
auto orig_fusion_config = GetFusionsConfig(&contraction_config);
if (!orig_fusion_config->ops().empty() &&
orig_fusion_config->ops(0) == OneDnnFusionConfig::BIAS) {
return absl::OkStatus();
}
std::vector<HloInstruction*> new_operands;
for (auto operand : contraction->operands()) {
new_operands.push_back(operand);
}

// At this point, the addend could have one of the following
// possiblities that the current fusion can handle:
//
// - addend -> Convert -> Broadcast -> Add
// - addend -> Broadcast -> Convert -> Add
// - addend -> Convert
// - addend -> Broadcast
// - addend
//
// Hunt for addend through possible sequences above and check the addend
// is compatible for onednn fusion.
HloInstruction* addend = nullptr;
HloInstruction* optional_addend_broadcast = nullptr;
auto addend_pattern = m::AnyOf<HloInstruction>(
m::Broadcast(&optional_addend_broadcast, m::Convert(&addend, m::Op())),
m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))),
m::Convert(&addend, m::Op()),
m::Broadcast(&optional_addend_broadcast, m::Op(&addend)),
m::Op(&addend));
if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus();

if (optional_addend_broadcast &&
// At this point, the addend could have one of the following
// possiblities that the current fusion can handle:
//
// - addend -> Convert -> Broadcast -> Add
// - addend -> Broadcast -> Convert -> Add
// - addend -> Convert
// - addend -> Broadcast
// - addend
//
// Hunt for addend through possible sequences above and check the addend
// is compatible for onednn fusion.
HloInstruction* addend = nullptr;
HloInstruction* optional_addend_broadcast = nullptr;
auto addend_pattern = m::AnyOf<HloInstruction>(
m::Broadcast(&optional_addend_broadcast,
m::Convert(&addend, m::Op())),
m::Convert(m::Broadcast(&optional_addend_broadcast, m::Op(&addend))),
m::Convert(&addend, m::Op()),
m::Broadcast(&optional_addend_broadcast, m::Op(&addend)),
m::Op(&addend));
if (!Match(addend_intermediate, addend_pattern)) return absl::OkStatus();

if (optional_addend_broadcast &&
(IsOneDnnMatmulInstr(contraction) || addend->shape().rank() != 1)) {
auto new_shape =
AdjustBiasShape(optional_addend_broadcast, contraction->shape());
if (new_shape.ok()) {
addend = addend->AddInstruction(
HloInstruction::CreateBitcast(new_shape.value(), addend));
auto new_shape =
AdjustBiasShape(optional_addend_broadcast, contraction->shape());
if (new_shape.ok()) {
addend = addend->AddInstruction(
HloInstruction::CreateBitcast(new_shape.value(), addend));
} else {
VLOG(2) << new_shape.status();
return absl::OkStatus();
}
}

// Validate addend for fusion.
if (IsSupportedType(addend->shape().element_type()) &&
IsOperandFusible(addend, contraction)) {
new_operands.push_back(addend);
} else {
VLOG(2) << new_shape.status();
return absl::OkStatus();
}
}

// Validate addend for fusion.
if (IsSupportedType(addend->shape().element_type()) &&
IsOperandFusible(addend, contraction)) {
new_operands.push_back(addend);
} else {
return absl::OkStatus();
}

auto custom_call = Cast<HloCustomCallInstruction>(instr->AddInstruction(
contraction->CloneWithNewOperands(contraction->shape(), new_operands)));

auto backend_config = custom_call->backend_config<BackendConfig>();

// TODO(intel-tf): Remove this restriction once oneDNN has an optimized
// implementation for broadcasted add across all dimensions.
OneDnnFusionConfig_FusionKind kind = OneDnnFusionConfig::UNDEFINED;
kind =
(ShapeUtil::TrueRank(addend->shape()) == 1)
? (GetKernelConfig<config>(&backend_config)->fusions().ops().empty()
? OneDnnFusionConfig::BIAS
: OneDnnFusionConfig::UNDEFINED)
: OneDnnFusionConfig::BINARY_ADD;
if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus();

GetKernelConfig<config>(&backend_config)->mutable_fusions()->add_ops(kind);

if (optional_addend_broadcast) {
GetKernelConfig<config>(&backend_config)
->mutable_optimization_config()
->set_bias_broadcast(true);
}
TF_RETURN_IF_ERROR(custom_call->set_backend_config(*backend_config));

HloInstruction* new_instr;
// If matched pattern has custom-call -> bitcast -> add, then we need to
// insert bitcast after the new fusion to maintain the correct shape
// (new-custom-call -> bitcast). Also, this will optionally be followed
// by -> convert for bf16 case to avoid datatype mismatch.
if (optional_contraction_bitcast != nullptr &&
optional_contraction_bitcast->opcode() == HloOpcode::kBitcast) {
if (optional_contraction_convert != nullptr &&
optional_contraction_convert->opcode() == HloOpcode::kConvert) {
auto bitcast_call =
custom_call->AddInstruction(HloInstruction::CreateBitcast(
ShapeUtil::ChangeElementType(
instr->shape(), custom_call->shape().element_type()),
custom_call));
new_instr = bitcast_call->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(
bitcast_call->shape(),
optional_contraction_convert->shape().element_type()),
bitcast_call));
} else {
new_instr = custom_call->AddInstruction(
HloInstruction::CreateBitcast(instr->shape(), custom_call));
auto custom_call = Cast<HloCustomCallInstruction>(
instr->AddInstruction(contraction->CloneWithNewOperands(
contraction->shape(), new_operands)));

auto backend_config = custom_call->backend_config<BackendConfig>();
auto fusions_config = GetFusionsConfig(&backend_config);
auto optimization_config = GetOptimizationsConfig(&backend_config);
// TODO(intel-tf): Here, we allow 1D addends only when they are the first
// fused op. Remove this restriction once oneDNN has an optimized
// implementation for broadcasted add across all dimensions.
OneDnnFusionConfig_FusionKind kind =
(ShapeUtil::TrueRank(addend->shape()) == 1)
? (fusions_config->ops().empty() ? OneDnnFusionConfig::BIAS
: OneDnnFusionConfig::UNDEFINED)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the code doesn't support having 1D addend (bias) and fused ops at the same time? Please document this as a comment in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should ensure that 1D addends are fused only when they are the first post-op operation. So the following should be acceptable:

  1. Bias (1D) + <0 or more non-add post-op>
  2. Bias (1D) + <0 or more non-add post-op> + Add (non-1D) + <0 or more non-add post-op>

Following is not allowed:

  1. Bias (1D) + <0 or more non-add post-op> + Add (1D) + <0 or more non-add post-op>

This condition was added a few months ago because at the time oneDNN had optimized implementations for broadcasted add operations across certain dimensions only. As a result, some cases defaulted to the ref implementation, which significantly impacted performance.
We can re-evaluate this restriction with the latest oneDNN release and/or relax this a bit by blocking only those 1D cases where broadcasting occurs along some specific dimensions.

: OneDnnFusionConfig::BINARY_ADD;
if (kind == OneDnnFusionConfig::UNDEFINED) return absl::OkStatus();

fusions_config->add_ops(kind);

if (optional_addend_broadcast) {
optimization_config->set_bias_broadcast(true);
}
} else {
if (optional_contraction_convert != nullptr &&
optional_contraction_convert->opcode() == HloOpcode::kConvert) {
new_instr = custom_call->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(
custom_call->shape(),
optional_contraction_convert->shape().element_type()),
custom_call));
TF_RETURN_IF_ERROR(custom_call->set_backend_config(*backend_config));

HloInstruction* new_instr;
// If matched pattern has custom-call -> bitcast -> add, then we need to
// insert bitcast after the new fusion to maintain the correct shape
// (new-custom-call -> bitcast). Also, this will optionally be followed
// by -> convert for bf16 case to avoid datatype mismatch.
if (optional_contraction_bitcast != nullptr &&
optional_contraction_bitcast->opcode() == HloOpcode::kBitcast) {
if (optional_contraction_convert != nullptr &&
optional_contraction_convert->opcode() == HloOpcode::kConvert) {
auto bitcast_call =
custom_call->AddInstruction(HloInstruction::CreateBitcast(
ShapeUtil::ChangeElementType(
instr->shape(), custom_call->shape().element_type()),
custom_call));
new_instr =
bitcast_call->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(
bitcast_call->shape(),
optional_contraction_convert->shape().element_type()),
bitcast_call));
} else {
new_instr = custom_call->AddInstruction(
HloInstruction::CreateBitcast(instr->shape(), custom_call));
}
} else {
new_instr = custom_call;
if (optional_contraction_convert != nullptr &&
optional_contraction_convert->opcode() == HloOpcode::kConvert) {
new_instr = custom_call->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(
custom_call->shape(),
optional_contraction_convert->shape().element_type()),
custom_call));
} else {
new_instr = custom_call;
}
}
TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
}
TF_RETURN_IF_ERROR(ReplaceInstruction(instr, new_instr));
return absl::OkStatus();
}

Expand Down
31 changes: 20 additions & 11 deletions xla/service/cpu/onednn_contraction_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/service/cpu/onednn_matmul.h"
#include "xla/service/cpu/onednn_convolution.h"
#include "tsl/platform/threadpool.h"

namespace xla {
Expand Down Expand Up @@ -61,18 +63,25 @@ class OneDnnContractionRewriter : public HloModulePass {
const tsl::thread::ThreadPool* compile_threadpool_;
};

#define HANDLE_OP_INTERNAL(internal_callee, contraction, ...) \
switch (contraction->backend_config<BackendConfig>() \
->backend_config_oneof_case()) { \
case BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig: \
return internal_callee< \
BackendConfig::BackendConfigOneofCase::kOnednnMatmulConfig>( \
contraction, __VA_ARGS__); \
default: \
return internal_callee< \
BackendConfig::BackendConfigOneofCase::kOnednnConvConfig>( \
contraction, __VA_ARGS__); \
using OneDnnContractionVariant =
std::variant<PrimitiveTrait<kOnednnConvConfig>,
PrimitiveTrait<kOnednnMatmulConfig>>;

template <BackendConfigOneofCase config>
struct PrimitiveTrait<config, OneDnnFusionConfig*> {
static OneDnnFusionConfig* GetTransformationConfig(
typename PrimitiveTrait<config>::pointer_type kernel_config) {
return kernel_config->mutable_fusions();
}
};

template <BackendConfigOneofCase config>
struct PrimitiveTrait<config, OneDnnOptimizationConfig*> {
static OneDnnOptimizationConfig* GetTransformationConfig(
typename PrimitiveTrait<config>::pointer_type kernel_config) {
return kernel_config->mutable_optimization_config();
}
};

} // namespace cpu
} // namespace xla
Expand Down
1 change: 1 addition & 0 deletions xla/service/cpu/onednn_convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ extern void __xla_cpu_runtime_OneDnnConvolution(void* result, void** args);
template <>
struct PrimitiveTrait<kOnednnConvConfig> {
using pointer_type = xla::cpu::OneDnnConvolutionConfig*;
static const BackendConfigOneofCase kConfigVal = kOnednnConvConfig;
};

} // namespace cpu
Expand Down
1 change: 1 addition & 0 deletions xla/service/cpu/onednn_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ extern void __xla_cpu_runtime_OneDnnMatMulReorder(void* result, void** args);
template <>
struct PrimitiveTrait<kOnednnMatmulConfig> {
using pointer_type = xla::cpu::OneDnnMatMulConfig*;
static const BackendConfigOneofCase kConfigVal = kOnednnMatmulConfig;
};

} // namespace cpu
Expand Down
2 changes: 1 addition & 1 deletion xla/service/cpu/onednn_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ typedef BackendConfig::BackendConfigOneofCase BackendConfigOneofCase;
template <typename PrimDesc>
std::unique_ptr<PrimDesc> CreateOneDnnPrimDesc(HloInstruction*);

template <BackendConfigOneofCase config>
template <BackendConfigOneofCase config, typename TransformationType = void>
struct PrimitiveTrait;

template <BackendConfigOneofCase config>
Expand Down