-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add total(f, model)
to replace implicit sum(f, Flux.params(model))
#57
Conversation
if p isa ProjectTo # e.g. Array, NamedTuple | ||
p(y) | ||
else # p === identity for unknown structs | ||
# if p isa ProjectTo # e.g. Array, NamedTuple | ||
# p(y) # but for NamedTuple, this hits https://github.com/JuliaDiff/ChainRulesCore.jl/issues/538 | ||
if x isa Union{Number, AbstractArray} # these don't use Tangent | ||
ProjectTo(x)(unthunk(y)) | ||
else | ||
Tangent{typeof(x), typeof(y)}(y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is either a bug in earlier _Tangent_biwalk
, or in ChainRulesCore.
function total(f, x) | ||
values = [] | ||
fmap(y -> push!(values, f(y)), x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) | ||
sum(values) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While Any[]
doesn't seem great, this ends up about the same speed as my other idea:
const INIT = Base._InitialValue()
function total2(f, x; init = INIT)
fmap(x; exclude = isnumeric, walk = (f, z) -> foreach(f, _trainable(z))) do y
val = f(y)
init = init===INIT ? val : (init+val)
end
init
end
julia> @btime total(norm, $model) # Resnet from the docs
min 23.863 ms, mean 23.995 ms (1541 allocations, 130.06 KiB)
730.5533f0
julia> @btime total2(norm, $model)
min 23.834 ms, mean 23.982 ms (1538 allocations, 128.17 KiB)
730.5533f0
julia> m = (x = [3.0, 4.0], y = (sin, [5.0]), z = (6, 7));
julia> @btime total(norm, $m)
min 1.750 μs, mean 1.846 μs (16 allocations, 752 bytes)
10.0
julia> @btime total2(norm, $m)
min 1.675 μs, mean 1.769 μs (15 allocations, 640 bytes)
10.0
Should this be more general to allow computing the norm of the gradients as well? |
julia> total(norm, m) | ||
10.0 | ||
|
||
julia> total(length, m) == length(destructure(m)[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would solve FluxML/Flux.jl#2043 (as long as trainable parameters are what you want).
Or total(Base.summarysize, m)
for bytes, total(_ -> 1, m)
to count arrays.
One idea for making this differentiable is to take the non-diff part (caching) out of |
Not sure I follow. If |
|
OK. I guess this PR's take is that since essentially nothing else about Functors.jl is type-stable, everything takes a few μs, there's not much point pushing hard here. Decomposing into
|
For this PR, my main ask is that it not cut off any paths which could bring us better type stability in the future. That doesn't seem to be the case, but I don't understand the (un)thunking well enough to say for sure. Minor comments would be the inclusion of an |
z, total_back | ||
end | ||
|
||
function _total_hobbit(config::RuleConfig, f, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A brief comment about what this and _total_grad
do would help. "hobbit" in particular is alien terminology for anyone who hasn't read a couple of specific issues on the ChainRules repo 😛. Is there something more concise than _total_value_and_inner_pullbacks
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I rebased this but realised I have no memory of how it worked. Will revise or re-write.
Now that people are starting to use explicit params, we've seen a few instances where it would be nice to have a easy method for adding regularization terms. I believe this function should be easier to implement in a post-FluxML/Functors.jl#43 world too. |
#143 and #57 (comment) hint that the signature here should probably allow for |
Makes sense, though thinking about this stuff is always rather mind-bending.
I guess it would be future-proofing for models with nested differentiation? We could always kick this can down the road until someone needs this for Flux models. |
In the case of |
Also, I don't think adding |
Thinking more, I think do think they are always added. If I do something like In some |
I'm trying to adapt the chain = Chain(Dense(3=>5), Dense(5=>1, relu))
f(A) = sum(abs2, A)
# old way of doing things:
sum(f.(Flux.params(chain)))
# 7.247263f0
# new way:
total(f, chain)
# ERROR: MethodError: no method matching isnumeric(::Chain{Tuple{Dense{typeof(identity), Mat# rix{Float32}, Vector{Float32}}, Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}}}) |
No, it's not... I suspect you have Base.isnumeric, there's an unfortunate name clash? Or perhaps this has just rotted, sorry. I mean to revisit but have been busy. |
Okay, nevermind. I was using |
Thanks for the lightning reply ❤️ |
Bump on merging this? We still get regularization questions frequently. |
Now rebased, but the tests tell me I need to remember what on earth |
total
total(f, model)
to replace implicit sum(f, Flux.params(model))
I think that instead of introducing It could be implemented on top of Functors.fleaves |
This could be closed since we now have |
Are we still running into recompilation issues using |
I don't think we have recompilation issues with |
This proposes to add some kind of differentiable
sum(f, trainable(x))
which walks the model. I'm not certain this is the right thing yet.Right now this gets all trainable parameters. But perhaps a variant which takes a type
total(f, Union{Dense, Conv}, model)
might be a better explicit-parameters replacement formodules
? Xref FluxML/Flux.jl#1863 (comment)Closes FluxML/Functors.jl#35 , probably.
Edit: since I couldn't find this, big Flux issue about explicit parameters is FluxML/Flux.jl#1986 and snippet with a quick way to write
total
here: FluxML/Flux.jl#2040 (comment)