Skip to content

Commit

Permalink
Merge pull request #3744 from wonjuleee/support_dinov2
Browse files Browse the repository at this point in the history
Support dinov2 models (small, base, large, giant) from VisionTransformer backbone
  • Loading branch information
wonjuleee authored Jul 19, 2024
2 parents f977254 + 68fbf5e commit 920d128
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 36 deletions.
94 changes: 65 additions & 29 deletions src/otx/algo/classification/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from functools import partial
from typing import TYPE_CHECKING, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable, Literal

import torch
from timm.layers import (
Expand Down Expand Up @@ -90,7 +90,7 @@ class VisionTransformer(BaseModule):
lora: Enable LoRA training.
"""

arch_zoo = { # noqa: RUF012
arch_zoo: dict[str, dict] = { # noqa: RUF012
**dict.fromkeys(
["vit-t", "vit-tiny"],
{
Expand Down Expand Up @@ -145,6 +145,9 @@ class VisionTransformer(BaseModule):
"embed_dim": 384,
"depth": 12,
"num_heads": 6,
"reg_tokens": 4,
"no_embed_class": True,
"init_values": 1e-5,
},
),
**dict.fromkeys(
Expand All @@ -154,6 +157,9 @@ class VisionTransformer(BaseModule):
"embed_dim": 768,
"depth": 12,
"num_heads": 12,
"reg_tokens": 4,
"no_embed_class": True,
"init_values": 1e-5,
},
),
**dict.fromkeys(
Expand All @@ -163,6 +169,9 @@ class VisionTransformer(BaseModule):
"embed_dim": 1024,
"depth": 24,
"num_heads": 16,
"reg_tokens": 4,
"no_embed_class": True,
"init_values": 1e-5,
},
),
**dict.fromkeys(
Expand All @@ -172,6 +181,9 @@ class VisionTransformer(BaseModule):
"embed_dim": 1536,
"depth": 40,
"num_heads": 24,
"reg_tokens": 4,
"no_embed_class": True,
"init_values": 1e-5,
"mlp_ratio": 2.66667 * 2,
"mlp_layer": SwiGLUPacked,
"act_layer": nn.SiLU,
Expand All @@ -194,8 +206,8 @@ def __init__( # noqa: PLR0913
qk_norm: bool = False,
init_values: float | None = None,
class_token: bool = True,
no_embed_class: bool = False,
reg_tokens: int = 0,
no_embed_class: bool | None = None,
reg_tokens: int | None = None,
pre_norm: bool = False,
dynamic_img_size: bool = False,
dynamic_img_pad: bool = False,
Expand All @@ -216,21 +228,22 @@ def __init__( # noqa: PLR0913
if arch not in set(self.arch_zoo):
msg = f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
raise ValueError(msg)
arch_settings = self.arch_zoo[arch]

patch_size = patch_size or arch_settings["patch_size"]
embed_dim = embed_dim or arch_settings["embed_dim"]
depth = depth or arch_settings["depth"]
num_heads = num_heads or arch_settings["num_heads"]
mlp_layer = mlp_layer or arch_settings.get("mlp_layer", None) or Mlp
mlp_ratio = mlp_ratio or arch_settings.get("mlp_ratio", None) or 4.0
norm_layer = (
get_norm_layer(norm_layer) or arch_settings.get("norm_layer", None) or partial(nn.LayerNorm, eps=1e-6)
)
act_layer = get_act_layer(act_layer) or arch_settings.get("act_layer", None) or nn.GELU
arch_settings: dict[str, Any] = self.arch_zoo[arch]

self.img_size: int | tuple[int, int] = img_size
self.patch_size: int | tuple[int, int] = patch_size or arch_settings.get("patch_size", 16)
self.embed_dim = embed_dim or arch_settings.get("embed_dim", 768)
depth = depth or arch_settings.get("depth", 12)
num_heads = num_heads or arch_settings.get("num_heads", 12)
no_embed_class = no_embed_class or arch_settings.get("no_embed_class", False)
reg_tokens = reg_tokens or arch_settings.get("reg_tokens", 0)
init_values = init_values or arch_settings.get("init_values", None)
mlp_layer = mlp_layer or arch_settings.get("mlp_layer", Mlp)
mlp_ratio = mlp_ratio or arch_settings.get("mlp_ratio", 4.0)
norm_layer = get_norm_layer(norm_layer) or arch_settings.get("norm_layer", partial(nn.LayerNorm, eps=1e-6))
act_layer = get_act_layer(act_layer) or arch_settings.get("act_layer", nn.GELU)

self.num_classes = num_classes
self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.num_prefix_tokens += reg_tokens
self.num_reg_tokens = reg_tokens
Expand All @@ -244,21 +257,21 @@ def __init__( # noqa: PLR0913
# flatten deferred until after pos embed
embed_args.update({"strict_img_size": False, "output_fmt": "NHWC"})
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
img_size=self.img_size,
patch_size=self.patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
embed_dim=self.embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
dynamic_img_pad=dynamic_img_pad,
**embed_args,
)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if class_token else None
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, self.embed_dim)) if reg_tokens else None

embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, embed_len, self.embed_dim))

self.pos_drop = nn.Dropout(p=pos_drop_rate)
if patch_drop_rate > 0:
Expand All @@ -268,13 +281,13 @@ def __init__( # noqa: PLR0913
)
else:
self.patch_drop = nn.Identity()
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
self.norm_pre = norm_layer(self.embed_dim) if pre_norm else nn.Identity()

dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(
*[
block_fn(
dim=embed_dim,
dim=self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
Expand All @@ -291,7 +304,7 @@ def __init__( # noqa: PLR0913
],
)

self.norm = norm_layer(embed_dim)
self.norm = norm_layer(self.embed_dim)

self.lora = lora
if self.lora:
Expand Down Expand Up @@ -323,10 +336,33 @@ def init_weights(self) -> None:
def load_pretrained(self, checkpoint_path: Path, prefix: str = "") -> None:
"""Loads the pretrained weight to the VisionTransformer."""
checkpoint_ext = checkpoint_path.suffix
if checkpoint_ext == ".npz":
if checkpoint_ext == ".npz": # deit models
self._load_npz_weights(self, checkpoint_path, prefix)
elif checkpoint_ext == ".pth":
self.load_state_dict(torch.load(checkpoint_path), strict=False)
elif checkpoint_ext == ".pth": # dinov2 models

def resize_positional_embeddings(pos_embed: torch.Tensor, new_shape: tuple[int, int]) -> torch.Tensor:
# Resize the embeddings using bilinear interpolation.
pos_embed = pos_embed.permute(0, 2, 1).reshape(1, -1, 37, 37) # 560 (img_size) / 14 (patch_size) = 37
pos_embed_resized = nn.functional.interpolate(
pos_embed,
size=(new_shape[0], new_shape[1]),
mode="bilinear",
)
return pos_embed_resized.reshape(1, -1, new_shape[0] * new_shape[1]).permute(0, 2, 1)

# convert dinov2 pretrained weights
state_dict = torch.load(checkpoint_path)
state_dict.pop("mask_token", None)
state_dict["reg_token"] = state_dict.pop("register_tokens")
state_dict["cls_token"] = state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0]

img_size = (self.img_size, self.img_size) if isinstance(self.img_size, int) else self.img_size
patch_size = (self.patch_size, self.patch_size) if isinstance(self.patch_size, int) else self.patch_size
state_dict["pos_embed"] = resize_positional_embeddings(
state_dict.pop("pos_embed")[:, 1:],
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
)
self.load_state_dict(state_dict, strict=False)
else:
msg = f"Unsupported `checkpoint_extension` {checkpoint_ext}, please choose from 'npz' or 'pth'."
raise ValueError(msg)
Expand Down
14 changes: 7 additions & 7 deletions src/otx/recipe/classification/multi_class_cls/dino_v2.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
model:
class_path: otx.algo.classification.dino_v2.DINOv2RegisterClassifier
class_path: otx.algo.classification.vit.VisionTransformerForMulticlassCls
init_args:
label_info: 1000
freeze_backbone: False
backbone: dinov2_vits14_reg
arch: "dinov2-small"

optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1e-5
lr: 0.0001
weight_decay: 0.05

scheduler:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: min
mode: max
factor: 0.5
patience: 9
monitor: train/loss
patience: 1
monitor: val/accuracy

engine:
task: MULTI_CLASS_CLS
Expand Down

0 comments on commit 920d128

Please sign in to comment.