Skip to content

Commit

Permalink
Use fast version of exp if type is F16 or BF16.
Browse files Browse the repository at this point in the history
There seems to be no dedicated libdevice call for Exp with F16 or BF16 type.
Currently we upcast to F32 and use __nv_expf. However it seems likely that
__nv_fast_expf is good enough for F16 and BF16 type, so use it as it is
considerably faster.

PiperOrigin-RevId: 692841933
  • Loading branch information
akuegel authored and Google-ML-Automation committed Nov 4, 2024
1 parent 336cc6e commit 9b19353
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 37 deletions.
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ xla_cc_test(
"//xla/tests:xla_internal_test_main",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TargetParser",
"@tsl//tsl/platform:test",
],
)
Expand Down
14 changes: 7 additions & 7 deletions xla/service/gpu/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,23 @@ absl::StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
absl::string_view name) {
// Device functions don't have f16 math functions, so we convert the operands
// to f32 before calling the function and then convert the result back to f16.
bool cast_result_to_fp16 = false;
std::vector<llvm::Value*> converted_operands(operands.begin(),
operands.end());
std::vector<PrimitiveType> converted_input_types(input_types.begin(),
input_types.end());
PrimitiveType original_output_type = output_type;
switch (output_type) {
case BF16:
case F16:
cast_result_to_fp16 = true;
for (int64_t i = 0; i < operands.size(); ++i) {
if (input_types[i] == F16) {
if (input_types[i] == original_output_type) {
converted_operands[i] =
FPCast(converted_operands[i], b()->getFloatTy());
converted_input_types[i] = F32;
}
}
output_type = F32;
[[fallthrough]];
break;
case F32:
break;
case F64:
Expand All @@ -92,13 +92,13 @@ absl::StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitDeviceMathCall(
PrimitiveType_Name(output_type));
}
const std::string& munged_callee = ObtainDeviceFunctionName(
funcid, output_type,
funcid, original_output_type,
llvm::Triple(b()->GetInsertBlock()->getModule()->getTargetTriple()));
llvm::Value* result = EmitMathCall(munged_callee, converted_operands,
converted_input_types, output_type, name)
.value();
if (cast_result_to_fp16) {
result = FPCast(result, b()->getHalfTy());
if (output_type != original_output_type) {
result = FPCast(result, operands[0]->getType());
}
return result;
}
Expand Down
51 changes: 32 additions & 19 deletions xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <functional>
#include <iterator>
#include <optional>
#include <queue>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -607,24 +608,27 @@ template <typename MhloOp, typename... ExtraArgs>
SmallVector<Value, 1> MapHloOp(mlir::Type result_type,
llvm::ArrayRef<mlir::Type> arg_types,
llvm::ArrayRef<Value> args,
llvm::ArrayRef<mlir::NamedAttribute> attributes,
ImplicitLocOpBuilder& b,
ExtraArgs&&... extra_args) {
Value result = mhlo::MhloOpToStdScalarOp::mapOpOfType<MhloOp>(
b.getLoc(), result_type, arg_types,
typename MhloOp::Adaptor(args, std::forward<ExtraArgs>(extra_args)...),
/*attributes=*/std::nullopt, &b);
attributes, &b);
if (result.getType().isInteger(1)) {
result = b.create<mlir::arith::ExtUIOp>(b.getI8Type(), result);
}
return {result};
}

template <typename MhloOp>
SmallVector<Value, 1> MapElementwiseOp(llvm::ArrayRef<mlir::Type> arg_types,
llvm::ArrayRef<Value> args,
ImplicitLocOpBuilder& b) {
SmallVector<Value, 1> MapElementwiseOp(
llvm::ArrayRef<mlir::Type> arg_types, llvm::ArrayRef<Value> args,
ImplicitLocOpBuilder& b,
llvm::ArrayRef<mlir::NamedAttribute> attributes = std::nullopt) {
// We use the last argument's type because of select.
return MapHloOp<MhloOp>(args.back().getType(), arg_types, args, b);
return MapHloOp<MhloOp>(args.back().getType(), arg_types, args, attributes,
b);
}

} // namespace
Expand Down Expand Up @@ -946,9 +950,9 @@ absl::StatusOr<SmallVector<Value, 1>> EmitReducePrecision(
mhlo::ReducePrecisionOp::Properties properties;
properties.exponent_bits = builder.getI32IntegerAttr(instr->exponent_bits());
properties.mantissa_bits = builder.getI32IntegerAttr(instr->mantissa_bits());
return MapHloOp<mhlo::ReducePrecisionOp>(operands.front().getType(),
arg_types, operands, builder,
nullptr, properties);
return MapHloOp<mhlo::ReducePrecisionOp>(
operands.front().getType(), arg_types, operands,
/*attributes=*/std::nullopt, builder, nullptr, properties);
}

absl::StatusOr<SmallVector<Value, 1>> HloToMlir(
Expand Down Expand Up @@ -1014,11 +1018,12 @@ absl::StatusOr<SmallVector<Value, 1>> HloToMlir(
TF_ASSIGN_OR_RETURN(auto operands,
GetOperands(instr, indices, operand_provider, builder));

llvm::SmallVector<mlir::NamedAttribute> attributes;
switch (instr->opcode()) {
case HloOpcode::kAbs:
return {
MapHloOp<mhlo::AbsOp>(PrimitiveTypeToMlirType(element_type, builder),
arg_types, operands, builder)};
return {MapHloOp<mhlo::AbsOp>(
PrimitiveTypeToMlirType(element_type, builder), arg_types, operands,
/*attributes=*/std::nullopt, builder)};
case HloOpcode::kAdd:
if (element_type == PRED) {
return MapElementwiseOp<mhlo::OrOp>(arg_types, operands, builder);
Expand All @@ -1041,26 +1046,34 @@ absl::StatusOr<SmallVector<Value, 1>> HloToMlir(
case HloOpcode::kComplex:
return MapHloOp<mhlo::ComplexOp>(
PrimitiveTypeToMlirType(element_type, builder), arg_types, operands,
builder);
/*attributes=*/std::nullopt, builder);
case HloOpcode::kCos:
return MapElementwiseOp<mhlo::CosineOp>(arg_types, operands, builder);
case HloOpcode::kDivide:
return MapElementwiseOp<mhlo::DivOp>(arg_types, operands, builder);
case HloOpcode::kErf:
return MapElementwiseOp<mhlo::ErfOp>(arg_types, operands, builder);
case HloOpcode::kExp:
return MapElementwiseOp<mhlo::ExpOp>(arg_types, operands, builder);
if (element_type == F16 || element_type == BF16) {
attributes.emplace_back(
mlir::StringAttr::get(builder.getContext(), "fastmath"),
mlir::arith::FastMathFlagsAttr::get(
builder.getContext(), mlir::arith::FastMathFlags::afn));
}
return MapElementwiseOp<mhlo::ExpOp>(arg_types, operands, builder,
attributes);
case HloOpcode::kExpm1:
return MapElementwiseOp<mhlo::Expm1Op>(arg_types, operands, builder);
case HloOpcode::kFloor:
return MapElementwiseOp<mhlo::FloorOp>(arg_types, operands, builder);
case HloOpcode::kIsFinite:
return MapHloOp<mhlo::IsFiniteOp>(builder.getI1Type(), arg_types,
operands, builder);
operands, /*attributes=*/std::nullopt,
builder);
case HloOpcode::kImag:
return MapHloOp<mhlo::ImagOp>(
PrimitiveTypeToMlirType(element_type, builder), arg_types, operands,
builder);
/*attributes=*/std::nullopt, builder);
case HloOpcode::kLog:
return MapElementwiseOp<mhlo::LogOp>(arg_types, operands, builder);
case HloOpcode::kLog1p:
Expand Down Expand Up @@ -1106,13 +1119,13 @@ absl::StatusOr<SmallVector<Value, 1>> HloToMlir(
case HloOpcode::kPopulationCount:
return MapHloOp<mhlo::PopulationCountOp>(
PrimitiveTypeToMlirType(element_type, builder), arg_types, operands,
builder);
/*attributes=*/std::nullopt, builder);
case HloOpcode::kPower:
return MapElementwiseOp<mhlo::PowOp>(arg_types, operands, builder);
case HloOpcode::kReal:
return MapHloOp<mhlo::RealOp>(
PrimitiveTypeToMlirType(element_type, builder), arg_types, operands,
builder);
/*attributes=*/std::nullopt, builder);
case HloOpcode::kReducePrecision:
return EmitReducePrecision(instr, arg_types, operands, builder);
case HloOpcode::kRemainder:
Expand Down Expand Up @@ -1154,7 +1167,7 @@ absl::StatusOr<SmallVector<Value, 1>> HloToMlir(
case HloOpcode::kBitcastConvert:
return MapHloOp<mhlo::BitcastConvertOp>(
PrimitiveTypeToMlirType(element_type, builder), arg_types, operands,
builder);
/*attributes=*/std::nullopt, builder);
case HloOpcode::kConvert:
return EmitConvert(instr, arg_types, operands, builder);
case HloOpcode::kBitcast:
Expand Down Expand Up @@ -1265,7 +1278,7 @@ class SubgraphConverter {
absl::StatusOr<SmallVector<Value>> EmitInstruction(
const HloInstruction* instr, ValueRange indices);
absl::StatusOr<SmallVector<Value>> EmitElementwiseInstruction(
const HloInstruction* instr, ValueRange indices);
const HloInstruction* root, ValueRange indices);

private:
const PartitionedComputation& computation_;
Expand Down
37 changes: 37 additions & 0 deletions xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,43 @@ TEST_F(ElementalHloToMlirTest, ConvertToUnsigned64Saturation) {
)"));
}

TEST_F(ElementalHloToMlirTest, ExpF16_UsesFastmathFlag) {
TF_EXPECT_OK(Run(R"(
ENTRY main {
p0 = f16[4] parameter(0)
ROOT exp = f16[4] exponential(p0)
})",
R"(
// CHECK: @main_exp(
// CHECK: math.exp %{{.*}} fastmath<afn> : f16
)"));
}

TEST_F(ElementalHloToMlirTest, ExpBF16_UsesFastmathFlag) {
TF_EXPECT_OK(Run(R"(
ENTRY main {
p0 = bf16[4] parameter(0)
ROOT exp = bf16[4] exponential(p0)
})",
R"(
// CHECK: @main_exp(
// CHECK: math.exp %{{.*}} fastmath<afn> : bf16
)"));
}

TEST_F(ElementalHloToMlirTest, ExpF32_DoesntUseFastmathFlag) {
TF_EXPECT_OK(Run(R"(
ENTRY main {
p0 = f32[4] parameter(0)
ROOT exp = f32[4] exponential(p0)
})",
R"(
// CHECK: @main_exp(
// CHECK: math.exp
// CHECK-NOT: fastmath
)"));
}

TEST_F(ElementalHloToMlirTest, PopulationCountUnsigned) {
TF_EXPECT_OK(Run(R"(
ENTRY main{
Expand Down
4 changes: 1 addition & 3 deletions xla/service/gpu/fusions/triton/emitter_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,19 +353,17 @@ absl::StatusOr<Value> EmitElementwiseLibdeviceFunction(
triple.setTriple("amdgcn-unknown-unknown");
}
llvm::SmallVector<Value, 2> casted_inputs;
PrimitiveType casted_output_type = output_type;
if (output_type == PrimitiveType::BF16 || output_type == PrimitiveType::F16) {
// Upcast the inputs to F32.
for (int64_t i = 0; i < inputs.size(); ++i) {
casted_inputs.push_back(Cast(b, inputs[i], b.getF32Type()));
}
casted_output_type = F32;
} else {
casted_inputs.assign(inputs.begin(), inputs.end());
}
Value res = b.create<mt::ExternElementwiseOp>(
casted_inputs[0].getType(), casted_inputs, "libdevice", libdevice_path,
ObtainDeviceFunctionName(dev_fn_id.value(), casted_output_type, triple),
ObtainDeviceFunctionName(dev_fn_id.value(), output_type, triple),
/*pure=*/true);
if (output_type == PrimitiveType::BF16 || output_type == PrimitiveType::F16) {
// Downcast back to the original output type.
Expand Down
6 changes: 4 additions & 2 deletions xla/service/gpu/gpu_float_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,11 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const {
case HloOpcode::kReducePrecision:
return true;
// Elementwise ops.
case HloOpcode::kExp:
return LowPrecisionType() == BF16;
case HloOpcode::kAdd:
case HloOpcode::kSubtract:
case HloOpcode::kMultiply: {
case HloOpcode::kMultiply:
case HloOpcode::kSubtract: {
if (LowPrecisionType() == BF16) {
auto* cuda_compute_capability =
std::get_if<se::CudaComputeCapability>(&compute_capability_);
Expand Down
15 changes: 15 additions & 0 deletions xla/service/gpu/gpu_float_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,21 @@ ENTRY main {
EXPECT_FALSE(Normalize(module.get(), cc, BF16, F32));
}

TEST_F(FloatSupportTest, Bf16ExpIsNotNormalized) {
auto cc = se::CudaComputeCapability::Ampere();
constexpr absl::string_view kHloModule = R"(
HloModule m
ENTRY main {
p0 = bf16[] parameter(0)
ROOT r = bf16[] exponential(p0)
})";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(kHloModule));
EXPECT_FALSE(Normalize(module.get(), cc, BF16, F32));
}

TEST_F(FloatSupportTest,
BF16ReductionOnHopperIsOnlyNormalizedIfReducerIsUnsupported) {
auto cc = se::CudaComputeCapability::Hopper();
Expand Down
30 changes: 26 additions & 4 deletions xla/service/gpu/target_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include "xla/service/gpu/target_util.h"

#include <cstring>
#include <functional>
#include <optional>
#include <string>
Expand Down Expand Up @@ -309,6 +310,14 @@ std::optional<TargetDeviceFunctionID> GetTargetDeviceFunctionID(HloOpcode op) {
return std::nullopt;
}

namespace {
// TODO(b/370452608): Add more functions that have a fast approximation for f32
// that we can use for f16 types.
bool HasFastF32Approximation(TargetDeviceFunctionID func_id) {
return func_id == TargetDeviceFunctionID::kExp;
}
} // namespace

std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
PrimitiveType output_type,
llvm::Triple target_triple) {
Expand All @@ -317,24 +326,37 @@ std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
// the root name are specific to the target.
struct TargetDeviceFunction gpu_root_names = GetDeviceFunctionRoot(func_id);
if (target_triple.isNVPTX()) {
if (output_type == F32) {
return StrCat(gpu_root_names.nvptx_root, "f");
bool is_supported_output_type =
output_type == BF16 || output_type == F16 || output_type == F32;
if (is_supported_output_type) {
std::string function_name = StrCat(gpu_root_names.nvptx_root, "f");
if (HasFastF32Approximation(func_id) &&
(output_type == BF16 || output_type == F16)) {
// All function names start with "__nv". The approximate version of the
// function names continues with "_fast".
return function_name.insert(strlen("__nv"), "_fast");
}
return function_name;
} else if (output_type == F64) {
return gpu_root_names.nvptx_root;
} else {
LOG(FATAL) << "Unexpected type while getting device function name: "
<< primitive_util::LowercasePrimitiveTypeName(output_type);
}
} else if (target_triple.getArch() == llvm::Triple::amdgcn) {
if (output_type == F32) {
// TODO(b/370452608): Are there approximate functions we can use for BF16
// and F16 types?
if (output_type == BF16 || output_type == F16 || output_type == F32) {
return StrCat(gpu_root_names.amdgpu_root, "_f32");
} else if (output_type == F64) {
return StrCat(gpu_root_names.amdgpu_root, "_f64");
} else {
LOG(FATAL) << "Unexpected type while getting device function name.";
}
} else if (target_triple.isSPIR()) {
if (output_type == F32) {
// TODO(b/370452608): Are there approximate functions we can use for BF16
// and F16 types?
if (output_type == BF16 || output_type == F16 || output_type == F32) {
if (gpu_root_names.spir_root == "_Z17__spirv_ocl_hypot" ||
gpu_root_names.spir_root == "_Z15__spirv_ocl_pow" ||
gpu_root_names.spir_root == "_Z17__spirv_ocl_atan2" ||
Expand Down
2 changes: 2 additions & 0 deletions xla/service/gpu/target_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ llvm::CallInst* EmitCallToTargetIntrinsic(
void AnnotateFunctionAsGpuKernel(llvm::Module* module, llvm::Function* func,
llvm::IRBuilder<>* b);

// 'output_type' is the type of the math op corresponding to 'func_id' for which
// we want to obtain the device function name.
std::string ObtainDeviceFunctionName(TargetDeviceFunctionID func_id,
PrimitiveType output_type,
llvm::Triple target_triple);
Expand Down
11 changes: 11 additions & 0 deletions xla/service/gpu/target_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TargetParser/Triple.h"
#include "tsl/platform/test.h"

namespace xla {
Expand Down Expand Up @@ -63,6 +64,16 @@ TEST_F(TargetUtilTest, AMDGCNGroupBarrier) {
EXPECT_FALSE(llvm::verifyModule(module_, &llvm::errs()));
}

TEST(TargetUtil, ObtainDeviceFunctionNameExp) {
llvm::Triple triple("nvptx64-unknown-unknown");
EXPECT_EQ(ObtainDeviceFunctionName(TargetDeviceFunctionID::kExp,
/*output_type=*/F32, triple),
"__nv_expf");
EXPECT_EQ(ObtainDeviceFunctionName(TargetDeviceFunctionID::kExp,
/*output_type=*/BF16, triple),
"__nv_fast_expf");
}

} // namespace
} // namespace gpu
} // namespace xla
Loading

0 comments on commit 9b19353

Please sign in to comment.