Skip to content

Commit

Permalink
F.scaled_dot_product_attention support (#26572)
Browse files Browse the repository at this point in the history
* add sdpa

* wip

* cleaning

* add ref

* yet more cleaning

* and more :)

* wip llama

* working llama

* add output_attentions=True support

* bigcode sdpa support

* fixes

* gpt-bigcode support, require torch>=2.1.1

* add falcon support

* fix conflicts falcon

* style

* fix attention_mask definition

* remove output_attentions from attnmaskconverter

* support whisper without removing any Copied from statement

* fix mbart default to eager renaming

* fix typo in falcon

* fix is_causal in SDPA

* check is_flash_attn_2_available in the models init as well in case the model is not initialized through from_pretrained

* add warnings when falling back on the manual implementation

* precise doc

* wip replace _flash_attn_enabled by config.attn_implementation

* fix typo

* add tests

* style

* add a copy.deepcopy on the config in from_pretrained, as we do not want to modify it inplace

* obey to config.attn_implementation if a config is passed in from_pretrained

* fix is_torch_sdpa_available when torch is not installed

* remove dead code

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/bart/modeling_bart.py

Co-authored-by: Arthur <[email protected]>

* remove duplicate pretraining_tp code

* add dropout in llama

* precise comment on attn_mask

* add fmt: off for _unmask_unattended docstring

* precise num_masks comment

* nuke pretraining_tp in LlamaSDPAAttention following Arthur's suggestion

* cleanup modeling_utils

* backward compatibility

* fix style as requested

* style

* improve documentation

* test pass

* style

* add _unmask_unattended tests

* skip meaningless tests for idefics

* hard_check SDPA requirements when specifically requested

* standardize the use if XXX_ATTENTION_CLASSES

* fix SDPA bug with mem-efficient backend on CUDA when using fp32

* fix test

* rely on SDPA is_causal parameter to handle the causal mask in some cases

* fix FALCON_ATTENTION_CLASSES

* remove _flash_attn_2_enabled occurences

* fix test

* add OPT to the list of supported flash models

* improve test

* properly test on different SDPA backends, on different dtypes & properly handle separately the pad tokens in the test

* remove remaining _flash_attn_2_enabled occurence

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/modeling_attn_mask_utils.py

Co-authored-by: Arthur <[email protected]>

* Update docs/source/en/perf_infer_gpu_one.md

Co-authored-by: Arthur <[email protected]>

* remove use_attn_implementation

* fix docstring & slight bug

* make attn_implementation internal (_attn_implementation)

* typos

* fix tests

* deprecate use_flash_attention_2=True

* fix test

* add back llama that was removed by mistake

* fix tests

* remove _flash_attn_2_enabled occurences bis

* add check & test that passed attn_implementation is valid

* fix falcon torchscript export

* fix device of mask in tests

* add tip about torch.jit.trace and move bt doc below sdpa

* fix parameterized.expand order

* move tests from test_modeling_attn_mask_utils to test_modeling_utils as a relevant test class is already there

* update sdpaattention class with the new cache

* Update src/transformers/configuration_utils.py

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/bark/modeling_bark.py

* address review comments

* WIP torch.jit.trace fix. left: test both eager & sdpa

* add test for torch.jit.trace for both eager/sdpa

* fix falcon with torch==2.0 that needs to use sdpa

* fix doc

* hopefully last fix

* fix key_value_length that has no default now in mask converter

* is it flacky?

* fix speculative decoding bug

* tests do pass

* fix following #27907

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
fxmarty and ArthurZucker authored Dec 8, 2023
1 parent ce0bbd5 commit 80377eb
Show file tree
Hide file tree
Showing 54 changed files with 2,225 additions and 452 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/llm_tutorial_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ flush()
```

For comparison, let's run the same function, but enable Flash Attention instead.
To do so, we convert the model to [BetterTransformers](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is based on Flash Attention.
To do so, we convert the model to [BetterTransformer](https://huggingface.co/docs/optimum/bettertransformer/overview) and by doing so enabling PyTorch's [SDPA self-attention](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention) which in turn is able to use Flash Attention.

```python
model.to_bettertransformer()
Expand Down
6 changes: 3 additions & 3 deletions docs/source/en/model_doc/bark.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ pip install -U flash-attn --no-build-isolation

##### Usage

To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the `attn_implementation="flash_attention_2"` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:

```python
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
```

##### Performance comparison
Expand Down Expand Up @@ -114,7 +114,7 @@ import torch
device = "cuda" if torch.cuda.is_available() else "cpu"

# load in fp16 and use Flash Attention 2
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = BarkModel.from_pretrained("suno/bark-small", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)

# enable CPU offload
model.enable_cpu_offload()
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/distilbert.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> device = "cuda" # the device to load the model onto

>>> tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModel.from_pretrained("distilbert-base-uncased", torch_dtype=torch.float16, attn_implementation="flash_attention_2")

>>> text = "Replace me by any text you'd like."

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/gpt_bigcode.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("bigcode/gpt_bigcode-santacoder", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("bigcode/gpt_bigcode-santacoder")

>>> prompt = "def hello_world():"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/gpt_neo.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-2.7B")

>>> prompt = "def hello_world():"
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/gpt_neox.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ pip install -U flash-attn --no-build-isolation

### Usage

To load a model using Flash Attention 2, we can pass the `use_flash_attention_2` flag to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:
To load a model using Flash Attention 2, we can pass the argument `attn_implementation="flash_attention_2"` to [`.from_pretrained`](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained). We'll also load the model in half-precision (e.g. `torch.float16`), since it results in almost no degradation to audio quality but significantly lower memory usage and faster inference:

```python
>>> from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast

model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, use_flash_attention_2=True).to(device)
model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to(device)
...
```

Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/mistral.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

>>> prompt = "My favourite condiment is"
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/model_doc/opt.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import OPTForCausalLM, GPT2Tokenizer
>>> device = "cuda" # the device to load the model onto

>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True)
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, attn_implementation="flash_attention_2")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")

>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the "
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/phi.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ To load and run a model using Flash Attention 2, refer to the snippet below:
>>> from transformers import PhiForCausalLM, AutoTokenizer

>>> # define the model and tokenizer and push the model and tokens to the GPU.
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, use_flash_attention_2=True).to("cuda")
>>> model = PhiForCausalLM.from_pretrained("susnato/phi-1_5_dev", torch_dtype=torch.float16, attn_implementation="flash_attention_2").to("cuda")
>>> tokenizer = AutoTokenizer.from_pretrained("susnato/phi-1_5_dev")

>>> # feel free to change the prompt to your liking.
Expand Down Expand Up @@ -163,4 +163,4 @@ Below is an expected speedup diagram that compares pure inference time between t
- forward

</pt>
</frameworkcontent>
</frameworkcontent>
107 changes: 71 additions & 36 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,29 @@ FlashAttention-2 is experimental and may change considerably in future versions.
1. additionally parallelizing the attention computation over sequence length
2. partitioning the work between GPU threads to reduce communication and shared memory reads/writes between them

FlashAttention-2 supports inference with Llama, Mistral, Falcon and Bark models. You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
FlashAttention-2 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)

You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.

Before you begin, make sure you have FlashAttention-2 installed. For NVIDIA GPUs, the library is installable through pip: `pip install flash-attn --no-build-isolation`. We strongly suggest to refer to the [detailed installation instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).

FlashAttention-2 is also supported on AMD GPUs, with the current support limited to **Instinct MI210 and Instinct MI250**. We strongly suggest to use the following [Dockerfile](https://github.com/huggingface/optimum-amd/tree/main/docker/transformers-pytorch-amd-gpu-flash/Dockerfile) to use FlashAttention-2 on AMD GPUs.

To enable FlashAttention-2, add the `use_flash_attention_2` parameter to [`~AutoModelForCausalLM.from_pretrained`]:
To enable FlashAttention-2, pass the argument `attn_implementation="flash_attention_2"` to [`~AutoModelForCausalLM.from_pretrained`]:

```python
import torch
Expand All @@ -54,13 +70,15 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

<Tip>

FlashAttention-2 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-2.

Note that `use_flash_attention_2=True` can also be used to enable Flash Attention 2, but is deprecated in favor of `attn_implementation="flash_attention_2"`.

</Tip>

Expand All @@ -77,14 +95,14 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)

# load in 4bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand Down Expand Up @@ -124,41 +142,21 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div>

## BetterTransformer

<Tip>

Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.

</Tip>

BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:

1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors

BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.

Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).

Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:

```python
model = model.to_bettertransformer()
```

You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:
## FlashAttention and memory-efficient attention through PyTorch's scaled_dot_product_attention

```py
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers, and is used by default for `torch>=2.1.1` when an implementation is available.

### FlashAttention
For now, Transformers supports inference and training through SDPA for the following architectures:
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)

SDPA can also call FlashAttention kernels under the hood. FlashAttention can only be used for models using the `fp16` or `bf16` dtype, so make sure to cast your model to the appropriate dtype before using it.
Note that FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type before using it.

To enable FlashAttention or to check whether it is available in a given setting (hardware, problem size), use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:
By default, `torch.nn.functional.scaled_dot_product_attention` selects the most performant kernel available, but to check whether a backend is available in a given setting (hardware, problem size), you can use [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager:

```diff
import torch
Expand Down Expand Up @@ -187,6 +185,43 @@ RuntimeError: No available kernel. Aborting execution.
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```

## BetterTransformer

<Tip warning={true}>

Part of BetterTransformer features are being upstreamed in Transformers, with native `torch.nn.scaled_dot_product_attention` default support. BetterTransformer still has a wider coverage than the Transformers SDPA integration, but you can expect more and more architectures to support natively SDPA in Transformers.

</Tip>


<Tip>

Check out our benchmarks with BetterTransformer and scaled dot product attention in the [Out of the box acceleration and memory savings of 🤗 decoder models with PyTorch 2.0](https://pytorch.org/blog/out-of-the-box-acceleration/) and learn more about the fastpath execution in the [BetterTransformer](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2) blog post.

</Tip>

BetterTransformer accelerates inference with its fastpath (native PyTorch specialized implementation of Transformer functions) execution. The two optimizations in the fastpath execution are:

1. fusion, which combines multiple sequential operations into a single "kernel" to reduce the number of computation steps
2. skipping the inherent sparsity of padding tokens to avoid unnecessary computation with nested tensors

BetterTransformer also converts all attention operations to use the more memory-efficient [scaled dot product attention (SDPA)](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention), and it calls optimized kernels like [FlashAttention](https://huggingface.co/papers/2205.14135) under the hood.

Before you start, make sure you have 🤗 Optimum [installed](https://huggingface.co/docs/optimum/installation).

Then you can enable BetterTransformer with the [`PreTrainedModel.to_bettertransformer`] method:

```python
model = model.to_bettertransformer()
```

You can return the original Transformers model with the [`~PreTrainedModel.reverse_bettertransformer`] method. You should use this before saving your model to use the canonical Transformers modeling:

```py
model = model.reverse_bettertransformer()
model.save_pretrained("saved_model")
```

## bitsandbytes

bitsandbytes is a quantization library that includes support for 4-bit and 8-bit quantization. Quantization reduces your model size compared to its native full precision version, making it easier to fit large models onto GPUs with limited memory.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ AWQ quantization can also be combined with [FlashAttention-2](perf_infer_gpu_one
```py
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", attn_implementation="flash_attention_2", device_map="cuda:0")
```


Expand Down
10 changes: 5 additions & 5 deletions docs/source/ja/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Flash Attention 2は、モデルのdtypeが`fp16`または`bf16`の場合にの

### Quick usage

モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`use_flash_attention_2`を追加します。
モデルでFlash Attention 2を有効にするには、`from_pretrained`の引数に`attn_implementation="flash_attention_2"`を追加します。


```python
Expand All @@ -57,7 +57,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand Down Expand Up @@ -114,7 +114,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand All @@ -132,7 +132,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)
```

Expand All @@ -151,7 +151,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
use_flash_attention_2=True,
attn_implementation="flash_attention_2",
)

lora_config = LoraConfig(
Expand Down
Loading

0 comments on commit 80377eb

Please sign in to comment.