-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce BLAS threads while parallelizing over GEMM #580
Conversation
Judge resultBenchmark Report for /home/runner/work/FluxMLBenchmarks.jl/FluxMLBenchmarks.jl/benchmark/script/..Job Properties
ResultsA ratio greater than
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfoTarget
Baseline
Target resultBenchmark Report for /home/runner/work/FluxMLBenchmarks.jl/FluxMLBenchmarks.jl/benchmark/script/..Job Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Baseline resultBenchmark Report for /home/runner/work/FluxMLBenchmarks.jl/FluxMLBenchmarks.jl/benchmark/script/..Job Properties
ResultsBelow is a table of this job's results, obtained by running the benchmarks.
Benchmark Group ListHere's a list of all the benchmark groups executed by this job:
Julia versioninfo
Runtime information
|
This kind of reporting is way too lengthy. @avik-pal what do you think, should we merge? |
Seems like the opposite of what I expected, it seems to slow down the |
We've thought about doing this before even if it doesn't result in perf gains, just to make multithreading less unwieldy with Flux. The main challenge has always been that Is there really no way to disable multithreading for just the context of a conv routine? Perhaps some lower-level call we can make which is guaranteed to run single-threaded? |
Is there a specific graph we should look at? Just looking for apples-to-apples, I think Since we haven't used this benchmark tool recently, I did just double check that it checked out and ran the right versions of the baseline and target. It seems to all be correct (you can check the raw action log here: https://github.com/FluxML/FluxMLBenchmarks.jl/actions/runs/8889259577/job/24407415613). |
Try this script. using Lux, Random
import Flux, Metalhead
using Zygote
using UnicodePlots
using LinearAlgebra
using ThreadPinning
pinthreads(:cores)
BLAS.set_num_threads(min(4, Threads.nthreads()))
@info "BLAS Threads: $(BLAS.get_num_threads())"
threadinfo()
versioninfo()
flux_model = Metalhead.VGG(19)
lux_model = FromFluxAdaptor()(flux_model.layers);
ps, st = Lux.setup(Xoshiro(), lux_model);
st_test = Lux.testmode(st);
bsizes = 2 .^ (0:8)
lux_timings = zeros(Float64, length(bsizes))
flux_timings = zeros(Float64, length(bsizes))
for (i, bsize) in enumerate(bsizes)
x_input = rand(Float32, 224, 224, 3, bsize)
lux_timings[i] = @belapsed $lux_model($x_input, $ps, $st_test)
flux_timings[i] = @belapsed $flux_model($x_input)
@info "Batch size: $bsize" Lux=lux_timings[i] Flux=flux_timings[i] ratio=(lux_timings[i] /
flux_timings[i])
end
display(lineplot(bsizes, hcat(lux_timings, flux_timings); name=["Lux" "Flux"], color=[:blue :red]))
bsizes = 2 .^ (0:8)
lux_backward_timings = zeros(Float64, length(bsizes))
flux_backward_timings = zeros(Float64, length(bsizes))
f1 = (m, x) -> sum(abs2, m(x))
f2 = (m, x, ps, st) -> sum(abs2, first(m(x, ps, st)))
for (i, bsize) in enumerate(bsizes)
x_input = rand(Float32, 224, 224, 3, bsize)
lux_backward_timings[i] = @belapsed Zygote.gradient(
f2, $lux_model, $x_input, $ps, $st)
flux_backward_timings[i] = @belapsed Zygote.gradient(
f1, $flux_model, $x_input)
@info "Batch size: $bsize" Lux=lux_backward_timings[i] Flux=flux_backward_timings[i] ratio=(lux_backward_timings[i] /
flux_backward_timings[i])
end
display(lineplot(bsizes, hcat(lux_backward_timings, flux_backward_timings);
name=["Lux" "Flux"], color=[:blue :red])) [ Info: BLAS Threads: 4
System: 64 cores (2-way SMT), 2 sockets, 2 NUMA domains
| 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,
16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,
64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,
80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95 |
| 32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,
48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,
96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,
112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127 |
# = Julia thread, # = HT, # = Julia thread on HT, | = Socket seperator
Julia threads: 16
├ Occupied CPU-threads: 16
└ Mapping (Thread => CPUID): 1 => 0, 2 => 1, 3 => 2, 4 => 3, 5 => 4, ...
Julia Version 1.10.2
Commit bd47eca2c8a (2024-03-01 10:14 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 128 virtual cores) For the Forward Pass For the Backward Pass |
Can someone add the benchmark label?