Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor multi branch #8634

Merged
merged 27 commits into from
Aug 26, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
8fb235a
refactor dataflow
hhaAndroid Aug 23, 2022
d17d1e9
fix docstr
hhaAndroid Aug 23, 2022
ce1e068
fix commit
hhaAndroid Aug 23, 2022
d2780bf
fix commit
hhaAndroid Aug 24, 2022
f33fe2b
fix visualizer hook
hhaAndroid Aug 24, 2022
3cbdf31
fix UT
hhaAndroid Aug 24, 2022
9723433
fix UT
hhaAndroid Aug 24, 2022
1e661bb
resolve conflicts
hhaAndroid Aug 24, 2022
0b90fb2
fix UT error
hhaAndroid Aug 24, 2022
f8c5e7b
fix bug
hhaAndroid Aug 24, 2022
6d21f79
Refactor semi data flow
Czm369 Aug 24, 2022
485117d
Merge branch 'data_flow' of github.com:hhaAndroid/mmdetection into re…
Czm369 Aug 24, 2022
adc8ffa
update to mmengine main
hhaAndroid Aug 25, 2022
3ef2233
update typehint
hhaAndroid Aug 25, 2022
53777ab
replace data preprocess output type to dict
hhaAndroid Aug 25, 2022
b133fd8
update
hhaAndroid Aug 25, 2022
2f1add5
fix typehint
hhaAndroid Aug 25, 2022
d4ccf31
Merge branch 'data_flow' of github.com:hhaAndroid/mmdetection into re…
Czm369 Aug 25, 2022
f57acb6
Refactor MultiBranchDataPreprocessor again
Czm369 Aug 25, 2022
bca7bde
Solve thr conflict
Czm369 Aug 25, 2022
d4e67ee
Add some docstring
Czm369 Aug 25, 2022
8e3097a
Add some examples
Czm369 Aug 26, 2022
e267a81
Fix some commits
Czm369 Aug 26, 2022
cc418f4
Merge branch 'dev-3.x' of github.com:open-mmlab/mmdetection into refa…
Czm369 Aug 26, 2022
88880a9
Merge branch 'dev-3.x' of github.com:open-mmlab/mmdetection into refa…
Czm369 Aug 26, 2022
c7c4390
fix some commits
Czm369 Aug 26, 2022
afa69c5
Merge branch 'dev-3.x' of github.com:open-mmlab/mmdetection into refa…
Czm369 Aug 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions mmdet/datasets/transforms/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,62 @@ class MultiBranch(BaseTransform):
r"""Multiple branch pipeline wrapper.

Generate multiple data-augmented versions of the same image.
`MultiBranch` needs to specify the branch names of all
pipelines of the dataset, perform corresponding data augmentation
for the current branch, and return None for other branches,
which ensures the uniformity of return value.

Args:
branch_field (list): List of branch names.
branch_pipelines (dict): Dict of different pipeline configs
to be composed.

Examples:
>>> branch_field = ['sup', 'unsup_teacher', 'unsup_student']
>>> sup_pipeline = [
>>> dict(type='LoadImageFromFile',
>>> file_client_args=file_client_args),
>>> dict(type='LoadAnnotations', with_bbox=True),
>>> dict(type='RandomResize', scale=scale, keep_ratio=True),
>>> dict(type='RandomFlip', prob=0.5),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> sup=dict(type='PackDetInputs'))
>>> ]
>>> weak_pipeline = [
>>> dict(type='LoadImageFromFile',
>>> file_client_args=file_client_args),
>>> dict(type='LoadAnnotations', with_bbox=True),
>>> dict(type='RandomResize', scale=scale, keep_ratio=True),
>>> dict(type='RandomFlip', prob=0.0),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> sup=dict(type='PackDetInputs'))
>>> ]
>>> strong_pipeline = [
>>> dict(type='LoadImageFromFile',
>>> file_client_args=file_client_args),
>>> dict(type='LoadAnnotations', with_bbox=True),
>>> dict(type='RandomResize', scale=scale, keep_ratio=True),
>>> dict(type='RandomFlip', prob=1.0),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> sup=dict(type='PackDetInputs'))
>>> ]
>>> unsup_pipeline = [
>>> dict(type='LoadImageFromFile',
>>> file_client_args=file_client_args),
>>> dict(type='LoadEmptyAnnotations'),
>>> dict(
>>> type='MultiBranch',
>>> branch_field=branch_field,
>>> unsup_teacher=weak_pipeline,
>>> unsup_student=strong_pipeline)
>>> ]

"""

def __init__(self, branch_field: List[str],
Expand Down
83 changes: 82 additions & 1 deletion mmdet/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,87 @@ def forward(
class MultiBranchDataPreprocessor(BaseDataPreprocessor):
"""DataPreprocessor wrapper for multi-branch data.

Take semi-supervised object detection as an example, assume that
the ratio of labeled data and unlabeled data in a batch is 1:2,
`sup` indicates the branch where the labeled data is augmented,
`unsup_teacher` and `unsup_student` indicate the branches where
the unlabeled data is augmented by different pipeline. Therefore,
the input format of multi-branch data is shown as follows:

.. code-block:: none
{
'inputs':
{
'sup': [Tensor, None, None],
'unsup_teacher': [None, Tensor, Tensor],
'unsup_student': [None, Tensor, Tensor],
},
'data_sample':
{
'sup': [DetDataSample, None, None],
'unsup_teacher': [None, DetDataSample, DetDataSample],
'unsup_student': [NOne, DetDataSample, DetDataSample],
}
}

Filter out branches with a value of None:

.. code-block:: none
{
'inputs':
{
'sup': [Tensor],
'unsup_teacher': [Tensor, Tensor],
'unsup_student': [Tensor, Tensor],
},
'data_sample':
{
'sup': [DetDataSample],
'unsup_teacher': [DetDataSample, DetDataSample],
'unsup_student': [DetDataSample, DetDataSample],
}
}

Group data by branch:

.. code-block:: none
{
'sup':
{
'inputs': [Tensor]
'data_sample': [DetDataSample, DetDataSample]
},
'unsup_teacher':
{
'inputs': [Tensor, Tensor]
'data_sample': [DetDataSample, DetDataSample]
},
'unsup_student':
{
'inputs': [Tensor, Tensor]
'data_sample': [DetDataSample, DetDataSample]
},
}

After preprocessing data from different branches,
the multi-branch data needs to be reformatted as:

.. code-block:: none
{
'inputs':
{
'sup': [Tensor],
'unsup_teacher': [Tensor, Tensor],
'unsup_student': [Tensor, Tensor],
},
'data_sample':
{
'sup': [DetDataSample],
'unsup_teacher': [DetDataSample, DetDataSample],
'unsup_student': [DetDataSample, DetDataSample],
}
}

Args:
data_preprocessor (:obj:`ConfigDict` or dict): Config of
:class:`DetDataPreprocessor` to process the input data.
Expand All @@ -390,7 +471,7 @@ def forward(self, data: dict, training: bool = False) -> dict:

- 'inputs' (Dict[str, obj:`torch.Tensor`]): The forward data of
models from different branches.
- 'data_sample' (Dict[str,obj:`DetDataSample`]): The annotation
- 'data_sample' (Dict[str, obj:`DetDataSample`]): The annotation
info of the sample from different branches.
"""

Expand Down