Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@non_differentiable #150

Closed
willtebbutt opened this issue Apr 24, 2020 · 12 comments · Fixed by #207
Closed

@non_differentiable #150

willtebbutt opened this issue Apr 24, 2020 · 12 comments · Fixed by #207
Assignees
Labels
enhancement New feature or request

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Apr 24, 2020

Similarly to @nograd in Zygote, we need a @non_differentiable macro to make it straightforward for someone to declare that the particular method of a function should have an frule and rrule that return zeros / related.

Proposed Usage:

@non_differentiable f(x::Tx, y...)

should be expanded to something like

function frule((::Any, ::Any, ::Vararg), ::typeof(f), x::Tx, y...)
    return f(x, y...), Zero()
end

function rrule(::typeof(f), x::Tx, y...)
    function pullback_f_non_differentiable(::Any)
        return (DoesNotExist(), DoesNotExist(), map(DoesNotExist(), 1:length(y)))
    end
    return f(x, y...), pullback_f_non_differentiable
end

I think the frule and rrule that something should expand to once we've established the API is pretty clear (I'm not sure I've actually got the rrule correct for the vararg in the above). Same thing for what the kwargs bit of the API should look like.

So the point of discussion is does the above form of @non_differentiable do what we want it to do? In particular, do we want to insist / not insist that people provide argument names?

@cossio
Copy link

cossio commented Apr 26, 2020

Why not a shorter name, like @nograd?

@willtebbutt
Copy link
Member Author

My feeling is that @non_differentiable is more descriptive of what's going on.

To my ear @nograd suggests feels a bit specific to single-output functions, and it's not really clear what it means ie. why is there no gradient?

@cossio
Copy link

cossio commented Apr 26, 2020

True that grad is specific to single-output functions.
Maybe @nodiff?

It would be nice to also have a macro that shields a block of code from being differentiated (instead of just a function).

@willtebbutt
Copy link
Member Author

willtebbutt commented Apr 26, 2020

It would be nice to also have a macro that shields a block of code from being differentiated (instead of just a function).

Hmm yeah, this would be nice. Might be possible to do at the CR level with a closure... not sure if we would run into world-age issues or something though (as in, I really have no idea whether this is doable at the CR level or whether it has to be done at the AD level)

@simeonschaub
Copy link
Member

Shouldn't that be the AD tool's job though, not ChainRules'? I don't really see, how something like that would go into our rule definitions.

@willtebbutt
Copy link
Member Author

willtebbutt commented Apr 27, 2020

You could just see it as a special case of a user-defined rule, and we generally want users to be able to define custom rule and expect AD to play nicely with them, regardless which AD tool they use.

edit: reworded

@oxinabox oxinabox added the enhancement New feature or request label Apr 27, 2020
@oxinabox
Copy link
Member

oxinabox commented May 4, 2020

I wonder if we can automatically determine which should be DoesNotExist and which should be Zero.

e,g, we know that Symbols, AbstractStrings and Char have corresponding differentiable of DoesNotExist.
for everything else Zero i guess.
Alt could introduct a NonSpecificNonexistance <: AbstractZero

@MikeInnes
Copy link

MikeInnes commented Jun 2, 2020

Probably worth having a look at Zygote's utils file. Things like dropgrad, ignore (@cossio's suggestion), hook and @showgrad could all be easily incorporated into CR and at least ignore and hook seem sufficiently useful to do so. [Some utils have the name grad in them for historical reasons but work just as well in forward and reverse in principle.]

They are really easy to define (using Zygote syntax here):

hook(f, x) = x
@tangent hook(f, x) = x, ẋ -> f(ẋ)
@adjoint hook(f, x) = x, x̄ -> (nothing, f(x̄))

ignore(block) = block()
@tangent hook(block) = block(), _ -> nothing
@adjoint hook(block) = block(), _ -> nothing
ignore() do # can also use a thin macro layer like `@ignore begin ...`
  @info "ignored logging code"
end

hook(dx -> @show(x, dx), x) # print x along with its gradient/derivative

Zygote and co could then re-export these.

FYI, @nograd as it exists in Zygote is kind of a hack from very early in Zygote's existence and should probably be removed at some point. It's way to easy too define over-general rules and so better to encourage people to use signatures in general. In the @nograd case we had issues with things like rand which are non-differentiable in their Base form, but packages add differentiable methods. This proposal is already well-aligned with that but I just wanted to make it explicit.

@willtebbutt
Copy link
Member Author

While I remember: this would also be a great opportunity to implement a number of very specific rules for higher-order functions. For example, if foo isn't differentiable, you know that map(foo, data) also isn't.

@Keno
Copy link
Contributor

Keno commented Aug 13, 2020

👍 On this. Seeing lots of unnecessary work being done by my new AD pass because it's trying to work on things that definitely don't need to be AD'd.

@oxinabox
Copy link
Member

This is my script to get the start of an initial list of things to mark.
It generates a few hundred things to skip.

using Random

NDT = Union{
    AbstractString, AbstractChar, Symbol, Nothing, 
    Cstring, Cwstring, Regex, RegexMatch,
    IO, AbstractDisplay, Cmd, RawFD,
    Exception, Condition, 
    IndexStyle, Colon,
    Random.AbstractRNG,
}
# lets workout what types we should be thinking about.
module_types(mod) = filter(x->x isa DataType, getfield.(Ref(mod), sort(names(mod))))
for t in  module_types(Base)
    t <: NDT && continue  # already got it
    t <: Number && continue  # probably differentiable
    t <: AbstractArray && continue  # probably differentiable
    println(t)  # worth consideration
end

function is_nondiff(sig)
    global NDT
    if sig <:Tuple{Any}
        return false  # need 1 or more arguments
    elseif sig <:Tuple{Any, Vararg{Union{NDT, AbstractArray{<:NDT}}}}
        return true  # arguments are all nondiff types
    else
        false
    end
end

module_functions(mod) = filter(x->x isa Function, getfield.(Ref(mod), names(mod)))

function nondiff_overloads(mod)
    funcs = module_functions(mod)

    sigs = [mm.sig for func in funcs for mm in methods(func) if is_nondiff(mm.sig)]
    isempty(sigs) &&  return String[]
    # now we need to get rid of any signatures that are just special cases ones that we 
    # already have. e.g. not having both `foo(::String)` and `foo(::AbstractString)`
    # we can do this using `Union` which automatically simplifies its elements in this way
    # as a side-effect it also normalizes anon-type vars to have names
    usigs = Base.uniontypes(Union{sigs...})

    return map(usigs) do sig
        sig_plain = Base.unwrap_unionall(sig)
        function_type, arg_types = Iterators.peel(sig_plain.parameters)
        func_name = nameof(function_type.instance)
        modname = if mod == Base
            ""
        elseif parentmodule(mod) == Base 
            "$(nameof(mod))."
        else 
            "$mod."
        end
        args_text = join(("::$arg_type" for arg_type in arg_types), ", ")
        return "@non_differentiable $modname$func_name($args_text)"
    end
end

foreach(println, nondiff_overloads(Base))
println("\n")
println()
foreach(println, nondiff_overloads(Base.Threads))
println()
foreach(println, nondiff_overloads(Base.Iterators))
println()
foreach(println, nondiff_overloads(Base.Libc))
println()
foreach(println, nondiff_overloads(Base.Broadcast))
println()
foreach(println, nondiff_overloads(Base.Sys))

println("\nusing Random")
foreach(println, nondiff_overloads(Random))

@oxinabox
Copy link
Member

This is one of the rare cases with DoesNotExist/Zero iterating to give itself is useful,
since can just return it as the whole return and not have to worry about about if a function returns a tuple (for frule)

simeonschaub pushed a commit to simeonschaub/ChainRulesCore.jl that referenced this issue Nov 16, 2020
* Remove TODO sections

* Move info about chainrules vs chainrulescore to FAQ
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants