-
-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
batched_vec
>1000X slower than batched_mul
#462
Comments
Can reproduce. Note also: julia> let A = CUDA.zeros(64, 32, 128) # benchmark without Zygote involvement
x = CUDA.zeros(32, 128)
y1 = @btime CUDA.@sync simple_batched_vec($A, $x)
y2 = @btime CUDA.@sync NNlib.batched_vec($A, $x)
y1 ≈ y2
end
21.854 μs (13 allocations: 496 bytes)
28.889 μs (30 allocations: 1.20 KiB)
true
julia> let A = CUDA.zeros(640, 320, 1280) # bigger version with GPU allocations
x = CUDA.zeros(320, 1280)
CUDA.@time Zygote.pullback(simple_batched_vec, A, x)
CUDA.@time Zygote.pullback(NNlib.batched_vec, A, x)
end;
0.111997 seconds (74 CPU allocations: 3.906 KiB) (1 GPU allocation: 3.125 MiB, 0.03% memmgmt time)
0.155808 seconds (507 CPU allocations: 25.391 KiB) (1 GPU allocation: 3.125 MiB, 0.03% memmgmt time) The culprit here almost has to be this NNlib.jl/src/batched/batchedmul.jl Lines 175 to 181 in d8b9b41
|
I can confirm this; splitting it into two methods, e.g. batched_vec(A::AbstractArray{T,3} where T, B::AbstractMatrix) =
reshape(batched_mul(A, reshape(B, size(B,1), 1, size(B,2))), size(A,1), size(A,3))
# If B is transposed, then stride=1 is the batch dim, so we will end up copying anyway:
batched_vec(A::AbstractArray{T,3} where T, B::AdjOrTransAbsMat{<:BlasFloat, <:StridedMatrix}) =
batched_vec(A, copy(B)) fixes the issue. I can make a quick PR if this looks right to you @mcabbott. |
That looks great, I can't picture why it wasn't written that way the first time. |
Branching on the type of the second argument caused a subtle performance bug when differentiating via `Zygote`; see FluxML#462
Ran into a pretty confounding performance bug:
batched_vec
is performing much slower thanbatched_mul
, and getting slower each time you call it. I see there was previously a similar issue in #282, but I can't tell if that is related, or if this bug is due toNNlib
,Flux
,Zygote
, or some combination thereof...Here's an MWE:
Here is the output:
CUDA version info:
Project.toml
:Manifest.toml
:The text was updated successfully, but these errors were encountered: