You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
using ChainRulesCore, Zygote
struct Foo <:AbstractArray{Int,0}end
running
myfunc(x::Foo) =throw(ErrorException("This should not be called"))
ChainRulesCore.rrule(::typeof(myfunc), x::Foo) =99, _ -> (NO_FIELDS, 42)
Zygote.gradient(myfunc, Foo())
yields (42,), as expected. myfunc is bypassed and never called. However, the same with Base.getindex
Base.getindex(x::Foo) =throw(ErrorException("This should not be called"))
ChainRulesCore.rrule(::typeof(Base.getindex), x::Foo) =99, _ -> (NO_FIELDS, 42)
Zygote.gradient(Base.getindex, Foo())
throws
ERROR: This should not be called
so Zygote seems to ignore that rrule and call Base.getindex.
This only happens with
struct Foo <:AbstractArray{Int,0}end
When using
struct Foo end
Zygote does not call Base.getindex.
Also, using Foo <: AbstractArray{Int,0}, but ZygoteRules.@adjoint instead of ChainRulesCore.rrule
using Zygote, ZygoteRules
struct Foo <:AbstractArray{Int,0}end
Base.getindex(x::Foo) =throw(ErrorException("This should not be called"))
ZygoteRules.@adjoint Base.getindex(x::Foo) =99, _ -> (42,)
Zygote.gradient(Base.getindex, Foo())
works as expected.
The text was updated successfully, but these errors were encountered:
oschulz
changed the title
ChainRulesCore.rrule for Base.getindex seems to be ignored with Zygote
Zygote ignores ChainRulesCore.rrule for Base.getindex on AbstractArray types
Oct 21, 2020
With
running
yields
(42,)
, as expected.myfunc
is bypassed and never called. However, the same withBase.getindex
throws
so Zygote seems to ignore that
rrule
and callBase.getindex
.This only happens with
When using
Zygote does not call
Base.getindex
.Also, using
Foo <: AbstractArray{Int,0}
, butZygoteRules.@adjoint
instead ofChainRulesCore.rrule
works as expected.
The text was updated successfully, but these errors were encountered: