Skip to content

Commit

Permalink
Make _reshape_uncolon easier on inference
Browse files Browse the repository at this point in the history
Fixes #20848
  • Loading branch information
martinholters committed Mar 3, 2017
1 parent 017f5b4 commit 37cff21
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,19 @@ reshape(parent::AbstractArray, dims::Dims) = _reshape(parent, dims)
reshape(parent::AbstractArray, dims::Int...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Union{Int,Colon}...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Int,Colon}}}) = _reshape(parent, _reshape_uncolon(parent, dims))
# Recursively move dimensions to pre and post tuples, splitting on the Colon
@inline _reshape_uncolon(A, dims) = _reshape_uncolon(A, (), nothing, (), dims)
@inline _reshape_uncolon(A, pre, c::Void, post, dims::Tuple{Any, Vararg{Any}}) =
_reshape_uncolon(A, (pre..., dims[1]), c, post, tail(dims))
@inline _reshape_uncolon(A, pre, c::Void, post, dims::Tuple{Colon, Vararg{Any}}) =
_reshape_uncolon(A, pre, dims[1], post, tail(dims))
@inline _reshape_uncolon(A, pre, c::Colon, post, dims::Tuple{Any, Vararg{Any}}) =
_reshape_uncolon(A, pre, c, (post..., dims[1]), tail(dims))
_reshape_uncolon(A, pre, c::Colon, post, dims::Tuple{Colon, Vararg{Any}}) =
throw(DimensionMismatch("new dimensions $((pre..., c, post..., dims...)) may only have at most one omitted dimension specified by Colon()"))
@inline function _reshape_uncolon(A, pre, c::Colon, post, dims::Tuple{})
@inline function _reshape_uncolon(A, dims)
pre, post = _split_at_colon((), dims)
if any(d -> d isa Colon, post)
throw(DimensionMismatch("new dimensions $(dims) may only have at most one omitted dimension specified by Colon()"))
end
sz, remainder = divrem(length(A), prod(pre)*prod(post))
remainder == 0 || _throw_reshape_colon_dimmismatch(A, pre, post)
(pre..., sz, post...)
end
@inline _split_at_colon(pre, dims::Tuple{Any, Vararg{Any}}) =
_split_at_colon((pre..., dims[1]), tail(dims))
@inline _split_at_colon(pre, dims::Tuple{Colon, Vararg{Any}}) =
(pre, tail(dims))
_throw_reshape_colon_dimmismatch(A, pre, post) =
throw(DimensionMismatch("array size $(length(A)) must be divisible by the product of the new dimensions $((pre..., :, post...))"))

Expand Down

0 comments on commit 37cff21

Please sign in to comment.