Skip to content

Commit

Permalink
[XLA] Add shardings for implicit operands and return values of CaseOp…
Browse files Browse the repository at this point in the history
… and IfOp.

We only add arg shardings if there are result shardings, otherwise it means sharding propagation hasn't been done yet.

PiperOrigin-RevId: 644509565
  • Loading branch information
tensorflower-gardener authored and pull[bot] committed Jun 22, 2024
1 parent 625c053 commit 2012000
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 23 deletions.
100 changes: 77 additions & 23 deletions third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>
Expand Down Expand Up @@ -677,6 +678,20 @@ std::optional<xla::OpSharding> CreateTupleSharding(
return sharding;
}

// If `ops` has a single element, returns that element. Otherwise, returns
// a tuple instruction with `ops` and attaches a tuple sharding from
// `shardings`.
xla::XlaOp CreateTupleIfMultipleOps(
xla::XlaBuilder* builder, llvm::ArrayRef<xla::XlaOp> ops,
llvm::ArrayRef<std::optional<xla::OpSharding>> shardings) {
if (ops.size() == 1) {
return ops[0];
}
xla::XlaScopedShardingAssignment scoped_sharding(
builder, CreateTupleSharding(shardings));
return Tuple(builder, ops);
}

// Returns the flattened result shardings of the given `op_sharding`, i.e.,
// either:
// - an empty vector if `op_sharding` is `std::nullopt`.
Expand All @@ -700,6 +715,20 @@ llvm::SmallVector<std::optional<xla::OpSharding>> GetResultShardings(
return res_shardings;
}

// Returns the OpSharding of each op in `xla_ops`, or std::nullopt if the op
// doesn't have a sharding.
llvm::SmallVector<std::optional<xla::OpSharding>> GetXlaOpShardings(
llvm::ArrayRef<xla::XlaOp> xla_ops) {
llvm::SmallVector<std::optional<xla::OpSharding>> shardings;
shardings.reserve(xla_ops.size());
for (const xla::XlaOp& xla_op : xla_ops) {
auto sharding = xla_op.builder()->GetOpSharding(xla_op);
assert(sharding.ok() && "can't find XlaOp for argument");
shardings.push_back(*sharding);
}
return shardings;
}

namespace mlir {
namespace {
class ConvertToHloModule {
Expand Down Expand Up @@ -1602,17 +1631,37 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
llvm::SmallVector<mlir::Value> implicit_false_operands(
implicit_false_operand_set.begin(), implicit_false_operand_set.end());

llvm::SmallVector<std::optional<xla::OpSharding>> ret_shardings =
GetResultShardings(ctx.builder->sharding(), op->getNumResults());

llvm::SmallVector<xla::XlaOp> true_args;
if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args)))
return failure();

llvm::SmallVector<xla::XlaOp> false_args;
if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args)))
return failure();

llvm::SmallVector<std::optional<xla::OpSharding>> true_arg_shardings,
false_arg_shardings;
if (!ret_shardings.empty()) {
// We only add arg shardings if there are result shardings, otherwise it
// means sharding propagation hasn't been done yet.
true_arg_shardings = GetXlaOpShardings(true_args);
false_arg_shardings = GetXlaOpShardings(false_args);
}

// Create xla parameters for functions corresponding to ifOp regions using the
// implicit captures operands. Also export the instructions within those
// regions.
if (failed(ctx.converter->LowerRegionAsComputation(
&op.getTrueBranch(), &true_branch,
llvm::ArrayRef(implicit_true_operands),
/*ensure_single_arg*/ true)) ||
/*ensure_single_arg*/ true, true_arg_shardings, ret_shardings)) ||
failed(ctx.converter->LowerRegionAsComputation(
&op.getFalseBranch(), &false_branch,
llvm::ArrayRef(implicit_false_operands),
/*ensure_single_arg*/ true))) {
/*ensure_single_arg*/ true, false_arg_shardings, ret_shardings))) {
return failure();
}

Expand All @@ -1621,18 +1670,12 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) {
if (failed(GetXlaOp(op.getPred(), value_map, &pred, op))) return failure();

// Create the true branch Xla argument.
llvm::SmallVector<xla::XlaOp> true_args;
if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args)))
return failure();
xla::XlaOp true_arg =
true_args.size() == 1 ? true_args[0] : Tuple(ctx.builder, true_args);
CreateTupleIfMultipleOps(ctx.builder, true_args, true_arg_shardings);

// Create the false branch Xla argument.
llvm::SmallVector<xla::XlaOp> false_args;
if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args)))
return failure();
xla::XlaOp false_arg =
false_args.size() == 1 ? false_args[0] : Tuple(ctx.builder, false_args);
CreateTupleIfMultipleOps(ctx.builder, false_args, false_arg_shardings);

// Create XLA Conditional op.
auto ifop =
Expand Down Expand Up @@ -1673,18 +1716,30 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) {
llvm::SmallVector<mlir::Value> implicit_operands(
implicit_operand_set.begin(), implicit_operand_set.end());

llvm::SmallVector<std::optional<xla::OpSharding>> ret_shardings =
GetResultShardings(ctx.builder->sharding(), op->getNumResults());

// Create the branches[i]'s Xla argument.
llvm::SmallVector<xla::XlaOp> args;
if (failed(GetXlaOps(op, implicit_operands, ctx, args))) return failure();
branch_operands[i] = args.size() == 1 ? args[0] : Tuple(ctx.builder, args);

llvm::SmallVector<std::optional<xla::OpSharding>> arg_shardings;
if (!ret_shardings.empty()) {
// We only add arg shardings if there are result shardings, otherwise it
// means sharding propagation hasn't been done yet.
arg_shardings = GetXlaOpShardings(args);
}

branch_operands[i] =
CreateTupleIfMultipleOps(ctx.builder, args, arg_shardings);

// Create xla parameters for functions corresponding to region branches[i]
// using the implicit captures operands. Also export the instructions within
// that region.
computations_p[i] = &computations[i];
if (failed(ctx.converter->LowerRegionAsComputation(
&branches[i], computations_p[i], llvm::ArrayRef(implicit_operands),
/*ensure_single_arg*/ true)))
/*ensure_single_arg*/ true, arg_shardings, ret_shardings)))
return failure();
}

