forked from fsx950223/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlinformer.py
74 lines (56 loc) · 2.43 KB
/
linformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.core import scaled_dot_product_attention
@dataclass
class LinformerSelfAttentionConfig(AttentionConfig):
seq_len: int # dimension of the input sequence
k: Optional[int] # dimension of the internal space
@register_attention("linformer", LinformerSelfAttentionConfig)
class LinformerAttention(Attention):
def __init__(
self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs
):
"""
Linformer attention mechanism,
from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020).
The original notation is kept as is.
.. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2
"""
super().__init__()
if k is None:
k = seq_len // 4
self.k = k
self.E = nn.Linear(seq_len, k, bias=False)
self.F = nn.Linear(seq_len, k, bias=False)
self.attn_drop = nn.Dropout(dropout, inplace=False)
self.seq_len = seq_len
# MHA related flags:
# kq need to have the same dimension
self.requires_same_k_q_dimensions = True
# This attention does not support attention masks
self.supports_attention_mask = False
def forward(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs
):
# Handle a smaller dimension than expected
padding = 0
if q.shape[1] < self.seq_len:
padding = self.seq_len - q.shape[1]
pad_dims = (0, 0, 0, padding)
q = torch.nn.functional.pad(q, pad_dims)
k = torch.nn.functional.pad(k, pad_dims)
v = torch.nn.functional.pad(v, pad_dims)
k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1)
v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1)
y = scaled_dot_product_attention(
q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop
)
y = self.attn_drop(y)
return y[:, :-padding, :] if padding > 0 else y