Skip to content

Commit

Permalink
Merge pull request #52 from CliMA/ck/fix_kernel_splitting
Browse files Browse the repository at this point in the history
Dynamically split kernels based on parameter memory
  • Loading branch information
charleskawczynski authored Nov 1, 2024
2 parents 041fdee + 6cb6d4b commit 5633389
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 33 deletions.
1 change: 0 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ agents:

env:
JULIA_LOAD_PATH: "${JULIA_LOAD_PATH}:${BUILDKITE_BUILD_CHECKOUT_PATH}/.buildkite"
JULIA_DEPOT_PATH: "${BUILDKITE_BUILD_PATH}/${BUILDKITE_PIPELINE_SLUG}/depot/default"
JULIA_MAX_NUM_PRECOMPILE_FILES: 100
JULIA_CPU_TARGET: 'broadwell;skylake'
JULIA_NVTX_CALLBACKS: gc
Expand Down
82 changes: 67 additions & 15 deletions ext/MultiBroadcastFusionCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,75 @@ import MultiBroadcastFusion: fused_copyto!

MBF.device(x::CUDA.CuArray) = MBF.MBF_CUDA()

include("parameter_memory.jl")

"""
partition_kernels(fmb;
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function =
)
Splits fused broadcast kernels into a vector
of kernels, based on parameter memory limitations.
We first attempt to fuse
1:N, 1:N-1, 1:N-2, ... until we fuse 1:N-k
Next, we attempt to fuse
N-k+1:N, N-k+1:N-1, N-k+1:N-2, ...
And so forth.
"""
function partition_kernels(
fmb,
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function = fused_multibroadcast_args,
)
plim = get_param_lim()
usage = param_usage_args(args_func(fmb))
n_bins = 1
fmbs = (fmb,)
usage plim && return fmbs
fmbs_split = []
N = length(fmb.pairs)
i_start = 1
i_stop = N
while i_stop i_start
ith_pairs = fmb.pairs[i_start:i_stop]
ith_fmb = fused_broadcast_constructor(ith_pairs)
if param_usage_args(args_func(ith_fmb)) plim # first iteration will likely fail (ambitious)
push!(fmbs_split, ith_fmb)
i_stop == N && break
i_start = i_stop + 1 # N on first iteration
i_stop = N # reset i_stop
else
i_stop = i_stop - 1
end
end
return fmbs_split
end

