-
Notifications
You must be signed in to change notification settings - Fork 4
/
module.py
executable file
·106 lines (87 loc) · 3.2 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from torch import nn, einsum
from einops import rearrange
class SepConv1d(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,):
super(SepConv1d, self).__init__()
self.depthwise = torch.nn.Conv1d(in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=in_channels)
self.bn = torch.nn.BatchNorm1d(in_channels)
self.pointwise = torch.nn.Conv1d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.bn(x)
x = self.pointwise(x)
return x
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class ConvAttention(nn.Module):
def __init__(self, dim, img_size, heads = 8, dim_head = 512, kernel_size=3, q_stride=1, k_stride=1, v_stride=1, dropout = 0.,
last_stage=False):
super().__init__()
self.last_stage = last_stage
self.img_size = img_size
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
pad = (kernel_size - q_stride)//2
self.to_q = SepConv1d(dim, inner_dim, kernel_size, q_stride, pad)
self.to_k = SepConv1d(dim, inner_dim, kernel_size, k_stride, pad)
self.to_v = SepConv1d(dim, inner_dim, kernel_size, v_stride, pad)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, t, c = x.shape
h = self.heads
print(x.shape)
exit(1)
x = x.permute(0, 2, 1)
q = self.to_q(x)
q = q.permute(0, 2, 1)
v = self.to_v(x)
v = v.permute(0, 2, 1)
k = self.to_k(x)
k = k.permute(0, 2, 1)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = out.permute(0, 2, 1)
out = self.to_out(out)
return out