Skip to content

Commit

Permalink
basic soft float16 support
Browse files Browse the repository at this point in the history
nolta committed Aug 7, 2013

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 13456a4 commit 135a6b9
Showing 14 changed files with 114 additions and 14 deletions.
3 changes: 2 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
@@ -121,7 +121,7 @@ export
Box, Function, IntrinsicFunction, LambdaStaticData, Method, MethodTable,
Module, Nothing, Symbol, Task, Array,
# numeric types
Bool, FloatingPoint, Float32, Float64, Number, Integer, Int, Int8, Int16,
Bool, FloatingPoint, Float16, Float32, Float64, Number, Integer, Int, Int8, Int16,
Int32, Int64, Int128, Ptr, Real, Signed, Uint, Uint8, Uint16, Uint32,
Uint64, Uint128, Unsigned,
# string types
@@ -171,6 +171,7 @@ abstract Integer <: Real
abstract Signed <: Integer
abstract Unsigned <: Integer

bitstype 16 Float16 <: FloatingPoint
bitstype 32 Float32 <: FloatingPoint
bitstype 64 Float64 <: FloatingPoint

4 changes: 3 additions & 1 deletion base/exports.jl
Original file line number Diff line number Diff line change
@@ -156,12 +156,14 @@ export
ENDIAN_BOM,
ENV,
Inf,
Inf16,
Inf32,
LOAD_PATH,
MS_ASYNC,
MS_INVALIDATE,
MS_SYNC,
NaN,
NaN16,
NaN32,
OS_NAME,
RTLD_DEEPBIND,
@@ -248,7 +250,6 @@ export
At_rdiv_Bt,

# scalar math
#float16,
abs,
abs2,
acos,
@@ -323,6 +324,7 @@ export
fld1,
flipsign,
float,
float16,
float32,
float64,
floor,
38 changes: 29 additions & 9 deletions base/float.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#bitstype 16 Float16 <: FloatingPoint
## conversions to floating-point ##

convert(::Type{Float32}, x::Int128) = float32(uint128(abs(x)))*(1-2(x<0))
@@ -11,7 +10,12 @@ convert(::Type{Float64}, x::Uint128) = float64(uint64(x)) + ldexp(float64(uint64
promote_rule(::Type{Float64}, ::Type{Int128} ) = Float64
promote_rule(::Type{Float64}, ::Type{Uint128}) = Float64

for t1 in (Float32,Float64) #,Float16)
convert(::Type{Float16}, x::Union(Signed,Unsigned)) = convert(Float16, convert(Float32,x))
for t in (Bool,Char,Int8,Int16,Int32,Int64,Uint8,Uint16,Uint32,Uint64)
@eval promote_rule(::Type{Float16}, ::Type{$t}) = Float32
end

for t1 in (Float32,Float64)
for st in (Int8,Int16,Int32,Int64)
@eval begin
convert(::Type{$t1},x::($st)) = box($t1,sitofp($t1,unbox($st,x)))
@@ -25,12 +29,12 @@ for t1 in (Float32,Float64) #,Float16)
end
end
end
#convert(::Type{Float16}, x::Union(Float32,Float64)) = box(Float16,fptrunc(x,Float16))
#convert(::Type{Float32}, x::Float16) = box(Float32,fpext(Float32,x))
#convert(::Type{Float16}, x::Float32) = box(Float16,fptrunc(Float16,x))
convert(::Type{Float16}, x::Float64) = convert(Float16, convert(Float32,x))
convert(::Type{Float32}, x::Float64) = box(Float32,fptrunc(Float32,x))

# REPLACE when enabling Float16
#convert(::Type{Float64}, x::Union(Float32,Float16)) = box(Float64,fpext(Float64,x))
#convert(::Type{Float32}, x::Float16) = box(Float32,fpext(Float32,x))
convert(::Type{Float64}, x::Float16) = convert(Float64, convert(Float32,x))
convert(::Type{Float64}, x::Float32) = box(Float64,fpext(Float64,x))

convert(::Type{FloatingPoint}, x::Bool) = convert(Float32, x)
@@ -46,11 +50,14 @@ convert(::Type{FloatingPoint}, x::Uint32) = convert(Float64, x)
convert(::Type{FloatingPoint}, x::Uint64) = convert(Float64, x) # LOSSY
convert(::Type{FloatingPoint}, x::Uint128) = convert(Float64, x) # LOSSY

#float16(x) = convert(Float16, x)
float16(x) = convert(Float16, x)
float32(x) = convert(Float32, x)
float64(x) = convert(Float64, x)
float(x) = convert(FloatingPoint, x)

# possibly a hack, but useful for `f(x::Real) = f(float(x))` fallbacks
float(x::Float16) = float32(x)

## conversions from floating-point ##

# fallbacks using only convert, trunc, ceil, floor, round
@@ -105,16 +112,22 @@ floor(x::Float64) = ccall((:floor, Base.libm_name), Float64, (Float64,), x)

## floating point promotions ##

#promote_rule(::Type{Float32}, ::Type{Float16}) = Float32
promote_rule(::Type{Float32}, ::Type{Float16}) = Float32
promote_rule(::Type{Float64}, ::Type{Float16}) = Float64
promote_rule(::Type{Float64}, ::Type{Float32}) = Float64

#morebits(::Type{Float16}) = Float32
morebits(::Type{Float16}) = Float32
morebits(::Type{Float32}) = Float64

## floating point arithmetic ##

-(x::Float16) = reinterpret(Float16, reinterpret(Uint16,x) $ 0x8000)
-(x::Float32) = box(Float32,neg_float(unbox(Float32,x)))
-(x::Float64) = box(Float64,neg_float(unbox(Float64,x)))

for op in (:+,:-,:*,:/)
@eval ($op)(a::Float16, b::Float16) = ($op)(float32(a), float32(b))
end
+(x::Float32, y::Float32) = box(Float32,add_float(unbox(Float32,x),unbox(Float32,y)))
+(x::Float64, y::Float64) = box(Float64,add_float(unbox(Float64,x),unbox(Float64,y)))
-(x::Float32, y::Float32) = box(Float32,sub_float(unbox(Float32,x),unbox(Float32,y)))
@@ -198,12 +211,16 @@ isless (a::FloatingPoint, b::Integer) = (a<b) | isless(a,float(b))

## floating point traits ##

const Inf16 = box(Float16,unbox(Uint16,0x7c00))
const NaN16 = box(Float16,unbox(Uint16,0x7e00))
const Inf32 = box(Float32,unbox(Uint32,0x7f800000))
const NaN32 = box(Float32,unbox(Uint32,0x7fc00000))
const Inf = box(Float64,unbox(Uint64,0x7ff0000000000000))
const NaN = box(Float64,unbox(Uint64,0x7ff8000000000000))

@eval begin
inf(::Type{Float16}) = $Inf16
nan(::Type{Float16}) = $NaN16
inf(::Type{Float32}) = $Inf32
nan(::Type{Float32}) = $NaN32
inf(::Type{Float64}) = $Inf
@@ -214,6 +231,8 @@ const NaN = box(Float64,unbox(Uint64,0x7ff8000000000000))
issubnormal(x::Float32) = (abs(x) < $(box(Float32,unbox(Uint32,0x00800000)))) & (x!=0)
issubnormal(x::Float64) = (abs(x) < $(box(Float64,unbox(Uint64,0x0010000000000000)))) & (x!=0)

typemin(::Type{Float16}) = $(box(Float16,unbox(Uint16,0xfc00)))
typemax(::Type{Float16}) = $(Inf16)
typemin(::Type{Float32}) = $(-Inf32)
typemax(::Type{Float32}) = $(Inf32)
typemin(::Type{Float64}) = $(-Inf)
@@ -241,6 +260,7 @@ const NaN = box(Float64,unbox(Uint64,0x7ff8000000000000))
eps() = eps(Float64)
end

sizeof(::Type{Float16}) = 2
sizeof(::Type{Float32}) = 4
sizeof(::Type{Float64}) = 8

44 changes: 44 additions & 0 deletions base/float16.jl
Original file line number Diff line number Diff line change
@@ -41,3 +41,47 @@ function convert(::Type{Float32}, val::Float16)
return reinterpret(Float32, ret)
end

# Float32 -> Float16 algorithm from:
# "Fast Half Float Conversion" by Jeroen van der Zijp
# ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf

basetable = Array(Uint16, 512)
shifttable = Array(Uint8, 512)

for i = 0:255
e = i - 127
if e < -24 # Very small numbers map to zero
basetable[i|0x000+1]=0x0000;
basetable[i|0x100+1]=0x8000;
shifttable[i|0x000+1]=24;
shifttable[i|0x100+1]=24;
elseif e < -14 # Small numbers map to denorms
basetable[i|0x000+1]=(0x0400>>(-e-14));
basetable[i|0x100+1]=(0x0400>>(-e-14)) | 0x8000;
shifttable[i|0x000+1]=-e-1;
shifttable[i|0x100+1]=-e-1;
elseif e <= 15 # Normal numbers just lose precision
basetable[i|0x000+1]=((e+15)<<10);
basetable[i|0x100+1]=((e+15)<<10) | 0x8000;
shifttable[i|0x000+1]=13;
shifttable[i|0x100+1]=13;
elseif e < 128 # Large numbers map to Infinity
basetable[i|0x000+1]=0x7C00;
basetable[i|0x100+1]=0xFC00;
shifttable[i|0x000+1]=24;
shifttable[i|0x100+1]=24;
else # Infinity and NaN's stay Infinity and NaN's
basetable[i|0x000+1]=0x7C00;
basetable[i|0x100+1]=0xFC00;
shifttable[i|0x000+1]=13;
shifttable[i|0x100+1]=13;
end
end

function convert(::Type{Float16}, val::Float32)
f = reinterpret(Uint32, val)
i = (f >> 23) & 0x1ff + 1
h = basetable[i] + ((f & 0x007fffff) >> shifttable[i])
reinterpret(Float16, uint16(h))
end

3 changes: 3 additions & 0 deletions base/floatfuncs.jl
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ copysign(x::Float64, y::Real) = copysign(x, float64(y))

signbit(x::Float64) = signbit(reinterpret(Int64,x))
signbit(x::Float32) = signbit(reinterpret(Int32,x))
signbit(x::Float16) = signbit(reinterpret(Int16,x))

maxintfloat(::Type{Float64}) = 9007199254740992.
maxintfloat(::Type{Float32}) = float32(16777216.)
@@ -35,9 +36,11 @@ isfloat64(::Float64) = true
isfloat64(::Float32) = true

## precision, as defined by the effective number of bits in the mantissa ##
get_precision(::Float16) = 11
get_precision(::Float32) = 24
get_precision(::Float64) = 53

num2hex(x::Float16) = hex(reinterpret(Uint16,x), 4)
num2hex(x::Float32) = hex(box(Uint32,unbox(Float32,x)),8)
num2hex(x::Float64) = hex(box(Uint64,unbox(Float64,x)),16)

2 changes: 1 addition & 1 deletion base/io.jl
Original file line number Diff line number Diff line change
@@ -51,7 +51,7 @@ else
end

write(s::IO, x::Bool) = write(s, uint8(x))
#write(s::IO, x::Float16) = write(s, reinterpret(Int16,x))
write(s::IO, x::Float16) = write(s, reinterpret(Int16,x))
write(s::IO, x::Float32) = write(s, reinterpret(Int32,x))
write(s::IO, x::Float64) = write(s, reinterpret(Int64,x))

2 changes: 1 addition & 1 deletion base/number.jl
Original file line number Diff line number Diff line change
@@ -42,4 +42,4 @@ map(f::Callable, x::Number) = f(x)
const _numeric_conversion_func_names =
(:int,:integer,:signed,:int8,:int16,:int32,:int64,:int128,
:uint,:unsigned,:uint8,:uint16,:uint32,:uint64,:uint128,
:float,:float32,:float64)
:float,:float16,:float32,:float64)
1 change: 1 addition & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
@@ -94,6 +94,7 @@ include("floatfuncs.jl")
include("math.jl")
importall .Math
include("primes.jl")
include("float16.jl")

# concurrency and parallelism
include("serialize.jl")
4 changes: 4 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
@@ -132,6 +132,8 @@ static Type *T_uint64;
static Type *T_char;
static Type *T_size;
static Type *T_psize;
static Type *T_float16;
static Type *T_pfloat16;
static Type *T_float32;
static Type *T_pfloat32;
static Type *T_float64;
@@ -3077,6 +3079,8 @@ static void init_julia_llvm_env(Module *m)
else
T_size = T_uint32;
T_psize = PointerType::get(T_size, 0);
T_float16 = Type::getHalfTy(getGlobalContext());
T_pfloat16 = PointerType::get(T_float16, 0);
T_float32 = Type::getFloatTy(getGlobalContext());
T_pfloat32 = PointerType::get(T_float32, 0);
T_float64 = Type::getDoubleTy(getGlobalContext());
1 change: 1 addition & 0 deletions src/init.c
Original file line number Diff line number Diff line change
@@ -870,6 +870,7 @@ void jl_get_builtin_hooks(void)
jl_uint32_type = (jl_datatype_t*)core("Uint32");
jl_uint64_type = (jl_datatype_t*)core("Uint64");

jl_float16_type = (jl_datatype_t*)core("Float16");
jl_float32_type = (jl_datatype_t*)core("Float32");
jl_float64_type = (jl_datatype_t*)core("Float64");
jl_floatingpoint_type = (jl_datatype_t*)core("FloatingPoint");
1 change: 1 addition & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ jl_datatype_t *jl_int32_type;
jl_datatype_t *jl_uint32_type;
jl_datatype_t *jl_int64_type;
jl_datatype_t *jl_uint64_type;
jl_datatype_t *jl_float16_type;
jl_datatype_t *jl_float32_type;
jl_datatype_t *jl_float64_type;
jl_datatype_t *jl_floatingpoint_type;
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
@@ -380,6 +380,7 @@ extern jl_datatype_t *jl_int32_type;
extern jl_datatype_t *jl_uint32_type;
extern jl_datatype_t *jl_int64_type;
extern jl_datatype_t *jl_uint64_type;
extern jl_datatype_t *jl_float16_type;
extern jl_datatype_t *jl_float32_type;
extern jl_datatype_t *jl_float64_type;
extern jl_datatype_t *jl_floatingpoint_type;
@@ -523,6 +524,7 @@ void *allocobj(size_t sz);
#define jl_is_uint64(v) jl_typeis(v,jl_uint64_type)
#define jl_is_float(v) jl_subtype(v,(jl_value_t*)jl_floatingpoint_type,true)
#define jl_is_floattype(v) jl_subtype(v,(jl_value_t*)jl_floatingpoint_type,false)
#define jl_is_float16(v) jl_typeis(v,jl_float16_type)
#define jl_is_float32(v) jl_typeis(v,jl_float32_type)
#define jl_is_float64(v) jl_typeis(v,jl_float64_type)
#define jl_is_bool(v) jl_typeis(v,jl_bool_type)
2 changes: 1 addition & 1 deletion test/Makefile
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@ TESTS = all core keywordargs numbers strings unicode collections hashing \
remote iostring arrayops linalg blas fft dsp sparse bitarray random \
math functional bigint sorting statistics spawn parallel arpack file \
git pkg pkg2 resolve resolve2 suitesparse complex version pollfd mpfr \
broadcast socket floatapprox priorityqueue readdlm regex
broadcast socket floatapprox priorityqueue readdlm regex float16

$(TESTS) ::
@$(PRINT_JULIA) $(call spawn,$(JULIA_EXECUTABLE)) ./runtests.jl $@
21 changes: 21 additions & 0 deletions test/float16.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

f = float16(2.)
g = float16(1.)

@test -f === float16(-2.)

@test f+g === 3f0
@test f-g === 1f0
@test f*g === 2f0
@test f/g === 2f0

@test f + 2 === 4f0
@test f - 2 === 0f0
@test f*2 === 4f0
@test f/2 === 1f0
@test f + 2. === 4.
@test f - 2. === 0.
@test f*2. === 4.
@test f/2. === 1.

@test_approx_eq sin(f) sin(2f0)

1 comment on commit 135a6b9

@StefanKarpinski
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. If this works, I say merge it and we can ditch my branch.

Please sign in to comment.