Skip to content

Commit

Permalink
Merge pull request #217 from JuliaDiff/ox/nondiff_kw
Browse files Browse the repository at this point in the history
Add kwarg support to at-nondifferentiable
  • Loading branch information
oxinabox authored Sep 11, 2020
2 parents dc7e159 + 185e518 commit 39f1caf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.8"
version = "0.9.9"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Expand Down
28 changes: 14 additions & 14 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ macro non_differentiable(sig_expr)
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]

unconstrained_args = _unconstrain.(constrained_args)
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)

primal_invoke = :($(primal_name)($(unconstrained_args...); kwargs...))

quote
$(_nondiff_frule_expr(primal_sig_parts, primal_invoke))
Expand All @@ -304,28 +305,27 @@ macro non_differentiable(sig_expr)
end

function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
return Expr(
:(=),
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
# Julia functions always only have 1 output, so just return a single DoesNotExist()
Expr(:tuple, primal_invoke, DoesNotExist()),
)
return esc(:(
function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...)
# Julia functions always only have 1 output, so return a single DoesNotExist()
return ($primal_invoke, DoesNotExist())
end
))
end

function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
num_primal_inputs = length(primal_sig_parts) - 1
primal_name = first(primal_invoke.args)
pullback_expr = Expr(
:function,
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
Expr(:call, propagator_name(primal_name, :pullback), :_),
Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...)
)
rrule_defn = Expr(
:(=),
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
Expr(:tuple, primal_invoke, pullback_expr),
)
return rrule_defn
return esc(:(
function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...)
return ($primal_invoke, $pullback_expr)
end
))
end


Expand Down
25 changes: 25 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,31 @@ end
@test rrule(pointy_identity, 2.0) == nothing
end

@testset "kwargs" begin
kw_demo(x; kw=2.0) = x + kw
@non_differentiable kw_demo(::Any)

@testset "not setting kw" begin
@assert kw_demo(1.5) == 3.5

res, pullback = rrule(kw_demo, 1.5)
@test res == 3.5
@test pullback(4.1) == (NO_FIELDS, DoesNotExist())

@test frule((Zero(), 11.1), kw_demo, 1.5) == (3.5, DoesNotExist())
end

@testset "setting kw" begin
@assert kw_demo(1.5; kw=3.0) == 4.5

res, pullback = rrule(kw_demo, 1.5; kw=3.0)
@test res == 4.5
@test pullback(1.1) == (NO_FIELDS, DoesNotExist())

@test frule((Zero(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, DoesNotExist())
end
end

@testset "Not supported (Yet)" begin
# Varargs are not supported
@test_macro_throws ErrorException @non_differentiable vararg1(xs...)
Expand Down

2 comments on commit 39f1caf

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/21249

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.9 -m "<description of version>" 39f1caf3e55c5dc5a5c85205aabcaca657bb2b0a
git push origin v0.9.9

Please sign in to comment.