Skip to content

Commit

Permalink
Merge pull request #309 from AayushSabharwal/as/broadcast-from-array
Browse files Browse the repository at this point in the history
feat: add ability to set VectorOfArray with Array using broadcast
  • Loading branch information
ChrisRackauckas authored Dec 22, 2023
2 parents dd5c756 + 300f692 commit 8c87d85
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
19 changes: 17 additions & 2 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,25 @@ end
bc = Broadcast.flatten(bc)
N = narrays(bc)
@inbounds for i in 1:N
if dest[:, i] isa AbstractArray && !isa(dest[:, i], StaticArraysCore.SArray)
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
copyto!(dest[:, i], unpack_voa(bc, i))
else
dest[:, i] = copy(unpack_voa(bc, i))
unpacked = unpack_voa(bc, i)
dest[:, i] = unpacked.f(unpacked.args...)
end
end
dest
end

@inline function Base.copyto!(dest::AbstractVectorOfArray,
bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
bc = Broadcast.flatten(bc)
@inbounds for i in 1:length(dest.u)
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
copyto!(dest[:, i], unpack_voa(bc, i))
else
unpacked = unpack_voa(bc, i)
dest[:, i] = unpacked.f(unpacked.args...)
end
end
dest
Expand Down
24 changes: 24 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,27 @@ z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})])
z .= x .+ y

@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])

u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})])
u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})])
u3 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})])

function f(u1,u2,u3)
u3 .= u1 .+ u2
end
f(u1,u2,u3)
@test (@allocated f(u1,u2,u3)) == 0

yy = [2.0 1.0; 2.0 1.0]
zz = x .+ yy
@test zz == [4.0 2.0; 4.0 2.0]

z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})])
z .= zz
@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])

function f!(z,zz)
z .= zz
end
f!(z,zz)
@test (@allocated f!(z,zz)) == 0

0 comments on commit 8c87d85

Please sign in to comment.