Skip to content

Commit

Permalink
Fix StructArray broadcast in VectorOfArray
Browse files Browse the repository at this point in the history
Fixes #410

This specializes so that if `u.u` is not a vector, it will convert the broadcast to fix that. I couldn't find a nice generic way to use `map` so the fallback is to build the vector and convert, which seems to not be a big performance issue. For StructArrays, `convert(typeof(x), Vector(x))` fails, and so this case is specialized.
  • Loading branch information
ChrisRackauckas committed Nov 20, 2024
1 parent 9c8387c commit 83c990a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -33,6 +34,7 @@ RecursiveArrayToolsMeasurementsExt = "Measurements"
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
RecursiveArrayToolsSparseArraysExt = ["SparseArrays"]
RecursiveArrayToolsStructArraysExt = "StructArrays"
RecursiveArrayToolsTrackerExt = "Tracker"
RecursiveArrayToolsZygoteExt = "Zygote"

Expand Down
6 changes: 6 additions & 0 deletions ext/RecursiveArrayToolsStructArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
module RecursiveArrayToolsStructArraysExt

import RecursiveArrayTools, StructArrays
RecursiveArrayTools.rewrap(::StructArrays.StructArray, u) = StructArrays.StructArray(u)

end
17 changes: 11 additions & 6 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,28 +849,33 @@ end

@inline function Base.copy(bc::Broadcast.Broadcasted{<:VectorOfArrayStyle})
bc = Broadcast.flatten(bc)

parent = find_VoA_parent(bc.args)

if parent isa AbstractVector
u = if parent isa AbstractVector
# this is the default behavior in v3.15.0
N = narrays(bc)
return VectorOfArray(map(1:N) do i
map(1:N) do i
copy(unpack_voa(bc, i))
end)
end
else # if parent isa AbstractArray
return VectorOfArray(map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
map(enumerate(Iterators.product(axes(parent)...))) do (i, _)
copy(unpack_voa(bc, i))
end)
end
end
VectorOfArray(rewrap(parent, u))
end

rewrap(::Array,u) = u
rewrap(parent, u) = convert(typeof(parent), u)

for (type, N_expr) in [
(Broadcast.Broadcasted{<:VectorOfArrayStyle}, :(narrays(bc))),
(Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle}, :(length(dest.u)))
]
@eval @inline function Base.copyto!(dest::AbstractVectorOfArray,
bc::$type)
@show typeof(dest)
error()
bc = Broadcast.flatten(bc)
N = $N_expr
@inbounds for i in 1:N
Expand Down
7 changes: 7 additions & 0 deletions test/copy_static_array_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,10 @@ a_voa = VectorOfArray(a)
a_voa .= 1.0
@test a_voa[1] == SVector(1.0, 1.0)
@test a_voa[2] == SVector(1.0, 1.0)

#Broadcast Copy of StructArray
x = StructArray{SVector{2, Float64}}((randn(2), randn(2)))
vx = VectorOfArray(x)
vx2 = copy(vx) .+ 1
ans = vx .+ vx2
@test ans.u isa StructArray

0 comments on commit 83c990a

Please sign in to comment.