Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

> > > Hi, how do I replace Darknet53 in Yolov3 with MobileNetV2? #5450

Closed
GalSang17 opened this issue Jun 25, 2021 · 19 comments
Closed

> > > Hi, how do I replace Darknet53 in Yolov3 with MobileNetV2? #5450

GalSang17 opened this issue Jun 25, 2021 · 19 comments
Assignees

Comments

@GalSang17
Copy link

Hi, how do I replace Darknet53 in Yolov3 with MobileNetV2?

IMHO, just changing the config file is OK. It should be working in the mm-detection-way. There is nothing special.

image
Is it just like this? The neck section will throw an exception

emmm, there are several points:

For the backbone part, you should check the strides for each out_indices. The length of out_indices should be 3, if you do like minimal change of neck or head.

For the neck part, you need to at least make sure that in_channels work with the backbone.

BTW, you can open a new issue and we can track the problem there.

Originally posted by @ElectronicElephant in #3083 (comment)

@GalSang17
Copy link
Author

Yes, I just want to modify the backbone part minimally, but I don't know how to start

@ElectronicElephant
Copy link
Contributor

Hi,

I think the out_indices in backbone should be set to [4, 6, 7]

in_channels in neck should be [64, 160, 320]

Reference:

arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],

image
image

@GalSang17
Copy link
Author

Hi,

I think the out_indices in backbone should be set to [4, 6, 7]

in_channels in neck should be [64, 160, 320]

Reference:

arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],

image
image

I tried the modification method you mentioned, but it doesn't seem to work, like this screenshot
image
image

Then I tried out_indices=(2,4,6) and the output was consistent with the same size data tested with Darknet53, like this screenshot
image
image

@ElectronicElephant
Copy link
Contributor

Hi @GalSang17 ,

Sorry that I have made some mistakes. Then, 2,4,6 is Okay for backbone.

You may try in_channels=[320, 96, 32] in neck, and make sure the out_channels in neck is consistent with the in_channels in head. It's more flexible thanks to #5218 compared with the one I originally wrote.

BRW, I think you should reduce the channels in head, because you are using a lightweight model.

@GalSang17
Copy link
Author

Hi @GalSang17 ,

Sorry that I have made some mistakes. Then, 2,4,6 is Okay for backbone.

You may try in_channels=[320, 96, 32] in neck, and make sure the out_channels in neck is consistent with the in_channels in head. It's more flexible thanks to #5218 compared with the one I originally wrote.

BRW, I think you should reduce the channels in head, because you are using a lightweight model.

Thanks for your answer, it is now working properly. By the way, does Open-MMLab have pre-training weights for MobileNetV2, like this
image

@ElectronicElephant
Copy link
Contributor

Hi @GalSang17 ,
Sorry that I have made some mistakes. Then, 2,4,6 is Okay for backbone.
You may try in_channels=[320, 96, 32] in neck, and make sure the out_channels in neck is consistent with the in_channels in head. It's more flexible thanks to #5218 compared with the one I originally wrote.
BRW, I think you should reduce the channels in head, because you are using a lightweight model.

Thanks for your answer, it is now working properly. By the way, does Open-MMLab have pre-training weights for MobileNetV2, like this
image

Hi, I'm not sure about it. But, since MobileNet is a quite small backbone, it should be Okay to start from scratch~

@ElectronicElephant
Copy link
Contributor

@GalSang17 Hi, you can try https://github.com/open-mmlab/mmcv/blob/db097bd1e97fc446a7551c715970611d2fcc848d/mmcv/model_zoo/open_mmlab.json#L42

@tikitong
Copy link

tikitong commented Jul 1, 2021

Hi @GalSang17 ,
Sorry that I have made some mistakes. Then, 2,4,6 is Okay for backbone.
You may try in_channels=[320, 96, 32] in neck, and make sure the out_channels in neck is consistent with the in_channels in head. It's more flexible thanks to #5218 compared with the one I originally wrote.
BRW, I think you should reduce the channels in head, because you are using a lightweight model.

Thanks for your answer, it is now working properly. By the way, does Open-MMLab have pre-training weights for MobileNetV2, like this
image

Hi @GalSang17 it's works correctly for you ? Can you tell me which out_channels you put in the neck? Thanks a lot.

