From 22488faffac1612823df4ed61f1574b13a56d8e1 Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Thu, 3 Oct 2024 09:16:22 -0700 Subject: [PATCH] Added ranks 96 and 128 to BGMV kernel --- server/lorax_server/utils/graph.py | 4 ++-- server/lorax_server/utils/sgmv.py | 2 +- server/punica_kernels/punica_kernels/bgmv/bgmv_config.h | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/server/lorax_server/utils/graph.py b/server/lorax_server/utils/graph.py index 27c345d00..0bafc1559 100644 --- a/server/lorax_server/utils/graph.py +++ b/server/lorax_server/utils/graph.py @@ -26,7 +26,7 @@ MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 96)) -MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", BGMV_MAX_RANK)) +MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64)) SLOT_PAD_VALUE = -1 SEGMENT_PAD_VALUE = -1 @@ -43,7 +43,7 @@ # Include 0 to ensure we can use cuda graphs without adapters # TODO(travis): use padding to allow for more ranks without increasing memory usage -CACHED_MAX_RANKS = [0, 8, 16, 32, 64] +CACHED_MAX_RANKS = [0, 8, 16, 32, 64, 96, 128] CACHED_MAX_RANKS = [r for r in CACHED_MAX_RANKS if r <= MAX_RANK] _allowed_ranks = set(CACHED_MAX_RANKS) diff --git a/server/lorax_server/utils/sgmv.py b/server/lorax_server/utils/sgmv.py index 2a4701d09..6efb2647f 100644 --- a/server/lorax_server/utils/sgmv.py +++ b/server/lorax_server/utils/sgmv.py @@ -20,7 +20,7 @@ MIN_RANK_CUSTOM = 16 MAX_RANK_CUSTOM = 128 SGMV_BLOCK_SIZE = 16 -BGMV_MAX_RANK = 64 +BGMV_MAX_RANK = 128 def has_sgmv() -> bool: diff --git a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h index 00af5a3e0..ecc42bcb9 100644 --- a/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h +++ b/server/punica_kernels/punica_kernels/bgmv/bgmv_config.h @@ -77,6 +77,8 @@ void bgmv_kernel(T *__restrict__ Y, const T *__restrict__ X, FOR_BGMV_WIDE(f, T, 8) \ FOR_BGMV_WIDE(f, T, 16) \ FOR_BGMV_WIDE(f, T, 32) \ - FOR_BGMV_WIDE(f, T, 64) + FOR_BGMV_WIDE(f, T, 64) \ + FOR_BGMV_WIDE(f, T, 96) \ + FOR_BGMV_WIDE(f, T, 128) // clang-format on