Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Bug fix: SSL on multiple nodes used wrong LR scheduler #628

Merged
merged 3 commits into from
Jan 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ in inference-only runs when using lightning containers.
- ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model
weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs.
- ([#604](https://github.com/microsoft/InnerEye-DeepLearning/pull/604)) Fix issue where runs on a VM would download the dataset even when a local dataset is provided.
- ([#628](https://github.com/microsoft/InnerEye-DeepLearning/pull/628)) SSL SimCLR using the wrong LR schedule when running on multiple nodes
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training

### Removed
Expand Down
3 changes: 2 additions & 1 deletion InnerEye/ML/SSL/lightning_containers/ssl_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ def create_model(self) -> LightningModule:
model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value,
dataset_name=self.ssl_training_dataset_name.value,
use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet,
gpus=self.total_num_gpus,
num_samples=self.data_module.num_samples,
batch_size=self.data_module.batch_size,
gpus=self.num_gpus_per_node(),
num_nodes=self.num_nodes,
learning_rate=self.l_rate,
max_epochs=self.num_epochs)
elif self.ssl_training_type == SSLTrainingType.BYOL:
Expand Down
28 changes: 26 additions & 2 deletions Tests/SSL/test_ssl_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
import math
from pathlib import Path
from typing import Dict
from unittest import mock

import math
import numpy as np
import pandas as pd
import pytest
Expand All @@ -15,7 +16,6 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn import Module
from torch.optim.lr_scheduler import _LRScheduler
from typing import Dict

from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import is_windows
Expand All @@ -29,6 +29,7 @@
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier
from InnerEye.ML.runner import Runner
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel
Expand Down Expand Up @@ -402,3 +403,26 @@ def test_online_evaluator_distributed() -> None:
# We still need to mock DDP here because the constructor relies on having a process group available
mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device])
assert callback.evaluator == mock_ddp_result


def test_simclr_batch_size() -> None:
"""
Test if the number of nodes is correctly passed through to the SIMCLR model. After an update of the semantics of
the "gpus" argument in LightningBolts, we had a regression, leading to incorrect use of the cosine
LR scheduler.
"""
with mock.patch("InnerEye.ML.deep_learning_config.TrainerParams.num_gpus_per_node", return_value=1):
with mock.patch("InnerEye.ML.SSL.lightning_containers.ssl_container.get_encoder_output_dim", return_value=1):
container = CIFAR10SimCLR()
num_samples = 100
batch_size = 10
container.data_module = mock.MagicMock(num_samples=num_samples, batch_size=batch_size)
assert container.num_nodes == 1
model1 = container.create_model()
old_iters_per_epoch = model1.train_iters_per_epoch
assert old_iters_per_epoch == num_samples / batch_size
# Increasing the number of nodes should increase effective batch size, and hence reduce number of
# iterations per epoch
container.num_nodes = 2
model2 = container.create_model()
assert model2.train_iters_per_epoch == old_iters_per_epoch // container.num_nodes # type:ignore