Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor multi branch #8634

Merged
merged 27 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8fb235a
refactor dataflow
hhaAndroid Aug 23, 2022
d17d1e9
fix docstr
hhaAndroid Aug 23, 2022
ce1e068
fix commit
hhaAndroid Aug 23, 2022
d2780bf
fix commit
hhaAndroid Aug 24, 2022
f33fe2b
fix visualizer hook
hhaAndroid Aug 24, 2022
3cbdf31
fix UT
hhaAndroid Aug 24, 2022
9723433
fix UT
hhaAndroid Aug 24, 2022
1e661bb
resolve conflicts
hhaAndroid Aug 24, 2022
0b90fb2
fix UT error
hhaAndroid Aug 24, 2022
f8c5e7b
fix bug
hhaAndroid Aug 24, 2022
6d21f79
Refactor semi data flow
Czm369 Aug 24, 2022
485117d
Merge branch 'data_flow' of github.com:hhaAndroid/mmdetection into re…
Czm369 Aug 24, 2022
adc8ffa
update to mmengine main
hhaAndroid Aug 25, 2022
3ef2233
update typehint
hhaAndroid Aug 25, 2022
53777ab
replace data preprocess output type to dict
hhaAndroid Aug 25, 2022
b133fd8
update
hhaAndroid Aug 25, 2022
2f1add5
fix typehint
hhaAndroid Aug 25, 2022
d4ccf31
Merge branch 'data_flow' of github.com:hhaAndroid/mmdetection into re…
Czm369 Aug 25, 2022
f57acb6
Refactor MultiBranchDataPreprocessor again
Czm369 Aug 25, 2022
bca7bde
Solve thr conflict
Czm369 Aug 25, 2022
d4e67ee
Add some docstring
Czm369 Aug 25, 2022
8e3097a
Add some examples
Czm369 Aug 26, 2022
e267a81
Fix some commits
Czm369 Aug 26, 2022
cc418f4
Merge branch 'dev-3.x' of github.com:open-mmlab/mmdetection into refa…
Czm369 Aug 26, 2022
88880a9
Merge branch 'dev-3.x' of github.com:open-mmlab/mmdetection into refa…
Czm369 Aug 26, 2022
c7c4390
fix some commits
Czm369 Aug 26, 2022
afa69c5
Merge branch 'dev-3.x' of github.com:open-mmlab/mmdetection into refa…
Czm369 Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion configs/_base_/datasets/semi_coco_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

scale = [(1333, 400), (1333, 1200)]

branch_field = ['sup', 'unsup_teacher', 'unsup_student']
# pipeline used to augment labeled data,
# which will be sent to student model for supervised training.
sup_pipeline = [
Expand All @@ -41,7 +42,10 @@
dict(type='RandomFlip', prob=0.5),
dict(type='RandAugment', aug_space=color_space, aug_num=1),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
dict(type='MultiBranch', sup=dict(type='PackDetInputs'))
dict(
type='MultiBranch',
branch_field=branch_field,
sup=dict(type='PackDetInputs'))
]

# pipeline used to augment unlabeled data weakly,
Expand Down Expand Up @@ -82,6 +86,7 @@
dict(type='LoadEmptyAnnotations'),
dict(
type='MultiBranch',
branch_field=branch_field,
unsup_teacher=weak_pipeline,
unsup_student=strong_pipeline,
)
Expand Down
26 changes: 22 additions & 4 deletions mmdet/datasets/transforms/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,53 @@ class MultiBranch(BaseTransform):
Generate multiple data-augmented versions of the same image.

Args:
branch_field (list): List of branch names.
branch_pipelines (dict): Dict of different pipeline configs
to be composed.

"""

def __init__(self, **branch_pipelines: dict) -> None:
def __init__(self, branch_field: List[str],
**branch_pipelines: dict) -> None:
self.branch_field = branch_field
self.branch_pipelines = {
branch: Compose(pipeline)
for branch, pipeline in branch_pipelines.items()
}

def transform(self, results: dict) -> Optional[List[dict]]:
def transform(self, results: dict) -> dict:
"""Transform function to apply transforms sequentially.

Args:
results (dict): Result dict from loading pipeline.

Returns:
list[dict]: Results from different pipeline.
dict:

