Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Freezing layer parameters still computes all gradients #1688

Closed
vladtkachuk4 opened this issue Aug 4, 2021 · 4 comments
Closed

Freezing layer parameters still computes all gradients #1688

vladtkachuk4 opened this issue Aug 4, 2021 · 4 comments

Comments

@vladtkachuk4
Copy link

vladtkachuk4 commented Aug 4, 2021

I have been looking at this as a way to train only some of the layer parameters while freezing the others. More specifically, passing in a subset of the model (ex: m[3:end]) to Flux.params. This does work, but I noticed that the runtime was surprisingly the same regardless of how many layers I would freeze.

Below is a simple example showing this behaviour:

julia> using Flux, BenchmarkTools

julia> m = Chain(Dense(100, 50, relu), Dense(50, 2), softmax);

julia> opt = Descent(0.01);

julia> data, labels = rand(Float32, 100, 100), zeros(Float32, 2, 100);

julia> loss(x, y) = sum(Flux.crossentropy(m(x), y));

julia> function get_grads(s)
           gs = gradient(Flux.params(m[s:end])) do
               l = loss(data, labels)
           end
       end
get_grads (generic function with 1 method)

julia> @benchmark get_grads(1)
BenchmarkTools.Trial:
  memory estimate:  199.59 KiB
  allocs estimate:  772
  --------------
  minimum time:     258.748 μs (0.00% GC)
  median time:      296.283 μs (0.00% GC)
  mean time:        347.441 μs (9.07% GC)
  maximum time:     24.036 ms (97.39% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark get_grads(2)
BenchmarkTools.Trial:
  memory estimate:  199.91 KiB
  allocs estimate:  787
  --------------
  minimum time:     260.033 μs (0.00% GC)
  median time:      300.003 μs (0.00% GC)
  mean time:        360.708 μs (8.68% GC)
  maximum time:     13.053 ms (96.72% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark get_grads(3)
BenchmarkTools.Trial:
  memory estimate:  199.97 KiB
  allocs estimate:  786
  --------------
  minimum time:     257.480 μs (0.00% GC)
  median time:      304.113 μs (0.00% GC)
  mean time:        524.964 μs (6.56% GC)
  maximum time:     64.466 ms (0.00% GC)
  --------------
  samples:          9327
  evals/sample:     1

This seems to be happening because the gradients for all the layers are still being computed regardless of what is passed to Flux.params. For example:

julia> get_grads(1).grads
IdDict{Any, Any} with 6 entries:
  Float32[0.0, 0.0]                                  => Float32[0.0, 0.0]
  Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0… => Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0…
  Float32[0.0719976 -0.132163 … 0.121904 0.132747; … => Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
  Float32[-0.0377398 0.132804 … -0.334346 -0.068553… => Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
  :(Main.labels)                                     => Float32[0.00645833 0.00652768 … 0.00303443 0.00228067; 0.00742811 0.00735225 … 0.0134045 0.0158998]
  :(Main.data)                                       => Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

julia> get_grads(2).grads
IdDict{Any, Any} with 5 entries:
  Float32[0.0, 0.0]                                  => Float32[0.0, 0.0]
  :(Main.m)                                          => (layers = ((weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 ……
  Float32[-0.0377398 0.132804 … -0.334346 -0.068553… => Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]
  :(Main.labels)                                     => Float32[0.00645833 0.00652768 … 0.00303443 0.00228067; 0.00742811 0.00735225 … 0.0134045 0.0158998]
  :(Main.data)                                       => Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

julia> get_grads(3).grads
IdDict{Any, Any} with 3 entries:
  :(Main.m)      => (layers = ((weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], bias = Float32[0.0, 0.0, …
  :(Main.labels) => Float32[0.00645833 0.00652768 … 0.00303443 0.00228067; 0.00742811 0.00735225 … 0.0134045 0.0158998]
  :(Main.data)   => Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]

The gradients move to :(Main.m), but are still computed.

This behaviour is unideal (especially when only the last layer is updated) for runtime.

Is this intended and is there a way to stop the frozen layer gradients from being computed?

@DhairyaLGandhi
Copy link
Member

It is required to calculate partials over the parameters. This is because Zygote may require these partials to be used elsewhere in the computation to return back gradients of the set of parameters requested. If you want to see the final results more properly, it is recommended that we avoid global variables and access to them across Julia, but also in differentiation. Removing the global accesses also comes with a performance bump and the correct gradients being returned.

julia> loss(m, x, y) = sum(Flux.crossentropy(m(x), y));

julia> function get_grads(m, data, labels, s)
            gs = gradient(Flux.params(m[s:end])) do
              l = loss(m, data, labels)
           end
         end

julia> get_grads(m, data, labels, 1).grads
IdDict{Any, Any} with 4 entries:
  Float32[-0.16087 -0.0740108  0.122463 0. => Float32[0.0 0.0  0.0 0.0; 0.0 0.0  0.0 0.0;  ; 0.0 0.0  0.0 0.0; 0.0 0.0  0.
  Float32[-0.279719 0.265606  0.0243676 0. => Float32[0.0 0.0  0.0 0.0; 0.0 0.0  0.0 0.0]
  Float32[0.0, 0.0]                          => Float32[0.0, 0.0]
  Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 => Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0    0.0, 0.0, 0.0, 0.0, 

julia> get_grads(m, data, labels, 2).grads
IdDict{Any, Any} with 2 entries:
  Float32[-0.279719 0.265606  0.0243676 0.0286161; -0.148928 -0.0405258  0.295 => Float32[0.0 0.0  0.0 0.0; 0.0 0.0  0.0 0.0]
  Float32[0.0, 0.0]                                                               => Float32[0.0, 0.0]

julia> get_grads(m, data, labels, 3).grads
IdDict{Any, Any}()

@vladtkachuk4
Copy link
Author

Great advice!
However, this still leaves the problem that the runtime is the same regardless of how many layers are frozen.

julia> function get_grads_2(m, data, labels, s)
           gs = gradient(Flux.params(m[s:end])) do
               l = loss(m, data, labels)
           end
       end
get_grads_2 (generic function with 1 method)

julia> @benchmark get_grads_2(m, data, labels, 1)
BenchmarkTools.Trial:
  memory estimate:  200.52 KiB
  allocs estimate:  774
  --------------
  minimum time:     247.308 μs (0.00% GC)
  median time:      282.521 μs (0.00% GC)
  mean time:        338.753 μs (8.55% GC)
  maximum time:     12.615 ms (96.10% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark get_grads_2(m, data, labels, 2)
BenchmarkTools.Trial:
  memory estimate:  200.94 KiB
  allocs estimate:  790
  --------------
  minimum time:     248.344 μs (0.00% GC)
  median time:      283.671 μs (0.00% GC)
  mean time:        334.742 μs (9.23% GC)
  maximum time:     14.135 ms (97.35% GC)
  --------------
  samples:          10000
  evals/sample:     1

julia> @benchmark get_grads_2(m, data, labels, 3)
BenchmarkTools.Trial:
  memory estimate:  201.09 KiB
  allocs estimate:  789
  --------------
  minimum time:     242.827 μs (0.00% GC)
  median time:      277.384 μs (0.00% GC)
  mean time:        319.943 μs (9.27% GC)
  maximum time:     15.763 ms (95.98% GC)
  --------------
  samples:          10000
  evals/sample:     1

Which, as you mentioned is because the gradients are still computed in the background.
For anyone doing transfer learning with a large network and only needing to train the final layer this seems like it would be a big issue. Is there any way to address this currently (A way to tell Zygote not to take derivates of the other layers)?

@ToucheSir
Copy link
Member

gradient only runs AD for operations that happen inside the callback, so a simple change is to move the frozen part of the network outside:

loss(m, x, y) = Flux.crossentropy(m(x), y) # no need for sum, crossentropy aggregates by default

function get_grads(trunk, head, data, labels)
  z = trunk(data)
  gradient(Flux.params(head)) do
    loss(head, z, labels)
  end
end

get_grads(m[1:end-1], m[end], data, labels)
get_grads(m[1:end-2], m[end-1:end], data, labels)
get_grads(m[1:end-3], m[end-2:end], data, labels)

You see this pattern quite a bit with torch.no_grad as well. In a source-source system like Zygote removing anything that doesn't need to go through AD from the AD system is generally beneficial for compilation latency as well.

@vladtkachuk4
Copy link
Author

@ToucheSir this is exactly what I was looking for! Thanks so much for the help :)
Closing issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants