Skip to content

Commit

Permalink
Add sdpa for Vivit (#33757)
Browse files Browse the repository at this point in the history
* chore:add sdpa to vivit

* fix:failing slow test_inference_interpolate_pos_encoding(failing on main branch too)

* chore:fix nits

* ci:fix repo consistency failure

* chore:add info and benchmark to model doc

* [run_slow] vivit

* chore:revert interpolation test fix for new issue

* [run_slow] vivit

* [run_slow] vivit

* [run_slow] vivit

* chore:add fallback for output_attentions being True

* [run_slow] vivit

* style:make fixup

* [run_slow] vivit
  • Loading branch information
RUFFY-369 authored Oct 15, 2024
1 parent 23874f5 commit 293e627
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 1 deletion.
37 changes: 37 additions & 0 deletions docs/source/en/model_doc/vivit.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,43 @@ The abstract from the paper is the following:

This model was contributed by [jegormeister](https://huggingface.co/jegormeister). The original code (written in JAX) can be found [here](https://github.com/google-research/scenic/tree/main/scenic/projects/vivit).

### Using Scaled Dot Product Attention (SDPA)

PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.

SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.

```
from transformers import VivitModel
model = VivitModel.from_pretrained("google/vivit-b-16x2-kinetics400", attn_implementation="sdpa", torch_dtype=torch.float16)
...
```

For the best speedups, we recommend loading the model in half-precision (e.g. `torch.float16` or `torch.bfloat16`).

On a local benchmark (A100-40GB, PyTorch 2.3.0, OS Ubuntu 22.04) with `float32` and `google/vivit-b-16x2-kinetics400` model, we saw the following speedups during inference.

### Training
| num_training_steps | batch_size | is cuda | Speedup (%) | Eager peak mem (MB) | sdpa peak mem (MB) | Mem saving (%) |
|---------------------:|-------------:|----------:|--------------:|----------------------:|---------------------:|-----------------:|
| 100 | 1 | True | 7.122 | 2575.28 | 5932.54 | 130.364 |



### Inference
| num_batches | batch_size | is cuda | is half | Speedup (%) | Mem eager (MB) | Mem BT (MB) | Mem saved (%) |
|---------------|--------------|-----------|-----------|---------------|------------------|---------------|-----------------|
| 20 | 1 | True | False | 15.422 | 715.807 | 317.079 | 125.75 |
| 20 | 2 | True | False | 17.146 | 1234.75 | 447.175 | 176.122 |
| 20 | 4 | True | False | 18.093 | 2275.82 | 709.864 | 220.6 |
| 20 | 8 | True | False | 19.284 | 4358.19 | 1233.24 | 253.393 |

## VivitConfig

[[autodoc]] VivitConfig
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
Expand Down
61 changes: 60 additions & 1 deletion src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,51 @@ def forward(
return outputs


# Adapted from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->Vivit
class VivitSdpaSelfAttention(VivitSelfAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"VivitSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support"
" `output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but specifying"
" the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be"
' removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
head_mask,
output_attentions,
)

mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->Vivit
class VivitSelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -286,6 +331,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->Vivit
class VivitSdpaAttention(VivitAttention):
def __init__(self, config: VivitConfig) -> None:
super().__init__(config)
self.attention = VivitSdpaSelfAttention(config)


class VivitIntermediate(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -320,14 +372,20 @@ def forward(self, hidden_states, input_tensor):
return hidden_states


VIVIT_ATTENTION_CLASSES = {
"eager": VivitAttention,
"sdpa": VivitSdpaAttention,
}


class VivitLayer(nn.Module):
"""This corresponds to the EncoderBlock class in the scenic/vivit implementation."""

def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = VivitAttention(config)
self.attention = VIVIT_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = VivitIntermediate(config)
self.output = VivitOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down Expand Up @@ -436,6 +494,7 @@ class VivitPreTrainedModel(PreTrainedModel):
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []
_supports_sdpa = True

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
6 changes: 6 additions & 0 deletions tests/models/vivit/test_modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(
layer_norm_eps=1e-06,
qkv_bias=True,
scope=None,
attn_implementation="eager",
mask_ratio=0.5,
):
self.parent = parent
self.batch_size = batch_size
Expand All @@ -86,12 +88,15 @@ def __init__(
self.layer_norm_eps = layer_norm_eps
self.qkv_bias = qkv_bias
self.scope = scope
self.attn_implementation = attn_implementation

self.seq_length = (
(self.image_size // self.tubelet_size[2])
* (self.image_size // self.tubelet_size[1])
* (self.num_frames // self.tubelet_size[0])
) + 1 # CLS token
self.mask_ratio = mask_ratio
self.num_masks = int(mask_ratio * self.seq_length)

def prepare_config_and_inputs(self):
pixel_values = floats_tensor(
Expand Down Expand Up @@ -122,6 +127,7 @@ def get_config(self):
initializer_range=self.initializer_range,
layer_norm_eps=self.layer_norm_eps,
qkv_bias=self.qkv_bias,
attn_implementation=self.attn_implementation,
)
config.num_labels = self.num_labels
return config
Expand Down

0 comments on commit 293e627

Please sign in to comment.