Skip to content

Commit

Permalink
Merge pull request #308 from SciML/recursivecopy_simplification
Browse files Browse the repository at this point in the history
Simplify and better test recursivecopy!
  • Loading branch information
ChrisRackauckas authored Dec 21, 2023
2 parents 4dc8305 + 777311c commit dd5c756
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecursiveArrayTools"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
authors = ["Chris Rackauckas <[email protected]>"]
version = "3.2.1"
version = "3.2.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
58 changes: 33 additions & 25 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,39 +42,47 @@ like `copy!` on arrays of scalars.
"""
function recursivecopy! end

for type in [AbstractArray, AbstractVectorOfArray]
@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: StaticArraysCore.StaticArray,
T2 <: StaticArraysCore.StaticArray,
N}
@inbounds for i in eachindex(a)
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
b[i] = copy(a[i])
end
function recursivecopy!(b::AbstractArray{T, N}, a::AbstractArray{T2, N}) where {T <: StaticArraysCore.StaticArray,
T2 <: StaticArraysCore.StaticArray,
N}
@inbounds for i in eachindex(a)
# TODO: Check for `setindex!`` and use `copy!(b[i],a[i])` or `b[i] = a[i]`, see #19
b[i] = copy(a[i])
end
end

@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: Enum, T2 <: Enum, N}
copyto!(b, a)
end
function recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {T <: Enum, T2 <: Enum, N}
copyto!(b, a)
end

function recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {T <: Number, T2 <: Number, N}
copyto!(b, a)
end

@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: Number, T2 <: Number, N}
function recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {T <: Union{AbstractArray, AbstractVectorOfArray},
T2 <: Union{AbstractArray, AbstractVectorOfArray}, N}
if ArrayInterface.ismutable(T)
@inbounds for i in eachindex(b, a)
recursivecopy!(b[i], a[i])
end
else
copyto!(b, a)
end
return b
end

@eval function recursivecopy!(b::$type{T, N},
a::$type{T2, N}) where {T <: Union{AbstractArray, AbstractVectorOfArray},
T2 <: Union{AbstractArray, AbstractVectorOfArray}, N}
if ArrayInterface.ismutable(T)
@inbounds for i in eachindex(b, a)
recursivecopy!(b[i], a[i])
end
else
copyto!(b, a)
function recursivecopy!(b::AbstractVectorOfArray, a::AbstractVectorOfArray)
if ArrayInterface.ismutable(eltype(b.u))
@inbounds for i in eachindex(b.u, a.u)
recursivecopy!(b.u[i], a.u[i])
end
return b
else
copyto!(b.u, a.u)
end
return b
end

"""
Expand Down
8 changes: 8 additions & 0 deletions test/utils_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,12 @@ end
@test u1.u[2] == [2.0,2.0]
@test u1.u[1] isa MVector
@test u1.u[2] isa MVector

u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})])
u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})])
recursivecopy!(u1,u2)
@test u1.u[1] == [4.0,4.0]
@test u1.u[2] == [2.0,2.0]
@test u1.u[1] isa SVector
@test u1.u[2] isa SVector
end

0 comments on commit dd5c756

Please sign in to comment.