function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA)
(; pairs) = fmb
dest = first(pairs).first
destinations = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), destinations) ||
error("Cannot fuse broadcast expressions with unequal broadcast axes")
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
destinations = map(p -> p.first, fmb.pairs)
fmbs = partition_kernels(fmb)
for fmb in fmbs
(; pairs) = fmb
dest = first(pairs).first
dests = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), dests) || error(
"Cannot fuse broadcast expressions with unequal broadcast axes",
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
end
return destinations
end
import Base.Broadcast
Expand Down
168 changes: 168 additions & 0 deletions ext/parameter_memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
function get_param_lim()
config = CUDA.compiler_config(CUDA.device())
(; ptx, cap) = config.params
return cap >= v"7.0" && ptx >= v"8.1" ? 32764 : 4096
end
param_usage(arg) = sizeof(typeof(CUDA.cudaconvert(arg)))
param_usage_args(args) =
sum(x -> param_usage(x), args) + param_usage(CUDA.KernelState)

function fused_multibroadcast_args(fmb::MBF.FusedMultiBroadcast)
dest = first(fmb.pairs).first
CI = CartesianIndices(axes(dest))
return (fmb, CI)
end

"""
Options(
[types...];
match_only::Bool = false
print_types::Bool = false
recursion_types = (UnionAll,DataType)
recursion_depth = 1000
)
Printing options for `@rprint_parameter_memory`:
- `match_only`: only print properties that match the given types
- `print_types`: print types (e.g., `prop::typeof(prop)`)
- `recursion_types`: skip recursing through recursion types (e.g., `UnionAll` and `DataType`)
to avoid infinite recursion
- `recursion_depth`: limit recursion depth (to avoid infinite recursion)
"""
struct Options{T}
types::T
match_only::Bool
print_types::Bool
recursion_types::Tuple
recursion_depth::Int
size_threshhold::Int
max_type_depth::Int
function Options(
types...;
match_only = false,
print_types = true,
recursion_types = (UnionAll, DataType),
recursion_depth = 1000,
size_threshhold = 10,
max_type_depth = 1,
)
if (types isa AbstractArray || types isa Tuple) && length(types) > 0
types = types[1]
else
types = (Union{},)
end
return new{typeof(types)}(
types,
match_only,
print_types,
recursion_types,
recursion_depth,
size_threshhold,
max_type_depth,
)
end
end
Options(type::Type; kwargs...) = Options((type,); kwargs...)

Options() = Options(();)

function type_string(io, obj; maxdepth)
sz = get(io, :displaysize, displaysize(io))::Tuple{Int, Int}
S = max(sz[2], 120)
slim = Base.type_depth_limit(string(typeof(obj)), S; maxdepth)
return slim
end

function _rprint_parameter_memory(io, obj, pc; o::Options, name, counter = 0)
counter > o.recursion_depth && return
for pn in propertynames(obj)
prop = getproperty(obj, pn)
pc_full = (pc..., ".", pn)
pc_string = name * string(join(pc_full))
if any(map(type -> prop isa type, o.types))
suffix =
o.print_types ?
"::$(type_string(io, prop; maxdepth=o.max_type_depth))" : ""
s = sizeof(typeof(CUDA.cudaconvert(prop)))
if s > o.size_threshhold
println(io, "size: $s, $pc_string$suffix")
end
if !any(map(x -> prop isa x, o.recursion_types))
_rprint_parameter_memory(
io,
prop,
pc_full;
o,
name,
counter = counter + 1,
)
counter > o.recursion_depth && return
end
else
if !o.match_only
suffix =
o.print_types ?
"::$(type_string(io, prop; maxdepth=o.max_type_depth))" : ""
s = sizeof(typeof(CUDA.cudaconvert(prop)))
if s > o.size_threshhold
println(io, "size: $s, $pc_string$suffix")
end
end
if !any(map(x -> prop isa x, o.recursion_types))
_rprint_parameter_memory(
io,
prop,
pc_full;
o,
name,
counter = counter + 1,
)
end
counter > o.recursion_depth && return
end
end
end

print_name(io, name, o) = o.match_only || println(io, name)

function rprint_parameter_memory(io, obj, name, o::Options = Options())
print_name(io, name, o)
_rprint_parameter_memory(
io,
obj,
(); # pc
o,
name,
)
println(io, "")
end

"""
@rprint_parameter_memory obj options
Recursively print out propertynames and
parameter memory of `obj` given options
`options`. See [`Options`](@ref) for more
information on available options.
"""
macro rprint_parameter_memory(obj, o)
return :(rprint_parameter_memory(
stdout,
$(esc(obj)),
$(string(obj)),
$(esc(o)),
))
end

"""
@rprint_parameter_memory obj options
Recursively print out propertynames and
parameter memory of `obj` given options
`options`. See [`Options`](@ref) for more
information on available options.
"""
macro rprint_parameter_memory(obj)
return :(rprint_parameter_memory(stdout, $(esc(obj)), $(string(obj))))
end
42 changes: 42 additions & 0 deletions src/collection/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,45 @@ macro make_fused(fusion_style, type_name, fused_name)
end
end
end

"""
@make_get_fused fusion_style type_name fused_named
This macro
- Defines a type type_name
- Defines a macro, `@fused_name`, using the fusion type `fusion_style`
This allows users to flexibility
to customize their broadcast fusion.
# Example
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedBroadcast
MBF.@make_get_fused MBF.fused_direct MyFusedBroadcast get_fused
x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)
# 4 reads, 2 writes
fmb = @get_fused begin
@. y1 = x1
@. y2 = x1
end
@test fmb isa MyFusedBroadcast
```
"""
macro make_get_fused(fusion_style, type_name, fused_name)
t = esc(type_name)
f = esc(fused_name)
return quote
macro $f(expr)
_pairs = esc($(fusion_style)(expr))
t = $t
quote
$t($_pairs)
end
end
end
end
1 change: 1 addition & 0 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@make_type FusedMultiBroadcast
@make_fused fused_direct FusedMultiBroadcast fused_direct
@make_fused fused_assemble FusedMultiBroadcast fused_assemble
@make_get_fused fused_direct FusedMultiBroadcast get_fused_direct

struct MBF_CPU end
struct MBF_CUDA end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using TestEnv
TestEnv.activate()
using CUDA # (optional)
using Revise; include(joinpath("test", "execution", "parameter_memory.jl"))
using Revise; include(joinpath("test", "execution", "kernel_splitting.jl"))
=#

include("utils_test.jl")
Expand Down Expand Up @@ -49,9 +49,7 @@ function perf_kernel_shared_reads_fused!(X, Y)
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory
end
end
#! format: on
Expand All @@ -66,18 +64,8 @@ problem_size = (50, 5, 5, 6, 5400)
array_size = problem_size # array
X = get_arrays(:x, AType, bm.float_type, array_size)
Y = get_arrays(:y, AType, bm.float_type, array_size)
@testset "Test breaking case with parameter memory" begin
if use_cuda
try
perf_kernel_shared_reads_fused!(X, Y)
error("The above kernel should error")
catch e
@test startswith(
e.msg,
"Kernel invocation uses too much parameter memory.",
)
end
end
@testset "Test kernel splitting with too much parameter memory" begin
use_cuda && perf_kernel_shared_reads_fused!(X, Y)
end

nothing
Loading

0 comments on commit 5633389

Please sign in to comment.