diff --git a/base/arraymath.jl b/base/arraymath.jl index efd6b3b768d99..1b542f67e01cb 100644 --- a/base/arraymath.jl +++ b/base/arraymath.jl @@ -38,10 +38,10 @@ end ## Binary arithmetic operators ## -promote_array_type{Scalar, Arry}(::Type{Scalar}, ::Type{Arry}) = promote_type(Scalar, Arry) -promote_array_type{S<:Real, A<:FloatingPoint}(::Type{S}, ::Type{A}) = A -promote_array_type{S<:Integer, A<:Integer}(::Type{S}, ::Type{A}) = A -promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S +promote_array_type{Scalar, Arry}(F, ::Type{Scalar}, ::Type{Arry}) = promote_op(F, Scalar, Arry) +promote_array_type{S<:Real, A<:FloatingPoint}(F, ::Type{S}, ::Type{A}) = A +promote_array_type{S<:Integer, A<:Integer}(F, ::Type{S}, ::Type{A}) = A +promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}) = S # Handle operations that return different types ./(x::Number, Y::AbstractArray) = @@ -57,10 +57,16 @@ promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S .^(X::AbstractArray, y::Number ) = reshape([ x ^ y for x in X ], size(X)) -for f in (:+, :-, :div, :mod, :&, :|, :$) +for (f,F) in ((:+, AddFun()), + (:-, SubFun()), + (:div, IDivFun()), + (:mod, ModFun()), + (:&, AndFun()), + (:|, OrFun()), + (:$, XorFun())) @eval begin function ($f){S,T}(A::Range{S}, B::Range{T}) - F = similar(A, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B))) i = 1 for (a,b) in zip(A,B) @inbounds F[i] = ($f)(a, b) @@ -69,7 +75,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) return F end function ($f){S,T}(A::AbstractArray{S}, B::Range{T}) - F = similar(A, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B))) i = 1 for b in B @inbounds F[i] = ($f)(A[i], b) @@ -78,7 +84,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) return F end function ($f){S,T}(A::Range{S}, B::AbstractArray{T}) - F = similar(B, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(B, promote_op($F,S,T), promote_shape(size(A),size(B))) i = 1 for a in A @inbounds F[i] = ($f)(a, B[i]) @@ -87,7 +93,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) return F end function ($f){S,T}(A::AbstractArray{S}, B::AbstractArray{T}) - F = similar(A, promote_type(S,T), promote_shape(size(A),size(B))) + F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B))) for i in eachindex(A,B) @inbounds F[i] = ($f)(A[i], B[i]) end @@ -95,17 +101,28 @@ for f in (:+, :-, :div, :mod, :&, :|, :$) end end end -for f in (:.+, :.-, :.*, :.%, :.<<, :.>>, :div, :mod, :rem, :&, :|, :$) +for (f,F) in ((:.+, DotAddFun()), + (:.-, DotSubFun()), + (:.*, DotMulFun()), + (:.%, DotRemFun()), + (:.<<, DotLSFun()), + (:.>>, DotRSFun()), + (:div, IDivFun()), + (:mod, ModFun()), + (:rem, RemFun()), + (:&, AndFun()), + (:|, OrFun()), + (:$, XorFun())) @eval begin function ($f){T}(A::Number, B::AbstractArray{T}) - F = similar(B, promote_array_type(typeof(A),T)) + F = similar(B, promote_array_type($F,typeof(A),T)) for i in eachindex(B) @inbounds F[i] = ($f)(A, B[i]) end return F end function ($f){T}(A::AbstractArray{T}, B::Number) - F = similar(A, promote_array_type(typeof(B),T)) + F = similar(A, promote_array_type($F,typeof(B),T)) for i in eachindex(A) @inbounds F[i] = ($f)(A[i], B) end diff --git a/base/bitarray.jl b/base/bitarray.jl index b22e8a97a8942..b48de7858dd15 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -855,21 +855,23 @@ for f in (:+, :-) return r end end -for f in (:.+, :.-), - (arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)), - (:(B::BitArray), :(x::Number) , :(promote_array_type(typeof(x), Bool)), :(b, x)), - (:(x::Bool) , :(B::BitArray), Int , :(x, b)), - (:(x::Number) , :(B::BitArray), :(promote_array_type(typeof(x), Bool)), :(x, b))) - @eval function ($f)($arg1, $arg2) - r = Array($T, size(B)) - bi = start(B) - ri = 1 - while !done(B, bi) - b, bi = next(B, bi) - @inbounds r[ri] = ($f)($fargs...) - ri += 1 +for (f,F) in ((:.+, DotAddFun()), + (:.-, DotSubFun())) + for (arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)), + (:(B::BitArray), :(x::Number) , :(promote_array_type($F, typeof(x), Bool)), :(b, x)), + (:(x::Bool) , :(B::BitArray), Int , :(x, b)), + (:(x::Number) , :(B::BitArray), :(promote_array_type($F, typeof(x), Bool)), :(x, b))) + @eval function ($f)($arg1, $arg2) + r = Array($T, size(B)) + bi = start(B) + ri = 1 + while !done(B, bi) + b, bi = next(B, bi) + @inbounds r[ri] = ($f)($fargs...) + ri += 1 + end + return r end - return r end end @@ -897,7 +899,7 @@ function div(x::Bool, B::BitArray) end function div(x::Number, B::BitArray) all(B) || throw(DivideError()) - pt = promote_array_type(typeof(x), Bool) + pt = promote_array_type(IDivFun(), typeof(x), Bool) y = div(x, true) reshape(pt[ y for i = 1:length(B) ], size(B)) end @@ -918,15 +920,16 @@ function mod(x::Bool, B::BitArray) end function mod(x::Number, B::BitArray) all(B) || throw(DivideError()) - pt = promote_array_type(typeof(x), Bool) + pt = promote_array_type(ModFun(), typeof(x), Bool) y = mod(x, true) reshape(pt[ y for i = 1:length(B) ], size(B)) end -for f in (:div, :mod) +for (f,F) in ((:div, IDivFun()), + (:mod, ModFun())) @eval begin function ($f)(B::BitArray, x::Number) - F = Array(promote_array_type(typeof(x), Bool), size(B)) + F = Array(promote_array_type($F, typeof(x), Bool), size(B)) for i = 1:length(F) F[i] = ($f)(B[i], x) end diff --git a/base/complex.jl b/base/complex.jl index 339a40d7849d0..eda1a60f82894 100644 --- a/base/complex.jl +++ b/base/complex.jl @@ -747,7 +747,7 @@ big{T<:FloatingPoint,N}(A::AbstractArray{Complex{T},N}) = convert(AbstractArray{ ## promotion to complex ## -promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(::Type{S}, ::Type{Complex{AT}}) = Complex{AT} +promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(F, ::Type{S}, ::Type{Complex{AT}}) = Complex{AT} function complex{S<:Real,T<:Real}(A::Array{S}, B::Array{T}) if size(A) != size(B); throw(DimensionMismatch()); end diff --git a/base/functors.jl b/base/functors.jl index 8f239ea661c40..18df3e5e7d19c 100644 --- a/base/functors.jl +++ b/base/functors.jl @@ -34,15 +34,27 @@ call(::AndFun, x, y) = x & y immutable OrFun <: Func{2} end call(::OrFun, x, y) = x | y +immutable XorFun <: Func{2} end +call(::XorFun, x, y) = x $ y + immutable AddFun <: Func{2} end call(::AddFun, x, y) = x + y +immutable DotAddFun <: Func{2} end +call(::DotAddFun, x, y) = x .+ y + immutable SubFun <: Func{2} end call(::SubFun, x, y) = x - y +immutable DotSubFun <: Func{2} end +call(::DotSubFun, x, y) = x .- y + immutable MulFun <: Func{2} end call(::MulFun, x, y) = x * y +immutable DotMulFun <: Func{2} end +call(::DotMulFun, x, y) = x .* y + immutable RDivFun <: Func{2} end call(::RDivFun, x, y) = x / y @@ -52,6 +64,15 @@ call(::LDivFun, x, y) = x \ y immutable IDivFun <: Func{2} end call(::IDivFun, x, y) = div(x, y) +immutable ModFun <: Func{2} end +call(::ModFun, x, y) = mod(x, y) + +immutable RemFun <: Func{2} end +call(::RemFun, x, y) = rem(x, y) + +immutable DotRemFun <: Func{2} end +call(::RemFun, x, y) = x .% y + immutable PowFun <: Func{2} end call(::PowFun, x, y) = x ^ y @@ -67,6 +88,12 @@ call(::LessFun, x, y) = x < y immutable MoreFun <: Func{2} end call(::MoreFun, x, y) = x > y +immutable DotLSFun <: Func{2} end +call(::DotLSFun, x, y) = x .<< y + +immutable DotRSFun <: Func{2} end +call(::DotRSFun, x, y) = x .>> y + # a fallback unspecialized function object that allows code using # function objects to not care whether they were able to specialize on # the function value or not diff --git a/base/promotion.jl b/base/promotion.jl index 56c82062f54ea..63e3b9a3d7606 100644 --- a/base/promotion.jl +++ b/base/promotion.jl @@ -199,6 +199,12 @@ checked_add(x::Integer, y::Integer) = checked_add(promote(x,y)...) checked_sub(x::Integer, y::Integer) = checked_sub(promote(x,y)...) checked_mul(x::Integer, y::Integer) = checked_mul(promote(x,y)...) +# "Promotion" that takes a Functor into account. You can override this +# as needed. For example, if you need to provide a custom result type +# for the multiplication of two types, +# promote_op{R<:MyType,S<:MyType}(::MulFun, ::Type{R}, ::Type{S}) = MyType{multype(R,S)} +promote_op{R,S}(::Any, ::Type{R}, ::Type{S}) = promote_type(R, S) + ## catch-alls to prevent infinite recursion when definitions are missing ## no_op_err(name, T) = error(name," not defined for ",T) diff --git a/doc/devdocs/promote-op.rst b/doc/devdocs/promote-op.rst new file mode 100644 index 0000000000000..beb74fe07a56e --- /dev/null +++ b/doc/devdocs/promote-op.rst @@ -0,0 +1,38 @@ +.. currentmodule:: Base + +.. _devdocs-promote-op: + +Operator-sensitive promotion +============================ + +In certain cases, the :ref:`simple rules for promotion +` may not be sufficient. For example, consider a +type that can represent an object with physical units, here restricted +to a single unit like "meter":: + + immutable MeterUnits{T,P} <: Number + val::T + end + MeterUnits{T}(val::T, pow::Int) = MeterUnits{T,pow}(val) + + m = MeterUnits(1.0, 1) # 1.0 meter, i.e. units of length + m2 = MeterUnits(1.0, 2) # 1.0 meter^2, i.e. units of area + +Now let's define the operations ``+`` and ``*`` for these objects: +``m+m`` should have the type of ``m`` but ``m*m`` should have the type +of ``m2``. When the result type depends on the operation, and not +just the input types, ``promote_rule`` will be inadequate. + +Fortunately, it's possible to provide such definitions via ``promote_op``:: + + Base.promote_op{R,S}(::Base.AddFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),1} + Base.promote_op{R,S}(::Base.MulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2} + Base.promote_op{R,S}(::Base.DotMulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2} + +The first one defines the promotion rule for ``+``, and the second one +for ``*``. ``AddFun``, ``MulFun``, and ``DotMulFun`` are "functor +types" defined in `functor.jl +`_. + +It's worth noting that as julia's internal representation of functions +evolves, this interface may change in a future version of Julia. diff --git a/doc/manual/conversion-and-promotion.rst b/doc/manual/conversion-and-promotion.rst index 737b45f9c4fed..4d5cda7f8c59a 100644 --- a/doc/manual/conversion-and-promotion.rst +++ b/doc/manual/conversion-and-promotion.rst @@ -276,6 +276,9 @@ the the catch-all method definitions given in *(x::Number, y::Number) = *(promote(x,y)...) /(x::Number, y::Number) = /(promote(x,y)...) +In certain cases, the result type also depends on the operator; how to +handle such scenarios is described :ref:`elsewhere `. + These method definitions say that in the absence of more specific rules for adding, subtracting, multiplying and dividing pairs of numeric values, promote the values to a common type and then try again. That's @@ -309,6 +312,8 @@ programmers to supply the expected types to constructor functions explicitly, but sometimes, especially for numeric problems, it can be convenient to do promotion automatically. +.. _man-promotion-rules: + Defining Promotion Rules ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/test/arrayops.jl b/test/arrayops.jl index 217e3c7589035..d989ea7b46fd0 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -1212,3 +1212,29 @@ b = rand(6,7) @test_throws BoundsError copy!(a,b) @test_throws ArgumentError copy!(a,2:3,1:3,b,1:5,2:7) @test_throws ArgumentError Base.copy_transpose!(a,2:3,1:3,b,1:5,2:7) + +# return type declarations (promote_op) +module RetTypeDecl + using Base.Test + + immutable MeterUnits{T,P} <: Number + val::T + end + MeterUnits{T}(val::T, pow::Int) = MeterUnits{T,pow}(val) + + m = MeterUnits(1.0, 1) # 1.0 meter, i.e. units of length + m2 = MeterUnits(1.0, 2) # 1.0 meter^2, i.e. units of area + + (+){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,1}(x.val+y.val) + (*){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,2}(x.val*y.val) + (.*){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,2}(x.val*y.val) + + Base.promote_op{R,S}(::Base.AddFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),1} + Base.promote_op{R,S}(::Base.MulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2} + Base.promote_op{R,S}(::Base.DotMulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2} + + @test @inferred(m+[m,m]) == [m+m,m+m] + @test @inferred([m,m]+m) == [m+m,m+m] + @test @inferred(m.*[m,m]) == [m2,m2] + @test @inferred([m,m].*m) == [m2,m2] +end