Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MistralAttention: where is the sliding window #29777

Closed
fteufel opened this issue Mar 21, 2024 · 9 comments
Closed

MistralAttention: where is the sliding window #29777

fteufel opened this issue Mar 21, 2024 · 9 comments

Comments

@fteufel
Copy link
Contributor

fteufel commented Mar 21, 2024

Hi,

I'm trying to understand the implementation of Mistral's attention in MistralAttention.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L195
It is my understanding that it should always be using local window attention. In MistralFlashAttention2 this is very obvious, with config.sliding_window being used.

However, I'm not sure where the sliding window is used in the base MistralAttention without flash attention:

class MistralAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

but the forward pass simply reads

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

which I understand as full self attention.

Is the sliding window only used when running with Flash Attention, or am I missing something?
Thanks!

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @younesbelkada

@PenutChen
Copy link
Contributor

PenutChen commented Mar 22, 2024

@fteufel
Copy link
Contributor Author

fteufel commented Mar 22, 2024

Thanks, I see. But wouldn't this throw away any computational efficiency gains expected from using a sliding window in the first place?

@PenutChen
Copy link
Contributor

I have the same question. I think the sliding window has two aspects:

  1. From the perspective of the attention mask, it essentially acts as a token-level sliding window that influences each token's view of the context.
  2. From a kv-cache perspective, truncating the cache outside the window can improve computational efficiency.

Just my guess above.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Mar 25, 2024

Yes this would throw away the gains, and it is pretty much expected as the best way to use sliding_window is through the sdpa or the flash_attention api, unless a rotating buffer is used.

Closing as expected, feel free to discuss! 🤗

@fteufel
Copy link
Contributor Author

fteufel commented Mar 25, 2024

Hi @ArthurZucker interesting - so sdpa actually exploits the local window structure of the attention mask in the backend?

@ArthurZucker
Copy link
Collaborator

It should if the mask is correctly passed yeah. New sdpa has the sliding_window argument anyway. Not sure it was correctly prepared before, important PR: #29407

@ehuaa
Copy link
Contributor

ehuaa commented Mar 27, 2024

It should if the mask is correctly passed yeah. New sdpa has the sliding_window argument anyway. Not sure it was correctly prepared before, important PR: #29407

@ArthurZucker Did you mention this pr? pytorch/pytorch#114823, which is not use sliding_window param explicitly but can handle the sliding window mask in the sdpa function, am i right?
So if we pass the right mask through _prepare_4d_causal_attention_mask_for_sdpa as you mentioned here, #29407, we can use local window feature of Mistral. But I think we can still gain some computational efficiency with local attention without the rotating buffer, because of the sparsity of attention mask of sliding window attention.

@ArthurZucker
Copy link
Collaborator

Sorry I meant the new SDPA codepath in transformers but it's not merged yet, yes as you say handles the mask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants