Skip to content

Commit

Permalink
PR #19160: Add support for fusion operations in GetLoopInductionVaria…
Browse files Browse the repository at this point in the history
…bleIndex

Imported from GitHub PR #19160

After optimizations, while loop analysis was unable to handle the fused operations, primarily for GetLoopInductionVariableIndex and ComputeWhileLoopTripCount. This patch handles only GetloopInductionVariableIdx.

If the GetGTEOperandIndex function fails to find a precise GTE operand, we try to deduce the unique dependence by using the HloExtractor.
Copybara import of the project:

--
b87de10 by Shraiysh Vaishay <[email protected]>:

Add support for fusion operations in GetLoopInductionVariableIndex

After optimizations, while loop analysis was unable to handle the fused
operations, primarily for GetLoopInductionVariableIndex and
ComputeWhileLoopTripCount. This patch handles only
GetloopInductionVariableIdx.

If the GetGTEOperandIndex function fails to find a precise GTE operand,
we try to deduce the unique dependence by using the HloExtractor.

Merging this change closes #19160

COPYBARA_INTEGRATE_REVIEW=#19160 from shraiysh:while_loop_enhancements b87de10
PiperOrigin-RevId: 696040922
  • Loading branch information
shraiysh authored and Google-ML-Automation committed Nov 13, 2024
1 parent a610534 commit 342e3c1
Show file tree
Hide file tree
Showing 3 changed files with 285 additions and 12 deletions.
2 changes: 2 additions & 0 deletions xla/hlo/analysis/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ cc_library(
"//xla/hlo/utils:hlo_query",
"//xla/service:collective_ops_utils",
"//xla/service:pattern_matcher",
"//xla/tools:hlo_extractor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
Expand Down
106 changes: 94 additions & 12 deletions xla/hlo/analysis/while_loop_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ limitations under the License.
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <memory>
#include <optional>

#include "absl/algorithm/container.h"
#include "absl/base/casts.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
Expand All @@ -35,17 +38,19 @@ limitations under the License.
#include "xla/service/collective_ops_utils.h"
#include "xla/service/pattern_matcher.h"
#include "xla/shape_util.h"
#include "xla/tools/hlo_extractor.h"

namespace xla {

using std::nullopt;
using std::optional;
namespace m = match;

// Finds and returns the non-constant operand in instr.
// Finds and returns the non-constant operand in instr, if there is only one
// such operand.
//
// If the instruction doesn't have exactly one unique non-constant operand,
// nullptr is returned.
// Returns nullptr if instr doesn't have exactly one unique non-constant
// operand.
static const HloInstruction* NonConstantOperand(const HloInstruction* instr) {
const HloInstruction* result = nullptr;
for (const HloInstruction* operand : instr->operands()) {
Expand Down Expand Up @@ -73,25 +78,24 @@ static optional<int64_t> GetGTEOperandIndex(const HloInstruction* instr,
// copy that is then used.
optional<int64_t> tuple_idx;
for (const HloInstruction* operand : instr->operands()) {
if (Match(operand, m::Constant())) {
if (operand->opcode() == HloOpcode::kConstant) {
continue;
}
auto possibly_gte_operand = operand;
auto possibly_gte = operand;

if (operand->opcode() == HloOpcode::kCopy) {
possibly_gte_operand = operand->operand(0);
possibly_gte = operand->operand(0);
}

if (possibly_gte_operand->opcode() != HloOpcode::kGetTupleElement) {
if (possibly_gte->opcode() != HloOpcode::kGetTupleElement) {
return nullopt;
}

if (!Match(possibly_gte_operand,
m::GetTupleElement(m::Op().Is(gte_operand)))) {
if (possibly_gte->operand(0) != gte_operand) {
return nullopt;
}

int64_t operand_tuple_idx = possibly_gte_operand->tuple_index();
int64_t operand_tuple_idx = possibly_gte->tuple_index();
// This is the first GTE we are seeing. Set tuple_idx.
if (!tuple_idx.has_value()) {
tuple_idx = operand_tuple_idx;
Expand All @@ -104,6 +108,84 @@ static optional<int64_t> GetGTEOperandIndex(const HloInstruction* instr,
return tuple_idx;
}

// If `out` is a function of a single value in the tuple `in` and has no other
// dependence, i.e. if `out=f(gte(in))`, then this function will return the
// unique get-tuple-element index for the dependence.
//
// For example, in the following HLO, this function will return `1`:
// in = (s32[], s32[], s32[]) tuple(a,b,c)
// gte.1 = get-tuple-element(in), index=1
// out = fusion(gte.1), ...
std::optional<int64_t> GetUniqueGTEDependenceIndex(const HloInstruction* out,
const HloInstruction* in) {
// Fast path : pattern matching.
std::optional<int64_t> tuple_idx = GetGTEOperandIndex(out, in);
if (tuple_idx != std::nullopt) {
return tuple_idx;
}

if (out->parent() != in->parent() || !in->shape().IsTuple()) {
return std::nullopt;
}

// Extracts the instruction `out` as a function of the instruction `in`.
// HloModule extracted
// ENTRY main {
// in = parameter(0)
// //... some calculations
// ROOT out = ...
// }
std::unique_ptr<HloModule> extracted = ExtractModule(
/*instruction=*/out, /*height=*/-1, /*extract_selector=*/
[in](const HloInstruction* inst) -> bool { return inst != in; },
/*replace_type_selector=*/
[](const HloInstruction* inst) -> ReplaceType {
return ReplaceType::kReplaceParam;
});
HloComputation* entry = extracted->entry_computation();

// Check that the extracted module takes nothing but `in` as input. If `out`
// does not depend on in, the extracted module will have some other shape for
// input.
if (entry->num_parameters() != 1 ||
entry->parameter_instruction(0)->shape() != in->shape()) {
return std::nullopt;
}
HloInstruction* param = entry->parameter_instruction(0);

// If there are no users for the input `in`, it would mean that `out` does not
// depend on a get-tuple-element of `in`.
if (param->user_count() == 0) {
return nullopt;
}

// If any of the users of the input `in` is not a get-tuple-element
// instruction, then that would mean that the output does not depend uniquely
// on a get-tuple-element of on `in`, instead it depends on some other
// calculations on `in`.
if (absl::c_any_of(param->users(), [](const HloInstruction* inst) -> bool {
return inst->opcode() != HloOpcode::kGetTupleElement;
})) {
return std::nullopt;
}

// We extract the candidate index from the first user. At this point we
// already know that the all the users are get-tuple-elements and that there
// is atleast one user.
int64_t candidate_index = param->users()[0]->tuple_index();

// We check that all the users of the input instruction `in` (which we already
// know to be get-tuple-element instructions) have the same tuple index.
if (absl::c_any_of(param->users(),
[candidate_index](const HloInstruction* inst) -> bool {
return inst->tuple_index() != candidate_index;
})) {
return std::nullopt;
}

return candidate_index;
}

// The below function identifies a subset of all possible auxiliary
// induction variables (AIV). Specifically, candidates are gtes, e.g.,
// gte(param0, N)
Expand Down Expand Up @@ -277,7 +359,7 @@ optional<int64_t> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
auto* while_cond_root = while_cond->root_instruction();
auto* while_cond_param = while_cond->parameter_instruction(0);
optional<int64_t> indvar_tuple_idx =
GetGTEOperandIndex(while_cond_root, while_cond_param);
GetUniqueGTEDependenceIndex(while_cond_root, while_cond_param);
if (!indvar_tuple_idx) {
VLOG(2) << "Induction variable not found in loop condition: "
<< while_cond->root_instruction()->ToString();
Expand All @@ -303,7 +385,7 @@ optional<int64_t> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
while_body_inc = while_body_root->operand(*indvar_tuple_idx);
auto* while_body_param = while_body->parameter_instruction(0);
optional<int64_t> while_body_indvar_tuple_idx =
GetGTEOperandIndex(while_body_inc, while_body_param);
GetUniqueGTEDependenceIndex(while_body_inc, while_body_param);
if (!while_body_indvar_tuple_idx) {
VLOG(2)
<< "Induction variable not found in while body increment instruction: "
Expand Down
189 changes: 189 additions & 0 deletions xla/hlo/analysis/while_loop_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,5 +629,194 @@ TEST_F(WhileLoopAnalysisTest, AvoidBruteForceForHugeParams) {
EXPECT_EQ(trip_count, std::nullopt);
}

TEST_F(WhileLoopAnalysisTest, LoopFusionForLoopVariable) {
// This test verifies that fusions in initialization, condition and update are
// accepted by while loop analysis.
const char* hlo = R"(
HloModule test
fused_add.11 {
param_0.968 = s32[] parameter(0)
constant_1239_1 = s32[] constant(1)
ROOT add.1041.1 = s32[] add(param_0.968, constant_1239_1)
}
fused_add.11.clone.2 {
param_0.2169 = s32[] parameter(0)
constant_1239_4 = s32[] constant(1)
ROOT add.1041.4 = s32[] add(param_0.2169, constant_1239_4)
}
body {
param.1 = (s32[], s32[]) parameter(0)
loop_iter = s32[] get-tuple-element(param.1), index=0
data = s32[] get-tuple-element(param.1), index=1
loop_add_fusion.11 = s32[] fusion(loop_iter), kind=kLoop, calls=fused_add.11
loop_add_fusion.11.double_buffer_clone = s32[] fusion(loop_add_fusion.11), kind=kLoop, calls=fused_add.11.clone.2
ROOT tuple = (s32[], s32[]) tuple(loop_add_fusion.11.double_buffer_clone, data)
}
fused_compare {
param_0.987 = s32[] parameter(0)
constant_1238_1 = s32[] constant(7)
ROOT compare.98.1 = pred[] compare(param_0.987, constant_1238_1), direction=LT
}
condition {
param.2 = (s32[], s32[]) parameter(0)
loop_iter = s32[] get-tuple-element(param.2), index=0
ROOT loop_compare_fusion = pred[] fusion(loop_iter), kind=kLoop, calls=fused_compare
}
fused_add.12 {
param_0.968 = s32[] parameter(0)
constant_1239_1 = s32[] constant(1)
ROOT add.1041.1 = s32[] add(param_0.968, constant_1239_1)
}
ENTRY main {
data = s32[] parameter(0)
c.0 = s32[] constant(0)
c.1 = s32[] constant(1)
add.1 = s32[] add(c.0, c.1)
c.0.loop_double_buffer_peeled = s32[] fusion(add.1), kind=kLoop, calls=fused_add.12
tuple = (s32[], s32[]) tuple(c.0.loop_double_buffer_peeled, data)
ROOT while = while(tuple), body=body, condition=condition
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
HloInstruction* while_op = module->entry_computation()->root_instruction();
auto loop_induction_variable = GetLoopInductionVarTupleIdx(while_op);
ASSERT_TRUE(loop_induction_variable.has_value());
EXPECT_EQ(loop_induction_variable.value(), 0);
}

TEST_F(WhileLoopAnalysisTest, UpdateIsMultipleOperationsWithConstantOperand) {
const char* hlo = R"(
HloModule test
body {
param.1 = (s32[], s32[8,8]) parameter(0)
iter.1 = s32[] get-tuple-element(param.1), index=0
c.1 = s32[] constant(1)
add.1 = s32[] add(iter.1, c.1)
add.2 = s32[] add(add.1, c.1)
data.1 = s32[8,8] get-tuple-element(param.1), index=1
ROOT tuple = (s32[], s32[8,8]) tuple(add.2, data.1)
}
condition {
param = (s32[], s32[8,8]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c.10 = s32[] constant(10)
ROOT compare = pred[] compare(iter, c.10), direction=LT
}
ENTRY main {
c.0 = s32[] constant(0)
data = s32[8,8] parameter(0)
tuple = tuple(c.0, data)
ROOT while = while(tuple), body=body, condition=condition
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::optional<int64_t> indvar_idx = GetLoopInductionVarTupleIdx(while_op);
ASSERT_NE(indvar_idx, std::nullopt);
EXPECT_EQ(*indvar_idx, 0);
std::optional<int64_t> trip_count = ComputeWhileLoopTripCount(while_op);
EXPECT_EQ(trip_count, std::nullopt);
}

TEST_F(WhileLoopAnalysisTest,
UpdateIsMultipleOperationsWithoutConstantOperand) {
const char* hlo = R"(
HloModule test
body {
param.1 = (s32[], s32[8,8]) parameter(0)
iter.1 = s32[] get-tuple-element(param.1), index=0
c.1 = s32[] constant(1)
add.1 = s32[] add(c.1, c.1)
add.2 = s32[] add(iter.1, add.1)
data.1 = s32[8,8] get-tuple-element(param.1), index=1
ROOT tuple = (s32[], s32[8,8]) tuple(add.2, data.1)
}
condition {
param = (s32[], s32[8,8]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c.10 = s32[] constant(10)
ROOT compare = pred[] compare(iter, c.10), direction=LT
}
ENTRY main {
c.0 = s32[] constant(0)
data = s32[8,8] parameter(0)
tuple = tuple(c.0, data)
ROOT while = while(tuple), body=body, condition=condition
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::optional<int64_t> indvar_idx = GetLoopInductionVarTupleIdx(while_op);
ASSERT_NE(indvar_idx, std::nullopt);
EXPECT_EQ(*indvar_idx, 0);
std::optional<int64_t> trip_count = ComputeWhileLoopTripCount(while_op);
EXPECT_EQ(trip_count, std::nullopt);
}

TEST_F(WhileLoopAnalysisTest,
ConditionIsMultipleOperationsWithConstantOperand) {
const char* hlo = R"(
HloModule test
body {
param.1 = (s32[], s32[8,8]) parameter(0)
iter.1 = s32[] get-tuple-element(param.1), index=0
c.1 = s32[] constant(1)
add.1 = s32[] add(iter.1, c.1)
data.1 = s32[8,8] get-tuple-element(param.1), index=1
ROOT tuple = (s32[], s32[8,8]) tuple(add.1, data.1)
}
condition {
param = (s32[], s32[8,8]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c.10 = s32[] constant(10)
add.10 = s32[] add(iter, c.10)
ROOT compare = pred[] compare(add.10, c.10), direction=LT
}
ENTRY main {
c.0 = s32[] constant(0)
data = s32[8,8] parameter(0)
tuple = tuple(c.0, data)
ROOT while = while(tuple), body=body, condition=condition
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::optional<int64_t> indvar_idx = GetLoopInductionVarTupleIdx(while_op);
ASSERT_NE(indvar_idx, std::nullopt);
EXPECT_EQ(*indvar_idx, 0);
std::optional<int64_t> trip_count = ComputeWhileLoopTripCount(while_op);
EXPECT_EQ(trip_count, std::nullopt);
}

TEST_F(WhileLoopAnalysisTest,
ConditionIsMultipleOperationsWithoutConstantOperand) {
const char* hlo = R"(
HloModule test
body {
param.1 = (s32[], s32[8,8]) parameter(0)
iter.1 = s32[] get-tuple-element(param.1), index=0
c.1 = s32[] constant(1)
add.1 = s32[] add(iter.1, c.1)
data.1 = s32[8,8] get-tuple-element(param.1), index=1
ROOT tuple = (s32[], s32[8,8]) tuple(add.1, data.1)
}
condition {
param = (s32[], s32[8,8]) parameter(0)
iter = s32[] get-tuple-element(param), index=0
c.5 = s32[] constant(5)
add.10 = s32[] add(c.5, c.5)
ROOT compare = pred[] compare(iter, add.10), direction=LT
}
ENTRY main {
c.0 = s32[] constant(0)
data = s32[8,8] parameter(0)
tuple = tuple(c.0, data)
ROOT while = while(tuple), body=body, condition=condition
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::optional<int64_t> indvar_idx = GetLoopInductionVarTupleIdx(while_op);
ASSERT_NE(indvar_idx, std::nullopt);
EXPECT_EQ(*indvar_idx, 0);
std::optional<int64_t> trip_count = ComputeWhileLoopTripCount(while_op);
EXPECT_EQ(trip_count, std::nullopt);
}

} // namespace
} // namespace xla

0 comments on commit 342e3c1

Please sign in to comment.