diff --git a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl index 153de32e0..a55837dba 100644 --- a/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl +++ b/ext/ChainRulesKernelAbstractionsExt/ChainRulesKernelAbstractionsExt.jl @@ -10,18 +10,18 @@ using GPUArraysCore: AbstractGPUArray using KernelAbstractions function ChainRules.∇getindex!(dx::AbstractGPUArray, dy, inds...) - # kab = get_backend(dx) - - # if KA.supports_atomics(kab) - # gids = GPUArrays.to_indices(dx, inds) - # idims = map(length, gids) - # Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids) - # scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy)) - # else + kab = get_backend(dx) + + if KA.supports_atomics(kab) + gids = GPUArrays.to_indices(dx, inds) + idims = map(length, gids) + Is = map(Adapt.adapt(GPUArrays.ToGPU(dy)), gids) + scatter!(kab)(+, dx, dy, idims, Is...; ndrange=length(dy)) + else dx_cpu = Adapt.adapt(Array, dx) view(dx_cpu, Adapt.adapt(Array, inds)...) .+= Adapt.adapt(Array, dy) copyto!(dx, dx_cpu) - # end + end return dx end