diff --git a/Project.toml b/Project.toml index e60e8d562f..d1cb55b330 100644 --- a/Project.toml +++ b/Project.toml @@ -67,7 +67,7 @@ LinearAlgebra = "1" LazyBroadcast = "0.1" Logging = "1" MPI = "0.20" -MultiBroadcastFusion = "0.3" +MultiBroadcastFusion = "0.3, 0.4" NVTX = "0.3" OrderedCollections = "1" PkgVersion = "0.1, 0.2, 0.3" diff --git a/ext/cuda/data_layouts_fused_copyto.jl b/ext/cuda/data_layouts_fused_copyto.jl index 518e81d3f7..e8c170ceeb 100644 --- a/ext/cuda/data_layouts_fused_copyto.jl +++ b/ext/cuda/data_layouts_fused_copyto.jl @@ -72,7 +72,9 @@ function knl_fused_copyto_linear!(fmbc::FusedMultiBroadcast, us) end return nothing end - +import MultiBroadcastFusion +const MBFCUDA = + Base.get_extension(MultiBroadcastFusion, :MultiBroadcastFusionCUDAExt) # https://github.com/JuliaLang/julia/issues/56295 # Julia 1.11's Base.Broadcast currently requires # multiple integer indexing, wheras Julia 1.10 did not. @@ -80,48 +82,73 @@ end # special-case fixes for https://github.com/JuliaLang/julia/issues/28126 # (including the GPU-variant related issue resolution efforts: # JuliaGPU/GPUArrays.jl#454, JuliaGPU/GPUArrays.jl#464). + +function fused_multibroadcast_args(fmb::FusedMultiBroadcast) + dest = first(fmb.pairs).first + us = DataLayouts.UniversalSize(dest) + return (fmb, us) +end + +import MultiBroadcastFusion function fused_copyto!( - fmbc::FusedMultiBroadcast, + fmb::FusedMultiBroadcast, dest1::DataLayouts.AbstractData, ::ToCUDA, ) (_, _, Nv, _, Nh) = DataLayouts.universal_size(dest1) - if Nv > 0 && Nh > 0 - bcs = map(p -> p.second, fmbc.pairs) - destinations = map(p -> p.first, fmbc.pairs) - if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) && - all(d -> d isa DataLayouts.EndsWithField, destinations) && - !(VERSION ≥ v"1.11.0-beta") - pairs′ = map(fmbc.pairs) do p - bc′ = DataLayouts.to_non_extruded_broadcasted(p.second) - Pair(p.first, Base.Broadcast.instantiate(bc′)) - end - us = DataLayouts.UniversalSize(dest1) - fmbc′ = FusedMultiBroadcast(pairs′) - args = (fmbc′, us) - threads = threads_via_occupancy(knl_fused_copyto_linear!, args) - n_max_threads = min(threads, get_N(us)) - p = linear_partition(prod(size(dest1)), n_max_threads) - auto_launch!( - knl_fused_copyto_linear!, - args; - threads_s = p.threads, - blocks_s = p.blocks, - always_inline = false, - ) - else - us = DataLayouts.UniversalSize(dest1) - args = (fmbc, dest1, us) - threads = threads_via_occupancy(knl_fused_copyto!, args) - n_max_threads = min(threads, get_N(us)) - p = partition(dest1, n_max_threads) - auto_launch!( - knl_fused_copyto!, - args; - threads_s = p.threads, - blocks_s = p.blocks, - ) + (Nv > 0 && Nh > 0) || return nothing # short circuit + + if pkgversion(MultiBroadcastFusion) >= v"0.3.3" + # Automatically split kernels by available parameter memory space: + fmbs = MBFCUDA.partition_kernels( + fmb, + FusedMultiBroadcast, + fused_multibroadcast_args, + ) + for fmb in fmbs + launch_fused_copyto!(fmb) + end + else + launch_fused_copyto!(fmb) + end + return nothing +end + +function launch_fused_copyto!(fmb::FusedMultiBroadcast) + dest1 = first(fmb.pairs).first + us = DataLayouts.UniversalSize(dest1) + destinations = map(p -> p.first, fmb.pairs) + bcs = map(p -> p.second, fmb.pairs) + if all(bc -> DataLayouts.has_uniform_datalayouts(bc), bcs) && + all(d -> d isa DataLayouts.EndsWithField, destinations) && + !(VERSION ≥ v"1.11.0-beta") + pairs′ = map(fmb.pairs) do p + bc′ = DataLayouts.to_non_extruded_broadcasted(p.second) + Pair(p.first, Base.Broadcast.instantiate(bc′)) end + fmb′ = FusedMultiBroadcast(pairs′) + args = (fmb′, us) + threads = threads_via_occupancy(knl_fused_copyto_linear!, args) + n_max_threads = min(threads, get_N(us)) + p = linear_partition(prod(size(dest1)), n_max_threads) + auto_launch!( + knl_fused_copyto_linear!, + args; + threads_s = p.threads, + blocks_s = p.blocks, + always_inline = false, + ) + else + args = (fmb, dest1, us) + threads = threads_via_occupancy(knl_fused_copyto!, args) + n_max_threads = min(threads, get_N(us)) + p = partition(dest1, n_max_threads) + auto_launch!( + knl_fused_copyto!, + args; + threads_s = p.threads, + blocks_s = p.blocks, + ) end return nothing end