Skip to content

Commit

Permalink
Make broadcast_axes inferrable even if axes are of different types
Browse files Browse the repository at this point in the history
(cherry picked from commit 4230407)
  • Loading branch information
timholy authored and KristofferC committed Dec 12, 2018
1 parent a1f4471 commit 6174770
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
10 changes: 9 additions & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,19 @@ end
_bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))))
_bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
_bcs1(a, b::Integer) = _bcs1(b, a)
_bcs1(a, b) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
_bcs1(a, b) = _bcsm(b, a) ? _sametype(b, a) : (_bcsm(a, b) ? _sametype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size")))
# _bcsm tests whether the second index is consistent with the first
_bcsm(a, b) = a == b || length(b) == 1
_bcsm(a, b::Number) = b == 1
_bcsm(a::Number, b::Number) = a == b || b == 1
# Ensure inferrability when dealing with axes of different AbstractUnitRange types
# (We may not want to define general promotion rules between, say, OneTo and Slice, but if
# we get here we know the axes are at least consistent)
_sametype(a::T, b::T) where T = a
_sametype(a::OneTo, b::OneTo) = OneTo{Int}(a)
_sametype(a::OneTo, b) = OneTo{Int}(a)
_sametype(a, b::OneTo) = OneTo{Int}(a)
_sametype(a, b) = UnitRange{Int}(a)

## Check that all arguments are broadcast compatible with shape
# comparing one input against a shape
Expand Down
5 changes: 5 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@ A = OffsetArray(rand(4,4), (-3,5))
A = OffsetArray(view(rand(4,4), 1:4, 4:-1:1), (-3,5))
@test vec(A) == reshape(A, :) == reshape(A, 16) == reshape(A, Val(1)) == A[:] == vec(A.parent)

# broadcast
a = [1]
b = OffsetArray(a, (0,))
@test @inferred(a .+ b) == [2]

end # let

# Check that similar throws a MethodError rather than a
Expand Down

0 comments on commit 6174770

Please sign in to comment.