Skip to content

Commit

Permalink
[xla:gpu] NFC: Remove LMHLO op argument from EmitKernel #6224
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574213434
  • Loading branch information
anlunx authored and copybara-github committed Oct 17, 2023
1 parent fcdf7df commit dba73eb
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 36 deletions.
11 changes: 8 additions & 3 deletions xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ cc_library(
hdrs = ["in_place_dynamic_update_slice.h"],
deps = [
":fusion_emitter",
"//xla/hlo/ir:hlo",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:launch_dimensions",
"//xla/service/llvm_ir:dynamic_update_slice_util",
"//xla/service/llvm_ir:fused_ir_emitter",
"//xla/service/llvm_ir:ir_array",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
],
Expand Down Expand Up @@ -42,15 +44,13 @@ cc_library(
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:target_util",
"//xla/service/gpu:thunk",
"//xla/service/llvm_ir:buffer_assignment_util",
"//xla/service/llvm_ir:ir_array",
"//xla/service/llvm_ir:llvm_util",
"//xla/translate/mhlo_to_hlo:location_exporter",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@tsl//tsl/platform:errors",
],
)

Expand Down Expand Up @@ -90,9 +90,12 @@ cc_library(
"//xla/mlir_hlo:lhlo",
"//xla/service:elemental_ir_emitter",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:ir_emitter_context",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:parallel_loop_emitter",
"//xla/service/llvm_ir:fused_ir_emitter",
"//xla/service/llvm_ir:ir_array",
"@llvm-project//llvm:ir_headers",
],
)
Expand Down Expand Up @@ -189,6 +192,7 @@ cc_library(
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:ir_emitter_context",
"//xla/service/gpu:launch_dimensions",
"//xla/service/gpu:target_util",
"//xla/service/llvm_ir:fused_ir_emitter",
"//xla/service/llvm_ir:ir_array",
Expand All @@ -203,6 +207,7 @@ cc_library(
hdrs = ["input_slices.h"],
deps = [
":fusion_emitter",
"//xla/hlo/ir:hlo",
"//xla/service:elemental_ir_emitter",
"//xla/service/gpu:hlo_fusion_analysis",
"//xla/service/gpu:ir_emission_utils",
Expand Down
7 changes: 4 additions & 3 deletions xla/service/gpu/fusions/fusion_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/service/gpu/target_util.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "tsl/platform/errors.h"

namespace xla {
namespace gpu {
Expand Down Expand Up @@ -199,9 +200,9 @@ StatusOr<FusionEmissionResult> KernelFusionEmitterBase::Emit(
ir_emitter_context, suggested_kernel_name,
kernel_arguments.args(), fusion_op.getInputBuffers().size(),
launch_dims, builder);
TF_RETURN_IF_ERROR(EmitKernel(
ir_emitter_context, elemental_emitter, fusion_op, fusion,
launch_dims, std::move(inputs), std::move(outputs), builder, i));
TF_RETURN_IF_ERROR(EmitKernel(ir_emitter_context, elemental_emitter,
fusion, launch_dims, std::move(inputs),
std::move(outputs), builder, i));
// TODO(jreiffers): Return shmem_bytes from EmitKernel when
// converting the Triton emitters to this infrastructure.
return KernelReuseCache::Entry{kernel->getName().str(), launch_dims,
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/fusions/fusion_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class KernelFusionEmitterBase : public FusionInterface {
protected:
virtual Status EmitKernel(IrEmitterContext& ir_emitter_context,
ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op,
const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs,
Expand Down
9 changes: 5 additions & 4 deletions xla/service/gpu/fusions/in_place_dynamic_update_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ limitations under the License.

#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/IRBuilder.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/llvm_ir/dynamic_update_slice_util.h"
#include "xla/service/llvm_ir/fused_ir_emitter.h"
#include "xla/service/llvm_ir/ir_array.h"

namespace xla {
namespace gpu {
Expand All @@ -35,10 +37,9 @@ StatusOr<LaunchDimensions> InPlaceDynamicUpdateSliceEmitter::launch_dimensions(

Status InPlaceDynamicUpdateSliceEmitter::EmitKernel(
IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder,
int kernel_index) const {
const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs, std::vector<llvm_ir::IrArray> outputs,
llvm::IRBuilder<>* builder, int kernel_index) const {
// In case a dynamic slice update's output is bitcasted, we need to ensure we
// write to the output array using the shape and layout of the dynamic slice
// update. This cast is known to be safe to do iff, in the case the output of
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/fusions/in_place_dynamic_update_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class InPlaceDynamicUpdateSliceEmitter : public KernelFusionEmitterBase {
protected:
Status EmitKernel(IrEmitterContext& ir_emitter_context,
ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op,
const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs,
Expand Down
14 changes: 7 additions & 7 deletions xla/service/gpu/fusions/input_slices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.
#include "xla/service/gpu/fusions/input_slices.h"

#include "llvm/IR/IRBuilder.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/elemental_ir_emitter.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/parallel_loop_emitter.h"
Expand Down Expand Up @@ -158,10 +159,9 @@ StatusOr<LaunchDimensions> InputSlicesFusion::launch_dimensions(

Status InputSlicesFusion::EmitKernel(
IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder,
int kernel_index) const {
const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs, std::vector<llvm_ir::IrArray> outputs,
llvm::IRBuilder<>* builder, int kernel_index) const {
TF_ASSIGN_OR_RETURN(Shape element_shape,
GetConsistentInputShapeForRootSlices(
fusion.fused_instructions_computation()));
Expand All @@ -172,9 +172,9 @@ Status InputSlicesFusion::EmitKernel(
inputs, outputs, index, builder);
},
element_shape, launch_dims, builder)
.EmitLoop(llvm_ir::IrName(GetIrNameFromLoc(fusion_op.getLoc())),
GetIndexTypeForKernel(fusion_op, launch_dims.launch_bound(),
builder));
.EmitLoop(
fusion.name(),
GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder));
}

} // namespace gpu
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/fusions/input_slices.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ class InputSlicesFusion : public KernelFusionEmitterBase {
protected:
Status EmitKernel(IrEmitterContext& ir_emitter_context,
ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op,
const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs,
Expand Down
17 changes: 10 additions & 7 deletions xla/service/gpu/fusions/loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@ limitations under the License.
#include <vector>

#include "llvm/IR/IRBuilder.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/parallel_loop_emitter.h"
#include "xla/service/llvm_ir/fused_ir_emitter.h"
#include "xla/service/llvm_ir/ir_array.h"

namespace xla {
namespace gpu {

Status LoopFusion::EmitKernel(
IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder,
int kernel_index) const {
const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs, std::vector<llvm_ir::IrArray> outputs,
llvm::IRBuilder<>* builder, int kernel_index) const {
FusedIrEmitter fused_emitter(elemental_emitter);
for (int i = 0; i < fusion_op.getInputBuffers().size(); i++) {
for (int i = 0; i < fusion.fused_parameters().size(); i++) {
fused_emitter.BindGenerator(
*fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) {
return inputs[i].EmitReadArrayElement(index, builder);
Expand All @@ -41,11 +44,11 @@ Status LoopFusion::EmitKernel(
fused_emitter.GetGenerator(*fusion.fused_expression_root()));

llvm::Type* index_type =
GetIndexTypeForKernel(fusion_op, launch_dims.launch_bound(), builder);
GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);

return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder,
*analysis_.GetLoopFusionConfig())
.EmitLoop(GetIrNameFromLoc(fusion_op->getLoc()), index_type);
.EmitLoop(fusion.name(), index_type);
}

StatusOr<LaunchDimensions> LoopFusion::launch_dimensions(
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/fusions/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class LoopFusion : public KernelFusionEmitterBase {
protected:
Status EmitKernel(IrEmitterContext& ir_emitter_context,
ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op,
const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs,
Expand Down
11 changes: 6 additions & 5 deletions xla/service/gpu/fusions/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ limitations under the License.
#include <vector>

#include "llvm/IR/IRBuilder.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/permutation_util.h"
#include "xla/service/gpu/fusions/tiling_util.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/launch_dimensions.h"
#include "xla/service/gpu/target_util.h"
#include "xla/service/llvm_ir/fused_ir_emitter.h"
#include "xla/service/llvm_ir/ir_array.h"
Expand Down Expand Up @@ -73,10 +75,9 @@ llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index,

Status TransposeFusion::EmitKernel(
IrEmitterContext& ir_emitter_context, ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op, const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims, std::vector<llvm_ir::IrArray> inputs,
std::vector<llvm_ir::IrArray> outputs, llvm::IRBuilder<>* builder,
int kernel_index) const {
const HloFusionInstruction& fusion, const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs, std::vector<llvm_ir::IrArray> outputs,
llvm::IRBuilder<>* builder, int kernel_index) const {
const auto& tiling_scheme = *analysis_.GetTransposeTilingScheme();
const auto& hlo_roots = analysis_.fusion_roots();
FusedIrEmitter fused_emitter(elemental_emitter);
Expand Down Expand Up @@ -233,7 +234,7 @@ Status TransposeFusion::EmitKernel(
};

llvm::Type* index_type =
GetIndexTypeForKernel(fusion_op, launch_dims.launch_bound(), builder);
GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder);
return EmitTilingKernel(builder, tiling_scheme, index_type, tile_generator)
.status();
}
Expand Down
1 change: 0 additions & 1 deletion xla/service/gpu/fusions/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class TransposeFusion : public KernelFusionEmitterBase {
protected:
Status EmitKernel(IrEmitterContext& ir_emitter_context,
ElementalIrEmitter& elemental_emitter,
mlir::lmhlo::FusionOp fusion_op,
const HloFusionInstruction& fusion,
const LaunchDimensions& launch_dims,
std::vector<llvm_ir::IrArray> inputs,
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/tests/fusion.hlo
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ HloModule TestModule, is_scheduled=true
// CHECK: %[[VAL_36:.*]] = udiv i32 %[[VAL_29]], 802816
// CHECK: %[[VAL_37:.*]] = icmp ult i32 %[[VAL_5]], 102760448
// CHECK: br i1 %[[VAL_37]], label %[[VAL_38:.*]], label %[[VAL_39:.*]]
// CHECK: fusion_1.in_bounds-after: ; preds = %[[VAL_38]], %[[VAL_40:.*]]
// CHECK: fusion.1.in_bounds-after: ; preds = %[[VAL_38]], %[[VAL_40:.*]]
// CHECK: ret void
// CHECK: fusion_1.in_bounds-true: ; preds = %[[VAL_40]]
// CHECK: fusion.1.in_bounds-true: ; preds = %[[VAL_40]]
// CHECK: %[[VAL_41:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_42:.*]], i32 0, i32 %[[VAL_7]]
// CHECK: %[[VAL_43:.*]] = load float, ptr %[[VAL_41]], align 4, !invariant.load
// CHECK: %[[VAL_44:.*]] = getelementptr inbounds [64 x float], ptr %[[VAL_45:.*]], i32 0, i32 %[[VAL_7]]
Expand Down

0 comments on commit dba73eb

Please sign in to comment.