diff --git a/src/Missings.jl b/src/Missings.jl index 84d6dee..63a8efc 100644 --- a/src/Missings.jl +++ b/src/Missings.jl @@ -1,7 +1,7 @@ module Missings export allowmissing, disallowmissing, ismissing, missing, missings, - Missing, MissingException, levels, coalesce + Missing, MissingException, levels, coalesce, passmissing using Base: ismissing, missing, Missing, MissingException @@ -165,4 +165,57 @@ function levels(x) levs end +struct PassMissing{F} <: Function + f::F +end + +function (f::PassMissing{F})(x) where {F} + if @generated + return x === Missing ? missing : :(f.f(x)) + else + return x === missing ? missing : f.f(x) + end +end + +function (f::PassMissing{F})(xs...) where {F} + if @generated + for T in xs + T === Missing && return missing + end + return :(f.f(xs...)) + else + return any(ismissing, xs) ? missing : f.f(xs...) + end +end + +""" + passmissing(f) + +Return a function that returns `missing` if any of its positional arguments +are `missing` (even if their number or type is not consistent with any of the +methods defined for `f`) and otherwise applies `f` to these arguments. + +`passmissing` does not support passing keyword arguments to the `f` function. + +# Examples +```jldoctest +julia> passmissing(sqrt)(4) +2.0 + +julia> passmissing(sqrt)(missing) +missing + +julia> passmissing(sqrt).([missing, 4]) +2-element Array{Union{Missing, Float64},1}: + missing + 2.0 + +julia> passmissing((x,y)->"\$x \$y")(1, 2) +"1 2" + +julia> passmissing((x,y)->"\$x \$y")(missing) +missing +""" +passmissing(f::Base.Callable) = PassMissing{typeof(f)}(f) + end # module diff --git a/test/runtests.jl b/test/runtests.jl index c95aa45..576c67c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -355,4 +355,12 @@ using Test, Dates, InteractiveUtils, SparseArrays, Missings # MissingException @test sprint(showerror, MissingException("test")) == "MissingException: test" + + # Lifting + @test passmissing(sqrt)(4) == 2.0 + @test isequal(passmissing(sqrt)(missing), missing) + @test isequal(passmissing(sqrt).([missing, 4]), [missing, 2.0]) + @test passmissing((x,y)->"$x $y")(1, 2) == "1 2" + @test isequal(passmissing((x,y)->"$x $y")(missing), missing) + @test_throws ErrorException passmissing(string)(missing, base=2) end