-
Notifications
You must be signed in to change notification settings - Fork 9.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
> > > Hi, how do I replace Darknet53 in Yolov3 with MobileNetV2? #5450
Comments
Yes, I just want to modify the backbone part minimally, but I don't know how to start |
Hi, I think the out_indices in backbone should be set to [4, 6, 7] in_channels in neck should be [64, 160, 320] Reference:
|
Hi @GalSang17 , Sorry that I have made some mistakes. Then, You may try BRW, I think you should reduce the channels in head, because you are using a lightweight model. |
Thanks for your answer, it is now working properly. By the way, does Open-MMLab have pre-training weights for MobileNetV2, like this |
Hi, I'm not sure about it. But, since MobileNet is a quite small backbone, it should be Okay to start from scratch~ |
Hi @GalSang17 it's works correctly for you ? Can you tell me which |
Since #5218 is merged, you can use any out_channels.(btw, you may want to pull and checkout the latest code) IMHO, |
@ElectronicElephant Thank you very much, I would like to use MobileNetV2/V3_small (I have a small dataset) with YoloV3, it is also necessary to adapt the With this configuration I get this error: cfg.model = dict(
type='YOLOV3',
pretrained='open-mmlab://MobileNetV2',
backbone=dict(type='MobileNetV2', widen_factor=1.0, out_indices=(2, 4, 6)),
neck=dict(
type='YOLOV3Neck',
num_scales=3,
in_channels=[320, 96, 32],
out_channels=[160, 48, 16]),
bbox_head=dict(
type='YOLOV3Head',
num_classes=3,
in_channels=[160, 48, 16],
out_channels=[320, 96, 32],
anchor_generator=dict(
type='YOLOAnchorGenerator',
base_sizes=[[(116, 90), (156, 198), (373, 326)],
[(30, 61), (62, 45), (59, 119)],
[(10, 13), (16, 30), (33, 23)]],
strides=[32, 16, 8]),
bbox_coder=dict(type='YOLOBBoxCoder'),
featmap_strides=[32, 16, 8],
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0,
reduction='sum'),
loss_conf=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0,
reduction='sum'),
loss_xy=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=2.0,
reduction='sum'),
loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')),
train_cfg=dict(
assigner=dict(
type='GridAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0)),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
conf_thr=0.005,
nms=dict(type='nms', iou_threshold=0.45),
max_per_img=100)) |
Hi @tikitong , it should work as expected. Are you using the latest version of mmdetection? |
Hi @ElectronicElephant yes I think I use the last one I did a import torch
print(torch.__version__, torch.cuda.is_available())
import mmdet
print(mmdet.__version__)
import mmcv
print(mmcv.__version__)
from mmcv.ops import get_compiler_version, get_compiling_cuda_version
print(get_compiling_cuda_version())
print(get_compiler_version())
1.8.1 True
2.13.0
1.3.3
10.1
GCC 7.3 |
Just wait for a minute. I'll create a pr to add such config. |
@ElectronicElephant Thank you very much for the example #5510 ! Indeed my case should also work... I don't understand. With the basic DarkNet backbone it works fine. Here is my entire configuration file. Do you notice anything that could cause the error ? from mmcv import Config
from mmdet.apis import set_random_seed
from pathlib import Path
import os.path as osp
from mmcv import mkdir_or_exist
from mmdet.datasets import build_dataset
from mmdet.models import backbones, build_detector, necks
from mmdet.apis import train_detector
base = 'configs/yolo/yolov3_d53_320_273e_coco.py'
cfg = Config.fromfile(base)
cfg.model = dict(
type='YOLOV3',
pretrained='open-mmlab://MobileNetV2',
backbone=dict(type='MobileNetV2', widen_factor=1.0, out_indices=(2, 4, 6)),
neck=dict(
type='YOLOV3Neck',
num_scales=3,
in_channels=[320, 96, 32],
out_channels=[160, 48, 16]),
bbox_head=dict(
type='YOLOV3Head',
num_classes=3,
in_channels=[160, 48, 16],
out_channels=[320, 96, 32],
anchor_generator=dict(
type='YOLOAnchorGenerator',
base_sizes=[[(116, 90), (156, 198), (373, 326)],
[(30, 61), (62, 45), (59, 119)],
[(10, 13), (16, 30), (33, 23)]],
strides=[32, 16, 8]),
bbox_coder=dict(type='YOLOBBoxCoder'),
featmap_strides=[32, 16, 8],
loss_cls=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0,
reduction='sum'),
loss_conf=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0,
reduction='sum'),
loss_xy=dict(
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=2.0,
reduction='sum'),
loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')),
train_cfg=dict(
assigner=dict(
type='GridAssigner',
pos_iou_thr=0.5,
neg_iou_thr=0.5,
min_pos_iou=0)),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
conf_thr=0.005,
nms=dict(type='nms', iou_threshold=0.45),
max_per_img=100))
cfg.load_from = 'checkpoints/mobilenet_v2.pth'
cfg.dataset_type = 'MyDataset'
cfg.data_root = 'set'
cfg.classes = ('triangle', 'circle', 'square')
cfg.data.train.type = cfg.dataset_type
cfg.data.train.data_root = cfg.data_root
cfg.data.train.ann_file = 'train.txt'
cfg.data.train.img_prefix = 'image'
cfg.data.train.pipeline = cfg.train_pipeline
cfg.data.val.type = cfg.dataset_type
cfg.data.val.data_root = cfg.data_root
cfg.data.val.ann_file = 'val.txt'
cfg.data.val.img_prefix = 'image'
cfg.data.test.type = cfg.dataset_type
cfg.data.test.data_root = cfg.data_root
cfg.data.test.ann_file = 'test.txt'
cfg.data.test.img_prefix = 'image'
cfg.optimizer = dict(type='SGD', lr=6e-4, momentum=0.9, weight_decay=0.0001)
cfg.optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
cfg.runner = dict(type='EpochBasedRunner', max_epochs=40)
cfg.total_epochs = 40
cfg.evaluation = dict(interval=20, metric=['mAP'])
cfg.log_config = dict(
interval=172,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
cfg.checkpoint_config = dict(interval=50)
cfg.work_dir = f'work_dirs/{Path(base).stem}_'
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
print(f'Config: {cfg.total_epochs} epochs\n{cfg.pretty_text}')
datasets = [build_dataset(cfg.data.train)]
model = build_detector(cfg.model)
mkdir_or_exist(osp.abspath(cfg.work_dir))
train_detector(model, datasets, cfg, distributed=False, validate=True) |
@ElectronicElephant Thanks a lot, that was it ! Both files mmdet/models/necks/yolo_neck.py and tests/test_models/test_necks.py were not up to date. Despite the |
Any update? |
Feel free to reopen the issue if there is any problem |
emmm, there are several points:
For the backbone part, you should check the strides for each
out_indices
. The length ofout_indices
should be 3, if you do like minimal change of neck or head.For the neck part, you need to at least make sure that
in_channels
work with the backbone.BTW, you can open a new issue and we can track the problem there.
Originally posted by @ElectronicElephant in #3083 (comment)
The text was updated successfully, but these errors were encountered: