Skip to content

Commit

Permalink
add correct name
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed Apr 29, 2024
1 parent e261316 commit 3381fa1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

For now, Transformers supports SDPA inference and training for the following architectures:
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio_spectrogram_transformer#transformers.ASTModel)
* [Audio Spectrogram Transformer](https://huggingface.co/docs/transformers/model_doc/audio-spectrogram-transformer#transformers.ASTModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Bert](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
Expand Down
2 changes: 1 addition & 1 deletion utils/check_support_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def check_sdpa_support_list():
archs_supporting_sdpa.append(model_name)

for arch in archs_supporting_sdpa:
if arch not in doctext:
if arch not in doctext and arch not in doctext.replace("-", "_"):
raise ValueError(
f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."
)
Expand Down

0 comments on commit 3381fa1

Please sign in to comment.