Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wandb logging bug using Iteration based runner #2069

Open
levan92 opened this issue Jun 22, 2022 · 12 comments
Open

Wandb logging bug using Iteration based runner #2069

levan92 opened this issue Jun 22, 2022 · 12 comments
Assignees

Comments

@levan92
Copy link

levan92 commented Jun 22, 2022

Describe the Issue
Validation metrics reporting/logging to wandb does not happen when using IterBasedRunner

Reproduction

Here's a simple reproduction of the bug.

Using mmdetection,

In config file, faster_rcnn/faster_rcnn_r50_caffe_c4_1x_coco.py with the following edits:

max_iters = 100 
runner = dict(
    _delete_=True, 
    type='IterBasedRunner', 
    max_iters=max_iters
)

lr_config = dict(
    policy='step',
    gamma=0.1,
    by_epoch=False,
    warmup='linear',
    warmup_by_epoch=False,
    warmup_ratio=1.0,  # no warmup
    warmup_iters=10
    )

interval = 10
workflow = [('train', interval)]
checkpoint_config = dict(
    by_epoch=False, interval=interval)

evaluation = dict(
    interval=interval,
    metric=['bbox'])

log_config = dict(
    interval=5,
    hooks=[
        dict(type='TextLoggerHook', by_epoch=False),
        dict(
            type='WandbLoggerHook',
            init_kwargs=dict(
                project='train-tests',
                name='short'
                ),
            out_suffix=('.log.json', '.log', '.py'),
            by_epoch=False,
            ),
        ]
    )
  1. Did you make any modifications on the code? Did you understand what you have modified?

Environment
Python 3.8.8
mmcv 1.4.5
mmdet 2.25.0
wandb 0.12.0

Error traceback
If applicable, paste the error traceback here.

2022-06-22 13:55:37,523 - mmdet - INFO - Iter [5/100]   lr: 2.000e-02, eta: 0:01:18, time: 0.831, data_time: 0.035, memory: 7614, loss_rpn_cls: 0.1022, loss_rpn_bbox: 0.2308, loss_cls: 0.1
789, acc: 94.2188, loss_bbox: 0.1995, loss: 0.7115
2022-06-22 13:55:41,200 - mmdet - INFO - Saving checkpoint at 10 iterations
2022-06-22 13:55:43,173 - mmdet - INFO - Iter [10/100]  lr: 2.000e-03, eta: 0:01:32, time: 1.234, data_time: 0.111, memory: 7653, loss_rpn_cls: 0.1946, loss_rpn_bbox: 0.2274, loss_cls: 0.2
475, acc: 92.6758, loss_bbox: 0.2635, loss: 0.9330
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 100/100, 5.3 task/s, elapsed: 19s, ETA:     0s

2022-06-22 13:56:02,333 - mmdet - INFO - Evaluating bbox...
Loading and preparing results...
DONE (t=0.00s)
creating index...
index created!
Running per image evaluation...
Evaluate annotation type *bbox*
DONE (t=0.73s).
Accumulating evaluation results...
DONE (t=0.50s).
2022-06-22 13:56:03,582 - mmdet - INFO -
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.275
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=1000 ] = 0.429
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=1000 ] = 0.295
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.106
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.302
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.448
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.362
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=300 ] = 0.362
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=1000 ] = 0.362
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.162
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.382
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.535

2022-06-22 13:56:03,590 - mmdet - INFO - Iter(val) [100]        bbox_mAP: 0.2750, bbox_mAP_50: 0.4290, bbox_mAP_75: 0.2950, bbox_mAP_s: 0.1060, bbox_mAP_m: 0.3020, bbox_mAP_l: 0.4480, bbox
_mAP_copypaste: 0.275 0.429 0.295 0.106 0.302 0.448
wandb: WARNING Step must only increase in log calls.  Step 10 < 11; dropping {'val/bbox_mAP': 0.275, 'val/bbox_mAP_50': 0.429, 'val/bbox_mAP_75': 0.295, 'val/bbox_mAP_s': 0.106, 'val/bbox_
mAP_m': 0.302, 'val/bbox_mAP_l': 0.448, 'learning_rate': 0.002, 'momentum': 0.9}.
2022-06-22 13:56:07,778 - mmdet - INFO - Iter [15/100]  lr: 2.000e-04, eta: 0:03:14, time: 4.811, data_time: 4.090, memory: 7653, loss_rpn_cls: 0.1167, loss_rpn_bbox: 0.2035, loss_cls: 0.2
403, acc: 94.2578, loss_bbox: 0.1729, loss: 0.7333
2022-06-22 13:56:11,395 - mmdet - INFO - Saving checkpoint at 20 iterations
2022-06-22 13:56:13,374 - mmdet - INFO - Iter [20/100]  lr: 2.000e-04, eta: 0:02:42, time: 1.229, data_time: 0.117, memory: 7653, loss_rpn_cls: 0.0972, loss_rpn_bbox: 0.1758, loss_cls: 0.2
176, acc: 92.2461, loss_bbox: 0.2259, loss: 0.7165

Bug fix
In the wandb hook, for the log method, self.wandb.log is called with commit=True by default all the time. Therefore, the log call from last training step (before validation) will cause wandb to increment step by one. Then when wandb.log is called for the validation metric, wandb's step will be ahead of the current step (at validation) by one.

Is there a good way to commit only after the each validation is done?

@levan92
Copy link
Author

levan92 commented Jun 22, 2022

Tagging @xvjiarui as I saw that the commit enhancements for wandb.log is added by them in #659. Any ideas? Thank you!

@morganmcg1
Copy link

@ayulockin can probably help here :)

@ayulockin
Copy link

Hey @levan92, is the issue that the WandbLoggerHook is not logging validation metrics or is it that the steps (x axis) are not correct (meaningful).

Also since you are using MMDetection can you give MMDetWandbHook a try. It solves the issue of correct x axis.

@levan92
Copy link
Author

levan92 commented Jun 23, 2022

Thanks @ayulockin! Will try out the MMDetWandbHook.

Yup, the issue is that Wandb is not logging any of the val values at all, as wandb's step is higher than the given step (because of the reason I gave at the end of my original post).

@levan92
Copy link
Author

levan92 commented Jun 23, 2022

@ayulockin I've tried out MMDetWandbHook with IterBasedRunner, the validation metrics are still not getting logged due to the same error.

@ayulockin
Copy link

Hey @levan92, I have faced the same issue but didn't dig deeper. I have a hunch that it has something to do with MMDetection's workflow. The MMDetWandbHook subclasses MMCV's WandbLoggerHook to log the train/val metrics.

I can confirm that both WandbLoggerHook and MMDetWandbHook can log validation metrics if it's available to them. Is it a possibility that you can share your code as a colab notebook so that I can reproduce the issue?

@levan92
Copy link
Author

levan92 commented Jun 24, 2022

Yup, in mmcv's WandbLoggerHook, the log method calls self.wandb.log with the argument commit=True by default. When that happens, based on wandb's documentation, it increments wandb's step by 1 every time it is called. Therefore, when it is validation's turn to log, wandb will be ahead of the current step by 1.

A solution is to call log for the training step right before validation with commit=False, then call log for the subsequent validation with commit=True.

To reproduce, you can use the config snippet I provided above in the original post (append it to the existing cfg file I referenced).

@levan92
Copy link
Author

levan92 commented Jun 24, 2022

Here's an ugly fix that works:

In mmcv/mmcv/runner/hooks/logger/wandb.py (From Line 90):

    @master_only
    def log(self, runner) -> None:
        tags = self.get_loggable_tags(runner)
        if tags:
            step = self.get_iter(runner)
            if not self.by_epoch and (step%self.eval_interval) == 0 and not self.eval_step:
                commit = False 
                self.eval_step = True # eval step next 
            else: 
                commit = self.commit
                self.eval_step = False
            if self.with_step:
                self.wandb.log(
                    tags, step=step, commit=commit)
            else:
                tags['global_step'] = step
                self.wandb.log(tags, commit=commit)

What do you think? There's probably a more elegant way of doing this, but this works for me for now.

@ayulockin
Copy link

based on wandb's documentation, it increments wandb's step by 1 every time it is called. Therefore, when it is validation's turn to log, wandb will be ahead of the current step by 1.

Hey @levan92, I don't think that should be an issue. W&B doesn't care about the step mismatch. In the UI the validation metric will be at nth step where n=actual_x+lag.

For a more elegant solution you should check out wandb.define_metric.

@levan92
Copy link
Author

levan92 commented Jun 24, 2022

@ayulockin Ah I see. However, when I ran it originally, this warning message was showing after each validation

wandb: WARNING Step must only increase in log calls.  Step 20 < 21; dropping {'val/bbox_mAP': 0.358, 'val/bbox_mAP_50': 0.539, 'val/bbox_mAP_75': 0.378, 'val/bbox_mAP_s': 0.149, 'val/bbox_mAP_m': 0.431, 'val/bbox_mAP_l': 0.556, 'learning_rate': 0.00020000000000000004, 'momentum': 0.9}.

And also none of these val metrics were logged to the wandb run. See original wandb run here.

However, after applying my hot-fix above, the warning message no longer shows up and the val metrics successfully logs to wandb. See wandb run after fix here.

[Update] I tried the same experiments with upgraded wandb (upgrade from pip version 0.12.0 to 0.12.9), warning message will not show on the newer wandb pip version, but the performance remains the same: val metrics will only be logged to wandb after applying the fix above

@ayulockin
Copy link

Thanks for checking it out @levan92. I will investigate more in this direction and make a PR to fix it.

@zhouzaida
Copy link
Collaborator

Hi @levan92 , as a workaround, you can set with_step as False. More discussions about the argument can be found at #913

with_step (bool): If True, the step will be logged from

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants