Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nondifferentiable macro #207

Merged
merged 11 commits into from
Sep 2, 2020
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MuladdMacro: @muladd

export on_new_rule, refresh_rules # generation tools
export frule, rrule # core function
export @scalar_rule, @thunk # definition helper macros
export @non_differentiable, @scalar_rule, @thunk # definition helper macros
export canonicalize, extern, unthunk # differential operations
# differentials
export Composite, DoesNotExist, InplaceableThunk, One, Thunk, Zero, AbstractZero, AbstractThunk
Expand Down
165 changes: 148 additions & 17 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,10 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
@assert Meta.isexpr(call, :call)

# Annotate all arguments in the signature as scalars
inputs = map(call.args[2:end]) do arg
esc(Meta.isexpr(arg, :(::)) ? arg : Expr(:(::), arg, :Number))
end

inputs = esc.(_constrain_and_name.(call.args[2:end], :Number))
# Remove annotations and escape names for the call
for (i, arg) in enumerate(call.args)
if Meta.isexpr(arg, :(::))
call.args[i] = esc(first(arg.args))
else
call.args[i] = esc(arg)
end
end
call.args[2:end] .= _unconstrain.(call.args[2:end])
call.args = esc.(call.args)

# For consistency in code that follows we make all partials tuple expressions
partials = map(partials) do partial
Expand All @@ -143,6 +135,7 @@ function _normalize_scalarrules_macro_input(call, maybe_setup, partials)
return call, setup_stmts, inputs, partials
end


function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
n_outputs = length(partials)
n_inputs = length(inputs)
Expand Down Expand Up @@ -178,7 +171,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)

# Δs is the input to the propagator rule
# because this is a pull-back there is one per output of function
Δs = [Symbol(string(:Δ, i)) for i in 1:n_outputs]
Δs = [Symbol(:Δ, i) for i in 1:n_outputs]

# 1 partial derivative per input
pullback_returns = map(1:n_inputs) do input_i
Expand All @@ -189,7 +182,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials)
# Multi-output functions have pullbacks with a tuple input that will be destructured
pullback_input = n_outputs == 1 ? first(Δs) : Expr(:tuple, Δs...)
pullback = quote
function $(propagator_name(f, :pullback))($pullback_input)
function $(esc(propagator_name(f, :pullback)))($pullback_input)
return (NO_FIELDS, $(pullback_returns...))
end
end
Expand All @@ -215,16 +208,14 @@ function propagation_expr(Δs, ∂s, _conj = false)
∂s = map(esc, ∂s)
n∂s = length(∂s)

# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression
# literals.
# Due to bugs in Julia 1.0, we can't use `.+` or `.*` inside expression literals.
∂_mul_Δs = if _conj
ntuple(i->:(conj($(∂s[i])) * $(Δs[i])), n∂s)
else
ntuple(i->:($(∂s[i]) * $(Δs[i])), n∂s)
end

# Avoiding the extra `+` operation, it is potentially expensive for vector
# mode AD.
# Avoiding the extra `+` operation, it is potentially expensive for vector mode AD.
sumed_∂_mul_Δs = if n∂s > 1
# we use `@.` to broadcast `*` and `+`
:(@. +($(∂_mul_Δs...)))
Expand Down Expand Up @@ -258,3 +249,143 @@ This is able to deal with fairly complex expressions for `f`:
propagator_name(f::Expr, propname::Symbol) = propagator_name(f.args[end], propname)
propagator_name(fname::Symbol, propname::Symbol) = Symbol(fname, :_, propname)
propagator_name(fname::QuoteNode, propname::Symbol) = propagator_name(fname.value, propname)

"""
@non_differentiable(signature_expression)

A helper to make it easier to declare that a method is not not differentiable.
This is a short-hand for defining an [`frule`](@ref) and [`rrule`](@ref) that
return [`DoesNotExist()`](@ref) for all partials (except for the function `s̄elf`-partial
itself which is `NO_FIELDS`)

Keyword arguments should not be included.

```jldoctest
julia> @non_differentiable Base.:(==)(a, b)

julia> _, pullback = rrule(==, 2.0, 3.0);

julia> pullback(1.0)
(Zero(), DoesNotExist(), DoesNotExist())
```

You can place type-constraints in the signature:
```jldoctest
julia> @non_differentiable Base.length(xs::Union{Number, Array})

julia> frule((Zero(), 1), length, [2.0, 3.0])
(2, DoesNotExist())
```
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

!!! warning
This helper macro covers only the simple common cases.
It does not support Varargs, or `where`-clauses.
For these you can declare the `rrule` and `frule` directly

"""
macro non_differentiable(sig_expr)
Meta.isexpr(sig_expr, :call) || error("Invalid use of `@non_differentiable`")
for arg in sig_expr.args
_isvararg(arg) && error("@non_differentiable does not support Varargs like: $arg")
end

primal_name, orig_args = Iterators.peel(sig_expr.args)

constrained_args = _constrain_and_name.(orig_args, :Any)
primal_sig_parts = [:(::typeof($primal_name)), constrained_args...]

unconstrained_args = _unconstrain.(constrained_args)
primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...)

quote
$(_nondiff_frule_expr(primal_sig_parts, primal_invoke))
$(_nondiff_rrule_expr(primal_sig_parts, primal_invoke))
end
end