@ElectronicElephant
Copy link
Contributor

Hi @GalSang17 ,
Sorry that I have made some mistakes. Then, 2,4,6 is Okay for backbone.
You may try in_channels=[320, 96, 32] in neck, and make sure the out_channels in neck is consistent with the in_channels in head. It's more flexible thanks to #5218 compared with the one I originally wrote.
BRW, I think you should reduce the channels in head, because you are using a lightweight model.

Thanks for your answer, it is now working properly. By the way, does Open-MMLab have pre-training weights for MobileNetV2, like this
image

Hi @GalSang17 it's works correctly for you ? Can you tell me which out_channels you put in the neck? Thanks a lot.

Since #5218 is merged, you can use any out_channels.(btw, you may want to pull and checkout the latest code) IMHO, [160, 48, 16] would be a good value to start.

@tikitong
Copy link

tikitong commented Jul 2, 2021

@ElectronicElephant Thank you very much, I would like to use MobileNetV2/V3_small (I have a small dataset) with YoloV3, it is also necessary to adapt the in_channels and out_channels of the bbox_head true?

With this configuration I get this error: RuntimeError: Given groups=1, weight of size [48, 96, 1, 1], expected input[8, 160, 32, 32] to have 96 channels, but got 160 channels instead

cfg.model = dict(
    type='YOLOV3',
    pretrained='open-mmlab://MobileNetV2',
    backbone=dict(type='MobileNetV2', widen_factor=1.0, out_indices=(2, 4, 6)),
    neck=dict(
        type='YOLOV3Neck',
        num_scales=3,
        in_channels=[320, 96, 32],
        out_channels=[160, 48, 16]),
    bbox_head=dict(
        type='YOLOV3Head',
        num_classes=3,
        in_channels=[160, 48, 16],
        out_channels=[320, 96, 32],
        anchor_generator=dict(
            type='YOLOAnchorGenerator',
            base_sizes=[[(116, 90), (156, 198), (373, 326)],
                        [(30, 61), (62, 45), (59, 119)],
                        [(10, 13), (16, 30), (33, 23)]],
            strides=[32, 16, 8]),
        bbox_coder=dict(type='YOLOBBoxCoder'),
        featmap_strides=[32, 16, 8],
        loss_cls=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            loss_weight=1.0,
            reduction='sum'),
        loss_conf=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            loss_weight=1.0,
            reduction='sum'),
        loss_xy=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            loss_weight=2.0,
            reduction='sum'),
        loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')),
    train_cfg=dict(
        assigner=dict(
            type='GridAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.5,
            min_pos_iou=0)),
    test_cfg=dict(
        nms_pre=1000,
        min_bbox_size=0,
        score_thr=0.05,
        conf_thr=0.005,
        nms=dict(type='nms', iou_threshold=0.45),
        max_per_img=100))

@ElectronicElephant
Copy link
Contributor

ElectronicElephant commented Jul 2, 2021

Hi @tikitong , it should work as expected. Are you using the latest version of mmdetection?

@tikitong
Copy link

tikitong commented Jul 2, 2021

Hi @ElectronicElephant yes I think I use the last one I did a git pull a week ago and updated my mmcv and mmdet versions.

import torch
print(torch.__version__, torch.cuda.is_available())

import mmdet
print(mmdet.__version__)

import mmcv
print(mmcv.__version__)

from mmcv.ops import get_compiler_version, get_compiling_cuda_version
print(get_compiling_cuda_version())
print(get_compiler_version())

1.8.1 True
2.13.0
1.3.3
10.1
GCC 7.3

@ElectronicElephant
Copy link
Contributor

Hi @ElectronicElephant yes I think I use the last one I did a git pull a week ago..

Just wait for a minute. I'll create a pr to add such config.

@ElectronicElephant
Copy link
Contributor

Hi @tikitong , you can check #5510

@tikitong
Copy link

tikitong commented Jul 2, 2021

@ElectronicElephant Thank you very much for the example #5510 ! Indeed my case should also work... I don't understand. With the basic DarkNet backbone it works fine. Here is my entire configuration file. Do you notice anything that could cause the error ?

from mmcv import Config
from mmdet.apis import set_random_seed
from pathlib import Path

