Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug when using iterbaserunner with 'val' workflow #542

Merged
merged 4 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _parse_losses(self, losses):

return loss, log_vars

def train_step(self, data, optimizer):
def train_step(self, data, optimizer=None, **kwargs):
Ezra-Yu marked this conversation as resolved.
Show resolved Hide resolved
"""The iteration step during training.

This method defines an iteration step during training, except for the
Expand All @@ -129,9 +129,9 @@ def train_step(self, data, optimizer):

Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
optimizer (:obj:`torch.optim.Optimizer` | dict | optional): The
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the right usage should be

Suggested change
optimizer (:obj:`torch.optim.Optimizer` | dict | optional): The
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The

Because optional is not a type, but a kind of annotation.

optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.

Returns:
dict: Dict of outputs. The following fields are contained.
Expand All @@ -151,12 +151,28 @@ def train_step(self, data, optimizer):

return outputs

def val_step(self, data, optimizer):
def val_step(self, data, optimizer=None, **kwargs):
"""The iteration step during validation.

This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.

Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict | optional): The
optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.

Returns:
dict: Dict of outputs. The following fields are contained.
- loss (torch.Tensor): A tensor for back propagation, which \
can be a weighted sum of multiple losses.
- log_vars (dict): Dict contains all the variables to be sent \
to the logger.
- num_samples (int): Indicates the batch size (when the model \
is DDP, it means the batch size on each GPU), which is \
used for averaging the logs.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_models/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,21 @@ def test_image_classifier():
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16

# test train_step without optimizer
outputs = model.train_step({'img': imgs, 'gt_label': label})
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16

# test val_step
outputs = model.val_step({'img': imgs, 'gt_label': label}, None)
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16

# test val_step without optimizer
outputs = model.val_step({'img': imgs, 'gt_label': label})
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16

# test forward
losses = model(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0
Expand Down