diff --git a/src/vector_of_array.jl b/src/vector_of_array.jl index 42c24089..933c542d 100644 --- a/src/vector_of_array.jl +++ b/src/vector_of_array.jl @@ -663,10 +663,25 @@ end bc = Broadcast.flatten(bc) N = narrays(bc) @inbounds for i in 1:N - if dest[:, i] isa AbstractArray && !isa(dest[:, i], StaticArraysCore.SArray) + if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) copyto!(dest[:, i], unpack_voa(bc, i)) else - dest[:, i] = copy(unpack_voa(bc, i)) + unpacked = unpack_voa(bc, i) + dest[:, i] = unpacked.f(unpacked.args...) + end + end + dest +end + +@inline function Base.copyto!(dest::AbstractVectorOfArray, + bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}) + bc = Broadcast.flatten(bc) + @inbounds for i in 1:length(dest.u) + if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i]) + copyto!(dest[:, i], unpack_voa(bc, i)) + else + unpacked = unpack_voa(bc, i) + dest[:, i] = unpacked.f(unpacked.args...) end end dest diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 4f79c3e6..cba4727a 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -125,3 +125,27 @@ z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) z .= x .+ y @test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) + +u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})]) +u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})]) +u3 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})]) + +function f(u1,u2,u3) + u3 .= u1 .+ u2 +end +f(u1,u2,u3) +@test (@allocated f(u1,u2,u3)) == 0 + +yy = [2.0 1.0; 2.0 1.0] +zz = x .+ yy +@test zz == [4.0 2.0; 4.0 2.0] + +z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})]) +z .= zz +@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})]) + +function f!(z,zz) + z .= zz +end +f!(z,zz) +@test (@allocated f!(z,zz)) == 0