Skip to content

Commit

Permalink
fix: GPU tests, CuArray conversion, autodiff
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 28, 2023
1 parent 87ef7d5 commit c1249b3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,17 @@ end
VectorOfArray(u),
y -> begin
y isa Ref && (y = VectorOfArray(y[].u))
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
for i in 1:size(y.u)[end]]),)
(VectorOfArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]

Check warning on line 98 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L98

Added line #L98 was not covered by tests
for i in 1:size(y)[end]]),)
end
end

@adjoint function DiffEqArray(u, t)
DiffEqArray(u, t),
y -> begin
y isa Ref && (y = VectorOfArray(y[].u))
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y.u) - 1)..., i]
for i in 1:size(y.u)[end]],
(DiffEqArray([y[ntuple(x -> Colon(), ndims(y) - 1)..., i]

Check warning on line 107 in ext/RecursiveArrayToolsZygoteExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsZygoteExt.jl#L107

Added line #L107 was not covered by tests
for i in 1:size(y)[end]],
t), nothing)
end
end
Expand Down
1 change: 1 addition & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end

import GPUArraysCore
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray) = T(Array(VA))

Check warning on line 31 in src/RecursiveArrayTools.jl

View check run for this annotation

Codecov / codecov/patch

src/RecursiveArrayTools.jl#L31

Added line #L31 was not covered by tests

import Requires
@static if !isdefined(Base, :get_extension)
Expand Down

0 comments on commit c1249b3

Please sign in to comment.