Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jun 21, 2024
1 parent e7bc827 commit 00cc8c3
Showing 1 changed file with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ using GPUArraysCore: AbstractGPUArray
using KernelAbstractions

function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...)
# kab = get_backend(dx)

# if KA.supports_atomics(kab)
# gids = GPUArrays.to_indices(dx, inds)
# idims = map(length, gids)
# Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids)
# scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy))
# else
kab = get_backend(dx)

if KA.supports_atomics(kab)
gids = GPUArrays.to_indices(dx, inds)
idims = map(length, gids)
Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids)
scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy))
else
dx_cpu = Adapt.adapt(Array, dx)
view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy)
copyto!(dx, dx_cpu)
# end
end
return dx
end

Expand Down

0 comments on commit 00cc8c3

Please sign in to comment.