Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamically split kernels based on parameter memory #52

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ agents:

env:
JULIA_LOAD_PATH: "${JULIA_LOAD_PATH}:${BUILDKITE_BUILD_CHECKOUT_PATH}/.buildkite"
JULIA_DEPOT_PATH: "${BUILDKITE_BUILD_PATH}/${BUILDKITE_PIPELINE_SLUG}/depot/default"
JULIA_MAX_NUM_PRECOMPILE_FILES: 100
JULIA_CPU_TARGET: 'broadwell;skylake'
JULIA_NVTX_CALLBACKS: gc
Expand Down
82 changes: 67 additions & 15 deletions ext/MultiBroadcastFusionCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,75 @@ import MultiBroadcastFusion: fused_copyto!

MBF.device(x::CUDA.CuArray) = MBF.MBF_CUDA()

include("parameter_memory.jl")

"""
partition_kernels(fmb;
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function =
)

Splits fused broadcast kernels into a vector
of kernels, based on parameter memory limitations.

We first attempt to fuse
1:N, 1:N-1, 1:N-2, ... until we fuse 1:N-k
Next, we attempt to fuse
N-k+1:N, N-k+1:N-1, N-k+1:N-2, ...

And so forth.
"""
function partition_kernels(
fmb,
fused_broadcast_constructor = MBF.FusedMultiBroadcast,
args_func::Function = fused_multibroadcast_args,
)
plim = get_param_lim()
usage = param_usage_args(args_func(fmb))
n_bins = 1
fmbs = (fmb,)
usage ≤ plim && return fmbs
fmbs_split = []
N = length(fmb.pairs)
i_start = 1
i_stop = N
while i_stop ≠ i_start
ith_pairs = fmb.pairs[i_start:i_stop]
ith_fmb = fused_broadcast_constructor(ith_pairs)
if param_usage_args(args_func(ith_fmb)) ≤ plim # first iteration will likely fail (ambitious)
push!(fmbs_split, ith_fmb)
i_stop == N && break
i_start = i_stop + 1 # N on first iteration
i_stop = N # reset i_stop
else
i_stop = i_stop - 1
end
end
return fmbs_split
end

