From 342e3c1f84a9708fd22f3ca19dba7a282d2ade68 Mon Sep 17 00:00:00 2001 From: Shraiysh Date: Wed, 13 Nov 2024 02:03:43 -0800 Subject: [PATCH] PR #19160: Add support for fusion operations in GetLoopInductionVariableIndex Imported from GitHub PR https://github.com/openxla/xla/pull/19160 After optimizations, while loop analysis was unable to handle the fused operations, primarily for GetLoopInductionVariableIndex and ComputeWhileLoopTripCount. This patch handles only GetloopInductionVariableIdx. If the GetGTEOperandIndex function fails to find a precise GTE operand, we try to deduce the unique dependence by using the HloExtractor. Copybara import of the project: -- b87de101421d7a4b98050379e931507d37bb90ab by Shraiysh Vaishay : Add support for fusion operations in GetLoopInductionVariableIndex After optimizations, while loop analysis was unable to handle the fused operations, primarily for GetLoopInductionVariableIndex and ComputeWhileLoopTripCount. This patch handles only GetloopInductionVariableIdx. If the GetGTEOperandIndex function fails to find a precise GTE operand, we try to deduce the unique dependence by using the HloExtractor. Merging this change closes #19160 COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/19160 from shraiysh:while_loop_enhancements b87de101421d7a4b98050379e931507d37bb90ab PiperOrigin-RevId: 696040922 --- xla/hlo/analysis/BUILD | 2 + xla/hlo/analysis/while_loop_analysis.cc | 106 +++++++++-- xla/hlo/analysis/while_loop_analysis_test.cc | 189 +++++++++++++++++++ 3 files changed, 285 insertions(+), 12 deletions(-) diff --git a/xla/hlo/analysis/BUILD b/xla/hlo/analysis/BUILD index 466ae6505f302..396157e85d454 100644 --- a/xla/hlo/analysis/BUILD +++ b/xla/hlo/analysis/BUILD @@ -145,6 +145,8 @@ cc_library( "//xla/hlo/utils:hlo_query", "//xla/service:collective_ops_utils", "//xla/service:pattern_matcher", + "//xla/tools:hlo_extractor", + "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", diff --git a/xla/hlo/analysis/while_loop_analysis.cc b/xla/hlo/analysis/while_loop_analysis.cc index 8a87dd27cc6e7..fd554a5f31178 100644 --- a/xla/hlo/analysis/while_loop_analysis.cc +++ b/xla/hlo/analysis/while_loop_analysis.cc @@ -18,7 +18,10 @@ limitations under the License. #include #include #include +#include +#include +#include "absl/algorithm/container.h" #include "absl/base/casts.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" @@ -35,6 +38,7 @@ limitations under the License. #include "xla/service/collective_ops_utils.h" #include "xla/service/pattern_matcher.h" #include "xla/shape_util.h" +#include "xla/tools/hlo_extractor.h" namespace xla { @@ -42,10 +46,11 @@ using std::nullopt; using std::optional; namespace m = match; -// Finds and returns the non-constant operand in instr. +// Finds and returns the non-constant operand in instr, if there is only one +// such operand. // -// If the instruction doesn't have exactly one unique non-constant operand, -// nullptr is returned. +// Returns nullptr if instr doesn't have exactly one unique non-constant +// operand. static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { const HloInstruction* result = nullptr; for (const HloInstruction* operand : instr->operands()) { @@ -73,25 +78,24 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, // copy that is then used. optional tuple_idx; for (const HloInstruction* operand : instr->operands()) { - if (Match(operand, m::Constant())) { + if (operand->opcode() == HloOpcode::kConstant) { continue; } - auto possibly_gte_operand = operand; + auto possibly_gte = operand; if (operand->opcode() == HloOpcode::kCopy) { - possibly_gte_operand = operand->operand(0); + possibly_gte = operand->operand(0); } - if (possibly_gte_operand->opcode() != HloOpcode::kGetTupleElement) { + if (possibly_gte->opcode() != HloOpcode::kGetTupleElement) { return nullopt; } - if (!Match(possibly_gte_operand, - m::GetTupleElement(m::Op().Is(gte_operand)))) { + if (possibly_gte->operand(0) != gte_operand) { return nullopt; } - int64_t operand_tuple_idx = possibly_gte_operand->tuple_index(); + int64_t operand_tuple_idx = possibly_gte->tuple_index(); // This is the first GTE we are seeing. Set tuple_idx. if (!tuple_idx.has_value()) { tuple_idx = operand_tuple_idx; @@ -104,6 +108,84 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, return tuple_idx; } +// If `out` is a function of a single value in the tuple `in` and has no other +// dependence, i.e. if `out=f(gte(in))`, then this function will return the +// unique get-tuple-element index for the dependence. +// +// For example, in the following HLO, this function will return `1`: +// in = (s32[], s32[], s32[]) tuple(a,b,c) +// gte.1 = get-tuple-element(in), index=1 +// out = fusion(gte.1), ... +std::optional GetUniqueGTEDependenceIndex(const HloInstruction* out, + const HloInstruction* in) { + // Fast path : pattern matching. + std::optional tuple_idx = GetGTEOperandIndex(out, in); + if (tuple_idx != std::nullopt) { + return tuple_idx; + } + + if (out->parent() != in->parent() || !in->shape().IsTuple()) { + return std::nullopt; + } + + // Extracts the instruction `out` as a function of the instruction `in`. + // HloModule extracted + // ENTRY main { + // in = parameter(0) + // //... some calculations + // ROOT out = ... + // } + std::unique_ptr extracted = ExtractModule( + /*instruction=*/out, /*height=*/-1, /*extract_selector=*/ + [in](const HloInstruction* inst) -> bool { return inst != in; }, + /*replace_type_selector=*/ + [](const HloInstruction* inst) -> ReplaceType { + return ReplaceType::kReplaceParam; + }); + HloComputation* entry = extracted->entry_computation(); + + // Check that the extracted module takes nothing but `in` as input. If `out` + // does not depend on in, the extracted module will have some other shape for + // input. + if (entry->num_parameters() != 1 || + entry->parameter_instruction(0)->shape() != in->shape()) { + return std::nullopt; + } + HloInstruction* param = entry->parameter_instruction(0); + + // If there are no users for the input `in`, it would mean that `out` does not + // depend on a get-tuple-element of `in`. + if (param->user_count() == 0) { + return nullopt; + } + + // If any of the users of the input `in` is not a get-tuple-element + // instruction, then that would mean that the output does not depend uniquely + // on a get-tuple-element of on `in`, instead it depends on some other + // calculations on `in`. + if (absl::c_any_of(param->users(), [](const HloInstruction* inst) -> bool { + return inst->opcode() != HloOpcode::kGetTupleElement; + })) { + return std::nullopt; + } + + // We extract the candidate index from the first user. At this point we + // already know that the all the users are get-tuple-elements and that there + // is atleast one user. + int64_t candidate_index = param->users()[0]->tuple_index(); + + // We check that all the users of the input instruction `in` (which we already + // know to be get-tuple-element instructions) have the same tuple index. + if (absl::c_any_of(param->users(), + [candidate_index](const HloInstruction* inst) -> bool { + return inst->tuple_index() != candidate_index; + })) { + return std::nullopt; + } + + return candidate_index; +} + // The below function identifies a subset of all possible auxiliary // induction variables (AIV). Specifically, candidates are gtes, e.g., // gte(param0, N) @@ -277,7 +359,7 @@ optional GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { auto* while_cond_root = while_cond->root_instruction(); auto* while_cond_param = while_cond->parameter_instruction(0); optional indvar_tuple_idx = - GetGTEOperandIndex(while_cond_root, while_cond_param); + GetUniqueGTEDependenceIndex(while_cond_root, while_cond_param); if (!indvar_tuple_idx) { VLOG(2) << "Induction variable not found in loop condition: " << while_cond->root_instruction()->ToString(); @@ -303,7 +385,7 @@ optional GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { while_body_inc = while_body_root->operand(*indvar_tuple_idx); auto* while_body_param = while_body->parameter_instruction(0); optional while_body_indvar_tuple_idx = - GetGTEOperandIndex(while_body_inc, while_body_param); + GetUniqueGTEDependenceIndex(while_body_inc, while_body_param); if (!while_body_indvar_tuple_idx) { VLOG(2) << "Induction variable not found in while body increment instruction: " diff --git a/xla/hlo/analysis/while_loop_analysis_test.cc b/xla/hlo/analysis/while_loop_analysis_test.cc index 3757f44d092b3..607ee64a58fa9 100644 --- a/xla/hlo/analysis/while_loop_analysis_test.cc +++ b/xla/hlo/analysis/while_loop_analysis_test.cc @@ -629,5 +629,194 @@ TEST_F(WhileLoopAnalysisTest, AvoidBruteForceForHugeParams) { EXPECT_EQ(trip_count, std::nullopt); } +TEST_F(WhileLoopAnalysisTest, LoopFusionForLoopVariable) { + // This test verifies that fusions in initialization, condition and update are + // accepted by while loop analysis. + const char* hlo = R"( + HloModule test + fused_add.11 { + param_0.968 = s32[] parameter(0) + constant_1239_1 = s32[] constant(1) + ROOT add.1041.1 = s32[] add(param_0.968, constant_1239_1) + } + fused_add.11.clone.2 { + param_0.2169 = s32[] parameter(0) + constant_1239_4 = s32[] constant(1) + ROOT add.1041.4 = s32[] add(param_0.2169, constant_1239_4) + } + body { + param.1 = (s32[], s32[]) parameter(0) + loop_iter = s32[] get-tuple-element(param.1), index=0 + data = s32[] get-tuple-element(param.1), index=1 + loop_add_fusion.11 = s32[] fusion(loop_iter), kind=kLoop, calls=fused_add.11 + loop_add_fusion.11.double_buffer_clone = s32[] fusion(loop_add_fusion.11), kind=kLoop, calls=fused_add.11.clone.2 + ROOT tuple = (s32[], s32[]) tuple(loop_add_fusion.11.double_buffer_clone, data) + } + fused_compare { + param_0.987 = s32[] parameter(0) + constant_1238_1 = s32[] constant(7) + ROOT compare.98.1 = pred[] compare(param_0.987, constant_1238_1), direction=LT + } + condition { + param.2 = (s32[], s32[]) parameter(0) + loop_iter = s32[] get-tuple-element(param.2), index=0 + ROOT loop_compare_fusion = pred[] fusion(loop_iter), kind=kLoop, calls=fused_compare + } + fused_add.12 { + param_0.968 = s32[] parameter(0) + constant_1239_1 = s32[] constant(1) + ROOT add.1041.1 = s32[] add(param_0.968, constant_1239_1) + } + ENTRY main { + data = s32[] parameter(0) + c.0 = s32[] constant(0) + c.1 = s32[] constant(1) + add.1 = s32[] add(c.0, c.1) + c.0.loop_double_buffer_peeled = s32[] fusion(add.1), kind=kLoop, calls=fused_add.12 + tuple = (s32[], s32[]) tuple(c.0.loop_double_buffer_peeled, data) + ROOT while = while(tuple), body=body, condition=condition + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + HloInstruction* while_op = module->entry_computation()->root_instruction(); + auto loop_induction_variable = GetLoopInductionVarTupleIdx(while_op); + ASSERT_TRUE(loop_induction_variable.has_value()); + EXPECT_EQ(loop_induction_variable.value(), 0); +} + +TEST_F(WhileLoopAnalysisTest, UpdateIsMultipleOperationsWithConstantOperand) { + const char* hlo = R"( + HloModule test + body { + param.1 = (s32[], s32[8,8]) parameter(0) + iter.1 = s32[] get-tuple-element(param.1), index=0 + c.1 = s32[] constant(1) + add.1 = s32[] add(iter.1, c.1) + add.2 = s32[] add(add.1, c.1) + data.1 = s32[8,8] get-tuple-element(param.1), index=1 + ROOT tuple = (s32[], s32[8,8]) tuple(add.2, data.1) + } + condition { + param = (s32[], s32[8,8]) parameter(0) + iter = s32[] get-tuple-element(param), index=0 + c.10 = s32[] constant(10) + ROOT compare = pred[] compare(iter, c.10), direction=LT + } + ENTRY main { + c.0 = s32[] constant(0) + data = s32[8,8] parameter(0) + tuple = tuple(c.0, data) + ROOT while = while(tuple), body=body, condition=condition + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::optional indvar_idx = GetLoopInductionVarTupleIdx(while_op); + ASSERT_NE(indvar_idx, std::nullopt); + EXPECT_EQ(*indvar_idx, 0); + std::optional trip_count = ComputeWhileLoopTripCount(while_op); + EXPECT_EQ(trip_count, std::nullopt); +} + +TEST_F(WhileLoopAnalysisTest, + UpdateIsMultipleOperationsWithoutConstantOperand) { + const char* hlo = R"( + HloModule test + body { + param.1 = (s32[], s32[8,8]) parameter(0) + iter.1 = s32[] get-tuple-element(param.1), index=0 + c.1 = s32[] constant(1) + add.1 = s32[] add(c.1, c.1) + add.2 = s32[] add(iter.1, add.1) + data.1 = s32[8,8] get-tuple-element(param.1), index=1 + ROOT tuple = (s32[], s32[8,8]) tuple(add.2, data.1) + } + condition { + param = (s32[], s32[8,8]) parameter(0) + iter = s32[] get-tuple-element(param), index=0 + c.10 = s32[] constant(10) + ROOT compare = pred[] compare(iter, c.10), direction=LT + } + ENTRY main { + c.0 = s32[] constant(0) + data = s32[8,8] parameter(0) + tuple = tuple(c.0, data) + ROOT while = while(tuple), body=body, condition=condition + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::optional indvar_idx = GetLoopInductionVarTupleIdx(while_op); + ASSERT_NE(indvar_idx, std::nullopt); + EXPECT_EQ(*indvar_idx, 0); + std::optional trip_count = ComputeWhileLoopTripCount(while_op); + EXPECT_EQ(trip_count, std::nullopt); +} + +TEST_F(WhileLoopAnalysisTest, + ConditionIsMultipleOperationsWithConstantOperand) { + const char* hlo = R"( + HloModule test + body { + param.1 = (s32[], s32[8,8]) parameter(0) + iter.1 = s32[] get-tuple-element(param.1), index=0 + c.1 = s32[] constant(1) + add.1 = s32[] add(iter.1, c.1) + data.1 = s32[8,8] get-tuple-element(param.1), index=1 + ROOT tuple = (s32[], s32[8,8]) tuple(add.1, data.1) + } + condition { + param = (s32[], s32[8,8]) parameter(0) + iter = s32[] get-tuple-element(param), index=0 + c.10 = s32[] constant(10) + add.10 = s32[] add(iter, c.10) + ROOT compare = pred[] compare(add.10, c.10), direction=LT + } + ENTRY main { + c.0 = s32[] constant(0) + data = s32[8,8] parameter(0) + tuple = tuple(c.0, data) + ROOT while = while(tuple), body=body, condition=condition + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::optional indvar_idx = GetLoopInductionVarTupleIdx(while_op); + ASSERT_NE(indvar_idx, std::nullopt); + EXPECT_EQ(*indvar_idx, 0); + std::optional trip_count = ComputeWhileLoopTripCount(while_op); + EXPECT_EQ(trip_count, std::nullopt); +} + +TEST_F(WhileLoopAnalysisTest, + ConditionIsMultipleOperationsWithoutConstantOperand) { + const char* hlo = R"( + HloModule test + body { + param.1 = (s32[], s32[8,8]) parameter(0) + iter.1 = s32[] get-tuple-element(param.1), index=0 + c.1 = s32[] constant(1) + add.1 = s32[] add(iter.1, c.1) + data.1 = s32[8,8] get-tuple-element(param.1), index=1 + ROOT tuple = (s32[], s32[8,8]) tuple(add.1, data.1) + } + condition { + param = (s32[], s32[8,8]) parameter(0) + iter = s32[] get-tuple-element(param), index=0 + c.5 = s32[] constant(5) + add.10 = s32[] add(c.5, c.5) + ROOT compare = pred[] compare(iter, add.10), direction=LT + } + ENTRY main { + c.0 = s32[] constant(0) + data = s32[8,8] parameter(0) + tuple = tuple(c.0, data) + ROOT while = while(tuple), body=body, condition=condition + })"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo)); + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::optional indvar_idx = GetLoopInductionVarTupleIdx(while_op); + ASSERT_NE(indvar_idx, std::nullopt); + EXPECT_EQ(*indvar_idx, 0); + std::optional trip_count = ComputeWhileLoopTripCount(while_op); + EXPECT_EQ(trip_count, std::nullopt); +} + } // namespace } // namespace xla