Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DLIGHT][GPU] Improved gemv outer fallback schedule #16973

Merged
merged 3 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,8 @@ def apply(
TS, TR = 4, 64
else:
TS, TR = 16, 32
else:
TS, TR = 1, 64
elif target.kind.name == "metal":
# Note that the following tile size is tuned on M2 Ultra for 7B
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
Expand All @@ -476,6 +478,8 @@ def apply(
TS, TR = 4, 16
else:
TS, TR = 2, 64
else:
TS, TR = 1, 64
elif target.kind.name == "rocm":
VEC_C = 4
# TODO: set LOAD_V_SHARED = False for now
Expand All @@ -489,13 +493,15 @@ def apply(
TS, TR = 1, 128
else:
TS, TR = 8, 64
else:
TS, TR = 1, 64
elif target.kind.name == "opencl" and "android" in str(target.host):
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 8
LOAD_V_SHARED = False
LOAD_V_VEC = -1
UNROLL = 8
TS, TR = 2, 64
TS, TR = 2, 32
elif target.kind.name == "vulkan":
VEC_C = 4
LOAD_V_SHARED = True
Expand All @@ -506,6 +512,8 @@ def apply(
TS, TR = 4, 32
else:
TS, TR = 16, 32
else:
TS, TR = 1, 64
elif target.kind.name == "opencl" and "mali" in str(target.attrs):
VEC_C = 8
LOAD_V_SHARED = False
Expand All @@ -519,9 +527,6 @@ def apply(
UNROLL = 64
TS, TR = 1, 64

if not isinstance(len_S, int):
TS, TR = 1, 64

while TS * TR > target.max_num_threads:
if TS > 1:
TS //= 2
Expand Down Expand Up @@ -709,7 +714,11 @@ def apply(
if not isinstance(len_r, int):
return None

if isinstance(len_s, int) and len_s > 32000:
if not isinstance(len_s, int):
TS, TR = 256, 1
LOAD_V_SHARED = True

if isinstance(len_s, int) and len_s > 96000:
return None

_, TILE_R = (
Expand Down Expand Up @@ -754,7 +763,8 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid
len_s = get_extent(sch, s)

# The config is designed for Adreno
tx_len = 64
LOAD_V_SHARED = 1
tx_len = 128
vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1
inner_r = 4

Expand All @@ -768,16 +778,23 @@ def sch_outer_reduction_fallback( # pylint: disable=too-many-arguments, invalid
sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8)
sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1)

cache_v = sch.cache_read(block, vector_input_buffers[0], "local")
sch.compute_at(cache_v, r1, preserve_unit_loops=True)
sch.vectorize(sch.get_loops(cache_v)[-1])
if LOAD_V_SHARED:
V_shared = sch.cache_read(block, vector_input_buffers[0], storage_scope="shared")
sch.compute_at(V_shared, bx, preserve_unit_loops=True)
l = sch.get_loops(block=V_shared)[-1]
_, tx, vec_r = sch.split(l, factors=[None, tx_len, 8], preserve_unit_iters=True)
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec_r)

sch.vectorize(vec)

# Schedule epilogue
if epilogue_info is not None:
sch.reverse_compute_at(epilogue_info.block_rv, tx)

sch.reverse_compute_at(epilogue_info.block_rv, bx, preserve_unit_loops=True)
ts_tile_s = sch.get_loops(epilogue_info.block_rv)[-1]
ts, vec = sch.split(ts_tile_s, factors=[tx_len, vec_len], preserve_unit_iters=True)
sch.bind(ts, "threadIdx.x")
sch.vectorize(vec)
sch.set_scope(block, 0, "local")

sch.decompose_reduction(block, r0)
Expand Down
Loading
Loading