Skip to content

Commit

Permalink
[XLA:SPMD] Support shard-as propagation with unspecified_dims.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629857357
  • Loading branch information
Tongfei-Guo authored and copybara-github committed May 1, 2024
1 parent c3366f8 commit ad8d093
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 5 deletions.
84 changes: 79 additions & 5 deletions xla/service/sharding_propagation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,49 @@ bool InferUnspecifiedDimsFromUsers(HloInstruction* annotate_op,
return changed;
}

bool InferUnspecifiedDimsFromShardGroup(
HloInstruction* annotate_op, absl::Span<const int64_t> unspecified_dims,
const absl::flat_hash_set<HloInstruction*>& shard_group) {
// ProcessShardingInstruction will either keep the "Sharding" custom call as
// is or replace it with a copy.
CHECK(annotate_op->IsCustomCall("Sharding") ||
annotate_op->opcode() == HloOpcode::kCopy);

// Do not propagate sharding to ShardBarrierTo custom-call.
if (annotate_op->IsCustomCall(spmd::kShardBarrierTo)) {
return false;
}

bool changed = false;
for (const HloInstruction* member : shard_group) {
if (member == annotate_op) {
continue;
}
// Do not propagate sharding from ShardBarrierFrom custom-call.
if (member->IsCustomCall(spmd::kShardBarrierFrom)) {
continue;
}
if (!IsSpatiallyPartitioned(member)) {
continue;
}
const HloSharding& member_sharding = member->sharding();
if (!member_sharding.IsTiled()) {
continue;
}
HloSharding partial_replicated =
hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
member_sharding, unspecified_dims);
HloSharding sharding = annotate_op->sharding();
if (!hlo_sharding_util::MergeShardingIfCompatible(
partial_replicated, sharding.NumTiles() + 1, &sharding)) {
continue;
}
annotate_op->set_sharding(sharding);
changed |= true;
}
return changed;
}

// Returns whether an op is a target for CSE prevention.
bool IsCSEPreventionTarget(const HloInstruction* instruction) {
// Scalar broadcasts are the most common CSE target that causes cross-layer
Expand Down Expand Up @@ -1582,7 +1625,7 @@ absl::StatusOr<bool> ProcessShardingInstruction(
if (instruction->IsCustomCall("Sharding") && !replaced_with_copy) {
// Pass shard group to operand sharding custom-call if it's not
// replaced with a copy, meaning that the shardings are to annotate
// shard_group or shard_barrier only.
// shard_group.
HloSharding operand_sharding = instruction->operand(0)->has_sharding()
? instruction->operand(0)->sharding()
: HloSharding::Unknown();
Expand Down Expand Up @@ -2238,7 +2281,8 @@ bool ShardingPropagation::InferShardingFromShardGroup(
// Propagate manual sharding.
if (!instruction->has_sharding() || instruction->sharding().IsTileMaximal()) {
for (const HloInstruction* member : shard_group) {
if (!member->has_sharding() || !member->sharding().IsManual()) {
if (!member->has_sharding() || !member->sharding().IsManual() ||
member == instruction) {
continue;
}
instruction->set_sharding(member->sharding());
Expand All @@ -2249,7 +2293,9 @@ bool ShardingPropagation::InferShardingFromShardGroup(
const bool may_combine_partial_sharding = is_spmd_ && aggressiveness > 0;
bool changed = false;
for (const HloInstruction* member : shard_group) {
if (member->IsCustomCall(spmd::kShardBarrierFrom)) {
// Do not propagate sharding from ShardBarrierFrom custom-call.
if (member == instruction ||
member->IsCustomCall(spmd::kShardBarrierFrom)) {
continue;
}
changed |= MaybeImproveInstructionSharding(member->sharding(), instruction,
Expand Down Expand Up @@ -3309,6 +3355,20 @@ absl::StatusOr<bool> ShardingPropagation::Run(
? shard_group_id_to_shard_as_group.at(shard_group_id)
: shard_group_id_to_shard_like_group.at(shard_group_id);
if (provided_shardings.contains(instruction)) {
if (!may_merge_partial) {
continue;
}
auto it = unspecified_dims.find(instruction);
if (it != unspecified_dims.end() &&
InferUnspecifiedDimsFromShardGroup(instruction, it->second,
shard_group)) {
++inferred_from_shard_group_counter;
VLOG(2) << "Refined partial sharding (shard group): "
<< instruction->ToString();
clear_cache(instruction);
already_inferred_from_shard_group.insert(instruction);
changed_last_iter = true;
}
continue;
}
already_inferred_from_shard_group.insert(instruction);
Expand Down Expand Up @@ -3469,9 +3529,23 @@ absl::StatusOr<bool> ShardingPropagation::Run(
VLOG(2) << "Aligning shard group: " << shard_as_group_id
<< " to sharding:" << common_sharding.ToString();
for (HloInstruction* member : shard_as_group) {
if (!member->IsCustomCall(spmd::kShardBarrierTo)) {
member->set_sharding(common_sharding);
if (member->IsCustomCall(spmd::kShardBarrierTo)) {
continue;
}
if (provided_shardings.contains(member)) {
auto it = unspecified_dims.find(member);
if (it != unspecified_dims.end()) {
HloSharding partial_replicated =
hlo_sharding_util::PartiallyReplicateTiledShardingOnAllDimsExcept(
common_sharding, it->second);
HloSharding sharding = member->sharding();
if (hlo_sharding_util::MergeShardingIfCompatible(
partial_replicated, sharding.NumTiles() + 1, &sharding)) {
member->set_sharding(sharding);
}
}
}
member->set_sharding(common_sharding);
}
}

Expand Down
74 changes: 74 additions & 0 deletions xla/service/sharding_propagation_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11793,5 +11793,79 @@ ENTRY entry_computation {
}
}

TEST_F(ShardingPropagationTest, ShardAsWithShardBarrier) {
const char* const hlo_string = R"(
HloModule pjit_f
ENTRY main.11 {
Arg_0.1 = bf16[384,1408]{1,0} parameter(0), sharding={devices=[1,16,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
broadcast.4 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
custom-call.5 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.4), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
broadcast.2 = bf16[8,384,1408]{2,1,0} broadcast(Arg_0.1), dimensions={1,2}
custom-call.3 = bf16[8,384,1408]{2,1,0} custom-call(broadcast.2), custom_call_target="Sharding", sharding={devices=[8,1,1,1024]<=[8192] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2]"
custom-call.6 = bf16[8,384,1408]{2,1,0} custom-call(custom-call.3), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
%shard-barrier-to = bf16[8,384,1408]{2,1,0} custom-call(%custom-call.6), custom_call_target="ShardBarrierTo", custom_call_has_side_effect=true
slice.7 = bf16[1,384,1408]{2,1,0} slice(shard-barrier-to), slice={[1:2], [0:384], [0:1408]}
reshape.8 = bf16[384,1408]{1,0} reshape(slice.7)
tuple.9 = (bf16[384,1408]{1,0}) tuple(reshape.8)
get-tuple-element.10 = bf16[384,1408]{1,0} get-tuple-element(tuple.9), index=0, sharding={devices=[16,1,512]<=[8,16,64]T(1,0,2) last_tile_dim_replicate}
ROOT tuple.13 = (bf16[384,1408]{1,0}, bf16[8,384,1408]{2,1,0}) tuple(get-tuple-element.10, custom-call.5)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(
/*is_spmd=*/true, /*propagate_metadata=*/true,
/*allow_spmd_sharding_propagation_to_output=*/{true},
/*allow_spmd_sharding_propagation_to_parameters=*/{false, false})
.Run(module.get()));
EXPECT_TRUE(changed);

XLA_VLOG_LINES(1, module->ToString());
auto* broadcast_4 = FindInstruction(module.get(), "broadcast.4");
ASSERT_NE(broadcast_4, nullptr);
EXPECT_THAT(
broadcast_4,
op::Sharding("{devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}"));
auto* copy = FindInstruction(module.get(), "copy");
ASSERT_NE(copy, nullptr);
EXPECT_THAT(
copy,
op::Sharding("{devices=[8,1,16,64]<=[8192] last_tile_dim_replicate}"));
}

TEST_F(ShardingPropagationTest, ShardAsWithShardBarrier2) {
const char* const hlo_string = R"(
HloModule module
ENTRY %elementwise {
%param0 = f32[5,7,11,13]{3,2,1,0} parameter(0)
%custom-call.0 = f32[5,7,11,13]{3,2,1,0} custom-call(param0), custom_call_target="Sharding", sharding={devices=[2,1,1,1,4]<=[8] last_tile_dim_replicate}, backend_config="unspecified_dims=[1,2,3]"
%shard-barrier-from = f32[5,7,11,13]{3,2,1,0} custom-call(%custom-call.0), custom_call_target="ShardBarrierFrom", custom_call_has_side_effect=true
%custom-call.2 = f32[5,7,11,13]{3,2,1,0} custom-call(shard-barrier-from), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
%param1 = f32[5,7,11,13]{3,2,1,0} parameter(1)
%custom-call.1 = f32[5,7,11,13]{3,2,1,0} custom-call(param1), custom_call_target="Sharding", sharding={devices=[1,2,2,1,2]<=[2,4]T(1,0) last_tile_dim_replicate}, backend_config="unspecified_dims=[0]"
%custom-call.3 = f32[5,7,11,13]{3,2,1,0} custom-call(custom-call.1), custom_call_target="Sharding", custom_call_has_side_effect=true, sharding={unknown shard_as 1}
ROOT %tuple = (f32[5,7,11,13]{3,2,1,0}, f32[5,7,11,13]{3,2,1,0}) tuple(%custom-call.0, %custom-call.3)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
ShardingPropagation(
/*is_spmd=*/true, /*propagate_metadata=*/true,
/*allow_spmd_sharding_propagation_to_output=*/{true},
/*allow_spmd_sharding_propagation_to_parameters=*/{false, false})
.Run(module.get()));
EXPECT_TRUE(changed);

XLA_VLOG_LINES(1, module->ToString());
EXPECT_THAT(
module->entry_computation()->root_instruction(),
op::Sharding(
"{{devices=[2,2,2,1]<=[8]}, {devices=[1,2,2,1,2]<=[2,4]T(1,0) "
"last_tile_dim_replicate}}"));
}

} // namespace
} // namespace xla

0 comments on commit ad8d093

Please sign in to comment.