Skip to content
This repository has been archived by the owner on Apr 18, 2023. It is now read-only.

handle AbstractZeros #203

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

handle AbstractZeros #203

wants to merge 7 commits into from

Conversation

mzgubic
Copy link

@mzgubic mzgubic commented Jul 15, 2021

Is this all that needs to be done?

helps with JuliaDiff/ChainRules.jl#442

src/core.jl Outdated Show resolved Hide resolved
src/core.jl Show resolved Hide resolved
test/core.jl Outdated
return y, myfunc_pullback
end

Nabla.generate_overload(Tuple{typeof(myfunc), Any, Any})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add a generic manually triggered refresh?
Does
ChainRulesOverloadGeneration.refresh_rules() work?
I think it should, it does have tests. (might not play so well with precompilation though, that i haven't tested)
https://juliadiff.org/ChainRulesOverloadGeneration.jl/stable/api.html#ChainRulesOverloadGeneration.refresh_rules-Tuple{}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it, but that doesn't seem to work. I will investigate more closely tomorrow, we are not in a rush for this anyway

@mzgubic mzgubic changed the title propagate AbstractZeros gracefully handle AbstractZeros Jul 15, 2021
test/core.jl Show resolved Hide resolved
@@ -145,7 +145,11 @@ function propagate(y::Branch, rvs_tape::Tape)
kwargs = getfield(y, :kwargs)
xs = map(unbox, args)
xids = map(pos, args)

ȳ isa AbstractZero && return nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should think about introducing a function barrier here.
But I think profiling around that is out of scope for this.
but maybe we can move this to right after first is unpacked?


r_ = getindex(a_, myfunc(i_, j_))
∇r_ = ∇(r_)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't actually have any tests.
It might also be good to have ones that have many references to the variable, some of which are NoTangent and others are not

function foo(x)
    k = 1.0:1000.0
    i = round(Int, 10.0 * x)
    return k[i] * x + 2.0 * x
end
@test ∇(foo, 1.6) == 2.0

@@ -7,6 +7,9 @@
These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/).
To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`.
See the [ChainRules project's documentation for more information](https://www.juliadiff.org/ChainRulesCore.jl/stable/).

After you define an `rrule`, e.g. for `myfunc(i, j)`, you also need to refresh the list of rule, `ChainRulesOverloadGeneration.refresh_rules()`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good improvement

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once tests are sorted merge when happy

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants