You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am incorporating some simple non-standard elements in a larger DNN.
To achieve this I had to manually define some reverse rules through ChainRulesCore syntax ChainRulesCore.rrule(::typeof(f),args...)
Which works great on the cpu, but somehow is not portable to the gpu.
I narrowed it to a MWE that fails the GPU compilation of the pullback function
using Zygote
using CUDA
f(A) = map(A) do a
a
end
A = rand(Float32,128,128)
gA = gpu(A)
f(A) # works
f(gA) # works
Δ,pb = pullback(f,A) #works
Δ,pb = pullback(f,gA) #fails
Any insights?
The text was updated successfully, but these errors were encountered:
I am incorporating some simple non-standard elements in a larger DNN.
To achieve this I had to manually define some reverse rules through ChainRulesCore syntax
ChainRulesCore.rrule(::typeof(f),args...)
Which works great on the cpu, but somehow is not portable to the gpu.
I narrowed it to a MWE that fails the GPU compilation of the pullback function
Any insights?
The text was updated successfully, but these errors were encountered: