diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 4926512e85b28..9ddfc7ad10ac0 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -20,6 +20,7 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from parameterized import parameterized from rotary_flash import apply_rotary_emb from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -56,6 +57,14 @@ def __init__(self, b, s, s2, sp, n, n2, h): self.kv_num_heads = n2 self.head_size = h + def __repr__(self): + return ( + f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " + + f"kv_sequence_length={self.kv_sequence_length}, past_sequence_length={self.past_sequence_length}, " + + f"past_sequence_length={self.past_sequence_length}, num_heads={self.num_heads}, " + + f"kv_num_heads={self.kv_num_heads}, head_size={self.head_size})" + ) + class PromptConfig: batch_size = 0 @@ -75,6 +84,13 @@ def __init__(self, b, sq, skv, sb, n, n2, h): self.kv_num_heads = n2 self.head_size = h + def __repr__(self): + return ( + f"PromptConfig(batch_size={self.batch_size}, q_sequence_length={self.q_sequence_length}, " + + f"kv_sequence_length={self.kv_sequence_length}, buffer_sequence_length={self.buffer_sequence_length}, " + + f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size})" + ) + def create_packed_multihead_attention_graph(config): nodes = [ @@ -1974,293 +1990,357 @@ def parity_check_gqa_past_no_buff( return all_close +def packed_mha_test_cases(): + batches = [2] if pipeline_mode else [1, 5] + seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s, 0, n, n, h) + yield str(config), config + + +def mha_test_cases(): + batches = [2] if pipeline_mode else [1, 5] + seqs = ( + [(1, 128), (113, 211), (2048, 2048)] + if pipeline_mode + else [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ] + ) + num_h = [1, 3] if pipeline_mode else [1, 6, 16] + h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + for b in batches: + for s, s2 in seqs: + for n in num_h: + for h in h_sizes: + config = Config(b, s, s2, 0, n, n, h) + yield str(config), config + + class TestMHA(unittest.TestCase): - def test_packed_mha(self): + @parameterized.expand(packed_mha_test_cases()) + def test_packed_mha(self, _, config): if not torch.cuda.is_available() or platform.system() != "Linux": return major, _ = torch.cuda.get_device_capability() if major < 8: return print("-------- TEST PACKED MHA ---------") - batches = [2] if pipeline_mode else [1, 5] - seqs = [8, 97, 256, 1024] if pipeline_mode else [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048] - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - for b in batches: - for s in seqs: - for n in num_h: - for h in h_sizes: - config = Config(b, s, s, 0, n, n, h) - all_close = parity_check_mha(config, True) - self.assertTrue(all_close) - - def test_mha(self): + all_close = parity_check_mha(config, True) + self.assertTrue(all_close) + + @parameterized.expand(mha_test_cases()) + def test_mha(self, _, config): if not torch.cuda.is_available() or platform.system() != "Linux": return major, _ = torch.cuda.get_device_capability() if major < 8: return print("-------- TEST MHA ---------") - batches = [2] if pipeline_mode else [1, 5] - seqs = ( - [(1, 128), (113, 211), (2048, 2048)] - if pipeline_mode - else [ - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), - (512, 256), - (1024, 1024), - (1023, 1024), - (1024, 1023), - (2048, 2048), - ] - ) - num_h = [1, 3] if pipeline_mode else [1, 6, 16] - h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - for b in batches: - for s, s2 in seqs: - for n in num_h: - for h in h_sizes: - config = Config(b, s, s2, 0, n, n, h) - all_close = parity_check_mha(config, False) - self.assertTrue(all_close) + all_close = parity_check_mha(config, False) + self.assertTrue(all_close) -class TestGQA(unittest.TestCase): - def test_gqa_no_past_memory_efficient(self): - if not torch.cuda.is_available(): - return - major, minor = torch.cuda.get_device_capability() - torch.manual_seed(69) - batches = [3] if pipeline_mode else [1, 3, 5] - seqs = ( - [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] - if pipeline_mode - else [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] - ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - if major < 5 or (major == 5 and minor < 3): - return - print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - for b in batches: - for sq, skv in seqs: - for n, n2 in num_h: - for h in h_sizes: +def gqa_no_past_memory_efficient_test_cases(): + batches = [3] if pipeline_mode else [1, 3, 5] + seqs = ( + [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + if pipeline_mode + else [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + torch.manual_seed(69) + + for b in batches: + for sq, skv in seqs: + for n, n2 in num_h: + for h in h_sizes: + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + ) + + +def gqa_no_past_flash_attention_test_cases(): + batches = [3] if pipeline_mode else [1, 3, 5] + seqs = ( + [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + if pipeline_mode + else [ + (127, 127), + (35, 35), + (2000, 2000), + (200, 200), + (240, 240), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + torch.manual_seed(69) + + for b in batches: + for sq, skv in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: for packed in [False, True]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - all_close = parity_check_gqa_prompt( + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", config, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, + local, + rotary, + rotary_interleaved, + packed, ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( + + +def gqa_past_memory_efficient_test_cases(): + batches = [5] if pipeline_mode else [1, 3, 5] + seqs = ( + [(1, 128), (1, 1024), (1, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + ) + + +def gqa_past_flash_attention_test_cases(): + batches = [5] if pipeline_mode else [1, 3, 5] + seqs = ( + [(1, 128), (1, 1024), (1, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", config, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, + local, + rotary, + rotary_interleaved, + packed, ) - self.assertTrue(all_close) - def test_gqa_no_past_flash_attention(self): + +class TestGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) + def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed): if not torch.cuda.is_available(): return - major, _ = torch.cuda.get_device_capability() - torch.manual_seed(69) - batches = [3] if pipeline_mode else [1, 3, 5] - seqs = ( - [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] - if pipeline_mode - else [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] + major, minor = torch.cuda.get_device_capability() + if major < 5 or (major == 5 and minor < 3): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") + + all_close = parity_check_gqa_prompt( + config, + rtol=5e-3, + atol=5e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + rtol=5e-3, + atol=5e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + self.assertTrue(all_close) + + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + if not torch.cuda.is_available(): + return + major, _ = torch.cuda.get_device_capability() if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (PROMPT CASE) --------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - for b in batches: - for sq, skv in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for packed in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - all_close = parity_check_gqa_prompt( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - - def test_gqa_past_memory_efficient(self): + + all_close = parity_check_gqa_prompt( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + + @parameterized.expand(gqa_past_memory_efficient_test_cases()) + def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed): if not torch.cuda.is_available(): return major, minor = torch.cuda.get_device_capability() if major < 5 or (major == 5 and minor < 3): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - batches = [5] if pipeline_mode else [1, 3, 5] - seqs = ( - [(1, 128), (1, 1024), (1, 2048)] - if pipeline_mode - else [ - (1, 128), - (1, 339), - (1, 1024), - (1, 5000), - (1, 800), - (1, 256), - (1, 799), - (1, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for packed in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - all_close = parity_check_gqa_past( - config, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - def test_gqa_past_flash_attention(self): + all_close = parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): if not torch.cuda.is_available(): return major, _ = torch.cuda.get_device_capability() - batches = [5] if pipeline_mode else [1, 3, 5] - seqs = ( - [(1, 128), (1, 1024), (1, 2048)] - if pipeline_mode - else [ - (1, 128), - (1, 339), - (1, 1024), - (1, 5000), - (1, 800), - (1, 256), - (1, 799), - (1, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (TOKEN GEN) -------") os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for packed in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - all_close = parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + + all_close = parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) if __name__ == "__main__":