- 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of
models from different branches.
- 'data_sample' (Dict[str,obj:`DetDataSample`]): The annotation
info of the sample from different branches.
"""
multi_results = {}
for branch in self.branch_field:
multi_results[branch] = {'inputs': None, 'data_samples': None}
for branch, pipeline in self.branch_pipelines.items():
branch_results = pipeline(copy.deepcopy(results))
# If one branch pipeline returns None,
# it will sample another data from dataset.
if branch_results is None:
return None
multi_results[branch] = branch_results
return multi_results

format_results = {}
for branch, results in multi_results.items():
for key in results.keys():
if format_results.get(key, None) is None:
format_results[key] = {branch: results[key]}
else:
format_results[key][branch] = results[key]
return format_results

def __repr__(self) -> str:
repr_str = self.__class__.__name__
Expand Down
59 changes: 41 additions & 18 deletions mmdet/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -377,37 +377,60 @@ def __init__(self, data_preprocessor: ConfigType) -> None:
super().__init__()
self.data_preprocessor = MODELS.build(data_preprocessor)

def forward(
self,
data: dict,
training: bool = False
) -> Tuple[Dict[str, torch.Tensor], Dict[str, Optional[list]]]:
def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization、padding and bgr2rgb conversion based on
``BaseDataPreprocessor`` for multi-branch data.

Args:
data (Sequence[dict]): data sampled from dataloader.
data (dict): Data sampled from dataloader.
training (bool): Whether to enable training time augmentation.

Returns:
Tuple[Dict[torch.Tensor], Dict[Optional[list]]]: Each tuple of
zip(dict, dict) is the data in the same format as the model input.
dict:

- 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of
models from different branches.
- 'data_sample' (Dict[str,obj:`DetDataSample`]): The annotation
info of the sample from different branches.
"""

if training is False:
return self.data_preprocessor(data, training)

# Filter out branches with a value of None
for key in data.keys():
for branch in data[key].keys():
data[key][branch] = list(
filter(lambda x: x is not None, data[key][branch]))

# Group data by branch
multi_branch_data = {}
for multi_results in data:
for branch, results in multi_results.items():
for key in data.keys():
for branch in data[key].keys():
if multi_branch_data.get(branch, None) is None:
multi_branch_data[branch] = [results]
multi_branch_data[branch] = {key: data[key][branch]}
elif multi_branch_data[branch].get(key, None) is None:
multi_branch_data[branch][key] = data[key][branch]
else:
multi_branch_data[branch].append(results)
multi_batch_inputs, multi_batch_data_samples = {}, {}
for branch, data in multi_branch_data.items():
multi_batch_inputs[branch], multi_batch_data_samples[
branch] = self.data_preprocessor(data, training)
return multi_batch_inputs, multi_batch_data_samples
multi_branch_data[branch][key].append(data[key][branch])

# Preprocess data from different branches
for branch, _data in multi_branch_data.items():
multi_branch_data[branch] = self.data_preprocessor(_data, training)

# Format data by inputs and data_samples
format_data = {}
for branch in multi_branch_data.keys():
for key in multi_branch_data[branch].keys():
if format_data.get(key, None) is None:
format_data[key] = {branch: multi_branch_data[branch][key]}
elif format_data[key].get(branch, None) is None:
format_data[key][branch] = multi_branch_data[branch][key]
else:
format_data[key][branch].append(
multi_branch_data[branch][key])

return format_data

@property
def device(self):
Expand Down
31 changes: 17 additions & 14 deletions tests/test_datasets/test_transforms/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def setUp(self):
'ignore_flag': 1
}]
}
self.branch_field = ['sup', 'sup_teacher', 'sup_student']
self.weak_pipeline = [
dict(type='ShearX'),
dict(type='PackDetInputs', meta_keys=self.meta_keys)
Expand All @@ -70,6 +71,7 @@ def setUp(self):
dict(type='RandomFlip', prob=0.5),
dict(
type='MultiBranch',
branch_field=self.branch_field,
sup_teacher=self.weak_pipeline,
sup_student=self.strong_pipeline),
]
Expand All @@ -79,6 +81,7 @@ def setUp(self):
dict(type='RandomFlip', prob=0.5),
dict(
type='MultiBranch',
branch_field=self.branch_field,
unsup_teacher=self.weak_pipeline,
unsup_student=self.strong_pipeline),
]
Expand All @@ -92,39 +95,39 @@ def test_transform(self):
# test branch sup_teacher and sup_student
sup_branches = ['sup_teacher', 'sup_student']
for branch in sup_branches:
self.assertIn(branch, labeled_results)
self.assertIn(branch, labeled_results['data_samples'])
self.assertIn('homography_matrix',
labeled_results[branch]['data_samples'])
labeled_results['data_samples'][branch])
self.assertIn('labels',
labeled_results[branch]['data_samples'].gt_instances)
labeled_results['data_samples'][branch].gt_instances)
self.assertIn('bboxes',
labeled_results[branch]['data_samples'].gt_instances)
labeled_results['data_samples'][branch].gt_instances)
self.assertIn('masks',
labeled_results[branch]['data_samples'].gt_instances)
labeled_results['data_samples'][branch].gt_instances)
self.assertIn('gt_sem_seg',
labeled_results[branch]['data_samples'])

