Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatically split fused kernels by parameter memory limits #2075

Merged
merged 2 commits into from
Nov 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -67,7 +67,7 @@ LinearAlgebra = "1"
LazyBroadcast = "0.1"
Logging = "1"
MPI = "0.20"
MultiBroadcastFusion = "0.3"
MultiBroadcastFusion = "0.3, 0.4"
NVTX = "0.3"
OrderedCollections = "1"
PkgVersion = "0.1, 0.2, 0.3"
101 changes: 64 additions & 37 deletions ext/cuda/data_layouts_fused_copyto.jl
Original file line number Diff line number Diff line change
@@ -72,56 +72,83 @@ function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us)
end
return nothing
end

import MultiBroadcastFusion
const MBFCUDA =
Base.get_extension(MultiBroadcastFusion, :MultiBroadcastFusionCUDAExt)
# https://github.com/JuliaLang/julia/issues/56295
# Julia 1.11's Base.Broadcast currently requires
# multiple integer indexing, wheras Julia 1.10 did not.
# This means that we cannot reserve linear indexing to
# special-case fixes for https://github.com/JuliaLang/julia/issues/28126
# (including the GPU-variant related issue resolution efforts:
# JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464).

function fused_multibroadcast_args(fmb::FusedMultiBroadcast)
dest = first(fmb.pairs).first
us = DataLayouts.UniversalSize(dest)
return (fmb, us)
end

import MultiBroadcastFusion
function fused_copyto!(
fmbc::FusedMultiBroadcast,
fmb::FusedMultiBroadcast,
dest1::DataLayouts.AbstractData,
::ToCUDA,
)
(_, _, Nv, _, Nh) = DataLayouts.universal_size(dest1)
if Nv > 0 && Nh > 0
bcs = map(p -> p.second, fmbc.pairs)
destinations = map(p -> p.first, fmbc.pairs)
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
!(VERSION ≥ v"1.11.0-beta")
pairs′ = map(fmbc.pairs) do p
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
Pair(p.first, Base.Broadcast.instantiate(bc′))
end
us = DataLayouts.UniversalSize(dest1)
fmbc′ = FusedMultiBroadcast(pairs′)
args = (fmbc′, us)
threads = threads_via_occupancy(knl_fused_copyto_linear!, args)
n_max_threads = min(threads, get_N(us))
p = linear_partition(prod(size(dest1)), n_max_threads)
auto_launch!(
knl_fused_copyto_linear!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
always_inline = false,
)
else
us = DataLayouts.UniversalSize(dest1)
args = (fmbc, dest1, us)
threads = threads_via_occupancy(knl_fused_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest1, n_max_threads)
auto_launch!(
knl_fused_copyto!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
(Nv > 0 && Nh > 0) || return nothing # short circuit

if pkgversion(MultiBroadcastFusion) >= v"0.3.3"
# Automatically split kernels by available parameter memory space:
fmbs = MBFCUDA.partition_kernels(
fmb,
FusedMultiBroadcast,
fused_multibroadcast_args,
)
for fmb in fmbs
launch_fused_copyto!(fmb)
end
else
launch_fused_copyto!(fmb)
end
return nothing
end

function launch_fused_copyto!(fmb::FusedMultiBroadcast)
dest1 = first(fmb.pairs).first
us = DataLayouts.UniversalSize(dest1)
destinations = map(p -> p.first, fmb.pairs)
bcs = map(p -> p.second, fmb.pairs)
if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) &&
all(d -> d isa DataLayouts.EndsWithField, destinations) &&
!(VERSION ≥ v"1.11.0-beta")
pairs′ = map(fmb.pairs) do p
bc′ = DataLayouts.to_non_extruded_broadcasted(p.second)
Pair(p.first, Base.Broadcast.instantiate(bc′))
end
fmb′ = FusedMultiBroadcast(pairs′)
args = (fmb′, us)
threads = threads_via_occupancy(knl_fused_copyto_linear!, args)
n_max_threads = min(threads, get_N(us))
p = linear_partition(prod(size(dest1)), n_max_threads)
auto_launch!(
knl_fused_copyto_linear!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
always_inline = false,
)
else
args = (fmb, dest1, us)
threads = threads_via_occupancy(knl_fused_copyto!, args)
n_max_threads = min(threads, get_N(us))
p = partition(dest1, n_max_threads)
auto_launch!(
knl_fused_copyto!,
args;
threads_s = p.threads,
blocks_s = p.blocks,
)
end
return nothing
end