diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index 1861d4a0f5baa..55e65824b683c 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -260,6 +260,7 @@ xla_cc_test( "//xla/tests:xla_internal_test_main", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@tsl//tsl/platform:test", ], ) diff --git a/xla/service/gpu/elemental_ir_emitter.cc b/xla/service/gpu/elemental_ir_emitter.cc index f629ecc4aff9c..093940e7f0f2f 100644 --- a/xla/service/gpu/elemental_ir_emitter.cc +++ b/xla/service/gpu/elemental_ir_emitter.cc @@ -66,23 +66,23 @@ absl::StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( absl::string_view name) { // Device functions don't have f16 math functions, so we convert the operands // to f32 before calling the function and then convert the result back to f16. - bool cast_result_to_fp16 = false; std::vector converted_operands(operands.begin(), operands.end()); std::vector converted_input_types(input_types.begin(), input_types.end()); + PrimitiveType original_output_type = output_type; switch (output_type) { + case BF16: case F16: - cast_result_to_fp16 = true; for (int64_t i = 0; i < operands.size(); ++i) { - if (input_types[i] == F16) { + if (input_types[i] == original_output_type) { converted_operands[i] = FPCast(converted_operands[i], b()->getFloatTy()); converted_input_types[i] = F32; } } output_type = F32; - [[fallthrough]]; + break; case F32: break; case F64: @@ -92,13 +92,13 @@ absl::StatusOr GpuElementalIrEmitter::EmitDeviceMathCall( PrimitiveType_Name(output_type)); } const std::string& munged_callee = ObtainDeviceFunctionName( - funcid, output_type, + funcid, original_output_type, llvm::Triple(b()->GetInsertBlock()->getModule()->getTargetTriple())); llvm::Value* result = EmitMathCall(munged_callee, converted_operands, converted_input_types, output_type, name) .value(); - if (cast_result_to_fp16) { - result = FPCast(result, b()->getHalfTy()); + if (output_type != original_output_type) { + result = FPCast(result, operands[0]->getType()); } return result; } diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc index fd0dc7c570c3e..b3351d439497a 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -607,12 +608,13 @@ template SmallVector MapHloOp(mlir::Type result_type, llvm::ArrayRef arg_types, llvm::ArrayRef args, + llvm::ArrayRef attributes, ImplicitLocOpBuilder& b, ExtraArgs&&... extra_args) { Value result = mhlo::MhloOpToStdScalarOp::mapOpOfType( b.getLoc(), result_type, arg_types, typename MhloOp::Adaptor(args, std::forward(extra_args)...), - /*attributes=*/std::nullopt, &b); + attributes, &b); if (result.getType().isInteger(1)) { result = b.create(b.getI8Type(), result); } @@ -620,11 +622,13 @@ SmallVector MapHloOp(mlir::Type result_type, } template -SmallVector MapElementwiseOp(llvm::ArrayRef arg_types, - llvm::ArrayRef args, - ImplicitLocOpBuilder& b) { +SmallVector MapElementwiseOp( + llvm::ArrayRef arg_types, llvm::ArrayRef args, + ImplicitLocOpBuilder& b, + llvm::ArrayRef attributes = std::nullopt) { // We use the last argument's type because of select. - return MapHloOp(args.back().getType(), arg_types, args, b); + return MapHloOp(args.back().getType(), arg_types, args, attributes, + b); } } // namespace @@ -946,9 +950,9 @@ absl::StatusOr> EmitReducePrecision( mhlo::ReducePrecisionOp::Properties properties; properties.exponent_bits = builder.getI32IntegerAttr(instr->exponent_bits()); properties.mantissa_bits = builder.getI32IntegerAttr(instr->mantissa_bits()); - return MapHloOp(operands.front().getType(), - arg_types, operands, builder, - nullptr, properties); + return MapHloOp( + operands.front().getType(), arg_types, operands, + /*attributes=*/std::nullopt, builder, nullptr, properties); } absl::StatusOr> HloToMlir( @@ -1014,11 +1018,12 @@ absl::StatusOr> HloToMlir( TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instr, indices, operand_provider, builder)); + llvm::SmallVector attributes; switch (instr->opcode()) { case HloOpcode::kAbs: - return { - MapHloOp(PrimitiveTypeToMlirType(element_type, builder), - arg_types, operands, builder)}; + return {MapHloOp( + PrimitiveTypeToMlirType(element_type, builder), arg_types, operands, + /*attributes=*/std::nullopt, builder)}; case HloOpcode::kAdd: if (element_type == PRED) { return MapElementwiseOp(arg_types, operands, builder); @@ -1041,7 +1046,7 @@ absl::StatusOr> HloToMlir( case HloOpcode::kComplex: return MapHloOp( PrimitiveTypeToMlirType(element_type, builder), arg_types, operands, - builder); + /*attributes=*/std::nullopt, builder); case HloOpcode::kCos: return MapElementwiseOp(arg_types, operands, builder); case HloOpcode::kDivide: @@ -1049,18 +1054,26 @@ absl::StatusOr> HloToMlir( case HloOpcode::kErf: return MapElementwiseOp(arg_types, operands, builder); case HloOpcode::kExp: - return MapElementwiseOp(arg_types, operands, builder); + if (element_type == F16 || element_type == BF16) { + attributes.emplace_back( + mlir::StringAttr::get(builder.getContext(), "fastmath"), + mlir::arith::FastMathFlagsAttr::get( + builder.getContext(), mlir::arith::FastMathFlags::afn)); + } + return MapElementwiseOp(arg_types, operands, builder, + attributes); case HloOpcode::kExpm1: return MapElementwiseOp(arg_types, operands, builder); case HloOpcode::kFloor: return MapElementwiseOp(arg_types, operands, builder); case HloOpcode::kIsFinite: return MapHloOp(builder.getI1Type(), arg_types, - operands, builder); + operands, /*attributes=*/std::nullopt, + builder); case HloOpcode::kImag: return MapHloOp( PrimitiveTypeToMlirType(element_type, builder), arg_types, operands, - builder); + /*attributes=*/std::nullopt, builder); case HloOpcode::kLog: return MapElementwiseOp(arg_types, operands, builder); case HloOpcode::kLog1p: @@ -1106,13 +1119,13 @@ absl::StatusOr> HloToMlir( case HloOpcode::kPopulationCount: return MapHloOp( PrimitiveTypeToMlirType(element_type, builder), arg_types, operands, - builder); + /*attributes=*/std::nullopt, builder); case HloOpcode::kPower: return MapElementwiseOp(arg_types, operands, builder); case HloOpcode::kReal: return MapHloOp( PrimitiveTypeToMlirType(element_type, builder), arg_types, operands, - builder); + /*attributes=*/std::nullopt, builder); case HloOpcode::kReducePrecision: return EmitReducePrecision(instr, arg_types, operands, builder); case HloOpcode::kRemainder: @@ -1154,7 +1167,7 @@ absl::StatusOr> HloToMlir( case HloOpcode::kBitcastConvert: return MapHloOp( PrimitiveTypeToMlirType(element_type, builder), arg_types, operands, - builder); + /*attributes=*/std::nullopt, builder); case HloOpcode::kConvert: return EmitConvert(instr, arg_types, operands, builder); case HloOpcode::kBitcast: @@ -1265,7 +1278,7 @@ class SubgraphConverter { absl::StatusOr> EmitInstruction( const HloInstruction* instr, ValueRange indices); absl::StatusOr> EmitElementwiseInstruction( - const HloInstruction* instr, ValueRange indices); + const HloInstruction* root, ValueRange indices); private: const PartitionedComputation& computation_; diff --git a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc index fff26b5587806..ee973b2f64db5 100644 --- a/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc +++ b/xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc @@ -1337,6 +1337,43 @@ TEST_F(ElementalHloToMlirTest, ConvertToUnsigned64Saturation) { )")); } +TEST_F(ElementalHloToMlirTest, ExpF16_UsesFastmathFlag) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f16[4] parameter(0) + ROOT exp = f16[4] exponential(p0) + })", + R"( + // CHECK: @main_exp( + // CHECK: math.exp %{{.*}} fastmath : f16 + )")); +} + +TEST_F(ElementalHloToMlirTest, ExpBF16_UsesFastmathFlag) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = bf16[4] parameter(0) + ROOT exp = bf16[4] exponential(p0) + })", + R"( + // CHECK: @main_exp( + // CHECK: math.exp %{{.*}} fastmath : bf16 + )")); +} + +TEST_F(ElementalHloToMlirTest, ExpF32_DoesntUseFastmathFlag) { + TF_EXPECT_OK(Run(R"( + ENTRY main { + p0 = f32[4] parameter(0) + ROOT exp = f32[4] exponential(p0) + })", + R"( + // CHECK: @main_exp( + // CHECK: math.exp + // CHECK-NOT: fastmath + )")); +} + TEST_F(ElementalHloToMlirTest, PopulationCountUnsigned) { TF_EXPECT_OK(Run(R"( ENTRY main{ diff --git a/xla/service/gpu/fusions/triton/emitter_helpers.cc b/xla/service/gpu/fusions/triton/emitter_helpers.cc index c09333b221d82..83039d1c72133 100644 --- a/xla/service/gpu/fusions/triton/emitter_helpers.cc +++ b/xla/service/gpu/fusions/triton/emitter_helpers.cc @@ -353,19 +353,17 @@ absl::StatusOr EmitElementwiseLibdeviceFunction( triple.setTriple("amdgcn-unknown-unknown"); } llvm::SmallVector casted_inputs; - PrimitiveType casted_output_type = output_type; if (output_type == PrimitiveType::BF16 || output_type == PrimitiveType::F16) { // Upcast the inputs to F32. for (int64_t i = 0; i < inputs.size(); ++i) { casted_inputs.push_back(Cast(b, inputs[i], b.getF32Type())); } - casted_output_type = F32; } else { casted_inputs.assign(inputs.begin(), inputs.end()); } Value res = b.create( casted_inputs[0].getType(), casted_inputs, "libdevice", libdevice_path, - ObtainDeviceFunctionName(dev_fn_id.value(), casted_output_type, triple), + ObtainDeviceFunctionName(dev_fn_id.value(), output_type, triple), /*pure=*/true); if (output_type == PrimitiveType::BF16 || output_type == PrimitiveType::F16) { // Downcast back to the original output type. diff --git a/xla/service/gpu/gpu_float_support.cc b/xla/service/gpu/gpu_float_support.cc index c02e158349962..bf9cb208738c5 100644 --- a/xla/service/gpu/gpu_float_support.cc +++ b/xla/service/gpu/gpu_float_support.cc @@ -97,9 +97,11 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { case HloOpcode::kReducePrecision: return true; // Elementwise ops. + case HloOpcode::kExp: + return LowPrecisionType() == BF16; case HloOpcode::kAdd: - case HloOpcode::kSubtract: - case HloOpcode::kMultiply: { + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: { if (LowPrecisionType() == BF16) { auto* cuda_compute_capability = std::get_if(&compute_capability_); diff --git a/xla/service/gpu/gpu_float_support_test.cc b/xla/service/gpu/gpu_float_support_test.cc index 1c1c7f4e8149d..dca2bb345373a 100644 --- a/xla/service/gpu/gpu_float_support_test.cc +++ b/xla/service/gpu/gpu_float_support_test.cc @@ -271,6 +271,21 @@ ENTRY main { EXPECT_FALSE(Normalize(module.get(), cc, BF16, F32)); } +TEST_F(FloatSupportTest, Bf16ExpIsNotNormalized) { + auto cc = se::CudaComputeCapability::Ampere(); + constexpr absl::string_view kHloModule = R"( +HloModule m + +ENTRY main { + p0 = bf16[] parameter(0) + ROOT r = bf16[] exponential(p0) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule)); + EXPECT_FALSE(Normalize(module.get(), cc, BF16, F32)); +} + TEST_F(FloatSupportTest, BF16ReductionOnHopperIsOnlyNormalizedIfReducerIsUnsupported) { auto cc = se::CudaComputeCapability::Hopper(); diff --git a/xla/service/gpu/target_util.cc b/xla/service/gpu/target_util.cc index c1152173f42e3..a93083266c032 100644 --- a/xla/service/gpu/target_util.cc +++ b/xla/service/gpu/target_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include "xla/service/gpu/target_util.h" +#include #include #include #include @@ -309,6 +310,14 @@ std::optional GetTargetDeviceFunctionID(HloOpcode op) { return std::nullopt; } +namespace { +// TODO(b/370452608): Add more functions that have a fast approximation for f32 +// that we can use for f16 types. +bool HasFastF32Approximation(TargetDeviceFunctionID func_id) { + return func_id == TargetDeviceFunctionID::kExp; +} +} // namespace + std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id, PrimitiveType output_type, llvm::Triple target_triple) { @@ -317,8 +326,17 @@ std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id, // the root name are specific to the target. struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id); if (target_triple.isNVPTX()) { - if (output_type == F32) { - return StrCat(gpu_root_names.nvptx_root, "f"); + bool is_supported_output_type = + output_type == BF16 || output_type == F16 || output_type == F32; + if (is_supported_output_type) { + std::string function_name = StrCat(gpu_root_names.nvptx_root, "f"); + if (HasFastF32Approximation(func_id) && + (output_type == BF16 || output_type == F16)) { + // All function names start with "__nv". The approximate version of the + // function names continues with "_fast". + return function_name.insert(strlen("__nv"), "_fast"); + } + return function_name; } else if (output_type == F64) { return gpu_root_names.nvptx_root; } else { @@ -326,7 +344,9 @@ std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id, << primitive_util::LowercasePrimitiveTypeName(output_type); } } else if (target_triple.getArch() == llvm::Triple::amdgcn) { - if (output_type == F32) { + // TODO(b/370452608): Are there approximate functions we can use for BF16 + // and F16 types? + if (output_type == BF16 || output_type == F16 || output_type == F32) { return StrCat(gpu_root_names.amdgpu_root, "_f32"); } else if (output_type == F64) { return StrCat(gpu_root_names.amdgpu_root, "_f64"); @@ -334,7 +354,9 @@ std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id, LOG(FATAL) << "Unexpected type while getting device function name."; } } else if (target_triple.isSPIR()) { - if (output_type == F32) { + // TODO(b/370452608): Are there approximate functions we can use for BF16 + // and F16 types? + if (output_type == BF16 || output_type == F16 || output_type == F32) { if (gpu_root_names.spir_root == "_Z17__spirv_ocl_hypot" || gpu_root_names.spir_root == "_Z15__spirv_ocl_pow" || gpu_root_names.spir_root == "_Z17__spirv_ocl_atan2" || diff --git a/xla/service/gpu/target_util.h b/xla/service/gpu/target_util.h index af83b7a849af3..297ff45bd6fd7 100644 --- a/xla/service/gpu/target_util.h +++ b/xla/service/gpu/target_util.h @@ -94,6 +94,8 @@ llvm::CallInst* EmitCallToTargetIntrinsic( void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func, llvm::IRBuilder<>* b); +// 'output_type' is the type of the math op corresponding to 'func_id' for which +// we want to obtain the device function name. std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id, PrimitiveType output_type, llvm::Triple target_triple); diff --git a/xla/service/gpu/target_util_test.cc b/xla/service/gpu/target_util_test.cc index 751efdecf23b5..c82c1f077cfd9 100644 --- a/xla/service/gpu/target_util_test.cc +++ b/xla/service/gpu/target_util_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TargetParser/Triple.h" #include "tsl/platform/test.h" namespace xla { @@ -63,6 +64,16 @@ TEST_F(TargetUtilTest, AMDGCNGroupBarrier) { EXPECT_FALSE(llvm::verifyModule(module_, &llvm::errs())); } +TEST(TargetUtil, ObtainDeviceFunctionNameExp) { + llvm::Triple triple("nvptx64-unknown-unknown"); + EXPECT_EQ(ObtainDeviceFunctionName(TargetDeviceFunctionID::kExp, + /*output_type=*/F32, triple), + "__nv_expf"); + EXPECT_EQ(ObtainDeviceFunctionName(TargetDeviceFunctionID::kExp, + /*output_type=*/BF16, triple), + "__nv_fast_expf"); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/tests/single_instruction.hlo b/xla/service/gpu/tests/single_instruction.hlo index 250499a8b590d..8903ff4204490 100644 --- a/xla/service/gpu/tests/single_instruction.hlo +++ b/xla/service/gpu/tests/single_instruction.hlo @@ -50,9 +50,10 @@ ENTRY main { // ----- -// CHECK-DAG: ex2.approx.ftz.f32 +// CHECK: ex2.approx.ftz.f32 %[[APPROX:.*]], %{{.*}} +// CHECK: mul.rn.f32 %{{.*}}, %[[APPROX]], %{{.*}} -HloModule Test, is_scheduled=true +HloModule DoesntUseEx2ApproximationDirectly, is_scheduled=true fused_computation { param_0 = f32[] parameter(0) @@ -66,6 +67,23 @@ ENTRY main { // ----- +// CHECK: ex2.approx.f32 +// CHECK-NOT: mul + +HloModule UsesEx2ApproximationDirectly, is_scheduled=true + +fused_computation { + param_0 = f16[] parameter(0) + ROOT b.1 = f16[] exponential(f16[] param_0) +} + +ENTRY main { + a = f16[] parameter(0) + ROOT wrapped_b = f16[] fusion(f16[] a), kind=kLoop, calls=fused_computation +} + +// ----- + // CHECK-SM80: min.NaN.f32 HloModule Test, is_scheduled=true