function fused_copyto!(fmb::MBF.FusedMultiBroadcast, ::MBF.MBF_CUDA)
(; pairs) = fmb
dest = first(pairs).first
destinations = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), destinations) ||
error("Cannot fuse broadcast expressions with unequal broadcast axes")
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
destinations = map(p -> p.first, fmb.pairs)
fmbs = partition_kernels(fmb)
for fmb in fmbs
(; pairs) = fmb
dest = first(pairs).first
dests = map(p -> p.first, pairs)
all(a -> axes(a) == axes(dest), dests) || error(
"Cannot fuse broadcast expressions with unequal broadcast axes",
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
nitems = length(parent(dest))
CI = CartesianIndices(axes(dest))
kernel =
CUDA.@cuda always_inline = true launch = false fused_copyto_kernel!(
fmb,
CI,
)
config = CUDA.launch_configuration(kernel.fun)
threads = min(nitems, config.threads)
blocks = cld(nitems, threads)
kernel(fmb, CI; threads, blocks)
end
return destinations
end
import Base.Broadcast
Expand Down
168 changes: 168 additions & 0 deletions ext/parameter_memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
function get_param_lim()
config = CUDA.compiler_config(CUDA.device())
(; ptx, cap) = config.params
return cap >= v"7.0" && ptx >= v"8.1" ? 32764 : 4096
end
param_usage(arg) = sizeof(typeof(CUDA.cudaconvert(arg)))
param_usage_args(args) =
sum(x -> param_usage(x), args) + param_usage(CUDA.KernelState)

function fused_multibroadcast_args(fmb::MBF.FusedMultiBroadcast)
dest = first(fmb.pairs).first
CI = CartesianIndices(axes(dest))
return (fmb, CI)
end

"""
Options(
[types...];
match_only::Bool = false
print_types::Bool = false
recursion_types = (UnionAll,DataType)
recursion_depth = 1000
)

Printing options for `@rprint_parameter_memory`:

- `match_only`: only print properties that match the given types
- `print_types`: print types (e.g., `prop::typeof(prop)`)
- `recursion_types`: skip recursing through recursion types (e.g., `UnionAll` and `DataType`)
to avoid infinite recursion
- `recursion_depth`: limit recursion depth (to avoid infinite recursion)
"""
struct Options{T}
types::T
match_only::Bool
print_types::Bool
recursion_types::Tuple
recursion_depth::Int
size_threshhold::Int
max_type_depth::Int
function Options(
types...;
match_only = false,
print_types = true,
recursion_types = (UnionAll, DataType),
recursion_depth = 1000,
size_threshhold = 10,
max_type_depth = 1,
)
if (types isa AbstractArray || types isa Tuple) && length(types) > 0
types = types[1]
else
types = (Union{},)
end
return new{typeof(types)}(
types,
match_only,
print_types,
recursion_types,
recursion_depth,
size_threshhold,
max_type_depth,
)
end
end
Options(type::Type; kwargs...) = Options((type,); kwargs...)

Options() = Options(();)

function type_string(io, obj; maxdepth)
sz = get(io, :displaysize, displaysize(io))::Tuple{Int, Int}
S = max(sz[2], 120)
slim = Base.type_depth_limit(string(typeof(obj)), S; maxdepth)
return slim
end

function _rprint_parameter_memory(io, obj, pc; o::Options, name, counter = 0)
counter > o.recursion_depth && return
for pn in propertynames(obj)
prop = getproperty(obj, pn)
pc_full = (pc..., ".", pn)
pc_string = name * string(join(pc_full))
if any(map(type -> prop isa type, o.types))
suffix =
o.print_types ?
"::$(type_string(io, prop; maxdepth=o.max_type_depth))" : ""
s = sizeof(typeof(CUDA.cudaconvert(prop)))
if s > o.size_threshhold
println(io, "size: $s, $pc_string$suffix")
end
if !any(map(x -> prop isa x, o.recursion_types))
_rprint_parameter_memory(
io,
prop,
pc_full;
o,
name,
counter = counter + 1,
)
counter > o.recursion_depth && return
end
else
if !o.match_only
suffix =
o.print_types ?
"::$(type_string(io, prop; maxdepth=o.max_type_depth))" : ""
s = sizeof(typeof(CUDA.cudaconvert(prop)))
if s > o.size_threshhold
println(io, "size: $s, $pc_string$suffix")
end
end
if !any(map(x -> prop isa x, o.recursion_types))
_rprint_parameter_memory(
io,
prop,
pc_full;
o,
name,
counter = counter + 1,
)
end
counter > o.recursion_depth && return
end
end
end

print_name(io, name, o) = o.match_only || println(io, name)

function rprint_parameter_memory(io, obj, name, o::Options = Options())
print_name(io, name, o)
_rprint_parameter_memory(
io,
obj,
(); # pc
o,
name,
)
println(io, "")
end

"""
@rprint_parameter_memory obj options

Recursively print out propertynames and
parameter memory of `obj` given options
`options`. See [`Options`](@ref) for more
information on available options.
"""
macro rprint_parameter_memory(obj, o)
return :(rprint_parameter_memory(
stdout,
$(esc(obj)),
$(string(obj)),
$(esc(o)),
))
end

"""
@rprint_parameter_memory obj options

Recursively print out propertynames and
parameter memory of `obj` given options
`options`. See [`Options`](@ref) for more
information on available options.
"""
macro rprint_parameter_memory(obj)
return :(rprint_parameter_memory(stdout, $(esc(obj)), $(string(obj))))
end
42 changes: 42 additions & 0 deletions src/collection/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,45 @@ macro make_fused(fusion_style, type_name, fused_name)
end
end
end

"""
@make_get_fused fusion_style type_name fused_named

This macro
- Defines a type type_name
- Defines a macro, `@fused_name`, using the fusion type `fusion_style`

This allows users to flexibility
to customize their broadcast fusion.

# Example
```julia
import MultiBroadcastFusion as MBF
MBF.@make_type MyFusedBroadcast
MBF.@make_get_fused MBF.fused_direct MyFusedBroadcast get_fused

x1 = rand(3,3)
y1 = rand(3,3)
y2 = rand(3,3)

# 4 reads, 2 writes
fmb = @get_fused begin
@. y1 = x1
@. y2 = x1
end
@test fmb isa MyFusedBroadcast
```
"""
macro make_get_fused(fusion_style, type_name, fused_name)
t = esc(type_name)
f = esc(fused_name)
return quote
macro $f(expr)
_pairs = esc($(fusion_style)(expr))
t = $t
quote
$t($_pairs)
end
end
end
end
1 change: 1 addition & 0 deletions src/execution/fused_kernels.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
@make_type FusedMultiBroadcast
@make_fused fused_direct FusedMultiBroadcast fused_direct
@make_fused fused_assemble FusedMultiBroadcast fused_assemble
@make_get_fused fused_direct FusedMultiBroadcast get_fused_direct

struct MBF_CPU end
struct MBF_CUDA end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using TestEnv
TestEnv.activate()
using CUDA # (optional)
using Revise; include(joinpath("test", "execution", "parameter_memory.jl"))
using Revise; include(joinpath("test", "execution", "kernel_splitting.jl"))
=#

include("utils_test.jl")
Expand Down Expand Up @@ -49,9 +49,7 @@ function perf_kernel_shared_reads_fused!(X, Y)
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y1 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2
@. y3 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 + x3
@. y4 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4
@. y2 = x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 * x2 + x3 + x4 + x1 + x2 # breaks on A100 due to too much parameter memory
end
end
#! format: on
Expand All @@ -66,18 +64,8 @@ problem_size = (50, 5, 5, 6, 5400)
array_size = problem_size # array
X = get_arrays(:x, AType, bm.float_type, array_size)
Y = get_arrays(:y, AType, bm.float_type, array_size)
@testset "Test breaking case with parameter memory" begin
if use_cuda
try
perf_kernel_shared_reads_fused!(X, Y)
error("The above kernel should error")
catch e
@test startswith(
e.msg,
"Kernel invocation uses too much parameter memory.",
)
end
end
@testset "Test kernel splitting with too much parameter memory" begin
use_cuda && perf_kernel_shared_reads_fused!(X, Y)
end

nothing
Loading
Loading