Expand Down Expand Up @@ -3482,10 +3537,6 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
// Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp.
llvm::SmallVector<xla::Shape, 4> arg_shapes;

// The arguments of `block` are ignored if `implicit_operands` is set,
// therefore `arg_shardings` should be empty in that case.
assert(arg_shardings.empty() || !implicit_operands);

auto args_size = block->getNumArguments();
if (implicit_operands) args_size = implicit_operands->size();

Expand All @@ -3512,10 +3563,13 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
"arg_tuple");

if (implicit_operands) {
int arg_index = 0;
for (auto implicit_operand : *implicit_operands)
lowering[implicit_operand] =
xla::GetTupleElement(tuple, arg_index++);
for (auto [arg_index, implicit_operand] :
llvm::enumerate(*implicit_operands)) {
xla::XlaScopedShardingAssignment scoped_sharding(
builder, arg_shardings.empty() ? std::nullopt
: arg_shardings[arg_index]);
lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index);
}
} else {
for (BlockArgument& arg : block->getArguments()) {
auto num = arg.getArgNumber();
Expand All @@ -3528,6 +3582,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
} else if (args_size == 1) {
// Save the location information as a name. For example JAX will set the
// name of the function argument. Want to preserve these for debugging.
xla::XlaScopedShardingAssignment scoped_sharding(
builder,
arg_shardings.empty() ? std::nullopt : arg_shardings.front());
if (implicit_operands) {
mlir::Value arg = (*implicit_operands)[0];
xla::XlaScopedOpMetadataAssignment op_metadata(
Expand All @@ -3537,9 +3594,6 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
mlir::BlockArgument arg = block->getArgument(0);
xla::XlaScopedOpMetadataAssignment op_metadata(
builder, GetOpNameMetadataFromLocation(arg));
xla::XlaScopedShardingAssignment scoped_sharding(
builder,
arg_shardings.empty() ? std::nullopt : arg_shardings.front());
lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_");
}
} else {
Expand Down
140 changes: 140 additions & 0 deletions third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,143 @@ func.func @main(%arg0: tensor<i32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>)
}
func.return %0#1, %0#2 : tensor<4xf32>, tensor<4xf32>
}

// -----

// CHECK-LABEL: HloModule main

// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) {
// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}}
// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1
// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}

// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) {
// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}
// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated}
// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]}
// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}

// CHECK: ENTRY %main.22 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) {
// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0)
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2)
// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}}
// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated}
// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]}
// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}
// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(s32[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), branch_computations={%region_0.8, %region_1.13},
// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}}
// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated}
// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]}
// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20)

func.func @main(%arg0: tensor<i32>,
%arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"},
%arg2: tensor<4xf32>,
%arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"},
%arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) {
%0:2 = "mhlo.case"(%arg0) ( {
mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32>
}, {
mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32>
}) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor<i32>) -> (tensor<4xf32>, tensor<4xf32>)
func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
}


// -----

// CHECK-LABEL: HloModule main

// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] {
// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0)

// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] {
// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0)

// CHECK: ENTRY %main.9 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] {
// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0)
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2)
// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(s32[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), branch_computations={%region_0.4, %region_1.6}
func.func @main(%arg0: tensor<i32>,
%arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"},
%arg2: tensor<4xf32>) -> tensor<4xf32> {
%0 = "mhlo.case"(%arg0) ( {
mhlo.return %arg1 : tensor<4xf32>
}, {
mhlo.return %arg2 : tensor<4xf32>
}) : (tensor<i32>) -> tensor<4xf32>
func.return %0 : tensor<4xf32>
}

// -----

// CHECK-LABEL: HloModule main

// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) {
// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}}
// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1
// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}

// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) {
// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}
// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated}
// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]}
// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}

// CHECK: ENTRY %main.22 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) {
// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0)
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2)
// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}}
// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated}
// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]}
// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}}
// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(pred[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), true_computation=%region_0.8, false_computation=%region_1.13,
// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}}
// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated}
// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]}
// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20)

func.func @main(%arg0: tensor<i1>,
%arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"},
%arg2: tensor<4xf32>,
%arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"},
%arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) {
%0:2 = "mhlo.if"(%arg0) ( {
mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32>
}, {
mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32>
}) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor<i1>) -> (tensor<4xf32>, tensor<4xf32>)
func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32>
}

// -----

// CHECK-LABEL: HloModule main

// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] {
// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0)

// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] {
// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0)

// CHECK: ENTRY %main.9 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] {
// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0)
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate}
// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2)
// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(pred[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), true_computation=%region_0.4, false_computation=%region_1.6

func.func @main(%arg0: tensor<i1>,
%arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"},
%arg2: tensor<4xf32>) -> tensor<4xf32> {
%0 = "mhlo.if"(%arg0) ( {
mhlo.return %arg1 : tensor<4xf32>
}, {
mhlo.return %arg2 : tensor<4xf32>
}) : (tensor<i1>) -> tensor<4xf32>
func.return %0 : tensor<4xf32>
}

0 comments on commit 2012000

Please sign in to comment.