labeled_results['data_samples'][branch])
# test branch unsup_teacher and unsup_student
unsup_branches = ['unsup_teacher', 'unsup_student']
for branch in unsup_branches:
self.assertIn(branch, unlabeled_results)
self.assertIn(branch, unlabeled_results['data_samples'])
self.assertIn('homography_matrix',
unlabeled_results[branch]['data_samples'])
unlabeled_results['data_samples'][branch])
self.assertNotIn(
'labels',
unlabeled_results[branch]['data_samples'].gt_instances)
unlabeled_results['data_samples'][branch].gt_instances)
self.assertNotIn(
'bboxes',
unlabeled_results[branch]['data_samples'].gt_instances)
unlabeled_results['data_samples'][branch].gt_instances)
self.assertNotIn(
'masks',
unlabeled_results[branch]['data_samples'].gt_instances)
unlabeled_results['data_samples'][branch].gt_instances)
self.assertNotIn('gt_sem_seg',
unlabeled_results[branch]['data_samples'])
unlabeled_results['data_samples'][branch])

def test_repr(self):
pipeline = [dict(type='PackDetInputs', meta_keys=())]
transform = MultiBranch(sup=pipeline, unsup=pipeline)
transform = MultiBranch(
branch_field=self.branch_field, sup=pipeline, unsup=pipeline)
self.assertEqual(
repr(transform),
("MultiBranch(branch_pipelines=['sup', 'unsup'])"))
Expand Down
69 changes: 31 additions & 38 deletions tests/test_models/test_data_preprocessors/test_data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,34 +314,26 @@ def setUp(self):
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=32)
self.multi_data = [
{
'sup': {
'inputs': torch.randint(0, 256, (3, 224, 224)),
'data_samples': DetDataSample()
}
self.multi_data = {
'inputs': {
'sup': [torch.randint(0, 256, (3, 224, 224))],
'unsup_teacher': [
torch.randint(0, 256, (3, 400, 600)),
torch.randint(0, 256, (3, 600, 400))
],
'unsup_student': [
torch.randint(0, 256, (3, 700, 500)),
torch.randint(0, 256, (3, 500, 700))
]
},
{
'unsup_teacher': {
'inputs': torch.randint(0, 256, (3, 400, 600)),
'data_samples': DetDataSample()
},
'unsup_student': {
'inputs': torch.randint(0, 256, (3, 700, 500)),
'data_samples': DetDataSample()
}
},
{
'unsup_teacher': {
'inputs': torch.randint(0, 256, (3, 600, 400)),
'data_samples': DetDataSample()
},
'unsup_student': {
'inputs': torch.randint(0, 256, (3, 500, 700)),
'data_samples': DetDataSample()
}
},
]
'data_samples': {
'sup': [DetDataSample()],
'unsup_teacher': [DetDataSample(),
DetDataSample()],
'unsup_student': [DetDataSample(),
DetDataSample()],
}
}
self.data = {
'inputs': [torch.randint(0, 256, (3, 224, 224))],
'data_samples': [DetDataSample()]
Expand All @@ -350,15 +342,16 @@ def setUp(self):
def test_multi_data_preprocessor(self):
processor = MultiBranchDataPreprocessor(self.data_preprocessor)
# test processing multi_data when training
multi_inputs, multi_data_samples = processor(
self.multi_data, training=True)
self.assertEqual(multi_inputs['sup'].shape, (1, 3, 224, 224))
self.assertEqual(multi_inputs['unsup_teacher'].shape, (2, 3, 608, 608))
self.assertEqual(multi_inputs['unsup_student'].shape, (2, 3, 704, 704))
self.assertEqual(len(multi_data_samples['sup']), 1)
self.assertEqual(len(multi_data_samples['unsup_teacher']), 2)
self.assertEqual(len(multi_data_samples['unsup_student']), 2)
multi_data = processor(self.multi_data, training=True)
self.assertEqual(multi_data['inputs']['sup'].shape, (1, 3, 224, 224))
self.assertEqual(multi_data['inputs']['unsup_teacher'].shape,
(2, 3, 608, 608))
self.assertEqual(multi_data['inputs']['unsup_student'].shape,
(2, 3, 704, 704))
self.assertEqual(len(multi_data['data_samples']['sup']), 1)
self.assertEqual(len(multi_data['data_samples']['unsup_teacher']), 2)
self.assertEqual(len(multi_data['data_samples']['unsup_student']), 2)
# test processing data when testing
inputs, data_samples = processor(self.data)
self.assertEqual(inputs.shape, (1, 3, 224, 224))
self.assertEqual(len(data_samples), 1)
data = processor(self.data)
self.assertEqual(data['inputs'].shape, (1, 3, 224, 224))
self.assertEqual(len(data['data_samples']), 1)