Skip to content

Commit

Permalink
Merge pull request #12292 from JuliaLang/teh/promotion
Browse files Browse the repository at this point in the history
Enable operator-sensitive extension of element-type promotion
  • Loading branch information
timholy committed Jul 29, 2015
2 parents 9b14a7c + 2b8a8d3 commit d7351cf
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 31 deletions.
41 changes: 29 additions & 12 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ end

## Binary arithmetic operators ##

promote_array_type{Scalar, Arry}(::Type{Scalar}, ::Type{Arry}) = promote_type(Scalar, Arry)
promote_array_type{S<:Real, A<:FloatingPoint}(::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer, A<:Integer}(::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S
promote_array_type{Scalar, Arry}(F, ::Type{Scalar}, ::Type{Arry}) = promote_op(F, Scalar, Arry)
promote_array_type{S<:Real, A<:FloatingPoint}(F, ::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer, A<:Integer}(F, ::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}) = S

# Handle operations that return different types
./(x::Number, Y::AbstractArray) =
Expand All @@ -57,10 +57,16 @@ promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S
.^(X::AbstractArray, y::Number ) =
reshape([ x ^ y for x in X ], size(X))

for f in (:+, :-, :div, :mod, :&, :|, :$)
for (f,F) in ((:+, AddFun()),
(:-, SubFun()),
(:div, IDivFun()),
(:mod, ModFun()),
(:&, AndFun()),
(:|, OrFun()),
(:$, XorFun()))
@eval begin
function ($f){S,T}(A::Range{S}, B::Range{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for (a,b) in zip(A,B)
@inbounds F[i] = ($f)(a, b)
Expand All @@ -69,7 +75,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::AbstractArray{S}, B::Range{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for b in B
@inbounds F[i] = ($f)(A[i], b)
Expand All @@ -78,7 +84,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::Range{S}, B::AbstractArray{T})
F = similar(B, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(B, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for a in A
@inbounds F[i] = ($f)(a, B[i])
Expand All @@ -87,25 +93,36 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::AbstractArray{S}, B::AbstractArray{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
for i in eachindex(A,B)
@inbounds F[i] = ($f)(A[i], B[i])
end
return F
end
end
end
for f in (:.+, :.-, :.*, :.%, :.<<, :.>>, :div, :mod, :rem, :&, :|, :$)
for (f,F) in ((:.+, DotAddFun()),
(:.-, DotSubFun()),
(:.*, DotMulFun()),
(:.%, DotRemFun()),
(:.<<, DotLSFun()),
(:.>>, DotRSFun()),
(:div, IDivFun()),
(:mod, ModFun()),
(:rem, RemFun()),
(:&, AndFun()),
(:|, OrFun()),
(:$, XorFun()))
@eval begin
function ($f){T}(A::Number, B::AbstractArray{T})
F = similar(B, promote_array_type(typeof(A),T))
F = similar(B, promote_array_type($F,typeof(A),T))
for i in eachindex(B)
@inbounds F[i] = ($f)(A, B[i])
end
return F
end
function ($f){T}(A::AbstractArray{T}, B::Number)
F = similar(A, promote_array_type(typeof(B),T))
F = similar(A, promote_array_type($F,typeof(B),T))
for i in eachindex(A)
@inbounds F[i] = ($f)(A[i], B)
end
Expand Down
39 changes: 21 additions & 18 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -855,21 +855,23 @@ for f in (:+, :-)
return r
end
end
for f in (:.+, :.-),
(arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)),
(:(B::BitArray), :(x::Number) , :(promote_array_type(typeof(x), Bool)), :(b, x)),
(:(x::Bool) , :(B::BitArray), Int , :(x, b)),
(:(x::Number) , :(B::BitArray), :(promote_array_type(typeof(x), Bool)), :(x, b)))
@eval function ($f)($arg1, $arg2)
r = Array($T, size(B))
bi = start(B)
ri = 1
while !done(B, bi)
b, bi = next(B, bi)
@inbounds r[ri] = ($f)($fargs...)
ri += 1
for (f,F) in ((:.+, DotAddFun()),
(:.-, DotSubFun()))
for (arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)),
(:(B::BitArray), :(x::Number) , :(promote_array_type($F, typeof(x), Bool)), :(b, x)),
(:(x::Bool) , :(B::BitArray), Int , :(x, b)),
(:(x::Number) , :(B::BitArray), :(promote_array_type($F, typeof(x), Bool)), :(x, b)))
@eval function ($f)($arg1, $arg2)
r = Array($T, size(B))
bi = start(B)
ri = 1
while !done(B, bi)
b, bi = next(B, bi)
@inbounds r[ri] = ($f)($fargs...)
ri += 1
end
return r
end
return r
end
end

Expand Down Expand Up @@ -897,7 +899,7 @@ function div(x::Bool, B::BitArray)
end
function div(x::Number, B::BitArray)
all(B) || throw(DivideError())
pt = promote_array_type(typeof(x), Bool)
pt = promote_array_type(IDivFun(), typeof(x), Bool)
y = div(x, true)
reshape(pt[ y for i = 1:length(B) ], size(B))
end
Expand All @@ -918,15 +920,16 @@ function mod(x::Bool, B::BitArray)
end
function mod(x::Number, B::BitArray)
all(B) || throw(DivideError())
pt = promote_array_type(typeof(x), Bool)
pt = promote_array_type(ModFun(), typeof(x), Bool)
y = mod(x, true)
reshape(pt[ y for i = 1:length(B) ], size(B))
end

for f in (:div, :mod)
for (f,F) in ((:div, IDivFun()),
(:mod, ModFun()))
@eval begin
function ($f)(B::BitArray, x::Number)
F = Array(promote_array_type(typeof(x), Bool), size(B))
F = Array(promote_array_type($F, typeof(x), Bool), size(B))
for i = 1:length(F)
F[i] = ($f)(B[i], x)
end
Expand Down
2 changes: 1 addition & 1 deletion base/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ big{T<:FloatingPoint,N}(A::AbstractArray{Complex{T},N}) = convert(AbstractArray{

## promotion to complex ##

promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(::Type{S}, ::Type{Complex{AT}}) = Complex{AT}
promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(F, ::Type{S}, ::Type{Complex{AT}}) = Complex{AT}

function complex{S<:Real,T<:Real}(A::Array{S}, B::Array{T})
if size(A) != size(B); throw(DimensionMismatch()); end
Expand Down
27 changes: 27 additions & 0 deletions base/functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,27 @@ call(::AndFun, x, y) = x & y
immutable OrFun <: Func{2} end
call(::OrFun, x, y) = x | y

immutable XorFun <: Func{2} end
call(::XorFun, x, y) = x $ y

immutable AddFun <: Func{2} end
call(::AddFun, x, y) = x + y

immutable DotAddFun <: Func{2} end
call(::DotAddFun, x, y) = x .+ y

immutable SubFun <: Func{2} end
call(::SubFun, x, y) = x - y

immutable DotSubFun <: Func{2} end
call(::DotSubFun, x, y) = x .- y

immutable MulFun <: Func{2} end
call(::MulFun, x, y) = x * y

immutable DotMulFun <: Func{2} end
call(::DotMulFun, x, y) = x .* y

immutable RDivFun <: Func{2} end
call(::RDivFun, x, y) = x / y

Expand All @@ -52,6 +64,15 @@ call(::LDivFun, x, y) = x \ y
immutable IDivFun <: Func{2} end
call(::IDivFun, x, y) = div(x, y)

immutable ModFun <: Func{2} end
call(::ModFun, x, y) = mod(x, y)

immutable RemFun <: Func{2} end
call(::RemFun, x, y) = rem(x, y)

immutable DotRemFun <: Func{2} end
call(::RemFun, x, y) = x .% y

immutable PowFun <: Func{2} end
call(::PowFun, x, y) = x ^ y

Expand All @@ -67,6 +88,12 @@ call(::LessFun, x, y) = x < y
immutable MoreFun <: Func{2} end
call(::MoreFun, x, y) = x > y

immutable DotLSFun <: Func{2} end
call(::DotLSFun, x, y) = x .<< y

immutable DotRSFun <: Func{2} end
call(::DotRSFun, x, y) = x .>> y

# a fallback unspecialized function object that allows code using
# function objects to not care whether they were able to specialize on
# the function value or not
Expand Down
6 changes: 6 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ checked_add(x::Integer, y::Integer) = checked_add(promote(x,y)...)
checked_sub(x::Integer, y::Integer) = checked_sub(promote(x,y)...)
checked_mul(x::Integer, y::Integer) = checked_mul(promote(x,y)...)

# "Promotion" that takes a Functor into account. You can override this
# as needed. For example, if you need to provide a custom result type
# for the multiplication of two types,
# promote_op{R<:MyType,S<:MyType}(::MulFun, ::Type{R}, ::Type{S}) = MyType{multype(R,S)}
promote_op{R,S}(::Any, ::Type{R}, ::Type{S}) = promote_type(R, S)

## catch-alls to prevent infinite recursion when definitions are missing ##

no_op_err(name, T) = error(name," not defined for ",T)
Expand Down
38 changes: 38 additions & 0 deletions doc/devdocs/promote-op.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
.. currentmodule:: Base

.. _devdocs-promote-op:

Operator-sensitive promotion
============================

In certain cases, the :ref:`simple rules for promotion
<man-promotion-rules>` may not be sufficient. For example, consider a
type that can represent an object with physical units, here restricted
to a single unit like "meter"::

immutable MeterUnits{T,P} <: Number
val::T
end
MeterUnits{T}(val::T, pow::Int) = MeterUnits{T,pow}(val)

m = MeterUnits(1.0, 1) # 1.0 meter, i.e. units of length
m2 = MeterUnits(1.0, 2) # 1.0 meter^2, i.e. units of area

Now let's define the operations ``+`` and ``*`` for these objects:
``m+m`` should have the type of ``m`` but ``m*m`` should have the type
of ``m2``. When the result type depends on the operation, and not
just the input types, ``promote_rule`` will be inadequate.

Fortunately, it's possible to provide such definitions via ``promote_op``::

Base.promote_op{R,S}(::Base.AddFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),1}
Base.promote_op{R,S}(::Base.MulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}
Base.promote_op{R,S}(::Base.DotMulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}

The first one defines the promotion rule for ``+``, and the second one
for ``*``. ``AddFun``, ``MulFun``, and ``DotMulFun`` are "functor
types" defined in `functor.jl
<https://github.com/JuliaLang/julia/blob/master/base/functors.jl>`_.

It's worth noting that as julia's internal representation of functions
evolves, this interface may change in a future version of Julia.
5 changes: 5 additions & 0 deletions doc/manual/conversion-and-promotion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ the the catch-all method definitions given in
*(x::Number, y::Number) = *(promote(x,y)...)
/(x::Number, y::Number) = /(promote(x,y)...)

In certain cases, the result type also depends on the operator; how to
handle such scenarios is described :ref:`elsewhere <devdocs-promote-op>`.

These method definitions say that in the absence of more specific rules
for adding, subtracting, multiplying and dividing pairs of numeric
values, promote the values to a common type and then try again. That's
Expand Down Expand Up @@ -309,6 +312,8 @@ programmers to supply the expected types to constructor functions
explicitly, but sometimes, especially for numeric problems, it can be
convenient to do promotion automatically.

.. _man-promotion-rules:

Defining Promotion Rules
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
26 changes: 26 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1212,3 +1212,29 @@ b = rand(6,7)
@test_throws BoundsError copy!(a,b)
@test_throws ArgumentError copy!(a,2:3,1:3,b,1:5,2:7)
@test_throws ArgumentError Base.copy_transpose!(a,2:3,1:3,b,1:5,2:7)

# return type declarations (promote_op)
module RetTypeDecl
using Base.Test

immutable MeterUnits{T,P} <: Number
val::T
end
MeterUnits{T}(val::T, pow::Int) = MeterUnits{T,pow}(val)

m = MeterUnits(1.0, 1) # 1.0 meter, i.e. units of length
m2 = MeterUnits(1.0, 2) # 1.0 meter^2, i.e. units of area

(+){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,1}(x.val+y.val)
(*){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,2}(x.val*y.val)
(.*){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,2}(x.val*y.val)

Base.promote_op{R,S}(::Base.AddFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),1}
Base.promote_op{R,S}(::Base.MulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}
Base.promote_op{R,S}(::Base.DotMulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}

@test @inferred(m+[m,m]) == [m+m,m+m]
@test @inferred([m,m]+m) == [m+m,m+m]
@test @inferred(m.*[m,m]) == [m2,m2]
@test @inferred([m,m].*m) == [m2,m2]
end

0 comments on commit d7351cf

Please sign in to comment.