Skip to content

Commit

Permalink
feat: add ability to set VectorOfArray with Array using broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 21, 2023
1 parent dd5c756 commit a30d7d0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,20 @@ end
bc = Broadcast.flatten(bc)
N = narrays(bc)
@inbounds for i in 1:N
if dest[:, i] isa AbstractArray && !isa(dest[:, i], StaticArraysCore.SArray)
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
copyto!(dest[:, i], unpack_voa(bc, i))

Check warning on line 667 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L666-L667

Added lines #L666 - L667 were not covered by tests
else
dest[:, i] = copy(unpack_voa(bc, i))

Check warning on line 669 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L669

Added line #L669 was not covered by tests
end
end
dest

Check warning on line 672 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L671-L672

Added lines #L671 - L672 were not covered by tests
end

@inline function Base.copyto!(dest::AbstractVectorOfArray,

Check warning on line 675 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L675

Added line #L675 was not covered by tests
bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
bc = Broadcast.flatten(bc)
@inbounds for i in 1:length(dest.u)
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])

Check warning on line 679 in src/vector_of_array.jl

View check run for this annotation

Codecov / codecov/patch

src/vector_of_array.jl#L677-L679

Added lines #L677 - L679 were not covered by tests
copyto!(dest[:, i], unpack_voa(bc, i))
else
dest[:, i] = copy(unpack_voa(bc, i))
Expand Down
8 changes: 8 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,11 @@ z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})])
z .= x .+ y

@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])

yy = [2.0 1.0; 2.0 1.0]
zz = x .+ yy
@test zz == [4.0, 2.0; 4.0, 2.0]

z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})])
z .= zz
@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])

0 comments on commit a30d7d0

Please sign in to comment.