From 20120007854d93c3d78811c9e9712c6d941191dd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 18 Jun 2024 14:18:54 -0700 Subject: [PATCH] [XLA] Add shardings for implicit operands and return values of CaseOp and IfOp. We only add arg shardings if there are result shardings, otherwise it means sharding propagation hasn't been done yet. PiperOrigin-RevId: 644509565 --- .../translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc | 100 ++++++++++--- .../translate/mhlo_to_hlo/tests/sharding.mlir | 140 ++++++++++++++++++ 2 files changed, 217 insertions(+), 23 deletions(-) diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc index 1095eb746bc7b9..4a6ad69d70428e 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc +++ b/third_party/xla/xla/translate/mhlo_to_hlo/mlir_hlo_to_hlo.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -677,6 +678,20 @@ std::optional CreateTupleSharding( return sharding; } +// If `ops` has a single element, returns that element. Otherwise, returns +// a tuple instruction with `ops` and attaches a tuple sharding from +// `shardings`. +xla::XlaOp CreateTupleIfMultipleOps( + xla::XlaBuilder* builder, llvm::ArrayRef ops, + llvm::ArrayRef> shardings) { + if (ops.size() == 1) { + return ops[0]; + } + xla::XlaScopedShardingAssignment scoped_sharding( + builder, CreateTupleSharding(shardings)); + return Tuple(builder, ops); +} + // Returns the flattened result shardings of the given `op_sharding`, i.e., // either: // - an empty vector if `op_sharding` is `std::nullopt`. @@ -700,6 +715,20 @@ llvm::SmallVector> GetResultShardings( return res_shardings; } +// Returns the OpSharding of each op in `xla_ops`, or std::nullopt if the op +// doesn't have a sharding. +llvm::SmallVector> GetXlaOpShardings( + llvm::ArrayRef xla_ops) { + llvm::SmallVector> shardings; + shardings.reserve(xla_ops.size()); + for (const xla::XlaOp& xla_op : xla_ops) { + auto sharding = xla_op.builder()->GetOpSharding(xla_op); + assert(sharding.ok() && "can't find XlaOp for argument"); + shardings.push_back(*sharding); + } + return shardings; +} + namespace mlir { namespace { class ConvertToHloModule { @@ -1602,17 +1631,37 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { llvm::SmallVector implicit_false_operands( implicit_false_operand_set.begin(), implicit_false_operand_set.end()); + llvm::SmallVector> ret_shardings = + GetResultShardings(ctx.builder->sharding(), op->getNumResults()); + + llvm::SmallVector true_args; + if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args))) + return failure(); + + llvm::SmallVector false_args; + if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args))) + return failure(); + + llvm::SmallVector> true_arg_shardings, + false_arg_shardings; + if (!ret_shardings.empty()) { + // We only add arg shardings if there are result shardings, otherwise it + // means sharding propagation hasn't been done yet. + true_arg_shardings = GetXlaOpShardings(true_args); + false_arg_shardings = GetXlaOpShardings(false_args); + } + // Create xla parameters for functions corresponding to ifOp regions using the // implicit captures operands. Also export the instructions within those // regions. if (failed(ctx.converter->LowerRegionAsComputation( &op.getTrueBranch(), &true_branch, llvm::ArrayRef(implicit_true_operands), - /*ensure_single_arg*/ true)) || + /*ensure_single_arg*/ true, true_arg_shardings, ret_shardings)) || failed(ctx.converter->LowerRegionAsComputation( &op.getFalseBranch(), &false_branch, llvm::ArrayRef(implicit_false_operands), - /*ensure_single_arg*/ true))) { + /*ensure_single_arg*/ true, false_arg_shardings, ret_shardings))) { return failure(); } @@ -1621,18 +1670,12 @@ LogicalResult ExportXlaOp(IfOp op, OpLoweringContext ctx) { if (failed(GetXlaOp(op.getPred(), value_map, &pred, op))) return failure(); // Create the true branch Xla argument. - llvm::SmallVector true_args; - if (failed(GetXlaOps(op, implicit_true_operands, ctx, true_args))) - return failure(); xla::XlaOp true_arg = - true_args.size() == 1 ? true_args[0] : Tuple(ctx.builder, true_args); + CreateTupleIfMultipleOps(ctx.builder, true_args, true_arg_shardings); // Create the false branch Xla argument. - llvm::SmallVector false_args; - if (failed(GetXlaOps(op, implicit_false_operands, ctx, false_args))) - return failure(); xla::XlaOp false_arg = - false_args.size() == 1 ? false_args[0] : Tuple(ctx.builder, false_args); + CreateTupleIfMultipleOps(ctx.builder, false_args, false_arg_shardings); // Create XLA Conditional op. auto ifop = @@ -1673,10 +1716,22 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { llvm::SmallVector implicit_operands( implicit_operand_set.begin(), implicit_operand_set.end()); + llvm::SmallVector> ret_shardings = + GetResultShardings(ctx.builder->sharding(), op->getNumResults()); + // Create the branches[i]'s Xla argument. llvm::SmallVector args; if (failed(GetXlaOps(op, implicit_operands, ctx, args))) return failure(); - branch_operands[i] = args.size() == 1 ? args[0] : Tuple(ctx.builder, args); + + llvm::SmallVector> arg_shardings; + if (!ret_shardings.empty()) { + // We only add arg shardings if there are result shardings, otherwise it + // means sharding propagation hasn't been done yet. + arg_shardings = GetXlaOpShardings(args); + } + + branch_operands[i] = + CreateTupleIfMultipleOps(ctx.builder, args, arg_shardings); // Create xla parameters for functions corresponding to region branches[i] // using the implicit captures operands. Also export the instructions within @@ -1684,7 +1739,7 @@ LogicalResult ExportXlaOp(CaseOp op, OpLoweringContext ctx) { computations_p[i] = &computations[i]; if (failed(ctx.converter->LowerRegionAsComputation( &branches[i], computations_p[i], llvm::ArrayRef(implicit_operands), - /*ensure_single_arg*/ true))) + /*ensure_single_arg*/ true, arg_shardings, ret_shardings))) return failure(); } @@ -3482,10 +3537,6 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( // Applicable for mhlo.IfOp or mhlo.CaseOp or mhlo.WhileOp. llvm::SmallVector arg_shapes; - // The arguments of `block` are ignored if `implicit_operands` is set, - // therefore `arg_shardings` should be empty in that case. - assert(arg_shardings.empty() || !implicit_operands); - auto args_size = block->getNumArguments(); if (implicit_operands) args_size = implicit_operands->size(); @@ -3512,10 +3563,13 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( "arg_tuple"); if (implicit_operands) { - int arg_index = 0; - for (auto implicit_operand : *implicit_operands) - lowering[implicit_operand] = - xla::GetTupleElement(tuple, arg_index++); + for (auto [arg_index, implicit_operand] : + llvm::enumerate(*implicit_operands)) { + xla::XlaScopedShardingAssignment scoped_sharding( + builder, arg_shardings.empty() ? std::nullopt + : arg_shardings[arg_index]); + lowering[implicit_operand] = xla::GetTupleElement(tuple, arg_index); + } } else { for (BlockArgument& arg : block->getArguments()) { auto num = arg.getArgNumber(); @@ -3528,6 +3582,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( } else if (args_size == 1) { // Save the location information as a name. For example JAX will set the // name of the function argument. Want to preserve these for debugging. + xla::XlaScopedShardingAssignment scoped_sharding( + builder, + arg_shardings.empty() ? std::nullopt : arg_shardings.front()); if (implicit_operands) { mlir::Value arg = (*implicit_operands)[0]; xla::XlaScopedOpMetadataAssignment op_metadata( @@ -3537,9 +3594,6 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( mlir::BlockArgument arg = block->getArgument(0); xla::XlaScopedOpMetadataAssignment op_metadata( builder, GetOpNameMetadataFromLocation(arg)); - xla::XlaScopedShardingAssignment scoped_sharding( - builder, - arg_shardings.empty() ? std::nullopt : arg_shardings.front()); lowering[arg] = xla::Parameter(builder, 0, arg_shapes[0], "Arg_"); } } else { diff --git a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir b/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir index a93bdee50abb8e..295c95075c9fc7 100644 --- a/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir +++ b/third_party/xla/xla/translate/mhlo_to_hlo/tests/sharding.mlir @@ -220,3 +220,143 @@ func.func @main(%arg0: tensor, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) } func.return %0#1, %0#2 : tensor<4xf32>, tensor<4xf32> } + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1 +// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: ENTRY %main.22 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated} +// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]} +// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(s32[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), branch_computations={%region_0.8, %region_1.13}, +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20) + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>, + %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, + %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = "mhlo.case"(%arg0) ( { + mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> + }, { + mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> + }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) + func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> +} + + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0) + +// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0) + +// CHECK: ENTRY %main.9 (Arg_0.1: s32[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { +// CHECK-NEXT: %Arg_0.1 = s32[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(s32[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), branch_computations={%region_0.4, %region_1.6} +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.case"(%arg0) ( { + mhlo.return %arg1 : tensor<4xf32> + }, { + mhlo.return %arg2 : tensor<4xf32> + }) : (tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.8 (arg_tuple.9: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.9 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %get-tuple-element.10 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=0, sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %get-tuple-element.11 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.9), index=1 +// CHECK-NEXT: ROOT %tuple.12 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.10, f32[4] %get-tuple-element.11), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: %region_1.13 (arg_tuple.14: (f32[4], f32[4])) -> (f32[4], f32[4]) { +// CHECK-NEXT: %arg_tuple.14 = (f32[4], f32[4]) parameter(0), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.15 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.16 = f32[4] get-tuple-element((f32[4], f32[4]) %arg_tuple.14), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.17 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.15, f32[4] %get-tuple-element.16), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} + +// CHECK: ENTRY %main.22 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4], Arg_3.4: f32[4], Arg_4.5: f32[4]) -> (f32[4], f32[4]) { +// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: %tuple.6 = (f32[4], f32[4]) tuple(f32[4] %Arg_1.2, f32[4] %Arg_2.3), sharding={{\{}}{devices=[2,2]<=[4] last_tile_dim_replicate}, {replicated}} +// CHECK-NEXT: %Arg_3.4 = f32[4] parameter(3), sharding={replicated} +// CHECK-NEXT: %Arg_4.5 = f32[4] parameter(4), sharding={devices=[4]<=[4]} +// CHECK-NEXT: %tuple.7 = (f32[4], f32[4]) tuple(f32[4] %Arg_3.4, f32[4] %Arg_4.5), sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %conditional.18 = (f32[4], f32[4]) conditional(pred[] %Arg_0.1, (f32[4], f32[4]) %tuple.6, (f32[4], f32[4]) %tuple.7), true_computation=%region_0.8, false_computation=%region_1.13, +// CHECK-SAME: sharding={{\{}}{replicated}, {devices=[4]<=[4]}} +// CHECK-NEXT: %get-tuple-element.19 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=0, sharding={replicated} +// CHECK-NEXT: %get-tuple-element.20 = f32[4] get-tuple-element((f32[4], f32[4]) %conditional.18), index=1, sharding={devices=[4]<=[4]} +// CHECK-NEXT: ROOT %tuple.21 = (f32[4], f32[4]) tuple(f32[4] %get-tuple-element.19, f32[4] %get-tuple-element.20) + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>, + %arg3: tensor<4xf32> {mhlo.sharding = "{replicated}"}, + %arg4: tensor<4xf32> {mhlo.sharding = "{devices=[4]<=[4]}"}) -> (tensor<4xf32>, tensor<4xf32>) { + %0:2 = "mhlo.if"(%arg0) ( { + mhlo.return %arg1, %arg2 : tensor<4xf32>, tensor<4xf32> + }, { + mhlo.return %arg3, %arg4 : tensor<4xf32>, tensor<4xf32> + }) {mhlo.sharding = "{{replicated},{devices=[4]<=[4]}}"} : (tensor) -> (tensor<4xf32>, tensor<4xf32>) + func.return %0#0, %0#1 : tensor<4xf32>, tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: HloModule main + +// CHECK: %region_0.4 (Arg_.5: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.5 = f32[4] parameter(0) + +// CHECK: %region_1.6 (Arg_.7: f32[4]) -> f32[4] { +// CHECK-NEXT: ROOT %Arg_.7 = f32[4] parameter(0) + +// CHECK: ENTRY %main.9 (Arg_0.1: pred[], Arg_1.2: f32[4], Arg_2.3: f32[4]) -> f32[4] { +// CHECK-NEXT: %Arg_0.1 = pred[] parameter(0) +// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1), sharding={devices=[2,2]<=[4] last_tile_dim_replicate} +// CHECK-NEXT: %Arg_2.3 = f32[4] parameter(2) +// CHECK-NEXT: ROOT %conditional.8 = f32[4] conditional(pred[] %Arg_0.1, f32[4] %Arg_1.2, f32[4] %Arg_2.3), true_computation=%region_0.4, false_computation=%region_1.6 + +func.func @main(%arg0: tensor, + %arg1: tensor<4xf32> {mhlo.sharding = "{devices=[2,2]<=[4] last_tile_dim_replicate}"}, + %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = "mhlo.if"(%arg0) ( { + mhlo.return %arg1 : tensor<4xf32> + }, { + mhlo.return %arg2 : tensor<4xf32> + }) : (tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +}