-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MaxVIT Model - BatchNorm momentum is incorrect #8250
Comments
Thank you for the report @hassonofer. Just checking with @TeodorPoncu before moving forward: Teodor do you remember discussing this during reviews? |
Hey @NicolasHug! I personally was not aware of that difference in parameter specification between PyTorch and TensorFlow. I do not recall that coming up during reviews (I've double checked with the original PR). I assume that detail might've flown under the radar due to us obtaining comparable results to the tiny variant from the paper ( |
Thank you for your quick reply @TeodorPoncu ! Since the |
@NicolasHug , yes. During inference time the momentum parameter has no effect on batch-norm as it uses the running means and variances for inference and the evaluation performance will be the exact same. The momentum parameter affects how these statistics are estimated during training time. Momentum was introduced in this paper to counteract small batch-sizes relative to the dataset size. The reason behind this is that the default way of computing the running mean and variance for the Batch Norm layer is done via a non stationary momentum ( As such, the longer the training run goes, a batch will contribute less and less to the statistics update when not setting a momentum value. Depending on how the underlying implementation is in For instance, if the default torch implementation does the above (which is the same way the paper describes it in eq.8 and algorithm 1.), users might notice unfavourable results when performing fine-tuning based on how they configure the momentum parameter. For a small-dataset, if they do not change the momentum inside the batch-norm layers then the learned statistics for the ImageNet solution space will be immediately washed away (since we will be assigning a weight of I would recommend changing the default value to 0.01 if and only if the actual torch implementation does |
Thanks @TeodorPoncu - yeah as far as I can tell from the Note in https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html, the formula is as you described |
@NicolasHug, in that case yes, the appropriate |
🐛 Describe the bug
Current BatchNorm momentum is set to 0.99 here
as noted, this was taken from the original implementation here
But due to the differences between PyTorch and TensoFlow implementation of BatchNorm, the momentum should be
1-momentum
in TorchVision implementation.As done (correctly to my understanding) at the MnasNet implementation.
Versions
Collecting environment information...
PyTorch version: 2.1.2+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.31
Python version: 3.9.2 (default, Feb 28 2021, 17:03:44) [GCC 10.2.1 20210110] (64-bit runtime)
Python platform: Linux-6.1.0-0.deb11.13-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A5000
GPU 1: NVIDIA RTX A5000
Nvidia driver version: 545.23.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] flake8==7.0.0
[pip3] flake8-pep585==0.1.7
[pip3] mypy==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.3
[pip3] onnx==1.15.0
[pip3] torch==2.1.2+cu118
[pip3] torch-model-archiver==0.9.0
[pip3] torch-workflow-archiver==0.2.11
[pip3] torchaudio==2.1.2+cu118
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.3.0.post0
[pip3] torchserve==0.9.0
[pip3] torchvision==0.16.2+cu118
[pip3] triton==2.1.0
[conda] Could not collect
The text was updated successfully, but these errors were encountered: