-
-
Notifications
You must be signed in to change notification settings - Fork 613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance issue when calculating loss #1255
Comments
Can you post some benchmarks? |
That was done to ensure type stability and prevent unnecessary type promotions that can occur by mixing |
I'm afraid it's not related to
Here's the example: m = Dense(32, 32)
x = rand(Float32, 32, 32)
y = rand(Float32, 32) @benchmark gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) / length(y)
end
@benchmark gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) * 1 // length(y)
end
Note that, when model becomes larger, the performance difference soon becomes very large (in my case, it's about two order of magnitudes). |
hmhm, I cannot reproduce julia> @btime gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) / length(y)
end
74.090 μs (3192 allocations: 97.61 KiB)
Grads(...)
julia> @btime gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) * 1 // length(y)
end
74.381 μs (3195 allocations: 97.73 KiB)
Grads(...) |
actually my last measurement was on Zygote 0.4.20, on newer Zygote versions I can reproduce the performance difference. @oxinabox could this be due to ChainRules? |
this is quite relevant, since i removed all |
Hard to say, without a lot more informatiom. |
Using julia> using Flux, BenchmarkTools
julia> m = Dense(32, 32)
Dense(32, 32)
julia> x = rand(Float32, 32, 32);
julia> y = rand(Float32, 32);
julia> @btime gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) / length(y)
end
124.031 μs (3197 allocations: 114.19 KiB)
Grads(...)
julia> @btime gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) * 1 // length(y)
end
73.434 μs (3196 allocations: 97.72 KiB)
Grads(...)
julia> @btime gradient(Flux.params(m)) do
mean(abs.(m(x) .- y))
end
73.004 μs (3190 allocations: 105.67 KiB)
Grads(...) |
The integer division problem has been fixed in ChainRules julia> using Flux, BenchmarkTools
[ Info: Precompiling Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
julia> m = Dense(32, 32)
Dense(32, 32)
julia> x = rand(Float32, 32, 32);
julia> y = rand(Float32, 32);
julia> @btime gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) / length(y)
end
75.177 μs (3193 allocations: 105.61 KiB)
Grads(...)
julia> @btime gradient(Flux.params(m)) do
sum(abs.(m(x) .- y)) * 1 // length(y)
end
75.825 μs (3196 allocations: 105.72 KiB)
Grads(...)
julia> @btime gradient(Flux.params(m)) do
mean(abs.(m(x) .- y))
end
74.869 μs (3190 allocations: 113.67 KiB)
Grads(...) |
I notice that many loss functions in this package are written like this:
Flux.jl/src/layers/stateless.jl
Line 8 in 942d5f6
Flux.jl/src/layers/stateless.jl
Line 23 in 942d5f6
Flux.jl/src/layers/stateless.jl
Line 35 in 942d5f6
They all have a
* 1 //
. Originally I thought it was redundant. Only recently I got a performance issue and found that doing so will avoid a performance issue. Can anyone explain why we need this?The text was updated successfully, but these errors were encountered: