-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nondifferentiable macro #207
Conversation
I'm not really clear what this means. Could you expand? |
Ah, typo. That should have read This relies on I mean julia functions only have 1 output. Anyway, if your primal type is an iterable then your differential type should be also. |
supporting "convert `x...` into `x::Vararg{Any}` and `(x::Int...)` into `x::Vararg{Int}`"
_constain_vararg(x::Symbol) = x
function _constain_vararg(expr::Expr, default_constraint)
Meta.isexpr(expr, :..., 1) || return expr
ret = _constrain_and_name(expr.args[1], default_constraint)
ret.args[2] = :(Vararg{$(ret.args[2])})
return ret
end We also need ot unconstrain it back into |
I'm totally fine with us kicking that can down the round |
This should be good to go |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking really good. Just style / documentation and testing related things that need addressing.
@@ -0,0 +1,65 @@ | |||
@testset "rule_definition_tools.jl" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like there's quite a lot of repeated code here. Did you consider writing a function to test that something has been successfully "non_differentiable"d?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Repeating code mades it read straight forward, and each is different enough that abstracting the tests would be make them harder to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me that that is true. Something like the following ought to do the majority of the work:
function test_nondifferentiable(foo, args, dargs, dy)
@test frule(dargs, foo, args...) == foo(args...)
y, pb = rrule(foo, args...)
@test y == foo(args...)
@test pb(dy) == (Zero(), map(_ -> DoesNotExist(), args)...)
end
To my mind this is more readable.
I'm not going to object to merging this over this though -- I'm happy to stick with what you've done if you feel strongly that it's more readable.
Also, we need a patch bump :) |
Co-authored-by: willtebbutt <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this.
This will appear on the API page (under Rule Definition Tools) -- Do we need/want to add it to anywhere else in the docs? (Also for some reason i don't see preview docs for this PR) Otherwise, all LGTM! |
I am happy enough with it being their for now. |
I will merge this once CI passes |
Closes #150
This relies on
DoesNotExist()
iterating to give itself, since we don't knopw how many returns a forward mode rule needs.Which i thought it did, but maybe we removed that at one point?
Its also useful for use with Composites.