import os.path as osp
from mmcv import mkdir_or_exist
from mmdet.datasets import build_dataset
from mmdet.models import backbones, build_detector, necks
from mmdet.apis import train_detector

base = 'configs/yolo/yolov3_d53_320_273e_coco.py'
cfg = Config.fromfile(base)
          
cfg.model = dict(
    type='YOLOV3',
    pretrained='open-mmlab://MobileNetV2',
    backbone=dict(type='MobileNetV2', widen_factor=1.0, out_indices=(2, 4, 6)),
    neck=dict(
        type='YOLOV3Neck',
        num_scales=3,
        in_channels=[320, 96, 32],
        out_channels=[160, 48, 16]),
    bbox_head=dict(
        type='YOLOV3Head',
        num_classes=3,
        in_channels=[160, 48, 16],
        out_channels=[320, 96, 32],
        anchor_generator=dict(
            type='YOLOAnchorGenerator',
            base_sizes=[[(116, 90), (156, 198), (373, 326)],
                        [(30, 61), (62, 45), (59, 119)],
                        [(10, 13), (16, 30), (33, 23)]],
            strides=[32, 16, 8]),
        bbox_coder=dict(type='YOLOBBoxCoder'),
        featmap_strides=[32, 16, 8],
        loss_cls=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            loss_weight=1.0,
            reduction='sum'),
        loss_conf=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            loss_weight=1.0,
            reduction='sum'),
        loss_xy=dict(
            type='CrossEntropyLoss',
            use_sigmoid=True,
            loss_weight=2.0,
            reduction='sum'),
        loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')),
    train_cfg=dict(
        assigner=dict(
            type='GridAssigner',
            pos_iou_thr=0.5,
            neg_iou_thr=0.5,
            min_pos_iou=0)),
    test_cfg=dict(
        nms_pre=1000,
        min_bbox_size=0,
        score_thr=0.05,
        conf_thr=0.005,
        nms=dict(type='nms', iou_threshold=0.45),
        max_per_img=100))

cfg.load_from = 'checkpoints/mobilenet_v2.pth'

cfg.dataset_type = 'MyDataset'
cfg.data_root = 'set'
cfg.classes = ('triangle', 'circle', 'square')

cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'image'
cfg.data.train.pipeline = cfg.train_pipeline

cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'image'

cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.ann_file = 'test.txt'
cfg.data.test.img_prefix = 'image'

cfg.optimizer = dict(type='SGD', lr=6e-4, momentum=0.9, weight_decay=0.0001)

cfg.optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))


cfg.runner = dict(type='EpochBasedRunner', max_epochs=40)
cfg.total_epochs = 40
cfg.evaluation = dict(interval=20, metric=['mAP'])

cfg.log_config = dict(
    interval=172,
    hooks=[dict(type='TextLoggerHook'),
           dict(type='TensorboardLoggerHook')])

cfg.checkpoint_config = dict(interval=50)

cfg.work_dir = f'work_dirs/{Path(base).stem}_'

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

print(f'Config: {cfg.total_epochs} epochs\n{cfg.pretty_text}')

datasets = [build_dataset(cfg.data.train)]
model = build_detector(cfg.model)
mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True)

@ElectronicElephant
Copy link
Contributor

Hi @tikitong , can you just double check if your local code contains the modification in #5218 ? I found nothing unusual. Hmm, wired.

@tikitong
Copy link

tikitong commented Jul 2, 2021

@ElectronicElephant Thanks a lot, that was it ! Both files mmdet/models/necks/yolo_neck.py and tests/test_models/test_necks.py were not up to date. Despite the git pull it's normal ? I report it as a bug? Now it works fine but mAP scores have dropped to 0.00.. with the DarkNet and MobileNetV2 backbone.

@jshilong
Copy link
Collaborator

@ElectronicElephant Thanks a lot, that was it ! Both files mmdet/models/necks/yolo_neck.py and tests/test_models/test_necks.py were not up to date. Despite the git pull it's normal ? I report it as a bug? Now it works fine but mAP scores have dropped to 0.00.. with the DarkNet and MobileNetV2 backbone.

Any update?

@jshilong
Copy link
Collaborator

jshilong commented Sep 6, 2021

Feel free to reopen the issue if there is any problem

@jshilong jshilong closed this as completed Sep 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants