-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@non_differentiable #150
Comments
Why not a shorter name, like |
My feeling is that To my ear |
True that It would be nice to also have a macro that shields a block of code from being differentiated (instead of just a function). |
Hmm yeah, this would be nice. Might be possible to do at the CR level with a closure... not sure if we would run into world-age issues or something though (as in, I really have no idea whether this is doable at the CR level or whether it has to be done at the AD level) |
Shouldn't that be the AD tool's job though, not ChainRules'? I don't really see, how something like that would go into our rule definitions. |
You could just see it as a special case of a user-defined rule, and we generally want users to be able to define custom rule and expect AD to play nicely with them, regardless which AD tool they use. edit: reworded |
I wonder if we can automatically determine which should be e,g, we know that |
Probably worth having a look at Zygote's utils file. Things like They are really easy to define (using Zygote syntax here): hook(f, x) = x
@tangent hook(f, x) = x, ẋ -> f(ẋ)
@adjoint hook(f, x) = x, x̄ -> (nothing, f(x̄))
ignore(block) = block()
@tangent hook(block) = block(), _ -> nothing
@adjoint hook(block) = block(), _ -> nothing ignore() do # can also use a thin macro layer like `@ignore begin ...`
@info "ignored logging code"
end
hook(dx -> @show(x, dx), x) # print x along with its gradient/derivative Zygote and co could then re-export these. FYI, |
While I remember: this would also be a great opportunity to implement a number of very specific rules for higher-order functions. For example, if |
👍 On this. Seeing lots of unnecessary work being done by my new AD pass because it's trying to work on things that definitely don't need to be AD'd. |
This is my script to get the start of an initial list of things to mark. using Random
NDT = Union{
AbstractString, AbstractChar, Symbol, Nothing,
Cstring, Cwstring, Regex, RegexMatch,
IO, AbstractDisplay, Cmd, RawFD,
Exception, Condition,
IndexStyle, Colon,
Random.AbstractRNG,
}
# lets workout what types we should be thinking about.
module_types(mod) = filter(x->x isa DataType, getfield.(Ref(mod), sort(names(mod))))
for t in module_types(Base)
t <: NDT && continue # already got it
t <: Number && continue # probably differentiable
t <: AbstractArray && continue # probably differentiable
println(t) # worth consideration
end
function is_nondiff(sig)
global NDT
if sig <:Tuple{Any}
return false # need 1 or more arguments
elseif sig <:Tuple{Any, Vararg{Union{NDT, AbstractArray{<:NDT}}}}
return true # arguments are all nondiff types
else
false
end
end
module_functions(mod) = filter(x->x isa Function, getfield.(Ref(mod), names(mod)))
function nondiff_overloads(mod)
funcs = module_functions(mod)
sigs = [mm.sig for func in funcs for mm in methods(func) if is_nondiff(mm.sig)]
isempty(sigs) && return String[]
# now we need to get rid of any signatures that are just special cases ones that we
# already have. e.g. not having both `foo(::String)` and `foo(::AbstractString)`
# we can do this using `Union` which automatically simplifies its elements in this way
# as a side-effect it also normalizes anon-type vars to have names
usigs = Base.uniontypes(Union{sigs...})
return map(usigs) do sig
sig_plain = Base.unwrap_unionall(sig)
function_type, arg_types = Iterators.peel(sig_plain.parameters)
func_name = nameof(function_type.instance)
modname = if mod == Base
""
elseif parentmodule(mod) == Base
"$(nameof(mod))."
else
"$mod."
end
args_text = join(("::$arg_type" for arg_type in arg_types), ", ")
return "@non_differentiable $modname$func_name($args_text)"
end
end
foreach(println, nondiff_overloads(Base))
println("\n")
println()
foreach(println, nondiff_overloads(Base.Threads))
println()
foreach(println, nondiff_overloads(Base.Iterators))
println()
foreach(println, nondiff_overloads(Base.Libc))
println()
foreach(println, nondiff_overloads(Base.Broadcast))
println()
foreach(println, nondiff_overloads(Base.Sys))
println("\nusing Random")
foreach(println, nondiff_overloads(Random)) |
This is one of the rare cases with |
* Remove TODO sections * Move info about chainrules vs chainrulescore to FAQ
Similarly to
@nograd
in Zygote, we need a@non_differentiable
macro to make it straightforward for someone to declare that the particular method of a function should have anfrule
andrrule
that return zeros / related.Proposed Usage:
should be expanded to something like
I think the
frule
andrrule
that something should expand to once we've established the API is pretty clear (I'm not sure I've actually got therrule
correct for the vararg in the above). Same thing for what thekwargs
bit of the API should look like.So the point of discussion is does the above form of
@non_differentiable
do what we want it to do? In particular, do we want to insist / not insist that people provide argument names?The text was updated successfully, but these errors were encountered: