From 6e9eefeec077f49c2b22bfeee8da537ed8517b22 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 13 Nov 2024 03:22:00 -0800 Subject: [PATCH] Reverts 14ca4eaba3d7e0c0ef53c64ff0c7d7d86d452140 PiperOrigin-RevId: 696061109 --- xla/service/gpu/gpu_float_support.cc | 2 ++ xla/service/gpu/gpu_float_support_test.cc | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/xla/service/gpu/gpu_float_support.cc b/xla/service/gpu/gpu_float_support.cc index df9e3efaeb31e..bf9cb208738c5 100644 --- a/xla/service/gpu/gpu_float_support.cc +++ b/xla/service/gpu/gpu_float_support.cc @@ -97,6 +97,8 @@ bool GpuFloatSupport::IsSupported(const HloInstruction& hlo) const { case HloOpcode::kReducePrecision: return true; // Elementwise ops. + case HloOpcode::kExp: + return LowPrecisionType() == BF16; case HloOpcode::kAdd: case HloOpcode::kMultiply: case HloOpcode::kSubtract: { diff --git a/xla/service/gpu/gpu_float_support_test.cc b/xla/service/gpu/gpu_float_support_test.cc index 1c1c7f4e8149d..dca2bb345373a 100644 --- a/xla/service/gpu/gpu_float_support_test.cc +++ b/xla/service/gpu/gpu_float_support_test.cc @@ -271,6 +271,21 @@ ENTRY main { EXPECT_FALSE(Normalize(module.get(), cc, BF16, F32)); } +TEST_F(FloatSupportTest, Bf16ExpIsNotNormalized) { + auto cc = se::CudaComputeCapability::Ampere(); + constexpr absl::string_view kHloModule = R"( +HloModule m + +ENTRY main { + p0 = bf16[] parameter(0) + ROOT r = bf16[] exponential(p0) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kHloModule)); + EXPECT_FALSE(Normalize(module.get(), cc, BF16, F32)); +} + TEST_F(FloatSupportTest, BF16ReductionOnHopperIsOnlyNormalizedIfReducerIsUnsupported) { auto cc = se::CudaComputeCapability::Hopper();