-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid cartesian iteration where possible. #454
Conversation
This doesn't seem to help much, as most broadcast objects require cartesian indexing: julia> IndexStyle(Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5*ones(1, 4))))
IndexCartesian() An alternative is to provide the CartesianIndices as a const, which makes the @inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
isempty(dest) && return dest
bc′ = Broadcast.preprocess(dest, bc)
# grid-stride kernel
function broadcast_kernel(ctx, dest, ::Val{idx}, bc′, nelem) where {idx}
i = 0
while i < nelem
i += 1
# the CartesianIndices are passed as a constant value,
# to prevent expensive integer divisions on non-constant values
j = @linearidx(dest, i)
J = @inbounds idx[j]
@inbounds dest[j] = bc′[J]
end
return
end
elements = length(dest)
elements_per_thread = typemax(Int)
idx = CartesianIndices(dest)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(idx), bc′, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, Val(idx), bc′, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)
return dest
end pass: ; preds = %L5
%51 = add i64 %50, -1
%52 = sdiv i64 %51, 2
%.neg = mul i64 %52, -2
%53 = add i64 %.neg, %50
%54 = mul i64 %12, %52
%55 = add i64 %53, -1
%56 = add i64 %55, %54
%57 = getelementptr inbounds float, float addrspace(1)* %13, i64 %56
%58 = load float, float addrspace(1)* %57, align 4
%59 = getelementptr inbounds float, float addrspace(1)* %.unpack9, i64 %51
store float %58, float addrspace(1)* %59, align 4
%.not = icmp slt i64 %48, %4
br i1 %.not, label %L5, label %common.ret
} ... instead of pass.us: ; preds = %L5.us
%32 = add i64 %31, -1
%33 = sdiv i64 %32, %.unpack5.unpack
%34 = mul i64 %33, %.unpack5.unpack
%35 = sub i64 %32, %34
%36 = add i64 %33, 1
%37 = select i1 %.not11, i64 %.fca.0.0.2.1.extract, i64 %36
%38 = add i64 %37, -1
%39 = mul i64 %13, %38
%40 = add i64 %35, %39
%41 = getelementptr inbounds float, float addrspace(1)* %14, i64 %40
%42 = load float, float addrspace(1)* %41, align 4
%43 = getelementptr inbounds float, float addrspace(1)* %.unpack9, i64 %32
store float %42, float addrspace(1)* %43, align 4
%.not.us = icmp slt i64 %29, %4
br i1 %.not.us, label %L5.us, label %common.ret
common.ret: ; preds = %pass.us, %L5.us, %pass.us.us, %L5.us.us, %L5.lr.ph.L5.lr.ph.split_crit_edge, %conversion
ret void
fail: ; preds = %L5.lr.ph.L5.lr.ph.split_crit_edge
call fastcc void @gpu_report_exception() #2
call fastcc void @gpu_signal_exception() #2
call void @llvm.trap()
unreachable
} FWIW, a driver script to test this: using Metal
function kernel_copy!(a, b)
(i,j) = thread_position_in_grid_2d()
@inbounds a[i,j] = b[i,j]
return
end
function benchmark(n=2^14, nsample=10)
test(n)
function measure(f, name)
a = MtlArray(rand(Float32, n,n))
b = similar(a)
ts = zeros(nsample)
for i ∈ 1:nsample
ts[i] = @elapsed Metal.@sync begin
f(a, b)
end
end
tmin = minimum(ts)
size_in_bytes = 2*length(a)*sizeof(Float32) #1R+1W
byte_per_ns = size_in_bytes / (tmin*1.e9)
println("$name performance: $(round(byte_per_ns; digits=3)) GB/s")
end
threads = (32,32)
grid_size = cld.(n, threads)
measure("kernel") do a, b
@metal threads=threads grid=grid_size kernel_copy!(a, b)
end
measure("broadcast") do a, b
a .= b
end
end
function test(n=2^14)
a = MtlArray(rand(Float32, n,n))
b = similar(a)
threads = (32,32)
grid_size = cld.(n, threads)
@metal threads=threads grid=grid_size kernel_copy!(a, b)
@assert Array(a) == Array(b)
b = similar(a)
a .= b
@assert Array(a) == Array(b)
end
function codegen()
a = MtlArray(rand(Float32, 2, 2))
b = MtlArray(rand(Float32, 2, 2))
#@device_code_llvm debuginfo=:none @metal kernel_copy!(a, b)
@device_code_llvm debuginfo=:none a .= b
end On my M1 Pro, the kernel gives about 180 GB/s, current broadcast does 10GB/s, the optimized one here does 24GB/s. |
... and (incorrectly) forcing the broadcast to use linear indexing all the way brings performance to 180GB/s. So the problem still is the cartesian indexing. EDIT: but passing a volatile CartesianIndex to the kernel is fast, so the problem is squarly with the indexing of the CartesianIndices and not with its use in Broadcast. That's good news. I might have an idea for a fix. |
Turns out that wasn't true. I implemented my idea for a fix at https://github.com/maleadt/StaticCartesian.jl, essentially, this not only puts the CartesianIndices iterator in the type domain (exposing the constant divisors to LLVM), but also implements the bit twiddling optimizations I was talking about in Julia. This results in In an attempt to improve this, I used the AIR intrinsic for mulhi(x::Int32, y::Int32) = ccall("extern air.mul_hi.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
mulhi(x::UInt32, y::UInt32) = ccall("extern air.mul_hi.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
mulhi(x::Int64, y::Int64) = ccall("extern air.mul_hi.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
mulhi(x::UInt64, y::UInt64) = ccall("extern air.mul_hi.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
@device_override StaticCartesian.mulhi(x::T, y::T) where T <: Union{Int32,UInt32,Int64,UInt64} = mulhi(x, y) That only brought performance to 30GB/s, still way to low. Disappointingly, going back and "just" putting the original CartesianIndices iterator in the type domain (thus emitting However, in trying all this, I noticed something very weird: removing the Or maybe the measurements here are off; I wish we had a decent profiler... |
Looks like with slow (no bounds-check, sdiv): ; @ /Users/tim/Julia/pkg/GPUArrays/src/host/broadcast.jl:74 within `broadcast_kernel`
; ┌ @ abstractarray.jl:1241 within `getindex`
; │┌ @ abstractarray.jl:1286 within `_getindex`
; ││┌ @ abstractarray.jl:1293 within `_to_subscript_indices`
; │││┌ @ abstractarray.jl:1315 within `_unsafe_ind2sub`
; ││││┌ @ abstractarray.jl:2639 within `_ind2sub` @ abstractarray.jl:2677
; │││││┌ @ int.jl:86 within `-`
%31 = add nsw i64 %29, -1
; │││││└
; │││││┌ @ abstractarray.jl:2690 within `_ind2sub_recurse`
; ││││││┌ @ abstractarray.jl:2697 within `_div`
; │││││││┌ @ int.jl:288 within `div`
%32 = sdiv i64 %31, 16383
; └└└└└└└└ fast (bounds-check, udiv): ; @ /Users/tim/Julia/pkg/GPUArrays/src/host/broadcast.jl:74 within `broadcast_kernel`
; ┌ @ abstractarray.jl:1241 within `getindex`
; │┌ @ abstractarray.jl:1285 within `_getindex`
; ││┌ @ abstractarray.jl:668 within `checkbounds` @ abstractarray.jl:653
; │││┌ @ abstractarray.jl:727 within `checkindex`
; ││││┌ @ bool.jl:38 within `&`
%.off.us = add nsw i64 %29, -1
%31 = icmp ugt i64 %.off.us, 268402688
; │││└└
; │││ @ abstractarray.jl:668 within `checkbounds`
br i1 %31, label %L52, label %pass.us
pass.us: ; preds = %L25.us
; ││└
; ││ @ abstractarray.jl:1286 within `_getindex`
; ││┌ @ abstractarray.jl:1293 within `_to_subscript_indices`
; │││┌ @ abstractarray.jl:1315 within `_unsafe_ind2sub`
; ││││┌ @ abstractarray.jl:2639 within `_ind2sub` @ abstractarray.jl:2677
; │││││┌ @ abstractarray.jl:2690 within `_ind2sub_recurse`
; ││││││┌ @ abstractarray.jl:2697 within `_div`
; │││││││┌ @ int.jl:288 within `div`
%.lhs.trunc.us = trunc i64 %.off.us to i32
%32 = udiv i32 %.lhs.trunc.us, 16383
%.zext.us = zext i32 %32 to i64
; └└└└└└└└ |
Seems like the div is a red herring, forcing EDIT: ah, adding the upper bound too makes it fast 🚀 |
4072ea0
to
0b23650
Compare
That's really cool. Indeed the backend generates bit-twiddling code (obtained using https://github.com/dougallj/applegpu) for these kernels: https://gist.github.com/maxwindiff/a1850531f72c20ff5c922ac3743f2093 kernel void div2(device uint *a) { *a /= 2; }
kernel void div3(device uint *a) { *a /= 3; }
kernel void div5(device uint *a) { *a /= 5; }
kernel void div7(device uint *a) { *a /= 7; }
kernel void div11(device uint *a) { *a /= 11; }
kernel void div13(device uint *a) { *a /= 13; } |
Cool, I didn't know about that disassembler! We should integrate that with Metal.jl. Did you use it with Metal C code, or how did you get a binary dump of the code generated by Julia? I'm working on wrapping MtlBinaryArchive, but that's a fair bit of work, so am wondering if I missed something. |
I used Metal C + Xcode. I was thinking of trying to add it to Metal.jl but you are too fast 😂 |
No need for that anymore: JuliaGPU/Metal.jl#96 🎉 |
This may have broken Flux, FluxML/Flux.jl#2214. Will try to create a MWE next week if nobody gets to it first. |
It also seems to cause a huge performance issue with Transformers. Haven't have a MWE, but it looks like the kernel is being recompiled over and over again. Most of the time is on cpu, the gpu is barely runned. |
@vchuravy Does KA.jl handle iterator selection better?