-
Notifications
You must be signed in to change notification settings - Fork 70
/
timm.py
150 lines (118 loc) · 4.86 KB
/
timm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# --------------------------------------------------------
from typing import Tuple
import torch
from timm.models.vision_transformer import Attention, Block, VisionTransformer
from tome.merge import bipartite_soft_matching, merge_source, merge_wavg
from tome.utils import parse_r
class ToMeBlock(Block):
"""
Modifications:
- Apply ToMe between the attention and mlp blocks
- Compute and propogate token size and potentially the token sources.
"""
def _drop_path1(self, x):
return self.drop_path1(x) if hasattr(self, "drop_path1") else self.drop_path(x)
def _drop_path2(self, x):
return self.drop_path2(x) if hasattr(self, "drop_path2") else self.drop_path(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Note: this is copied from timm.models.vision_transformer.Block with modifications.
attn_size = self._tome_info["size"] if self._tome_info["prop_attn"] else None
x_attn, metric = self.attn(self.norm1(x), attn_size)
x = x + self._drop_path1(x_attn)
r = self._tome_info["r"].pop(0)
if r > 0:
# Apply ToMe here
merge, _ = bipartite_soft_matching(
metric,
r,
self._tome_info["class_token"],
self._tome_info["distill_token"],
)
if self._tome_info["trace_source"]:
self._tome_info["source"] = merge_source(
merge, x, self._tome_info["source"]
)
x, self._tome_info["size"] = merge_wavg(merge, x, self._tome_info["size"])
x = x + self._drop_path2(self.mlp(self.norm2(x)))
return x
class ToMeAttention(Attention):
"""
Modifications:
- Apply proportional attention
- Return the mean of k over heads from attention
"""
def forward(
self, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# Note: this is copied from timm.models.vision_transformer.Attention with modifications.
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
# Apply proportional attention
if size is not None:
attn = attn + size.log()[:, None, None, :, 0]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
# Return k as well here
return x, k.mean(1)
def make_tome_class(transformer_class):
class ToMeVisionTransformer(transformer_class):
"""
Modifications:
- Initialize r, token size, and token sources.
"""
def forward(self, *args, **kwdargs) -> torch.Tensor:
self._tome_info["r"] = parse_r(len(self.blocks), self.r)
self._tome_info["size"] = None
self._tome_info["source"] = None
return super().forward(*args, **kwdargs)
return ToMeVisionTransformer
def apply_patch(
model: VisionTransformer, trace_source: bool = False, prop_attn: bool = True
):
"""
Applies ToMe to this transformer. Afterward, set r using model.r.
If you want to know the source of each token (e.g., for visualization), set trace_source = true.
The sources will be available at model._tome_info["source"] afterward.
For proportional attention, set prop_attn to True. This is only necessary when evaluating models off
the shelf. For trianing and for evaluating MAE models off the self set this to be False.
"""
ToMeVisionTransformer = make_tome_class(model.__class__)
model.__class__ = ToMeVisionTransformer
model.r = 0
model._tome_info = {
"r": model.r,
"size": None,
"source": None,
"trace_source": trace_source,
"prop_attn": prop_attn,
"class_token": model.cls_token is not None,
"distill_token": False,
}
if hasattr(model, "dist_token") and model.dist_token is not None:
model._tome_info["distill_token"] = True
for module in model.modules():
if isinstance(module, Block):
module.__class__ = ToMeBlock
module._tome_info = model._tome_info
elif isinstance(module, Attention):
module.__class__ = ToMeAttention