Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch Float16 LLVM representation from i16 to half #26381

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ const llvmtypes = IdDict{Any,String}(
Int32 => "i32", UInt32 => "i32",
Int64 => "i64", UInt64 => "i64",
Int128 => "i128", UInt128 => "i128",
Float16 => "i16", # half
Float16 => "half",
Float32 => "float",
Float64 => "double",
)
Expand Down
185 changes: 34 additions & 151 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,8 @@ A not-a-number value of type [`Float64`](@ref).
const NaN = NaN64

## conversions to floating-point ##
Float16(x::Integer) = convert(Float16, convert(Float32, x))
for t in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128)
@eval promote_rule(::Type{Float16}, ::Type{$t}) = Float16
end
promote_rule(::Type{Float16}, ::Type{Bool}) = Float16

for t1 in (Float32, Float64)
for t1 in (Float16, Float32, Float64)
for st in (Int8, Int16, Int32, Int64)
@eval begin
(::Type{$t1})(x::($st)) = sitofp($t1, x)
Expand All @@ -65,14 +60,15 @@ for t1 in (Float32, Float64)
end
end
end
(::Type{T})(x::Float16) where {T<:Integer} = T(Float32(x))

Bool(x::Real) = x==0 ? false : x==1 ? true : throw(InexactError(:Bool, Bool, x))

promote_rule(::Type{Float64}, ::Type{UInt128}) = Float64
promote_rule(::Type{Float64}, ::Type{Int128}) = Float64
promote_rule(::Type{Float32}, ::Type{UInt128}) = Float32
promote_rule(::Type{Float32}, ::Type{Int128}) = Float32
promote_rule(::Type{Float16}, ::Type{UInt128}) = Float16
promote_rule(::Type{Float16}, ::Type{Int128}) = Float16

function Float64(x::UInt128)
x == 0 && return 0.0
Expand Down Expand Up @@ -134,115 +130,16 @@ function Float32(x::Int128)
reinterpret(Float32, s | d + y)
end

function Float16(val::Float32)
f = reinterpret(UInt32, val)
if isnan(val)
t = 0x8000 ⊻ (0x8000 & ((f >> 0x10) % UInt16))
return reinterpret(Float16, t ⊻ ((f >> 0xd) % UInt16))
end
i = (f >> 23) & 0x1ff + 1
sh = shifttable[i]
f &= 0x007fffff
h::UInt16 = basetable[i] + (f >> sh)
# round
# NOTE: we maybe should ignore NaNs here, but the payload is
# getting truncated anyway so "rounding" it might not matter
nextbit = (f >> (sh-1)) & 1
if nextbit != 0
# Round halfway to even or check lower bits
if h&1 == 1 || (f & ((1<<(sh-1))-1)) != 0
h += 1
end
end
reinterpret(Float16, h)
end

function Float32(val::Float16)
local ival::UInt32 = reinterpret(UInt16, val)
local sign::UInt32 = (ival & 0x8000) >> 15
local exp::UInt32 = (ival & 0x7c00) >> 10
local sig::UInt32 = (ival & 0x3ff) >> 0
local ret::UInt32

if exp == 0
if sig == 0
sign = sign << 31
ret = sign | exp | sig
else
n_bit = 1
bit = 0x0200
while (bit & sig) == 0
n_bit = n_bit + 1
bit = bit >> 1
end
sign = sign << 31
exp = (-14 - n_bit + 127) << 23
sig = ((sig & (~bit)) << n_bit) << (23 - 10)
ret = sign | exp | sig
end
elseif exp == 0x1f
if sig == 0 # Inf
if sign == 0
ret = 0x7f800000
else
ret = 0xff800000
end
else # NaN
ret = 0x7fc00000 | (sign<<31) | (sig<<(23-10))
end
else
sign = sign << 31
exp = (exp - 15 + 127) << 23
sig = sig << (23 - 10)
ret = sign | exp | sig
end
return reinterpret(Float32, ret)
end

# Float32 -> Float16 algorithm from:
# "Fast Half Float Conversion" by Jeroen van der Zijp
# ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf

const basetable = Vector{UInt16}(uninitialized, 512)
const shifttable = Vector{UInt8}(uninitialized, 512)

for i = 0:255
e = i - 127
if e < -24 # Very small numbers map to zero
basetable[i|0x000+1] = 0x0000
basetable[i|0x100+1] = 0x8000
shifttable[i|0x000+1] = 24
shifttable[i|0x100+1] = 24
elseif e < -14 # Small numbers map to denorms
basetable[i|0x000+1] = (0x0400>>(-e-14))
basetable[i|0x100+1] = (0x0400>>(-e-14)) | 0x8000
shifttable[i|0x000+1] = -e-1
shifttable[i|0x100+1] = -e-1
elseif e <= 15 # Normal numbers just lose precision
basetable[i|0x000+1] = ((e+15)<<10)
basetable[i|0x100+1] = ((e+15)<<10) | 0x8000
shifttable[i|0x000+1] = 13
shifttable[i|0x100+1] = 13
elseif e < 128 # Large numbers map to Infinity
basetable[i|0x000+1] = 0x7C00
basetable[i|0x100+1] = 0xFC00
shifttable[i|0x000+1] = 24
shifttable[i|0x100+1] = 24
else # Infinity and NaN's stay Infinity and NaN's
basetable[i|0x000+1] = 0x7C00
basetable[i|0x100+1] = 0xFC00
shifttable[i|0x000+1] = 13
shifttable[i|0x100+1] = 13
end
end
Float16(x::UInt128) = convert(Float16, Float32(x))
Float16(x::Int128) = convert(Float16, Float32(x))

#convert(::Type{Float16}, x::Float32) = fptrunc(Float16, x)
Float16(x::Float32) = fptrunc(Float16, x)
Float16(x::Float64) = fptrunc(Float16, x)
Float32(x::Float64) = fptrunc(Float32, x)
Float16(x::Float64) = Float16(Float32(x))

#convert(::Type{Float32}, x::Float16) = fpext(Float32, x)
Float32(x::Float16) = fpext(Float32, x)
Float64(x::Float32) = fpext(Float64, x)
Float64(x::Float16) = Float64(Float32(x))
Float64(x::Float16) = fpext(Float64, x)

AbstractFloat(x::Bool) = Float64(x)
AbstractFloat(x::Int8) = Float64(x)
Expand Down Expand Up @@ -293,14 +190,14 @@ function unsafe_trunc end

for Ti in (Int8, Int16, Int32, Int64)
@eval begin
unsafe_trunc(::Type{$Ti}, x::Float16) = unsafe_trunc($Ti, Float32(x))
unsafe_trunc(::Type{$Ti}, x::Float16) = fptosi($Ti, x)
unsafe_trunc(::Type{$Ti}, x::Float32) = fptosi($Ti, x)
unsafe_trunc(::Type{$Ti}, x::Float64) = fptosi($Ti, x)
end
end
for Ti in (UInt8, UInt16, UInt32, UInt64)
@eval begin
unsafe_trunc(::Type{$Ti}, x::Float16) = unsafe_trunc($Ti, Float32(x))
unsafe_trunc(::Type{$Ti}, x::Float16) = fptosi($Ti, x)
unsafe_trunc(::Type{$Ti}, x::Float32) = fptoui($Ti, x)
unsafe_trunc(::Type{$Ti}, x::Float64) = fptoui($Ti, x)
end
Expand Down Expand Up @@ -339,37 +236,36 @@ unsafe_trunc(::Type{Int128}, x::Float16) = unsafe_trunc(Int128, Float32(x))

# matches convert methods
# also determines floor, ceil, round
trunc(::Type{Signed}, x::Float16) = trunc(Int,x)
trunc(::Type{Signed}, x::Float32) = trunc(Int,x)
trunc(::Type{Signed}, x::Float64) = trunc(Int,x)
trunc(::Type{Unsigned}, x::Float16) = trunc(UInt,x)
trunc(::Type{Unsigned}, x::Float32) = trunc(UInt,x)
trunc(::Type{Unsigned}, x::Float64) = trunc(UInt,x)
trunc(::Type{Integer}, x::Float16) = trunc(Int,x)
trunc(::Type{Integer}, x::Float32) = trunc(Int,x)
trunc(::Type{Integer}, x::Float64) = trunc(Int,x)
trunc(::Type{T}, x::Float16) where {T<:Integer} = trunc(T, Float32(x))

# fallbacks
floor(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,floor(x))
floor(::Type{T}, x::Float16) where {T<:Integer} = floor(T, Float32(x))
ceil(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,ceil(x))
ceil(::Type{T}, x::Float16) where {T<:Integer} = ceil(T, Float32(x))
round(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,round(x))
round(::Type{T}, x::Float16) where {T<:Integer} = round(T, Float32(x))

trunc(x::Float64) = trunc_llvm(x)
trunc(x::Float32) = trunc_llvm(x)
trunc(x::Float16) = Float16(trunc(Float32(x)))
trunc(x::Float16) = trunc_llvm(x)

floor(x::Float64) = floor_llvm(x)
floor(x::Float32) = floor_llvm(x)
floor(x::Float16) = Float16(floor(Float32(x)))
floor(x::Float16) = floor_llvm(x)

ceil(x::Float64) = ceil_llvm(x)
ceil(x::Float32) = ceil_llvm(x)
ceil(x::Float16) = Float16( ceil(Float32(x)))
ceil(x::Float16) = ceil_llvm(x)

round(x::Float64) = rint_llvm(x)
round(x::Float32) = rint_llvm(x)
round(x::Float16) = Float16(round(Float32(x)))
round(x::Float16) = rint_llvm(x)

## floating point promotions ##
promote_rule(::Type{Float32}, ::Type{Float16}) = Float32
Expand All @@ -384,36 +280,30 @@ _default_type(T::Union{Type{Real},Type{AbstractFloat}}) = Float64
## floating point arithmetic ##
-(x::Float64) = neg_float(x)
-(x::Float32) = neg_float(x)
-(x::Float16) = reinterpret(Float16, reinterpret(UInt16, x) ⊻ 0x8000)
-(x::Float16) = neg_float(x)

for op in (:+, :-, :*, :/, :\, :^)
@eval ($op)(a::Float16, b::Float16) = Float16(($op)(Float32(a), Float32(b)))
end
+(x::Float16, y::Float16) = add_float(x, y)
+(x::Float32, y::Float32) = add_float(x, y)
+(x::Float64, y::Float64) = add_float(x, y)
-(x::Float16, y::Float16) = sub_float(x, y)
-(x::Float32, y::Float32) = sub_float(x, y)
-(x::Float64, y::Float64) = sub_float(x, y)
*(x::Float16, y::Float16) = mul_float(x, y)
*(x::Float32, y::Float32) = mul_float(x, y)
*(x::Float64, y::Float64) = mul_float(x, y)
/(x::Float16, y::Float16) = div_float(x, y)
/(x::Float32, y::Float32) = div_float(x, y)
/(x::Float64, y::Float64) = div_float(x, y)

muladd(x::Float32, y::Float32, z::Float32) = muladd_float(x, y, z)
muladd(x::Float64, y::Float64, z::Float64) = muladd_float(x, y, z)
function muladd(a::Float16, b::Float16, c::Float16)
Float16(muladd(Float32(a), Float32(b), Float32(c)))
end
muladd(x::Float16, y::Float16, z::Float16) = muladd_float(x, y, z)

# TODO: faster floating point div?
# TODO: faster floating point fld?
# TODO: faster floating point mod?

for func in (:div,:fld,:cld,:rem,:mod)
@eval begin
$func(a::Float16,b::Float16) = Float16($func(Float32(a),Float32(b)))
end
end

