From 7f35d079f67daffcfb877153c81a526743aa7cf2 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 24 Aug 2020 19:05:00 +0100 Subject: [PATCH 01/11] WIP outline nondifferentiable macro (untested) --- src/rule_definition_tools.jl | 54 ++++++++++++++++++++++++++++++----- test/rule_definition_tools.jl | 6 ++++ 2 files changed, 53 insertions(+), 7 deletions(-) create mode 100644 test/rule_definition_tools.jl diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 2675c7598..38ef8577b 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -122,13 +122,8 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) end # Remove annotations and escape names for the call - for (i, arg) in enumerate(call.args) - if Meta.isexpr(arg, :(::)) - call.args[i] = esc(first(arg.args)) - else - call.args[i] = esc(arg) - end - end + call = _without_constraints(call) + call.args = esc.(call.args) # For consistency in code that follows we make all partials tuple expressions partials = map(partials) do partial @@ -143,6 +138,15 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) return call, setup_stmts, inputs, partials end +"turn `foo(a, b::S)` into `foo(a, b)`" +function _without_constraints(call_expr) + return Expr( + :call, + (Meta.isexpr(arg, :(::)) ? first(arg.args) : arg for arg in call_expr.args)... + ) +end + + function scalar_frule_expr(f, call, setup_stmts, inputs, partials) n_outputs = length(partials) n_inputs = length(inputs) @@ -258,3 +262,39 @@ This is able to deal with fairly complex expressions for `f`: propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname) propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) + + +macro @non_differentiable(call_expr) + Meta.isexpr(:call, call_expr) || error("Invalid use of `@non_differentiable`") + + primal_call = _without_constraints(call_expr) + primal_call.args = esc.(primal_call.args) + + # TODO Move to frule helper + frule_defn = Expr( + :(=), + Expr(:call, :(ChainRulesCore.frule), :_, call_expr.args...), + # How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that + # returns `DoesNotExist()` for every position. + Expr(:tuple, primal_call, DoesNotExist()) + ) + + # TODO Move to rrule helper + primal_name = first(primal_call.args) + pullback_expr = Expr( + :(=), + Expr(:call, propagator_name(primal_name, :pullback), :_), + Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in primal_call.args[2:end])...) + ) + rrule_defn = Expr( + :(=), + Expr(:call, :(ChainRulesCore.rrule), call_expr.args...), + Expr(:tuple, primal_call, pullback_expr), + ) + + quote + $frule_defn + $rrule_defn + end +end + diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl new file mode 100644 index 000000000..796215960 --- /dev/null +++ b/test/rule_definition_tools.jl @@ -0,0 +1,6 @@ +@testset "rule_definition_tools.jl" begin + + @testset "@nondifferentiable" begin + + end +end From fc426498536f93565ed46728515fb2ad740d864f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 25 Aug 2020 20:14:42 +0100 Subject: [PATCH 02/11] fix code generated --- src/ChainRulesCore.jl | 2 +- src/rule_definition_tools.jl | 53 ++++++++++++++++++++--------------- test/rule_definition_tools.jl | 5 ++++ 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index ca2b0a3ce..b16972e51 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,7 +4,7 @@ using MuladdMacro: @muladd export on_new_rule, refresh_rules # generation tools export frule, rrule # core function -export @scalar_rule, @thunk # definition helper macros +export @non_differentiable, @scalar_rule, @thunk # definition helper macros export canonicalize, extern, unthunk # differential operations # differentials export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 38ef8577b..35e0302a0 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -117,12 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) @assert Meta.isexpr(call, :call) # Annotate all arguments in the signature as scalars - inputs = map(call.args[2:end]) do arg - esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number)) - end + inputs = _constrain_and_name.(call.args[2:end], :Number) # Remove annotations and escape names for the call - call = _without_constraints(call) + call.args = _unconstrain.(call.args) call.args = esc.(call.args) # For consistency in code that follows we make all partials tuple expressions @@ -138,14 +136,20 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) return call, setup_stmts, inputs, partials end -"turn `foo(a, b::S)` into `foo(a, b)`" -function _without_constraints(call_expr) - return Expr( - :call, - (Meta.isexpr(arg, :(::)) ? first(arg.args) : arg for arg in call_expr.args)... - ) +"turn both `a` and `a::S` into `a`" +_unconstrain(arg::Symbol) = arg +function _unconstrain(arg::Expr) + Meta.isexpr(arg, :(::), 2) && return arg.args[1] # dop constraint. + error("malformed arguments: $arg") end +"turn both `a` and `::Number` into `a::Number` into `a::Number` etc" +function _constrain_and_name(arg::Expr, default_constraint) + Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. + Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name + error("malformed arguments: $arg") +end +_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type function scalar_frule_expr(f, call, setup_stmts, inputs, partials) n_outputs = length(partials) @@ -264,32 +268,37 @@ propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) -macro @non_differentiable(call_expr) - Meta.isexpr(:call, call_expr) || error("Invalid use of `@non_differentiable`") +macro non_differentiable(call_expr) + Meta.isexpr(call_expr, :call) || error("Invalid use of `@non_differentiable`") + primal_name, orig_args = Iterators.peel(call_expr.args) - primal_call = _without_constraints(call_expr) - primal_call.args = esc.(primal_call.args) + constrained_args = _constrain_and_name.(orig_args, :Any) + unconstrained_args = _unconstrain.(constrained_args) + primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...) + + + primal_sig_parts = [:(::typeof($primal_name)), constrained_args...] # TODO Move to frule helper frule_defn = Expr( :(=), - Expr(:call, :(ChainRulesCore.frule), :_, call_expr.args...), + Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...), # How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that # returns `DoesNotExist()` for every position. - Expr(:tuple, primal_call, DoesNotExist()) + Expr(:tuple, primal_invoke, DoesNotExist()) ) # TODO Move to rrule helper - primal_name = first(primal_call.args) + pullback_expr = Expr( - :(=), - Expr(:call, propagator_name(primal_name, :pullback), :_), - Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in primal_call.args[2:end])...) + :function, + Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)), + Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in constrained_args)...) ) rrule_defn = Expr( :(=), - Expr(:call, :(ChainRulesCore.rrule), call_expr.args...), - Expr(:tuple, primal_call, pullback_expr), + Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...), + Expr(:tuple, primal_invoke, pullback_expr), ) quote diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 796215960..152a97f39 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -4,3 +4,8 @@ end end + + +Base.remove_linenums!(@macroexpand @non_differentiable println(io::IO)) + +@non_differentiable println(io::IO) \ No newline at end of file From 0d5b1d8077a6ae5ca7672ada70cffa8b70004d9b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 26 Aug 2020 14:47:35 +0100 Subject: [PATCH 03/11] Finish testing and cleaning code on non_differentiable macro --- src/rule_definition_tools.jl | 50 +++++++++++++++++------------------ test/rule_definition_tools.jl | 39 +++++++++++++++++++++++---- 2 files changed, 58 insertions(+), 31 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 35e0302a0..7d4baed21 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -117,10 +117,9 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) @assert Meta.isexpr(call, :call) # Annotate all arguments in the signature as scalars - inputs = _constrain_and_name.(call.args[2:end], :Number) - + inputs = esc.(_constrain_and_name.(call.args[2:end], :Number)) # Remove annotations and escape names for the call - call.args = _unconstrain.(call.args) + call.args[2:end] .= _unconstrain.(call.args[2:end]) call.args = esc.(call.args) # For consistency in code that follows we make all partials tuple expressions @@ -186,7 +185,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) # Δs is the input to the propagator rule # because this is a pull-back there is one per output of function - Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs] + Δs = [Symbol(:Δ, i) for i in 1:n_outputs] # 1 partial derivative per input pullback_returns = map(1:n_inputs) do input_i @@ -197,7 +196,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) # Multi-output functions have pullbacks with a tuple input that will be destructured pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...) pullback = quote - function $(propagator_name(f, :pullback))($pullback_input) + function $(esc(propagator_name(f, :pullback)))($pullback_input) return (NO_FIELDS, $(pullback_returns...)) end end @@ -223,16 +222,14 @@ function propagation_expr(Δs, ∂s, _conj = false) ∂s = map(esc, ∂s) n∂s = length(∂s) - # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression - # literals. + # Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals. ∂_mul_Δs = if _conj ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s) else ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s) end - # Avoiding the extra `+` operation, it is potentially expensive for vector - # mode AD. + # Avoiding the extra `+` operation, it is potentially expensive for vector mode AD. sumed_∂_mul_Δs = if n∂s > 1 # we use `@.` to broadcast `*` and `+` :(@. +($(∂_mul_Δs...))) @@ -273,37 +270,38 @@ macro non_differentiable(call_expr) primal_name, orig_args = Iterators.peel(call_expr.args) constrained_args = _constrain_and_name.(orig_args, :Any) + primal_sig_parts = [:(::typeof($primal_name)), constrained_args...] + unconstrained_args = _unconstrain.(constrained_args) primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...) - - - primal_sig_parts = [:(::typeof($primal_name)), constrained_args...] + + quote + $(_nondiff_frule_expr(primal_sig_parts, primal_invoke)) + $(_nondiff_rrule_expr(primal_sig_parts, primal_invoke)) + end +end - # TODO Move to frule helper - frule_defn = Expr( +function _nondiff_frule_expr(primal_sig_parts, primal_invoke) + return Expr( :(=), Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...), - # How many outputs we have it doesn't matter: `DoesNotExist()` is a iterator that - # returns `DoesNotExist()` for every position. + # Julia functions always only have 1 output, so just return a single DoesNotExist() Expr(:tuple, primal_invoke, DoesNotExist()) ) +end - # TODO Move to rrule helper - +function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) + num_primal_inputs = length(primal_sig_parts) - 1 + primal_name = first(primal_invoke.args) pullback_expr = Expr( :function, Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)), - Expr(:tuple, NO_FIELDS, (DoesNotExist() for _ in constrained_args)...) + Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...) ) rrule_defn = Expr( :(=), Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...), Expr(:tuple, primal_invoke, pullback_expr), ) - - quote - $frule_defn - $rrule_defn - end -end - + return rrule_defn +end \ No newline at end of file diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 152a97f39..2e060e0e8 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -1,11 +1,40 @@ @testset "rule_definition_tools.jl" begin - @testset "@nondifferentiable" begin + @testset "@non_differentiable" begin + @testset "nondiff_2_1" begin + nondiff_2_1(x, y) = fill(7.5, 100)[x + y] + @non_differentiable nondiff_2_1(::Any, ::Any) + @test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist()) + res, pullback = rrule(nondiff_2_1, 3, 2) + @test res == 7.5 + @test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist()) + end - end -end + @testset "nondiff_1_2" begin + nondiff_1_2(x) = (5.0, 3.0) + @non_differentiable nondiff_1_2(::Any) + @test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist()) + res, pullback = rrule(nondiff_1_2, 3.1) + @test res == (5.0, 3.0) + @test isequal( + pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)), + (NO_FIELDS, DoesNotExist()), + ) + end + + @testset "specific signature" begin + nonembed_identity(x) = x + @non_differentiable nonembed_identity(::Integer) + @test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist()) + @test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing -Base.remove_linenums!(@macroexpand @non_differentiable println(io::IO)) + res, pullback = rrule(nonembed_identity, 2) + @test res == 2 + @test pullback(1.2) == (NO_FIELDS, DoesNotExist()) + + @test rrule(nonembed_identity, 2.0) == nothing + end + end +end -@non_differentiable println(io::IO) \ No newline at end of file From 15577d311f2d86c5fd3d080296b7602539089ce2 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 28 Aug 2020 16:24:32 +0100 Subject: [PATCH 04/11] Document and finish --- src/rule_definition_tools.jl | 121 +++++++++++++++++++++++++++++----- test/rule_definition_tools.jl | 27 +++++++- 2 files changed, 129 insertions(+), 19 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 7d4baed21..fa96452e4 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -135,20 +135,6 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials) return call, setup_stmts, inputs, partials end -"turn both `a` and `a::S` into `a`" -_unconstrain(arg::Symbol) = arg -function _unconstrain(arg::Expr) - Meta.isexpr(arg, :(::), 2) && return arg.args[1] # dop constraint. - error("malformed arguments: $arg") -end - -"turn both `a` and `::Number` into `a::Number` into `a::Number` etc" -function _constrain_and_name(arg::Expr, default_constraint) - Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. - Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name - error("malformed arguments: $arg") -end -_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type function scalar_frule_expr(f, call, setup_stmts, inputs, partials) n_outputs = length(partials) @@ -264,10 +250,47 @@ propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propna propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname) propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname) +""" + @non_differentiable(signature_expression) + +A helper to make it easier to declare that a method is not not differentiable. +This is a short-hand for defining a [`frule`](@ref) and an [`rrule`](@ref)+ pullback that +returns [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial +itself which is `NO_FIELDS`) -macro non_differentiable(call_expr) - Meta.isexpr(call_expr, :call) || error("Invalid use of `@non_differentiable`") - primal_name, orig_args = Iterators.peel(call_expr.args) +The usage is to put the macro before a function signature. +Keyword arguments should not be included. + +```jldoctest +julia> @non_differentiable Base.:(==)(a, b) + +julia> _, pullback = rrule(==, 2.0, 3.0); + +julia> pullback(1.0) +(Zero(), DoesNotExist(), DoesNotExist()) +``` + +You can place type-constraints in the signature: +```jldoctest +julia> @non_differentiable Base.length(xs::Union{Number, Array}) + +julia> frule((Zero(), 1), length, [2.0, 3.0]) +(2, DoesNotExist()) +``` + +!!! warning + This helper macro covers only the simple common cases. + It does not support Varargs, or `where`-clauses. + For these you can declare the `rrule` and `frule` directly + +""" +macro non_differentiable(sig_expr) + Meta.isexpr(sig_expr, :call) || error("Invalid use of `@non_differentiable`") + for arg in sig_expr.args + _isvararg(arg) && error("@non_differentiable does not support Varargs like: $arg") + end + + primal_name, orig_args = Iterators.peel(sig_expr.args) constrained_args = _constrain_and_name.(orig_args, :Any) primal_sig_parts = [:(::typeof($primal_name)), constrained_args...] @@ -304,4 +327,66 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) Expr(:tuple, primal_invoke, pullback_expr), ) return rrule_defn -end \ No newline at end of file +end + + +########### +# Helpers + +""" + _isvararg(expr) + +returns true if the expression could represent a vararg + +```jldoctest +julia> ChainRulesCore._isvararg(:(x...)) +true + +julia> ChainRulesCore._isvararg(:(x::Int...)) +true + +julia> ChainRulesCore._isvararg(:(::Int...)) +true + +julia> ChainRulesCore._isvararg(:(x::Vararg)) +true + +julia> ChainRulesCore._isvararg(:(x::Vararg{Int})) +true + +julia> ChainRulesCore._isvararg(:(::Vararg)) +true + +julia> ChainRulesCore._isvararg(:(::Vararg{Int})) +true + +julia> ChainRulesCore._isvararg(:(x)) +false +```` +""" +_isvararg(expr) = false +function _isvararg(expr::Expr) + Meta.isexpr(expr, :...) && return true + if Meta.isexpr(expr, :(::)) + constraint = last(expr.args) + constraint == :Vararg && return true + Meta.isexpr(constraint, :curly) && first(constraint.args) == :Vararg && return true + end + return false +end + + +"turn both `a` and `a::S` into `a`" +_unconstrain(arg::Symbol) = arg +function _unconstrain(arg::Expr) + Meta.isexpr(arg, :(::), 2) && return arg.args[1] # dop constraint. + error("malformed arguments: $arg") +end + +"turn both `a` and `::Number` into `a::Number` into `a::Number` etc" +function _constrain_and_name(arg::Expr, default_constraint) + Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. + Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name + error("malformed arguments: $arg") +end +_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type \ No newline at end of file diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 2e060e0e8..3e444dbd7 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -1,5 +1,4 @@ @testset "rule_definition_tools.jl" begin - @testset "@non_differentiable" begin @testset "nondiff_2_1" begin nondiff_2_1(x, y) = fill(7.5, 100)[x + y] @@ -35,6 +34,32 @@ @test rrule(nonembed_identity, 2.0) == nothing end + + @testset "Pointy UnionAll constraints" begin + pointy_identity(x) = x + @non_differentiable pointy_identity(::Vector{<:AbstractString}) + + @test frule((Zero(), 1.2), pointy_identity, ["2"]) == (["2"], DoesNotExist()) + @test frule((Zero(), 1.2), pointy_identity, 2.0) == nothing + + res, pullback = rrule(pointy_identity, ["2"]) + @test res == ["2"] + @test pullback(1.2) == (NO_FIELDS, DoesNotExist()) + + @test rrule(pointy_identity, 2.0) == nothing + end + + @testset "Not supported (Yet)" begin + # Varargs are not supported + @test_throws Exception @macroexpand(@non_differentiable vararg1(xs...))| + @test_throws Exception @macroexpand(@non_differentiable vararg1(xs::Vararg)) + + # Where clauses are not supported. + @test_throws Exception @macroexpand( + @non_differentiable where_identity(::Vector{T}) where T<:AbstractString + ) + end + end end From 88459bd5def2ade1cc42df48a36d65cdd6263b3b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 28 Aug 2020 18:41:48 +0100 Subject: [PATCH 05/11] Apply suggestions from code review Co-authored-by: willtebbutt --- src/rule_definition_tools.jl | 15 +++++++-------- test/rule_definition_tools.jl | 3 +-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index fa96452e4..ab05b3a41 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -254,11 +254,10 @@ propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.valu @non_differentiable(signature_expression) A helper to make it easier to declare that a method is not not differentiable. -This is a short-hand for defining a [`frule`](@ref) and an [`rrule`](@ref)+ pullback that -returns [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial +This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that +return [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial itself which is `NO_FIELDS`) -The usage is to put the macro before a function signature. Keyword arguments should not be included. ```jldoctest @@ -309,7 +308,7 @@ function _nondiff_frule_expr(primal_sig_parts, primal_invoke) :(=), Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...), # Julia functions always only have 1 output, so just return a single DoesNotExist() - Expr(:tuple, primal_invoke, DoesNotExist()) + Expr(:tuple, primal_invoke, DoesNotExist()), ) end @@ -379,14 +378,14 @@ end "turn both `a` and `a::S` into `a`" _unconstrain(arg::Symbol) = arg function _unconstrain(arg::Expr) - Meta.isexpr(arg, :(::), 2) && return arg.args[1] # dop constraint. + Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint. error("malformed arguments: $arg") end -"turn both `a` and `::Number` into `a::Number` into `a::Number` etc" -function _constrain_and_name(arg::Expr, default_constraint) +"turn both `a` and `::constraint` into `a::constraint` etc" +function _constrain_and_name(arg::Expr, _) Meta.isexpr(arg, :(::), 2) && return arg # it is already fine. Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name error("malformed arguments: $arg") end -_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type \ No newline at end of file +_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 3e444dbd7..433c9c742 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -21,7 +21,7 @@ ) end - @testset "specific signature" begin + @testset "constrained signature" begin nonembed_identity(x) = x @non_differentiable nonembed_identity(::Integer) @@ -62,4 +62,3 @@ end end - From 083518762a6109bed1ef0cf426c01cfccfcb5316 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 31 Aug 2020 18:44:24 +0100 Subject: [PATCH 06/11] Update test/rule_definition_tools.jl --- test/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 433c9c742..eec3bca53 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -1,6 +1,6 @@ @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin - @testset "nondiff_2_1" begin + @testset "two input one output function" begin nondiff_2_1(x, y) = fill(7.5, 100)[x + y] @non_differentiable nondiff_2_1(::Any, ::Any) @test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist()) From 125250918ab95fb1ee21099e85e06329d5950c66 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 31 Aug 2020 18:44:49 +0100 Subject: [PATCH 07/11] Update test/rule_definition_tools.jl --- test/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index eec3bca53..909ccdd78 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -9,7 +9,7 @@ @test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist()) end - @testset "nondiff_1_2" begin + @testset "one input, 2-tuple output function" begin nondiff_1_2(x) = (5.0, 3.0) @non_differentiable nondiff_1_2(::Any) @test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist()) From 41886592b6ced326ba2f5ae200b3546e99611269 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 1 Sep 2020 14:12:35 +0100 Subject: [PATCH 08/11] Include the new tests in runtests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index 8f995b354..306f846d2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,7 @@ using Test include("ruleset_loading.jl") include("rules.jl") + include("rule_definition_tools.jl") @testset "demos" begin From bb2d4638993750bf376ab889480d1daf56d46067 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 1 Sep 2020 14:12:49 +0100 Subject: [PATCH 09/11] Improve testing of not supported cases --- test/rule_definition_tools.jl | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 909ccdd78..601200f0e 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -1,3 +1,27 @@ +""" +Along same lines as `@test_throws` but to test if a macro throw an exception when it is +expanded. +""" +macro test_macro_throws(err_expr, expr) + quote + err = nothing + try + @macroexpand($(esc(expr))) + catch load_err + # all errors thrown at macro expansion time are LoadErrors, we need to unwrap + @assert load_err isa LoadError + err = load_err.error + end + # Reuse `@test_throws` logic + if err!==nothing + @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) + else + @test_throws $(esc(err_expr)) $(Meta.quot(expr)) + end + end +end + + @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin @testset "two input one output function" begin @@ -51,14 +75,14 @@ @testset "Not supported (Yet)" begin # Varargs are not supported - @test_throws Exception @macroexpand(@non_differentiable vararg1(xs...))| - @test_throws Exception @macroexpand(@non_differentiable vararg1(xs::Vararg)) + @test_macro_throws ErrorException @non_differentiable vararg1(xs...) + @test_macro_throws ErrorException @non_differentiable vararg1(xs::Vararg) # Where clauses are not supported. - @test_throws Exception @macroexpand( - @non_differentiable where_identity(::Vector{T}) where T<:AbstractString + @test_macro_throws( + ErrorException, + (@non_differentiable where_identity(::Vector{T}) where T<:AbstractString) ) end - end end From 00b959de87637c8ab135209ccab558605ef648c3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 1 Sep 2020 14:23:49 +0100 Subject: [PATCH 10/11] strip whitespace --- src/rule_definition_tools.jl | 6 +++--- test/rule_definition_tools.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index ab05b3a41..b686ffc0d 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -254,7 +254,7 @@ propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.valu @non_differentiable(signature_expression) A helper to make it easier to declare that a method is not not differentiable. -This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that +This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that return [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial itself which is `NO_FIELDS`) @@ -296,7 +296,7 @@ macro non_differentiable(sig_expr) unconstrained_args = _unconstrain.(constrained_args) primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...) - + quote $(_nondiff_frule_expr(primal_sig_parts, primal_invoke)) $(_nondiff_rrule_expr(primal_sig_parts, primal_invoke)) @@ -330,7 +330,7 @@ end ########### -# Helpers +# Helpers """ _isvararg(expr) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 601200f0e..7a05cbfec 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -1,5 +1,5 @@ """ -Along same lines as `@test_throws` but to test if a macro throw an exception when it is +Along same lines as `@test_throws` but to test if a macro throw an exception when it is expanded. """ macro test_macro_throws(err_expr, expr) @@ -17,7 +17,7 @@ macro test_macro_throws(err_expr, expr) @test_throws $(esc(err_expr)) ($(Meta.quot(expr)); throw(err)) else @test_throws $(esc(err_expr)) $(Meta.quot(expr)) - end + end end end @@ -72,7 +72,7 @@ end @test rrule(pointy_identity, 2.0) == nothing end - + @testset "Not supported (Yet)" begin # Varargs are not supported @test_macro_throws ErrorException @non_differentiable vararg1(xs...) From 9ab6955cf0f02459a8ba8ef2fb360cd08a88f4c3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 1 Sep 2020 14:24:05 +0100 Subject: [PATCH 11/11] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0d786877d..e586ea3ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.6" +version = "0.9.7" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"