Skip to content

Commit

Permalink
Put the cartesian iterator in the type domain.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 15, 2023
1 parent 8b26c7c commit 4072ea0
Showing 1 changed file with 34 additions and 13 deletions.
47 changes: 34 additions & 13 deletions src/host/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,50 @@ end
bc′ = Broadcast.preprocess(dest, bc)

# grid-stride kernel
function broadcast_kernel(ctx, dest, bc′, nelem)
i = 0
while i < nelem
i += 1
# HACK: cartesian iteration is slow, so avoid it if possible
# TODO: generalize this, much like how `eachindex` picks an appropriate iterator
# (::AnyGPUArray methods shouldn't be hardcoded to use linear indexing)
I = if isa(IndexStyle(dest), IndexLinear) && isa(IndexStyle(bc′), IndexLinear)
@linearidx(dest, i)
function broadcast_kernel(ctx, dest, ::Val{Is}, bc′, nelem) where Is
j = 0
while j < nelem
j += 1

i = @linearidx(dest, j)

# cartesian indexing is slow, so avoid it if possible
if isa(IndexStyle(dest), IndexCartesian) || isa(IndexStyle(bc′), IndexCartesian)
# this performs an integer division, which is expensive. to make it possible
# for the compiler to optimize it away, we put the iterator in the type
# domain so that the indices are available at compile time. note that LLVM
# only seems to replace pow2 divisions (with bitshifts), but other back-ends
# may be smarted and replace arbitrary divisions by bit operations.
#
# also see maleadt/StaticCartesian.jl, which implements this in Julia,
# but does not result in an additional speed-up on tested back-ends.
#
# XXX: why is this *faster* on Metal.jl when not using @inbounds?
I = #=@inbounds=# Is[i]
end

val = if isa(IndexStyle(bc′), IndexCartesian)
@inbounds bc′[I]
else
@inbounds bc′[i]
end

if isa(IndexStyle(dest), IndexCartesian)
@inbounds dest[I] = val
else
@cartesianidx(dest, i)
@inbounds dest[i] = val
end
@inbounds dest[I] = bc′[I]
end
return
end
elements = length(dest)
elements_per_thread = typemax(Int)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, bc′, 1;
Is = CartesianIndices(dest)
heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(Is), bc′, 1;
elements, elements_per_thread)
config = launch_configuration(backend(dest), heuristic;
elements, elements_per_thread)
gpu_call(broadcast_kernel, dest, bc′, config.elements_per_thread;
gpu_call(broadcast_kernel, dest, Val(Is), bc′, config.elements_per_thread;
threads=config.threads, blocks=config.blocks)

return dest
Expand Down

0 comments on commit 4072ea0

Please sign in to comment.