function _nondiff_frule_expr(primal_sig_parts, primal_invoke)
return Expr(
:(=),
Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...),
# Julia functions always only have 1 output, so just return a single DoesNotExist()
Expr(:tuple, primal_invoke, DoesNotExist()),
)
end

function _nondiff_rrule_expr(primal_sig_parts, primal_invoke)
num_primal_inputs = length(primal_sig_parts) - 1
primal_name = first(primal_invoke.args)
pullback_expr = Expr(
:function,
Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)),
Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...)
)
rrule_defn = Expr(
:(=),
Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...),
Expr(:tuple, primal_invoke, pullback_expr),
)
return rrule_defn
end


###########
# Helpers

"""
_isvararg(expr)

returns true if the expression could represent a vararg

```jldoctest
julia> ChainRulesCore._isvararg(:(x...))
true

julia> ChainRulesCore._isvararg(:(x::Int...))
true

julia> ChainRulesCore._isvararg(:(::Int...))
true

julia> ChainRulesCore._isvararg(:(x::Vararg))
true

julia> ChainRulesCore._isvararg(:(x::Vararg{Int}))
true

julia> ChainRulesCore._isvararg(:(::Vararg))
true

julia> ChainRulesCore._isvararg(:(::Vararg{Int}))
true

julia> ChainRulesCore._isvararg(:(x))
false
````
"""
_isvararg(expr) = false
function _isvararg(expr::Expr)
Meta.isexpr(expr, :...) && return true
if Meta.isexpr(expr, :(::))
constraint = last(expr.args)
constraint == :Vararg && return true
Meta.isexpr(constraint, :curly) && first(constraint.args) == :Vararg && return true
end
return false
end


"turn both `a` and `a::S` into `a`"
_unconstrain(arg::Symbol) = arg
function _unconstrain(arg::Expr)
Meta.isexpr(arg, :(::), 2) && return arg.args[1] # drop constraint.
error("malformed arguments: $arg")
end

"turn both `a` and `::constraint` into `a::constraint` etc"
function _constrain_and_name(arg::Expr, _)
Meta.isexpr(arg, :(::), 2) && return arg # it is already fine.
Meta.isexpr(arg, :(::), 1) && return Expr(:(::), gensym(), arg.args[1]) #add name
error("malformed arguments: $arg")
end
_constrain_and_name(name::Symbol, constraint) = Expr(:(::), name, constraint) # add type
64 changes: 64 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
@testset "rule_definition_tools.jl" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's quite a lot of repeated code here. Did you consider writing a function to test that something has been successfully "non_differentiable"d?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeating code mades it read straight forward, and each is different enough that abstracting the tests would be make them harder to read.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me that that is true. Something like the following ought to do the majority of the work:

function test_nondifferentiable(foo, args, dargs, dy)
    @test frule(dargs, foo, args...) == foo(args...)

    y, pb = rrule(foo, args...)
    @test y == foo(args...)
    @test pb(dy) == (Zero(), map(_ -> DoesNotExist(), args)...)
end

To my mind this is more readable.

I'm not going to object to merging this over this though -- I'm happy to stick with what you've done if you feel strongly that it's more readable.

@testset "@non_differentiable" begin
@testset "two input one output function" begin
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@non_differentiable nondiff_2_1(::Any, ::Any)
@test frule((Zero(), 1.2, 2.3), nondiff_2_1, 3, 2) == (7.5, DoesNotExist())
res, pullback = rrule(nondiff_2_1, 3, 2)
@test res == 7.5
@test pullback(4.5) == (NO_FIELDS, DoesNotExist(), DoesNotExist())
end

@testset "one input, 2-tuple output function" begin
nondiff_1_2(x) = (5.0, 3.0)
@non_differentiable nondiff_1_2(::Any)
@test frule((Zero(), 1.2), nondiff_1_2, 3.1) == ((5.0, 3.0), DoesNotExist())
res, pullback = rrule(nondiff_1_2, 3.1)
@test res == (5.0, 3.0)
@test isequal(
pullback(Composite{Tuple{Float64, Float64}}(1.2, 3.2)),
(NO_FIELDS, DoesNotExist()),
)
end

@testset "constrained signature" begin
nonembed_identity(x) = x
@non_differentiable nonembed_identity(::Integer)

@test frule((Zero(), 1.2), nonembed_identity, 2) == (2, DoesNotExist())
@test frule((Zero(), 1.2), nonembed_identity, 2.0) == nothing

res, pullback = rrule(nonembed_identity, 2)
@test res == 2
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())

@test rrule(nonembed_identity, 2.0) == nothing
end

@testset "Pointy UnionAll constraints" begin
pointy_identity(x) = x
@non_differentiable pointy_identity(::Vector{<:AbstractString})

@test frule((Zero(), 1.2), pointy_identity, ["2"]) == (["2"], DoesNotExist())
@test frule((Zero(), 1.2), pointy_identity, 2.0) == nothing

res, pullback = rrule(pointy_identity, ["2"])
@test res == ["2"]
@test pullback(1.2) == (NO_FIELDS, DoesNotExist())

@test rrule(pointy_identity, 2.0) == nothing
end

@testset "Not supported (Yet)" begin
# Varargs are not supported
@test_throws Exception @macroexpand(@non_differentiable vararg1(xs...))|
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
@test_throws Exception @macroexpand(@non_differentiable vararg1(xs::Vararg))

# Where clauses are not supported.
@test_throws Exception @macroexpand(
@non_differentiable where_identity(::Vector{T}) where T<:AbstractString
)
end

end
end