From 30939fd8e18d94de1121660f4f58a9d5a9ff5198 Mon Sep 17 00:00:00 2001 From: fcard Date: Fri, 19 Jun 2015 22:19:14 -0300 Subject: [PATCH] Make `any` and `all` short-circuiting fix #11750 --- base/deprecated.jl | 14 ++++++ base/functors.jl | 14 ++++++ base/reduce.jl | 107 +++++++++++++++++++++++++++------------------ test/reduce.jl | 12 +++++ 4 files changed, 105 insertions(+), 42 deletions(-) diff --git a/base/deprecated.jl b/base/deprecated.jl index 9fb537327b40fc..b0b3a01ffd7de6 100644 --- a/base/deprecated.jl +++ b/base/deprecated.jl @@ -635,3 +635,17 @@ end @deprecate mmap_bitarray{N}(::Type{Bool}, dims::NTuple{N,Integer}, s::IOStream, offset::FileOffset=position(s)) mmap(s, BitArray, dims, offset) @deprecate mmap_bitarray{N}(dims::NTuple{N,Integer}, s::IOStream, offset=position(s)) mmap(s, BitArray, dims, offset) + +# 11774 +# when removing these deprecations, move them to reduce.jl, remove the depwarns and uncomment the errors. +function nonboolean_any(itr) + depwarn("using non-boolean collections with any(itr) is deprecated, use reduce(|, itr) instead.", :nonboolean_any) + #throw(ArgumentError("any(itr) only accepts boolean collections. Use reduce(|, itr) instead.")) + reduce(|, itr) +end + +function nonboolean_all(itr) + depwarn("using non-boolean collections with all(itr) is deprecated, use reduce(&, itr) instead.", :nonboolean_all) + #throw(ArgumentError("all(itr) only accepts boolean collections. Use reduce(|, itr) instead.")) + reduce(&, itr) +end diff --git a/base/functors.jl b/base/functors.jl index 646facee9d7499..a44c26af9f279d 100644 --- a/base/functors.jl +++ b/base/functors.jl @@ -59,6 +59,20 @@ end call(f::UnspecializedFun{1}, x) = f.f(x) call(f::UnspecializedFun{2}, x, y) = f.f(x,y) +# Special purpose functors + +type Predicate <: Func{1} + f::Function +end +call(pred::Predicate, x) = pred.f(x)::Bool + + +immutable EqX{T} <: Func{1} + x::T +end +EqX{T}(x::T) = EqX{T}(x) + +call(f::EqX, y) = f.x == y #### Bitwise operators #### diff --git a/base/reduce.jl b/base/reduce.jl index a4850339135ec3..1cff2edd2711cb 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -155,12 +155,60 @@ end mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, A) mapreduce(f, op, a::Number) = f(a) -mapreduce(f, op::Function, A::AbstractArray) = _mapreduce(f, specialized_binary(op), A) +mapreduce(f, op::Function, A::AbstractArray) = mapreduce(f, specialized_binary(op), A) reduce(op, v0, itr) = mapreduce(IdFun(), op, v0, itr) reduce(op, itr) = mapreduce(IdFun(), op, itr) reduce(op, a::Number) = a +### short-circuiting specializations of mapreduce + +## conditions and results of short-circuiting + +const ShortCircuiting = Union{AndFun, OrFun} +const ReturnsBool = Union{EqX, Predicate} + +shortcircuits(::AndFun, x::Bool) = !x +shortcircuits(::OrFun, x::Bool) = x + +shorted(::AndFun) = false +shorted(::OrFun) = true + +sc_finish(::AndFun) = true +sc_finish(::OrFun) = false + +## mapreduce definitions + +function mapreduce_sc_impl(f, op, itr::AbstractArray) + @inbounds for x in itr + shortcircuits(op, f(x)) && return shorted(op) + end + return sc_finish(op) +end + +function mapreduce_sc_impl(f, op, itr) + for x in itr + shortcircuits(op, f(x)) && return shorted(op) + end + return sc_finish(op) +end + +mapreduce_no_sc(f, op, itr::Any) = mapfoldl(f, op, itr) +mapreduce_no_sc(f, op, itr::AbstractArray) = _mapreduce(f, op, itr) + +mapreduce_sc(f::Function, op, itr) = mapreduce_sc(specialized_unary(f), op, itr) +mapreduce_sc(f::ReturnsBool, op, itr) = mapreduce_sc_impl(f, op, itr) +mapreduce_sc(f::Func{1}, op, itr) = mapreduce_no_sc(f, op, itr) + +mapreduce_sc(f::IdFun, op, itr) = + eltype(itr) <: Bool? + mapreduce_sc_impl(f, op, itr) : + mapreduce_no_sc(f, op, itr) + +mapreduce(f, op::ShortCircuiting, n::Number) = n +mapreduce(f, op::ShortCircuiting, itr::AbstractArray) = mapreduce_sc(f,op,itr) +mapreduce(f, op::ShortCircuiting, itr::Any) = mapreduce_sc(f,op,itr) + ###### Specific reduction functions ###### @@ -298,53 +346,28 @@ end ## all & any -function mapfoldl(f, ::AndFun, itr) - for x in itr - !f(x) && return false - end - return true -end +# make sure that the identity function is defined before `any` or `all` are used +function identity end -function mapfoldl(f, ::OrFun, itr) - for x in itr - f(x) && return true - end - return false -end +any(itr) = any(IdFun(), itr) +all(itr) = all(IdFun(), itr) -function mapreduce_impl(f, op::AndFun, A::AbstractArray{Bool}, ifirst::Int, ilast::Int) - while ifirst <= ilast - @inbounds x = A[ifirst] - !f(x) && return false - ifirst += 1 - end - return true -end - -function mapreduce_impl(f, op::OrFun, A::AbstractArray{Bool}, ifirst::Int, ilast::Int) - while ifirst <= ilast - @inbounds x = A[ifirst] - f(x) && return true - ifirst += 1 - end - return false -end - -all(a) = mapreduce(IdFun(), AndFun(), a) -any(a) = mapreduce(IdFun(), OrFun(), a) - -all(pred::Union{Callable,Func{1}}, a) = mapreduce(pred, AndFun(), a) -any(pred::Union{Callable,Func{1}}, a) = mapreduce(pred, OrFun(), a) +any(f::Function, itr) = any(f === identity? IdFun() : Predicate(f), itr) +any(f::Func{1}, itr) = mapreduce_sc_impl(f, OrFun(), itr) +any(f::IdFun, itr) = + eltype(itr) <: Bool? + mapreduce_sc_impl(f, OrFun(), itr) : + nonboolean_any(itr) +all(f::Function, itr) = all(f === identity? IdFun() : Predicate(f), itr) +all(f::Func{1}, itr) = mapreduce_sc_impl(f, AndFun(), itr) +all(f::IdFun, itr) = + eltype(itr) <: Bool? + mapreduce_sc_impl(f, AndFun(), itr) : + nonboolean_all(itr) ## in & contains -immutable EqX{T} <: Func{1} - x::T -end -EqX{T}(x::T) = EqX{T}(x) - -call(f::EqX, y) = f.x == y in(x, itr) = any(EqX(x), itr) const ∈ = in diff --git a/test/reduce.jl b/test/reduce.jl index 078b597f1ecaf6..e7bda83a020eb8 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -201,6 +201,18 @@ prod2(itr) = invoke(prod, Tuple{Any}, itr) @test reduce(&, fill(trues(5), 24)) == trues(5) @test reduce(&, fill(falses(5), 24)) == falses(5) +@test_throws TypeError any(x->0, [false]) +@test_throws TypeError all(x->0, [false]) + +# short-circuiting any and all + +let c = [0, 0], A = 1:1000 + any(x->(c[1]=x; x==10), A) + all(x->(c[2]=x; x!=10), A) + + @test c == [10,10] +end + # in @test in(1, Int[]) == false