Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MaxVIT Model - BatchNorm momentum is incorrect #8250

Closed
hassonofer opened this issue Feb 2, 2024 · 6 comments · Fixed by #8312
Closed

MaxVIT Model - BatchNorm momentum is incorrect #8250

hassonofer opened this issue Feb 2, 2024 · 6 comments · Fixed by #8312
Labels

Comments

@hassonofer
Copy link

🐛 Describe the bug

Current BatchNorm momentum is set to 0.99 here
as noted, this was taken from the original implementation here

But due to the differences between PyTorch and TensoFlow implementation of BatchNorm, the momentum should be 1-momentum in TorchVision implementation.
As done (correctly to my understanding) at the MnasNet implementation.

Versions

Collecting environment information...
PyTorch version: 2.1.2+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.31

Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-6.1.0-0.deb11.13-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A5000
GPU 1: NVIDIA RTX A5000

Nvidia driver version: 545.23.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] flake8-pep585==0.1.7
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] onnx==1.15.0
[pip3] torch==2.1.2+cu118
[pip3] torch-model-archiver==0.9.0
[pip3] torch-workflow-archiver==0.2.11
[pip3] torchaudio==2.1.2+cu118
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.0.post0
[pip3] torchserve==0.9.0
[pip3] torchvision==0.16.2+cu118
[pip3] triton==2.1.0
[conda] Could not collect

@NicolasHug
Copy link
Member

Thank you for the report @hassonofer.

Just checking with @TeodorPoncu before moving forward: Teodor do you remember discussing this during reviews?

@TeodorPoncu
Copy link
Contributor

TeodorPoncu commented Feb 5, 2024

Hey @NicolasHug! I personally was not aware of that difference in parameter specification between PyTorch and TensorFlow.

I do not recall that coming up during reviews (I've double checked with the original PR).

I assume that detail might've flown under the radar due to us obtaining comparable results to the tiny variant from the paper (83.7 for the torchvision weights and 83.62 for the paper).

@NicolasHug
Copy link
Member

Thank you for your quick reply @TeodorPoncu !

Since the momentum parameter is only affecting the training of the model, and not inference (right?), we can probably fix the default from 0.99 to 0.01, and that would still keep the pre-trained weights perf the same (i.e. it would still be 83.7). WDYT?

@TeodorPoncu
Copy link
Contributor

TeodorPoncu commented Feb 5, 2024

@NicolasHug , yes. During inference time the momentum parameter has no effect on batch-norm as it uses the running means and variances for inference and the evaluation performance will be the exact same. The momentum parameter affects how these statistics are estimated during training time.

Momentum was introduced in this paper to counteract small batch-sizes relative to the dataset size. The reason behind this is that the default way of computing the running mean and variance for the Batch Norm layer is done via a non stationary momentum (1 / gamma) (torch reference implementation here) where gamma is incremented by 1 at every forward pass during training.

As such, the longer the training run goes, a batch will contribute less and less to the statistics update when not setting a momentum value. Depending on how the underlying implementation is in torch (I could not find if does something like running_mean = momentum * running_mean + (1 - momentum) * batch_mean, as the ref. implementation fallsback to a F binding), changing the momentum to something to 0.01 might affect users that are performing subsequent fine-tuning runs.

For instance, if the default torch implementation does the above (which is the same way the paper describes it in eq.8 and algorithm 1.), users might notice unfavourable results when performing fine-tuning based on how they configure the momentum parameter.

For a small-dataset, if they do not change the momentum inside the batch-norm layers then the learned statistics for the ImageNet solution space will be immediately washed away (since we will be assigning a weight of 0.01 to it), thus missing on the benefits of transfer learning.

I would recommend changing the default value to 0.01 if and only if the actual torch implementation does running_mean = (1 - momentum * running_mean) + momentum * batch_mean given the side-effects in can lead to in fine-tuning scenarios.

@NicolasHug
Copy link
Member

Thanks @TeodorPoncu - yeah as far as I can tell from the Note in https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html, the formula is as you described

@TeodorPoncu
Copy link
Contributor

@NicolasHug, in that case yes, the appropriate default momentum value should be set to 0.01 and it shouldn't have any side effects (inference or fine-tuning wise).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants