Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nondifferentiable macro #207

Merged
merged 11 commits into from
Sep 2, 2020
Merged

nondifferentiable macro #207

merged 11 commits into from
Sep 2, 2020

Conversation

oxinabox
Copy link
Member

@oxinabox oxinabox commented Aug 24, 2020

Closes #150

This relies on DoesNotExist() iterating to give itself, since we don't knopw how many returns a forward mode rule needs.
Which i thought it did, but maybe we removed that at one point?
Its also useful for use with Composites.

@willtebbutt
Copy link
Member

This relies on NonDifferentiable iterating to give itself, since we don't knopw how many returns a forward mode rule needs. Which i thought it did, but maybe we removed that at one point?

I'm not really clear what this means. Could you expand?

@oxinabox
Copy link
Member Author

This relies on NonDifferentiable iterating to give itself

Ah, typo. That should have read This relies on DoesNotExist() iterating to give itself.

I mean julia functions only have 1 output.
Sometimes that output is a Tuple (or other iterable).
The differential type for a Tuple is a Composte{Tuple...}, though I suspect some time for frules we have been a bit lax about enforcing that and sometimes just a matching Tuple is used.
We should check what we are doing for the sincos function.

Anyway, if your primal type is an iterable then your differential type should be also.
It should be fine for use to have a nondifferentiable version of say sincos
that has frule(_, ::typeof(nd_sincos), x) = ((nd_sin(x), nd_cos(x)), DoesNotExist())

@oxinabox
Copy link
Member Author

supporting varargs is hard.
Here is some code i have for it.
But i think for now i will just make it error.

"convert `x...` into `x::Vararg{Any}` and `(x::Int...)` into `x::Vararg{Int}`"
_constain_vararg(x::Symbol) = x
function _constain_vararg(expr::Expr, default_constraint)
    Meta.isexpr(expr, :..., 1) || return expr
    ret = _constrain_and_name(expr.args[1], default_constraint)
    ret.args[2] = :(Vararg{$(ret.args[2])})
    return ret
end

We also need ot unconstrain it back into x... for invoking the primal

@nickrobinson251
Copy link
Contributor

i think for now i will just make it error.

I'm totally fine with us kicking that can down the round

@oxinabox
Copy link
Member Author

This should be good to go

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking really good. Just style / documentation and testing related things that need addressing.

src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
src/rule_definition_tools.jl Outdated Show resolved Hide resolved
@@ -0,0 +1,65 @@
@testset "rule_definition_tools.jl" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's quite a lot of repeated code here. Did you consider writing a function to test that something has been successfully "non_differentiable"d?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeating code mades it read straight forward, and each is different enough that abstracting the tests would be make them harder to read.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that that is true. Something like the following ought to do the majority of the work:

function test_nondifferentiable(foo, args, dargs, dy)
    @test frule(dargs, foo, args...) == foo(args...)

    y, pb = rrule(foo, args...)
    @test y == foo(args...)
    @test pb(dy) == (Zero(), map(_ -> DoesNotExist(), args)...)
end

To my mind this is more readable.

I'm not going to object to merging this over this though -- I'm happy to stick with what you've done if you feel strongly that it's more readable.

test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
test/rule_definition_tools.jl Outdated Show resolved Hide resolved
@willtebbutt
Copy link
Member

Also, we need a patch bump :)

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with this.

@nickrobinson251
Copy link
Contributor

This will appear on the API page (under Rule Definition Tools) -- Do we need/want to add it to anywhere else in the docs?

(Also for some reason i don't see preview docs for this PR)

Otherwise, all LGTM!

@oxinabox
Copy link
Member Author

oxinabox commented Sep 1, 2020

This will appear on the API page (under Rule Definition Tools) -- Do we need/want to add it to anywhere else in the docs?

I am happy enough with it being their for now.

@oxinabox
Copy link
Member Author

oxinabox commented Sep 1, 2020

I will merge this once CI passes

@oxinabox oxinabox merged commit 95e094e into master Sep 2, 2020
@oxinabox oxinabox deleted the ox/nondiff branch September 2, 2020 10:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

@non_differentiable
4 participants