Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjuleee committed Jul 16, 2024
1 parent 577e59d commit adf1f39
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/otx/algo/classification/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def __init__( # noqa: PLR0913
raise ValueError(msg)
arch_settings: dict[str, Any] = self.arch_zoo[arch]

self.img_size = img_size
self.patch_size = patch_size or arch_settings.get("patch_size", 16)
self.img_size: int | tuple[int, int] = img_size
self.patch_size: int | tuple[int, int] = patch_size or arch_settings.get("patch_size", 16)
self.embed_dim = embed_dim or arch_settings.get("embed_dim", 768)
depth = depth or arch_settings.get("depth", 12)
num_heads = num_heads or arch_settings.get("num_heads", 12)
Expand All @@ -257,21 +257,21 @@ def __init__( # noqa: PLR0913
# flatten deferred until after pos embed
embed_args.update({"strict_img_size": False, "output_fmt": "NHWC"})
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
img_size=self.img_size,
patch_size=self.patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
embed_dim=self.embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, self.embed_dim)) if reg_tokens else None

embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, self.embed_dim))

self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
Expand All @@ -281,13 +281,13 @@ def __init__( # noqa: PLR0913
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
self.norm_pre = norm_layer(self.embed_dim) if pre_norm else nn.Identity()

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
block_fn(
dim=embed_dim,
dim=self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
Expand All @@ -304,7 +304,7 @@ def __init__( # noqa: PLR0913
],
)

self.norm = norm_layer(embed_dim)
self.norm = norm_layer(self.embed_dim)

self.lora = lora
if self.lora:
Expand Down

0 comments on commit adf1f39

Please sign in to comment.