Skip to content

Commit

Permalink
Add config to use cudnn attention
Browse files Browse the repository at this point in the history
  • Loading branch information
kaixih committed Mar 22, 2024
1 parent 378c6ad commit b3d08ad
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions paxml/tasks/lm/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from praxis import schedules
from praxis.layers import activations
from praxis.layers import embedding_softmax
from praxis.layers import gpu_fast_attention
from praxis.layers import models
from praxis.layers import transformer_models
from praxis.layers.injection import fp8_nvidia_gpu as fp8_ops
Expand Down Expand Up @@ -536,6 +537,7 @@ class TransformerLmSpmdAdafactor(base_experiment.BaseExperiment):
USE_GATED_ACTIVATION = False
DECAY_END = 100000
USE_FP8 = False
USE_CUDNN_FLASH_ATTENTION = False

# optimizer related
DROPOUT_PROB = 0.0
Expand Down Expand Up @@ -652,6 +654,16 @@ def task(self) -> pax_fiddle.Config[tasks_lib.SingleTask]:
if self.USE_ROTARY_POSITION_EMB:
transformer_layer_p.tr_atten_tpl.use_rotary_position_emb = True

if self.USE_CUDNN_FLASH_ATTENTION:
assert transformer_layer_p.tr_atten_tpl.cls == layers.DotProductAttention
assert model_p.lm_tpl.model_type == transformer_models.LanguageModelType.CAUSAL
fused_tr_atten_tpl = pax_fiddle.Config(
gpu_fast_attention.GpuCudnnFusedDotProductAttention,
is_causal=True,
)
fused_tr_atten_tpl.copy_fields_from(transformer_layer_p.tr_atten_tpl)
transformer_layer_p.tr_atten_tpl = fused_tr_atten_tpl

if self.USE_REPEATED_LAYER:
model_p.lm_tpl.stacked_transformer_tpl = pax_fiddle.Config(
layers.StackedTransformerRepeated
Expand Down

0 comments on commit b3d08ad

Please sign in to comment.