Skip to content

Commit

Permalink
Update transformers.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zwxandy authored Jun 30, 2023
1 parent 97e8db5 commit 5a8f88e
Showing 1 changed file with 54 additions and 10 deletions.
64 changes: 54 additions & 10 deletions src/utils/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from src.utils.sparsemax import Sparsemax
from src.utils.activation import Learnable_Relu
import numpy as np
import inference

args, _ = inference._parse_args()
lg = args.linear_gelu

class Attention(Module):
"""
Expand Down Expand Up @@ -56,6 +59,28 @@ def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0
# self.t1 = Linear(1, 64)
# self.t2 = Linear(64, 128)
# self.t3 = Linear(128, 197)

"""
self.qkv_expand = Sequential(
Linear(dim, dim * 6, bias=False),
Linear(dim * 6, dim * 3, bias=False)
# immediate nonlinear function (relu)
)
self.proj_expand = Sequential(
Linear(dim, 4 * dim),
Linear(4 * dim, 1 * dim)
# 1 -> 4 -> 1
# immediate nonlinear function (relu)
)
self.dw_conv = Conv2d(in_channels=256, out_channels=256, kernel_size=3,\
stride=1, padding=1, groups=256, bias=False)
self.W = Parameter(torch.zeros(size=(1, 256, 64, 64)), requires_grad=True) # (batch_size, channel, HW, HW)
# batch_size=128 -> cuda out of memory. Here set 1 and broadcast to batch_size (less parameters)
init.xavier_uniform_(self.W.data, gain=1.414)
self.kernel_s = Parameter(torch.zeros(self.num_heads, 1, 3, 3), requires_grad=True)
# follow the rule: dwconv -> (out_channels, 1, K, K)
self.kernel_o = Parameter(torch.zeros(1, 1, 3, 3), requires_grad=False)
"""

def forward(self, x):
B, N, C = x.shape # x: (B, HW + 1, C)
Expand Down Expand Up @@ -195,22 +220,30 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
# self.prelu = nn.PReLU(init=0.5)
self.relu6 = F.relu6

# self.beta = nn.Parameter(torch.ones(1, 257, 1), requires_grad=True) # token-wise
if lg:
self.beta = nn.Parameter(torch.ones(1, 65, 1), requires_grad=False)

"""
ratio_expand = int(2)
self.linear_expand = Sequential(
Linear(d_model, ratio_expand * d_model),
Linear(ratio_expand * d_model, dim_feedforward)
)
"""

def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor:
src = src + self.drop_path(self.self_attn(self.pre_norm(src)))
src = self.norm1(src)

src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) # origin gelu

# src2 = self.linear1(src)
# src2 = self.beta * self.activation(src2) + (1 - self.beta) * src2 # search gelu
# src2 = self.relu(src2)
# self.beta.data = torch.clamp(self.beta.data, min=0., max=1.)

# src2 = self.linear2(self.dropout1(src2))
if not lg:
src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) # original gelu

# src2 = (1 - self.beta) * self.activation(src2) + self.beta * src2 # relu after linear layers
elif lg:
src2 = self.linear1(src)
src2 = self.beta * self.activation(src2) + (1 - self.beta) * src2 # search gelu
# src2 = self.relu(src2)
src2 = self.linear2(self.dropout1(src2))
src2 = (1 - self.beta) * self.relu(src2) + self.beta * src2 # relu after linear layers

src = src + self.drop_path(self.dropout2(src2))
return src
Expand Down Expand Up @@ -238,10 +271,13 @@ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,

self.activation = F.gelu


def forward(self, src: torch.Tensor, mask=None, *args, **kwargs) -> torch.Tensor:
src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask))
src = self.norm1(src)

src2 = self.linear2(self.dropout1(self.activation(self.linear1(src))))

src = src + self.drop_path(self.dropout2(src2))
return src

Expand Down Expand Up @@ -302,6 +338,14 @@ def __init__(self,

self.fc = Linear(embedding_dim, num_classes)
self.apply(self.init_weight)

"""
ratio_expand = int(2)
self.fc_expand = Sequential(
Linear(embedding_dim, ratio_expand * embedding_dim),
Linear(ratio_expand * embedding_dim, num_classes)
)
"""

def forward(self, x):
# print('x:', x.shape)
Expand Down

0 comments on commit 5a8f88e

Please sign in to comment.