Skip to content

Commit

Permalink
feat(ml): DinoV2 feature-based projected discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 25, 2023
1 parent 7f7472d commit c67ffa8
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
49 changes: 47 additions & 2 deletions models/modules/projected_d/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def _make_depth(model):
return model


def _make_dinov2(model):
configure_get_feats_dinov2(model)
return model


def configure_forward_network(net):
def forward(x):
out0 = net.layer0(x)
Expand Down Expand Up @@ -144,7 +149,6 @@ def get_feats(x):

def configure_get_feats_depth(net):
def get_feats(x):

x = net.transform(x)

if net.channels_last == True:
Expand Down Expand Up @@ -184,6 +188,23 @@ def get_feats(x):
net.get_feats = get_feats


def configure_get_feats_dinov2(net):
dino_layers = {
"dinov2_vits14": [2, 5, 8, 11],
"dinov2_vitb14": [3, 8, 12, 17],
"dinov2_vitl14": [4, 10, 16, 23],
"dinov2_vitg14": [6, 16, 26, 39],
}

def get_feats(x):
feats = net.get_intermediate_layers(
x, n=[2, 5, 8, 11], return_class_token=False
)
return feats

net.get_feats = get_feats


def calc_channels(pretrained, inp_res=224):
channels = []
feats = []
Expand Down Expand Up @@ -216,8 +237,12 @@ def create_clip_model(model_name, config_path, weight_path, img_size):
return model[0].visual.float().cpu()


def create_segformer_model(model_name, config_path, weight_path, img_size):
def create_dinov2_model(model_name, config_path, weight_path, img_size):
dinov2_model = torch.hub.load("facebookresearch/dinov2", model_name)
return dinov2_model


def create_segformer_model(model_name, config_path, weight_path, img_size):
cfg = load_config_file(config_path)
try:
weights = torch.jit.load(weight_path).state_dict()
Expand Down Expand Up @@ -303,6 +328,26 @@ def create_depth_model(model_name, config_path, weight_path, img_size):
"create_model_function": create_depth_model,
"make_function": _make_depth,
},
"dinov2_vits14": {
"model_name": "dinov2_vits14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
"dinov2_vitb14": {
"model_name": "dinov2_vitb14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
"dinov2_vitl14": {
"model_name": "dinov2_vitl14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
"dinov2_vitg14": {
"model_name": "dinov2_vitg14",
"create_model_function": create_dinov2_model,
"make_function": _make_dinov2,
},
}


Expand Down
4 changes: 4 additions & 0 deletions options/base_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ def initialize(self, parser):
"vitclip16",
"vitclip14",
"depth",
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vitg14",
],
help="projected discriminator architecture",
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_run_semantic_mask_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

G_netG = ["mobile_resnet_attn", "segformer_attn_conv"]

D_proj_network_type = ["efficientnet", "vitsmall"]
D_proj_network_type = ["efficientnet", "vitsmall", "dinov2_vits14"]

D_netDs = [
["basic", "projected_d"],
Expand Down

0 comments on commit c67ffa8

Please sign in to comment.