rem(x::Float16, y::Float16) = rem_float(x, y)
rem(x::Float32, y::Float32) = rem_float(x, y)
rem(x::Float64, y::Float64) = rem_float(x, y)

Expand All @@ -431,33 +321,25 @@ function mod(x::T, y::T) where T<:AbstractFloat
end

## floating point comparisons ##
function ==(x::Float16, y::Float16)
ix = reinterpret(UInt16,x)
iy = reinterpret(UInt16,y)
if (ix|iy)&0x7fff > 0x7c00 #isnan(x) || isnan(y)
return false
end
if (ix|iy)&0x7fff == 0x0000
return true
end
return ix == iy
end
==(x::Float16, y::Float16) = eq_float(x, y)
==(x::Float32, y::Float32) = eq_float(x, y)
==(x::Float64, y::Float64) = eq_float(x, y)
!=(x::Float16, y::Float16) = ne_float(x, y)
!=(x::Float32, y::Float32) = ne_float(x, y)
!=(x::Float64, y::Float64) = ne_float(x, y)
<( x::Float16, y::Float16) = lt_float(x, y)
<( x::Float32, y::Float32) = lt_float(x, y)
<( x::Float64, y::Float64) = lt_float(x, y)
<=(x::Float16, y::Float16) = le_float(x, y)
<=(x::Float32, y::Float32) = le_float(x, y)
<=(x::Float64, y::Float64) = le_float(x, y)

isequal(x::Float16, y::Float16) = fpiseq(x, y)
isequal(x::Float32, y::Float32) = fpiseq(x, y)
isequal(x::Float64, y::Float64) = fpiseq(x, y)
isless( x::Float16, y::Float16) = fpislt(x, y)
isless( x::Float32, y::Float32) = fpislt(x, y)
isless( x::Float64, y::Float64) = fpislt(x, y)
for op in (:<, :<=, :isless)
@eval ($op)(a::Float16, b::Float16) = ($op)(Float32(a), Float32(b))
end

# Exact Float (Tf) vs Integer (Ti) comparisons
# Assumes:
Expand Down Expand Up @@ -512,7 +394,7 @@ end
<=(x::Union{Int32,UInt32}, y::Float32) = Float64(x)<=Float64(y)


abs(x::Float16) = reinterpret(Float16, reinterpret(UInt16, x) & 0x7fff)
abs(x::Float16) = abs_float(x)
abs(x::Float32) = abs_float(x)
abs(x::Float64) = abs_float(x)

Expand Down Expand Up @@ -648,7 +530,7 @@ such `y` exists (e.g. if `x` is `-Inf` or `NaN`), then return `x`.
prevfloat(x::AbstractFloat) = nextfloat(x,-1)

for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128)
for Tf in (Float32, Float64)
for Tf in (Float16, Float32, Float64)
if Ti <: Unsigned || sizeof(Ti) < sizeof(Tf)
# Here `Tf(typemin(Ti))-1` is exact, so we can compare the lower-bound
# directly. `Tf(typemax(Ti))+1` is either always exactly representable, or
Expand Down Expand Up @@ -825,6 +707,7 @@ eps(::AbstractFloat)


## byte order swaps for arbitrary-endianness serialization/deserialization ##
bswap(x::Float16) = bswap_int(x)
bswap(x::Float32) = bswap_int(x)
bswap(x::Float64) = bswap_int(x)

Expand Down
2 changes: 2 additions & 0 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,8 @@ end
end
z
end
@inline ^(x::Float16, y::Float16) = Float16(Float32(x)^Float32(y))

@inline ^(x::Float64, y::Integer) = ccall("llvm.pow.f64", llvmcall, Float64, (Float64, Float64), x, Float64(y))
@inline ^(x::Float32, y::Integer) = ccall("llvm.pow.f32", llvmcall, Float32, (Float32, Float32), x, Float32(y))
@inline ^(x::Float16, y::Integer) = Float16(Float32(x) ^ y)
Expand Down
Loading