diff --git a/models/modules/projected_d/projector.py b/models/modules/projected_d/projector.py index 0d9a16f54..1f14a8fb9 100644 --- a/models/modules/projected_d/projector.py +++ b/models/modules/projected_d/projector.py @@ -79,6 +79,11 @@ def _make_depth(model): return model +def _make_dinov2(model): + configure_get_feats_dinov2(model) + return model + + def configure_forward_network(net): def forward(x): out0 = net.layer0(x) @@ -144,7 +149,6 @@ def get_feats(x): def configure_get_feats_depth(net): def get_feats(x): - x = net.transform(x) if net.channels_last == True: @@ -184,6 +188,23 @@ def get_feats(x): net.get_feats = get_feats +def configure_get_feats_dinov2(net): + dino_layers = { + "dinov2_vits14": [2, 5, 8, 11], + "dinov2_vitb14": [3, 8, 12, 17], + "dinov2_vitl14": [4, 10, 16, 23], + "dinov2_vitg14": [6, 16, 26, 39], + } + + def get_feats(x): + feats = net.get_intermediate_layers( + x, n=[2, 5, 8, 11], return_class_token=False + ) + return feats + + net.get_feats = get_feats + + def calc_channels(pretrained, inp_res=224): channels = [] feats = [] @@ -216,8 +237,12 @@ def create_clip_model(model_name, config_path, weight_path, img_size): return model[0].visual.float().cpu() -def create_segformer_model(model_name, config_path, weight_path, img_size): +def create_dinov2_model(model_name, config_path, weight_path, img_size): + dinov2_model = torch.hub.load("facebookresearch/dinov2", model_name) + return dinov2_model + +def create_segformer_model(model_name, config_path, weight_path, img_size): cfg = load_config_file(config_path) try: weights = torch.jit.load(weight_path).state_dict() @@ -303,6 +328,26 @@ def create_depth_model(model_name, config_path, weight_path, img_size): "create_model_function": create_depth_model, "make_function": _make_depth, }, + "dinov2_vits14": { + "model_name": "dinov2_vits14", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, + "dinov2_vitb14": { + "model_name": "dinov2_vitb14", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, + "dinov2_vitl14": { + "model_name": "dinov2_vitl14", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, + "dinov2_vitg14": { + "model_name": "dinov2_vitg14", + "create_model_function": create_dinov2_model, + "make_function": _make_dinov2, + }, } diff --git a/options/base_options.py b/options/base_options.py index 2aba144ce..c0a44e0dd 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -443,6 +443,10 @@ def initialize(self, parser): "vitclip16", "vitclip14", "depth", + "dinov2_vits14", + "dinov2_vitb14", + "dinov2_vitl14", + "dinov2_vitg14", ], help="projected discriminator architecture", ) diff --git a/tests/test_run_semantic_mask_online.py b/tests/test_run_semantic_mask_online.py index 853a8b14c..933220d80 100644 --- a/tests/test_run_semantic_mask_online.py +++ b/tests/test_run_semantic_mask_online.py @@ -59,7 +59,7 @@ G_netG = ["mobile_resnet_attn", "segformer_attn_conv"] -D_proj_network_type = ["efficientnet", "vitsmall"] +D_proj_network_type = ["efficientnet", "vitsmall", "dinov2_vits14"] D_netDs = [ ["basic", "projected_d"],