Skip to content
This repository was archived by the owner on May 4, 2019. It is now read-only.

Fix broadcast() and map() with constructors #177

Merged
merged 1 commit into from
Feb 12, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Base: _default_eltype
using Compat

if VERSION >= v"0.6.0-dev.693"
Expand All @@ -10,6 +9,8 @@ else
end

if VERSION < v"0.6.0-dev" # Old approach needed for inference to work
using Base: _default_eltype

ftype(f, A) = typeof(f)
ftype(f, A...) = typeof(a -> f(a...))
ftype(T::DataType, A) = Type{T}
Expand All @@ -20,12 +21,12 @@ if VERSION < v"0.6.0-dev" # Old approach needed for inference to work
else
using Base: Zip2
end
ziptype(A) = Tuple{eltype(A)}
ziptype(A, B) = Zip2{Tuple{eltype(A)}, Tuple{eltype(B)}}
ziptype(A) = Tuple{eltype(eltype(A))}
ziptype(A, B) = Zip2{Tuple{eltype(eltype(A))}, Tuple{eltype(eltype(B))}}
@inline ziptype(A, B, C, D...) = Zip{Tuple{eltype(A)}, ziptype(B, C, D...)}

nullable_broadcast_eltype(f, As...) =
eltype(_default_eltype(Base.Generator{ziptype(As...), ftype(f, As...)}))
_default_eltype(Base.Generator{ziptype(As...), ftype(f, As...)})
else
Base.@pure nullable_eltypestuple(a) = Tuple{eltype(eltype(a))}
Base.@pure nullable_eltypestuple(T::Type) = Tuple{Type{eltype(T)}}
Expand Down
22 changes: 19 additions & 3 deletions src/lift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ eltypes(x) = Tuple{eltype_nullable(x)}
eltypes(x, xs...) = Tuple{eltype_nullable(x), eltypes(xs...).parameters...}

"""
lift(f, xs...)
lift(f, xs...)

Lift function `f`, passing it arguments `xs...`, using standard lifting semantics:
for a function call `f(xs...)`, return null if any `x` in `xs` is null; otherwise,
Expand All @@ -17,13 +17,13 @@ return `f` applied to values of `xs`.
N = nfields(xs)
args = (:(unsafe_get(xs[$i])) for i in 1:N)
checknull = (:(!isnull(xs[$i])) for i in 1:N)
if null_safe_op(f.instance, map(eltype_nullable, xs.parameters)...)
if isdefined(f, :instance) && null_safe_op(f.instance, map(eltype_nullable, xs.parameters)...)
return quote
val = f($(args...))
nonull = (&)($(checknull...))
@compat Nullable(val, nonull)
end
else
elseif VERSION >= v"0.6.0-dev"
return quote
U = Core.Inference.return_type(f, eltypes(xs...))
if (&)($(checknull...))
Expand All @@ -32,6 +32,22 @@ return `f` applied to values of `xs`.
return isleaftype(U) ? Nullable{U}() : Nullable()
end
end
else # Inference fails with the previous branch on Julia 0.5
if isdefined(f, :instance)
U = Core.Inference.return_type(f.instance, map(eltype_nullable, xs.parameters))
isleaftype(U) || (U = Union{})
elseif F === DataType # Function is a constructor
U = f.parameters[1]
else
U = Union{}
end
return quote
if (&)($(checknull...))
return Nullable(f($(args...)))
else
return Nullable{$U}()
end
end
end
end

Expand Down
15 changes: 15 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ module TestBroadcast
f(x::Real, y::Real) = x * y
f(x::Real, y::Real, z::Real) = x * y * z

@inferred NullableArrays.lift(f, (1,))
@inferred NullableArrays.lift(f, (1, 2.0))
@inferred NullableArrays.lift(f, (1, 2.0, 3.0))

for (dests, arrays, nullablearrays, mask) in
( ((C2, Z2), (A1, A2), (U1, U2), ()),
((C3, Z3), (A2, A3), (U2, U3), ()),
Expand Down Expand Up @@ -106,4 +110,15 @@ module TestBroadcast
@test isequal(broadcast(&, X, Y), NullableArray(A .& B, M1 .| M2))
@test isequal(broadcast(|, X, Y), NullableArray(A .| B, M1 .| M2))

# Test broadcasting with constructor
immutable SurvEvent
time::Float64
censored::Bool
end
t = NullableArray(rand(3))
c = NullableArray(rand(Bool, 3))
@test isequal(SurvEvent.(t, c), NullableArray([SurvEvent(get(t[i]), get(c[i])) for i in 1:3]))
@test isa(SurvEvent.(t, c), NullableVector{SurvEvent})
@inferred NullableArrays.lift(SurvEvent, (1, true))

end # module