From 4ea8d5665dcf05c7c42757e889fe463cd56d9b11 Mon Sep 17 00:00:00 2001 From: Oscar Smith Date: Mon, 22 Aug 2022 04:13:39 -0400 Subject: [PATCH] Fix some pow edge cases (#46412) --- base/math.jl | 37 +++++++++++++++++++++++-------------- test/math.jl | 8 ++++++-- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/base/math.jl b/base/math.jl index 8a4735c43ebea..f1ee129305418 100644 --- a/base/math.jl +++ b/base/math.jl @@ -1098,14 +1098,18 @@ end # @constprop aggressive to help the compiler see the switch between the integer and float # variants for callers with constant `y` @constprop :aggressive function ^(x::Float64, y::Float64) - yint = unsafe_trunc(Int, y) # Note, this is actually safe since julia freezes the result - y == yint && return x^yint - #numbers greater than 2*inv(eps(T)) must be even, and the pow will overflow - y >= 2*inv(eps()) && return x^(typemax(Int64)-1) xu = reinterpret(UInt64, x) - x<0 && y > -4e18 && throw_exp_domainerror(x) # |y| is small enough that y isn't an integer - x === 1.0 && return 1.0 - x==0 && return abs(y)*Inf*(!(y>0)) + xu == reinterpret(UInt64, 1.0) && return 1.0 + # Exponents greater than this will always overflow or underflow. + # Note that NaN can pass through this, but that will end up fine. + if !(abs(y)<0x1.8p62) + isnan(y) && return y + y = sign(y)*0x1.8p62 + end + yint = unsafe_trunc(Int64, y) # This is actually safe since julia freezes the result + y == yint && return @noinline x^yint + 2*xu==0 && return abs(y)*Inf*(!(y>0)) # if x==0 + x<0 && throw_exp_domainerror(x) # |y| is small enough that y isn't an integer !isfinite(x) && return x*(y>0 || isnan(x)) # x is inf or NaN if xu < (UInt64(1)<<52) # x is subnormal xu = reinterpret(UInt64, x * 0x1p52) # normalize x @@ -1124,18 +1128,23 @@ end end @constprop :aggressive function ^(x::T, y::T) where T <: Union{Float16, Float32} - yint = unsafe_trunc(Int64, y) # Note, this is actually safe since julia freezes the result + x == 1 && return one(T) + # Exponents greater than this will always overflow or underflow. + # Note that NaN can pass through this, but that will end up fine. + max_exp = T == Float16 ? T(3<<14) : T(0x1.Ap30) + if !(abs(y)= 2*inv(eps(T)) && return x^(typemax(Int64)-1) - x < 0 && y > -4e18 && throw_exp_domainerror(x) # |y| is small enough that y isn't an integer + x < 0 && throw_exp_domainerror(x) + !isfinite(x) && return x*(y>0 || isnan(x)) + x==0 && return abs(y)*T(Inf)*(!(y>0)) return pow_body(x, y) end @inline function pow_body(x::T, y::T) where T <: Union{Float16, Float32} - x == 1 && return one(T) - !isfinite(x) && return x*(y>0 || isnan(x)) - x==0 && return abs(y)*T(Inf)*(!(y>0)) return T(exp2(log2(abs(widen(x))) * y)) end diff --git a/test/math.jl b/test/math.jl index bae1f571ef16a..6ba40b7daa968 100644 --- a/test/math.jl +++ b/test/math.jl @@ -1325,8 +1325,12 @@ end for T in (Float16, Float32, Float64) for x in (0.0, -0.0, 1.0, 10.0, 2.0, Inf, NaN, -Inf, -NaN) for y in (0.0, -0.0, 1.0, -3.0,-10.0 , Inf, NaN, -Inf, -NaN) - got, expected = T(x)^T(y), T(big(x))^T(y) - @test isnan_type(T, got) && isnan_type(T, expected) || (got === expected) + got, expected = T(x)^T(y), T(big(x)^T(y)) + if isnan(expected) + @test isnan_type(T, got) || T.((x,y)) + else + @test got == expected || T.((x,y)) + end end end for _ in 1:2^16