From b3d08ad8106c58ecc77ed0e3411829a72783071f Mon Sep 17 00:00:00 2001 From: kaixih Date: Sat, 16 Mar 2024 05:12:35 +0000 Subject: [PATCH] Add config to use cudnn attention --- paxml/tasks/lm/model_params.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/paxml/tasks/lm/model_params.py b/paxml/tasks/lm/model_params.py index a4267f1f3..c002c05f2 100644 --- a/paxml/tasks/lm/model_params.py +++ b/paxml/tasks/lm/model_params.py @@ -33,6 +33,7 @@ from praxis import schedules from praxis.layers import activations from praxis.layers import embedding_softmax +from praxis.layers import gpu_fast_attention from praxis.layers import models from praxis.layers import transformer_models from praxis.layers.injection import fp8_nvidia_gpu as fp8_ops @@ -536,6 +537,7 @@ class TransformerLmSpmdAdafactor(base_experiment.BaseExperiment): USE_GATED_ACTIVATION = False DECAY_END = 100000 USE_FP8 = False + USE_CUDNN_FLASH_ATTENTION = False # optimizer related DROPOUT_PROB = 0.0 @@ -652,6 +654,16 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]: if self.USE_ROTARY_POSITION_EMB: transformer_layer_p.tr_atten_tpl.use_rotary_position_emb = True + if self.USE_CUDNN_FLASH_ATTENTION: + assert transformer_layer_p.tr_atten_tpl.cls == layers.DotProductAttention + assert model_p.lm_tpl.model_type == transformer_models.LanguageModelType.CAUSAL + fused_tr_atten_tpl = pax_fiddle.Config( + gpu_fast_attention.GpuCudnnFusedDotProductAttention, + is_causal=True, + ) + fused_tr_atten_tpl.copy_fields_from(transformer_layer_p.tr_atten_tpl) + transformer_layer_p.tr_atten_tpl = fused_tr_atten_tpl + if self.USE_REPEATED_LAYER: model_p.lm_tpl.stacked_transformer_tpl = pax_fiddle.Config( layers.StackedTransformerRepeated