-
Notifications
You must be signed in to change notification settings - Fork 5
handle AbstractZeros #203
base: master
Are you sure you want to change the base?
handle AbstractZeros #203
Conversation
test/core.jl
Outdated
return y, myfunc_pullback | ||
end | ||
|
||
Nabla.generate_overload(Tuple{typeof(myfunc), Any, Any}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add a generic manually triggered refresh?
Does
ChainRulesOverloadGeneration.refresh_rules()
work?
I think it should, it does have tests. (might not play so well with precompilation though, that i haven't tested)
https://juliadiff.org/ChainRulesOverloadGeneration.jl/stable/api.html#ChainRulesOverloadGeneration.refresh_rules-Tuple{}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried it, but that doesn't seem to work. I will investigate more closely tomorrow, we are not in a rush for this anyway
@@ -145,7 +145,11 @@ function propagate(y::Branch, rvs_tape::Tape) | |||
kwargs = getfield(y, :kwargs) | |||
xs = map(unbox, args) | |||
xids = map(pos, args) | |||
|
|||
ȳ isa AbstractZero && return nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should think about introducing a function barrier here.
But I think profiling around that is out of scope for this.
but maybe we can move this to right after ȳ
first is unpacked?
|
||
r_ = getindex(a_, myfunc(i_, j_)) | ||
∇r_ = ∇(r_) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't actually have any tests.
It might also be good to have ones that have many references to the variable, some of which are NoTangent and others are not
function foo(x)
k = 1.0:1000.0
i = round(Int, 10.0 * x)
return k[i] * x + 2.0 * x
end
@test ∇(foo, 1.6) == 2.0
@@ -7,6 +7,9 @@ | |||
These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/). | |||
To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`. | |||
See the [ChainRules project's documentation for more information](https://www.juliadiff.org/ChainRulesCore.jl/stable/). | |||
|
|||
After you define an `rrule`, e.g. for `myfunc(i, j)`, you also need to refresh the list of rule, `ChainRulesOverloadGeneration.refresh_rules()`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good improvement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
once tests are sorted merge when happy
Is this all that needs to be done?
helps with JuliaDiff/ChainRules.jl#442