Skip to content

Commit

Permalink
Merge pull request #2075 from CliMA/ck/auto_kernel_splitting
Browse files Browse the repository at this point in the history
Automatically split fused kernels by parameter memory limits
  • Loading branch information
charleskawczynski authored Nov 8, 2024
2 parents d0a9f9d + 5b200f5 commit 99132e0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ LinearAlgebra = "1"
LazyBroadcast = "0.1"
Logging = "1"
MPI = "0.20"
MultiBroadcastFusion = "0.3"
MultiBroadcastFusion = "0.3, 0.4"
NVTX = "0.3"
OrderedCollections = "1"
PkgVersion = "0.1, 0.2, 0.3"
Expand Down
101 changes: 64 additions & 37 deletions ext/cuda/data_layouts_fused_copyto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,56 +72,83 @@ function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us)
end
return nothing
end

import MultiBroadcastFusion
const MBFCUDA =
Base.get_extension(MultiBroadcastFusion, :MultiBroadcastFusionCUDAExt)
# https://github.com/JuliaLang/julia/issues/56295
# Julia 1.11's Base.Broadcast currently requires
# multiple integer indexing, wheras Julia 1.10 did not.
# This means that we cannot reserve linear indexing to
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
# (including the GPU-variant related issue resolution efforts:
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).

function fused_multibroadcast_args(fmb::FusedMultiBroadcast)
dest = first(fmb.pairs).first
us = DataLayouts.UniversalSize(dest)
return (fmb, us)
end

import MultiBroadcastFusion
function fused_copyto!(
fmbc::FusedMultiBroadcast,
fmb::FusedMultiBroadcast,
dest1::DataLayouts.AbstractData,
::ToCUDA,
)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest1)
if Nv > 0 && Nh > 0
bcs = map(p -> p.second, fmbc.pairs)
destinations = map(p -> p.first, fmbc.pairs)
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
!(VERSION v"1.11.0-beta")
pairs′ = map(fmbc.pairs) do p
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
Pair(p.first, Base.Broadcast.instantiate(bc′))
end
us = DataLayouts.UniversalSize(dest1)
fmbc′ = FusedMultiBroadcast(pairs′)
args = (fmbc′, us)
threads = threads_via_occupancy(knl_fused_copyto_linear!, args)
n_max_threads = min(threads, get_N(us))
p = linear_partition(prod(size(dest1)), n_max_threads)
auto_launch!(
knl_fused_copyto_linear!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
always_inline = false,
)
else
us = DataLayouts.UniversalSize(dest1)
args = (fmbc, dest1, us)
threads = threads_via_occupancy(knl_fused_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest1, n_max_threads)
auto_launch!(
knl_fused_copyto!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
(Nv > 0 && Nh > 0) || return nothing # short circuit

if pkgversion(MultiBroadcastFusion) >= v"0.3.3"
# Automatically split kernels by available parameter memory space:
fmbs = MBFCUDA.partition_kernels(
fmb,
FusedMultiBroadcast,
fused_multibroadcast_args,
)
for fmb in fmbs
launch_fused_copyto!(fmb)
end
else
launch_fused_copyto!(fmb)
end
return nothing
end

function launch_fused_copyto!(fmb::FusedMultiBroadcast)
dest1 = first(fmb.pairs).first
us = DataLayouts.UniversalSize(dest1)
destinations = map(p -> p.first, fmb.pairs)
bcs = map(p -> p.second, fmb.pairs)
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
!(VERSION v"1.11.0-beta")
pairs′ = map(fmb.pairs) do p
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
Pair(p.first, Base.Broadcast.instantiate(bc′))
end
fmb′ = FusedMultiBroadcast(pairs′)
args = (fmb′, us)
threads = threads_via_occupancy(knl_fused_copyto_linear!, args)
n_max_threads = min(threads, get_N(us))
p = linear_partition(prod(size(dest1)), n_max_threads)
auto_launch!(
knl_fused_copyto_linear!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
always_inline = false,
)
else
args = (fmb, dest1, us)
threads = threads_via_occupancy(knl_fused_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest1, n_max_threads)
auto_launch!(
knl_fused_copyto!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
end
return nothing
end

0 comments on commit 99132e0

Please sign in to comment.