Skip to content

Commit

Permalink
Make broadcast check the indices
Browse files Browse the repository at this point in the history
Fixes #19
  • Loading branch information
andyferris committed Jun 11, 2020
1 parent 0d34363 commit 5dc0545
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 18 deletions.
20 changes: 15 additions & 5 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct BroadcastedDictionary{I, T, F, Data <: Tuple} <: AbstractDictionary{I, T}
sharetokens::Bool
end

function BroadcastedDictionary(f, data)
@propagate_inbounds function BroadcastedDictionary(f, data)
dicts = _dicts(data...)
sharetokens = _sharetokens(dicts...)
I = keytype(dicts[1])
Expand All @@ -20,15 +20,22 @@ end
@inline Base.keys(d::BroadcastedDictionary) = _keys(d.data...)

@propagate_inbounds function Base.getindex(d::BroadcastedDictionary{I}, i::I) where {I}
return d.f(_getindex(i, d.data...)...)
if d.sharetokens
t = gettoken(d, i)
return d.f(_gettokenvalue(t, d.data...)...)
else
return d.f(_getindex(i, d.data...)...)
end
end

function Base.isassigned(d::BroadcastedDictionary{I}, i::I) where {I}
return _isassigned(i, d.data...)
end

istokenizable(d::BroadcastedDictionary) = d.sharetokens
tokens(d::BroadcastedDictionary) = _tokens(d.data...)
function tokens(d::BroadcastedDictionary)
_tokens(d.data...)
end

@propagate_inbounds function gettoken(d::BroadcastedDictionary{I}, i::I) where {I}
return gettoken(_tokens(d.data...), i)
Expand All @@ -53,10 +60,13 @@ _dicts() = ()
@inline _keys(d, ds...) = _keys(ds...)

_sharetokens(d) = true
@inline function _sharetokens(d, d2, ds...)
@propagate_inbounds function _sharetokens(d, d2, ds...)
if sharetokens(d, d2)
return _sharetokens(d, ds...)
else
@boundscheck if !isequal(keys(d), keys(d2))
throw(IndexError("Indices do not match"))
end
return false
end
end
Expand All @@ -72,7 +82,7 @@ _gettokenvalue(t) = ()
@propagate_inbounds _gettokenvalue(t, d, ds...) = (d[CartesianIndex()], _gettokenvalue(t, ds...)...)

_isassigned(i) = true
_isassigned(i, d, ds...) = _istokenassigned(i, ds...)
_isassigned(i, d, ds...) = _isassigned(i, ds...)
function _isassigned(i, d::AbstractDictionary, ds...)
if isassigned(d, i)
return _isassigned(i, ds...)
Expand Down
30 changes: 17 additions & 13 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
@testset "broadcast" begin
i = HashIndices([1,2,3,4,5])

@test issetequal(pairs((i .+ 1)::HashDictionary), [1=>2, 2=>3, 3=>4, 4=>5, 5=>6])
@test issetequal(pairs(Base.Broadcast.broadcasted(+, i, 1)::BroadcastedDictionary), [1=>2, 2=>3, 3=>4, 4=>5, 5=>6])
@test isequal((i .+ 1)::HashDictionary, dictionary([1=>2, 2=>3, 3=>4, 4=>5, 5=>6]))
@test isequal(Base.Broadcast.broadcasted(+, i, 1)::BroadcastedDictionary, dictionary([1=>2, 2=>3, 3=>4, 4=>5, 5=>6]))

@test issetequal(pairs((i .+ i)::HashDictionary), [1=>2, 2=>4, 3=>6, 4=>8, 5=>10])
@test issetequal(pairs(Base.Broadcast.broadcasted(+, i, i)::BroadcastedDictionary), [1=>2, 2=>4, 3=>6, 4=>8, 5=>10])
@test isequal((i .+ i)::HashDictionary, dictionary([1=>2, 2=>4, 3=>6, 4=>8, 5=>10]))
@test isequal(Base.Broadcast.broadcasted(+, i, i)::BroadcastedDictionary, dictionary([1=>2, 2=>4, 3=>6, 4=>8, 5=>10]))

@test issetequal(pairs((i .+ copy(i))::HashDictionary), [1=>2, 2=>4, 3=>6, 4=>8, 5=>10])
@test issetequal(pairs(Base.Broadcast.broadcasted(+, i, copy(i))::BroadcastedDictionary), [1=>2, 2=>4, 3=>6, 4=>8, 5=>10])
@test isequal((i .+ copy(i))::HashDictionary, dictionary([1=>2, 2=>4, 3=>6, 4=>8, 5=>10]))
@test isequal(Base.Broadcast.broadcasted(+, i, copy(i))::BroadcastedDictionary, dictionary([1=>2, 2=>4, 3=>6, 4=>8, 5=>10]))

@test_throws IndexError HashIndices([1,2]) .+ HashIndices([2,3])

d = i .+ 1

@test issetequal(pairs((d .+ 1)::HashDictionary), [1=>3, 2=>4, 3=>5, 4=>6, 5=>7])
@test issetequal(pairs(Base.Broadcast.broadcasted(+, d, 1)::BroadcastedDictionary), [1=>3, 2=>4, 3=>5, 4=>6, 5=>7])
@test isequal((d .+ 1)::HashDictionary, dictionary([1=>3, 2=>4, 3=>5, 4=>6, 5=>7]))
@test isequal(Base.Broadcast.broadcasted(+, d, 1)::BroadcastedDictionary, dictionary([1=>3, 2=>4, 3=>5, 4=>6, 5=>7]))

@test issetequal(pairs((d .+ d)::HashDictionary), [1=>4, 2=>6, 3=>8, 4=>10, 5=>12])
@test issetequal(pairs(Base.Broadcast.broadcasted(+, d, d)::BroadcastedDictionary), [1=>4, 2=>6, 3=>8, 4=>10, 5=>12])
@test isequal((d .+ d)::HashDictionary, dictionary([1=>4, 2=>6, 3=>8, 4=>10, 5=>12]))
@test isequal(Base.Broadcast.broadcasted(+, d, d)::BroadcastedDictionary, dictionary([1=>4, 2=>6, 3=>8, 4=>10, 5=>12]))

@test issetequal(pairs((d .+ copy(d))::HashDictionary), [1=>4, 2=>6, 3=>8, 4=>10, 5=>12])
@test issetequal(pairs(Base.Broadcast.broadcasted(+, d, copy(d))::BroadcastedDictionary), [1=>4, 2=>6, 3=>8, 4=>10, 5=>12])
@test isequal((d .+ copy(d))::HashDictionary, dictionary([1=>4, 2=>6, 3=>8, 4=>10, 5=>12]))
@test isequal(Base.Broadcast.broadcasted(+, d, copy(d))::BroadcastedDictionary, dictionary([1=>4, 2=>6, 3=>8, 4=>10, 5=>12]))

d2 = similar(d)
d2 .= d .+ d
@test issetequal(pairs(d2), [1=>4, 2=>6, 3=>8, 4=>10, 5=>12])
@test isequal(d2, dictionary([1=>4, 2=>6, 3=>8, 4=>10, 5=>12]))

@test_throws IndexError HashDictionary([1,2],[1,2]) .+ HashDictionary([2,3],[2,3])
end

0 comments on commit 5dc0545

Please sign in to comment.