From caad61f7b032078d8dc7288126bca4b4348967a2 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 18 Nov 2018 09:11:07 -0600 Subject: [PATCH] Make `broadcast_axes` inferrable even if axes are of different types (cherry picked from commit 423040776131a64564fb1770bb092eac37bd966d) --- base/broadcast.jl | 10 +++++++++- test/offsetarray.jl | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index a95cc8bbfed29d..712aaa5947b660 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -436,11 +436,19 @@ end _bcs1(a::Integer, b::Integer) = a == 1 ? b : (b == 1 ? a : (a == b ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size")))) _bcs1(a::Integer, b) = a == 1 ? b : (first(b) == 1 && last(b) == a ? b : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) _bcs1(a, b::Integer) = _bcs1(b, a) -_bcs1(a, b) = _bcsm(b, a) ? b : (_bcsm(a, b) ? a : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) +_bcs1(a, b) = _bcsm(b, a) ? _sametype(b, a) : (_bcsm(a, b) ? _sametype(a, b) : throw(DimensionMismatch("arrays could not be broadcast to a common size"))) # _bcsm tests whether the second index is consistent with the first _bcsm(a, b) = a == b || length(b) == 1 _bcsm(a, b::Number) = b == 1 _bcsm(a::Number, b::Number) = a == b || b == 1 +# Ensure inferrability when dealing with axes of different AbstractUnitRange types +# (We may not want to define general promotion rules between, say, OneTo and Slice, but if +# we get here we know the axes are at least consistent) +_sametype(a::T, b::T) where T = a +_sametype(a::OneTo, b::OneTo) = OneTo{Int}(a) +_sametype(a::OneTo, b) = OneTo{Int}(a) +_sametype(a, b::OneTo) = OneTo{Int}(a) +_sametype(a, b) = UnitRange{Int}(a) ## Check that all arguments are broadcast compatible with shape # comparing one input against a shape diff --git a/test/offsetarray.jl b/test/offsetarray.jl index b6078f41bab8db..5aa58853179ebd 100644 --- a/test/offsetarray.jl +++ b/test/offsetarray.jl @@ -466,6 +466,11 @@ A = OffsetArray(rand(4,4), (-3,5)) A = OffsetArray(view(rand(4,4), 1:4, 4:-1:1), (-3,5)) @test vec(A) == reshape(A, :) == reshape(A, 16) == reshape(A, Val(1)) == A[:] == vec(A.parent) +# broadcast +a = [1] +b = OffsetArray(a, (0,)) +@test @inferred(a .+ b) == [2] + end # let # Check that similar throws a MethodError rather than a