Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CatVector #577

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions src/custom_collections/CatVector.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
$(TYPEDEF)

An `AbstractVector` subtype that acts as a lazy concatenation of a number
of subvectors.
"""
struct CatVector{T, N, V<:AbstractVector{T}} <: AbstractVector{T}
vecs::NTuple{N, V}
end

@inline Base.size(vec::CatVector) = (mapreduce(length, +, vec.vecs; init=0),)

# Note: getindex and setindex are pretty naive. Consider precomputing map from
# index to vector upon CatVector construction.
Base.@propagate_inbounds function Base.getindex(vec::CatVector, index::Int)
@boundscheck index >= 1 || throw(BoundsError(vec, index))
i = 1
j = index
@inbounds while true
subvec = vec.vecs[i]
l = length(subvec)
if j <= l
return subvec[eachindex(subvec)[j]]
else
j -= l
i += 1
end
end
throw(BoundsError(vec, index))
end

Base.@propagate_inbounds function Base.setindex!(vec::CatVector, val, index::Int)
@boundscheck index >= 1 || throw(BoundsError(vec, index))
i = 1
j = index
while true
subvec = vec.vecs[i]
l = length(subvec)
if j <= l
subvec[eachindex(subvec)[j]] = val
return val
else
j -= l
i += 1
end
end
throw(BoundsError(vec, index))
end

Base.@propagate_inbounds function Base.copyto!(dest::AbstractVector{T}, src::CatVector{T}) where {T}
@boundscheck length(dest) == length(src) || throw(DimensionMismatch())
dest_indices = eachindex(dest)
k = 1
@inbounds for i in eachindex(src.vecs)
vec = src.vecs[i]
for j in eachindex(vec)
dest[dest_indices[k]] = vec[j]
k += 1
end
end
return dest
end

Base.similar(vec::CatVector) = CatVector(map(similar, vec.vecs))
Base.similar(vec::CatVector, ::Type{T}) where {T} = CatVector(map(x -> similar(x, T), vec.vecs))

@noinline cat_vectors_line_up_error() = throw(ArgumentError("Subvectors must line up"))

@inline function check_cat_vectors_line_up(x::CatVector, y::CatVector)
length(x.vecs) == length(y.vecs) || cat_vectors_line_up_error()
for i in eachindex(x.vecs)
length(x.vecs[i]) == length(y.vecs[i]) || cat_vectors_line_up_error()
end
nothing
end

@inline check_cat_vectors_line_up(x::CatVector, y) = nothing
@inline function check_cat_vectors_line_up(x::CatVector, y, tail...)
check_cat_vectors_line_up(x, y)
check_cat_vectors_line_up(x, tail...)
end

@propagate_inbounds function Base.copyto!(dest::CatVector, src::CatVector)
for i in eachindex(dest.vecs)
copyto!(dest.vecs[i], src.vecs[i])
end
return dest
end

@inline function Base.map!(f::F, dest::CatVector, args::CatVector...) where F
@boundscheck check_cat_vectors_line_up(dest, args...)
@inbounds for i in eachindex(dest.vecs)
map!(f, dest.vecs[i], map(arg -> arg.vecs[i], args)...)
end
return dest
end

Base.@propagate_inbounds catvec_broadcast_vec(arg::CatVector, range::UnitRange, k::Int) = arg.vecs[k]
Base.@propagate_inbounds catvec_broadcast_vec(arg::AbstractVector, range::UnitRange, k::Int) = view(arg, range)
Base.@propagate_inbounds catvec_broadcast_vec(arg::Number, range::UnitRange, k::Int) = arg

@inline function Base.copyto!(dest::CatVector, bc::Broadcast.Broadcasted{Nothing})
flat = Broadcast.flatten(bc)
@boundscheck check_cat_vectors_line_up(dest, flat.args...)
offset = 1
@inbounds for i in eachindex(dest.vecs)
let i = i, f = flat.f, args = flat.args
dest′ = dest.vecs[i]
range = offset : offset + length(dest′) - 1
args′ = map(arg -> catvec_broadcast_vec(arg, range, i), args)
axes′ = (eachindex(dest′),)
copyto!(dest′, Broadcast.Broadcasted{Nothing}(f, args′, axes′))
offset = last(range) + 1
end
end
return dest
end
4 changes: 3 additions & 1 deletion src/custom_collections/custom_collections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ export
CacheIndexDict,
SegmentedVector,
SegmentedBlockDiagonalMatrix,
UnorderedPair
UnorderedPair,
CatVector

export
foreach_with_extra_args,
Expand All @@ -34,5 +35,6 @@ include("IndexDict.jl")
include("SegmentedVector.jl")
include("SegmentedBlockDiagonalMatrix.jl")
include("UnorderedPair.jl")
include("CatVector.jl")

end # module
69 changes: 69 additions & 0 deletions test/test_custom_collections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,73 @@ Base.axes(m::NonOneBasedMatrix) = ((1:m.m) .- 2, (1:m.n) .+ 1)
dict = Dict(p1 => 3)
@test dict[p2] == 3
end

@testset "CatVector" begin
CatVector = RigidBodyDynamics.CatVector
Random.seed!(52)
vecs = ntuple(i -> rand(rand(0 : 5)), Val(10))
l = sum(length, vecs)
x = zeros(l)
y = CatVector(vecs)

@test length(y) == l

x .= y
for i in eachindex(x)
@test x[i] == y[i]
end

x .= 0
@test x != y
copyto!(x, y)
for i in eachindex(x)
@test x[i] == y[i]
end
@test x == y

y .= 0
rand!(x)
y .= x .+ y .+ 1
@test x .+ 1 == y

allocs = let x=x, vecs=vecs
@allocated copyto!(x, RigidBodyDynamics.CatVector(vecs))
end
@test allocs == 0

y2 = similar(y)
@test eltype(y2) == eltype(y)
for i in eachindex(y.vecs)
@test length(y2.vecs[i]) == length(y.vecs[i])
@test y2.vecs[i] !== y.vecs[i]
end

y3 = similar(y, Int)
@test eltype(y3) == Int
for i in eachindex(y.vecs)
@test length(y3.vecs[i]) == length(y.vecs[i])
@test y3.vecs[i] !== y.vecs[i]
end

y4 = similar(y)
copyto!(y4, y)
@test y4 == y

y5 = similar(y)
map!(+, y5, y, y)
@test Vector(y5) == Vector(y) + Vector(y)

z = similar(y)
rand!(z)
yvec = Vector(y)
zvec = Vector(z)

z .= muladd.(1e-3, y, z)
zvec .= muladd.(1e-3, yvec, zvec)
@test zvec == z
allocs = let y=y, z=z
@allocated z .= muladd.(1e-3, y, z)
end
@test allocs == 0
end
end