Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable operator-sensitive extension of element-type promotion #12292

Merged
merged 3 commits into from
Jul 29, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions base/arraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ end

## Binary arithmetic operators ##

promote_array_type{Scalar, Arry}(::Type{Scalar}, ::Type{Arry}) = promote_type(Scalar, Arry)
promote_array_type{S<:Real, A<:FloatingPoint}(::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer, A<:Integer}(::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S
promote_array_type{Scalar, Arry}(F, ::Type{Scalar}, ::Type{Arry}) = promote_op(F, Scalar, Arry)
promote_array_type{S<:Real, A<:FloatingPoint}(F, ::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer, A<:Integer}(F, ::Type{S}, ::Type{A}) = A
promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}) = S

# Handle operations that return different types
./(x::Number, Y::AbstractArray) =
Expand All @@ -57,10 +57,16 @@ promote_array_type{S<:Integer}(::Type{S}, ::Type{Bool}) = S
.^(X::AbstractArray, y::Number ) =
reshape([ x ^ y for x in X ], size(X))

for f in (:+, :-, :div, :mod, :&, :|, :$)
for (f,F) in ((:+, AddFun()),
(:-, SubFun()),
(:div, IDivFun()),
(:mod, ModFun()),
(:&, AndFun()),
(:|, OrFun()),
(:$, XorFun()))
@eval begin
function ($f){S,T}(A::Range{S}, B::Range{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for (a,b) in zip(A,B)
@inbounds F[i] = ($f)(a, b)
Expand All @@ -69,7 +75,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::AbstractArray{S}, B::Range{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for b in B
@inbounds F[i] = ($f)(A[i], b)
Expand All @@ -78,7 +84,7 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::Range{S}, B::AbstractArray{T})
F = similar(B, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(B, promote_op($F,S,T), promote_shape(size(A),size(B)))
i = 1
for a in A
@inbounds F[i] = ($f)(a, B[i])
Expand All @@ -87,25 +93,36 @@ for f in (:+, :-, :div, :mod, :&, :|, :$)
return F
end
function ($f){S,T}(A::AbstractArray{S}, B::AbstractArray{T})
F = similar(A, promote_type(S,T), promote_shape(size(A),size(B)))
F = similar(A, promote_op($F,S,T), promote_shape(size(A),size(B)))
for i in eachindex(A,B)
@inbounds F[i] = ($f)(A[i], B[i])
end
return F
end
end
end
for f in (:.+, :.-, :.*, :.%, :.<<, :.>>, :div, :mod, :rem, :&, :|, :$)
for (f,F) in ((:.+, DotAddFun()),
(:.-, DotSubFun()),
(:.*, DotMulFun()),
(:.%, DotRemFun()),
(:.<<, DotLSFun()),
(:.>>, DotRSFun()),
(:div, IDivFun()),
(:mod, ModFun()),
(:rem, RemFun()),
(:&, AndFun()),
(:|, OrFun()),
(:$, XorFun()))
@eval begin
function ($f){T}(A::Number, B::AbstractArray{T})
F = similar(B, promote_array_type(typeof(A),T))
F = similar(B, promote_array_type($F,typeof(A),T))
for i in eachindex(B)
@inbounds F[i] = ($f)(A, B[i])
end
return F
end
function ($f){T}(A::AbstractArray{T}, B::Number)
F = similar(A, promote_array_type(typeof(B),T))
F = similar(A, promote_array_type($F,typeof(B),T))
for i in eachindex(A)
@inbounds F[i] = ($f)(A[i], B)
end
Expand Down
39 changes: 21 additions & 18 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -855,21 +855,23 @@ for f in (:+, :-)
return r
end
end
for f in (:.+, :.-),
(arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)),
(:(B::BitArray), :(x::Number) , :(promote_array_type(typeof(x), Bool)), :(b, x)),
(:(x::Bool) , :(B::BitArray), Int , :(x, b)),
(:(x::Number) , :(B::BitArray), :(promote_array_type(typeof(x), Bool)), :(x, b)))
@eval function ($f)($arg1, $arg2)
r = Array($T, size(B))
bi = start(B)
ri = 1
while !done(B, bi)
b, bi = next(B, bi)
@inbounds r[ri] = ($f)($fargs...)
ri += 1
for (f,F) in ((:.+, DotAddFun()),
(:.-, DotSubFun()))
for (arg1, arg2, T, fargs) in ((:(B::BitArray), :(x::Bool) , Int , :(b, x)),
(:(B::BitArray), :(x::Number) , :(promote_array_type($F, typeof(x), Bool)), :(b, x)),
(:(x::Bool) , :(B::BitArray), Int , :(x, b)),
(:(x::Number) , :(B::BitArray), :(promote_array_type($F, typeof(x), Bool)), :(x, b)))
@eval function ($f)($arg1, $arg2)
r = Array($T, size(B))
bi = start(B)
ri = 1
while !done(B, bi)
b, bi = next(B, bi)
@inbounds r[ri] = ($f)($fargs...)
ri += 1
end
return r
end
return r
end
end

Expand Down Expand Up @@ -897,7 +899,7 @@ function div(x::Bool, B::BitArray)
end
function div(x::Number, B::BitArray)
all(B) || throw(DivideError())
pt = promote_array_type(typeof(x), Bool)
pt = promote_array_type(IDivFun(), typeof(x), Bool)
y = div(x, true)
reshape(pt[ y for i = 1:length(B) ], size(B))
end
Expand All @@ -918,15 +920,16 @@ function mod(x::Bool, B::BitArray)
end
function mod(x::Number, B::BitArray)
all(B) || throw(DivideError())
pt = promote_array_type(typeof(x), Bool)
pt = promote_array_type(ModFun(), typeof(x), Bool)
y = mod(x, true)
reshape(pt[ y for i = 1:length(B) ], size(B))
end

for f in (:div, :mod)
for (f,F) in ((:div, IDivFun()),
(:mod, ModFun()))
@eval begin
function ($f)(B::BitArray, x::Number)
F = Array(promote_array_type(typeof(x), Bool), size(B))
F = Array(promote_array_type($F, typeof(x), Bool), size(B))
for i = 1:length(F)
F[i] = ($f)(B[i], x)
end
Expand Down
2 changes: 1 addition & 1 deletion base/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ big{T<:FloatingPoint,N}(A::AbstractArray{Complex{T},N}) = convert(AbstractArray{

## promotion to complex ##

promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(::Type{S}, ::Type{Complex{AT}}) = Complex{AT}
promote_array_type{S<:Union{Complex, Real}, AT<:FloatingPoint}(F, ::Type{S}, ::Type{Complex{AT}}) = Complex{AT}

function complex{S<:Real,T<:Real}(A::Array{S}, B::Array{T})
if size(A) != size(B); throw(DimensionMismatch()); end
Expand Down
27 changes: 27 additions & 0 deletions base/functors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,27 @@ call(::AndFun, x, y) = x & y
immutable OrFun <: Func{2} end
call(::OrFun, x, y) = x | y

immutable XorFun <: Func{2} end
call(::XorFun, x, y) = x $ y

immutable AddFun <: Func{2} end
call(::AddFun, x, y) = x + y

immutable DotAddFun <: Func{2} end
call(::DotAddFun, x, y) = x .+ y

immutable SubFun <: Func{2} end
call(::SubFun, x, y) = x - y

immutable DotSubFun <: Func{2} end
call(::DotSubFun, x, y) = x .- y

immutable MulFun <: Func{2} end
call(::MulFun, x, y) = x * y

immutable DotMulFun <: Func{2} end
call(::DotMulFun, x, y) = x .* y

immutable RDivFun <: Func{2} end
call(::RDivFun, x, y) = x / y

Expand All @@ -52,6 +64,15 @@ call(::LDivFun, x, y) = x \ y
immutable IDivFun <: Func{2} end
call(::IDivFun, x, y) = div(x, y)

immutable ModFun <: Func{2} end
call(::ModFun, x, y) = mod(x, y)

immutable RemFun <: Func{2} end
call(::RemFun, x, y) = rem(x, y)

immutable DotRemFun <: Func{2} end
call(::RemFun, x, y) = x .% y

immutable PowFun <: Func{2} end
call(::PowFun, x, y) = x ^ y

Expand All @@ -67,6 +88,12 @@ call(::LessFun, x, y) = x < y
immutable MoreFun <: Func{2} end
call(::MoreFun, x, y) = x > y

immutable DotLSFun <: Func{2} end
call(::DotLSFun, x, y) = x .<< y

immutable DotRSFun <: Func{2} end
call(::DotRSFun, x, y) = x .>> y

# a fallback unspecialized function object that allows code using
# function objects to not care whether they were able to specialize on
# the function value or not
Expand Down
6 changes: 6 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ checked_add(x::Integer, y::Integer) = checked_add(promote(x,y)...)
checked_sub(x::Integer, y::Integer) = checked_sub(promote(x,y)...)
checked_mul(x::Integer, y::Integer) = checked_mul(promote(x,y)...)

# "Promotion" that takes a Functor into account. You can override this
# as needed. For example, if you need to provide a custom result type
# for the multiplication of two types,
# promote_op{R<:MyType,S<:MyType}(::MulFun, ::Type{R}, ::Type{S}) = MyType{multype(R,S)}
promote_op{R,S}(::Any, ::Type{R}, ::Type{S}) = promote_type(R, S)

## catch-alls to prevent infinite recursion when definitions are missing ##

no_op_err(name, T) = error(name," not defined for ",T)
Expand Down
38 changes: 38 additions & 0 deletions doc/devdocs/promote-op.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
.. currentmodule:: Base

.. _devdocs-promote-op:

Operator-sensitive promotion
============================

In certain cases, the :ref:`simple rules for promotion
<man-promotion-rules>` may not be sufficient. For example, consider a
type that can represent an object with physical units, here restricted
to a single unit like "meter"::

immutable MeterUnits{T,P} <: Number
val::T
end
MeterUnits{T}(val::T, pow::Int) = MeterUnits{T,pow}(val)

m = MeterUnits(1.0, 1) # 1.0 meter, i.e. units of length
m2 = MeterUnits(1.0, 2) # 1.0 meter^2, i.e. units of area

Now let's define the operations ``+`` and ``*`` for these objects:
``m+m`` should have the type of ``m`` but ``m*m`` should have the type
of ``m2``. When the result type depends on the operation, and not
just the input types, ``promote_rule`` will be inadequate.

Fortunately, it's possible to provide such definitions via ``promote_op``::

Base.promote_op{R,S}(::Base.AddFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),1}
Base.promote_op{R,S}(::Base.MulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}
Base.promote_op{R,S}(::Base.DotMulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}

The first one defines the promotion rule for ``+``, and the second one
for ``*``. ``AddFun``, ``MulFun``, and ``DotMulFun`` are "functor
types" defined in `functor.jl
<https://github.com/JuliaLang/julia/blob/master/base/functors.jl>`_.

It's worth noting that as julia's internal representation of functions
evolves, this interface may change in a future version of Julia.
5 changes: 5 additions & 0 deletions doc/manual/conversion-and-promotion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,9 @@ the the catch-all method definitions given in
*(x::Number, y::Number) = *(promote(x,y)...)
/(x::Number, y::Number) = /(promote(x,y)...)

In certain cases, the result type also depends on the operator; how to
handle such scenarios is described :ref:`elsewhere <devdocs-promote-op>`.

These method definitions say that in the absence of more specific rules
for adding, subtracting, multiplying and dividing pairs of numeric
values, promote the values to a common type and then try again. That's
Expand Down Expand Up @@ -309,6 +312,8 @@ programmers to supply the expected types to constructor functions
explicitly, but sometimes, especially for numeric problems, it can be
convenient to do promotion automatically.

.. _man-promotion-rules:

Defining Promotion Rules
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
26 changes: 26 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1212,3 +1212,29 @@ b = rand(6,7)
@test_throws BoundsError copy!(a,b)
@test_throws ArgumentError copy!(a,2:3,1:3,b,1:5,2:7)
@test_throws ArgumentError Base.copy_transpose!(a,2:3,1:3,b,1:5,2:7)

# return type declarations (promote_op)
module RetTypeDecl
using Base.Test

immutable MeterUnits{T,P} <: Number
val::T
end
MeterUnits{T}(val::T, pow::Int) = MeterUnits{T,pow}(val)

m = MeterUnits(1.0, 1) # 1.0 meter, i.e. units of length
m2 = MeterUnits(1.0, 2) # 1.0 meter^2, i.e. units of area

(+){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,1}(x.val+y.val)
(*){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,2}(x.val*y.val)
(.*){T}(x::MeterUnits{T,1}, y::MeterUnits{T,1}) = MeterUnits{T,2}(x.val*y.val)

Base.promote_op{R,S}(::Base.AddFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),1}
Base.promote_op{R,S}(::Base.MulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}
Base.promote_op{R,S}(::Base.DotMulFun, ::Type{MeterUnits{R,1}}, ::Type{MeterUnits{S,1}}) = MeterUnits{promote_type(R,S),2}

@test @inferred(m+[m,m]) == [m+m,m+m]
@test @inferred([m,m]+m) == [m+m,m+m]
@test @inferred(m.*[m,m]) == [m2,m2]
@test @inferred([m,m].*m) == [m2,m2]
end