Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable Bring-your-own-Lightning-model #417

Merged
merged 128 commits into from
Apr 19, 2021
Merged
Changes from 1 commit
Commits
Show all changes
128 commits
Select commit Hold shift + click to select a range
d13e689
todo
ant0nsc Feb 8, 2021
a98bca0
dead code
ant0nsc Feb 8, 2021
45ac2e6
more notes
ant0nsc Feb 9, 2021
5fac2c0
comments
ant0nsc Feb 9, 2021
6357feb
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Feb 10, 2021
96f6019
updates
ant0nsc Feb 10, 2021
9a24671
changes to inference structure
ant0nsc Feb 23, 2021
d1b4e81
report
ant0nsc Feb 23, 2021
cd17286
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Mar 12, 2021
87c92bd
update notes
ant0nsc Mar 12, 2021
92004a1
remove overrides from loader
ant0nsc Mar 12, 2021
8d720a5
Remove buildinformation.json
ant0nsc Mar 12, 2021
4b06bc0
formatting
ant0nsc Mar 12, 2021
50e17f9
first loader test
ant0nsc Mar 12, 2021
6dbb0ff
first loader test
ant0nsc Mar 12, 2021
34b246b
notes
ant0nsc Mar 12, 2021
a8ed8a9
Loader tested and working
ant0nsc Mar 15, 2021
283afa5
docu
ant0nsc Mar 15, 2021
4983d7c
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Mar 15, 2021
a939112
more tests
ant0nsc Mar 15, 2021
10c4b0a
refactor loader
ant0nsc Mar 15, 2021
06bacbf
better tests
ant0nsc Mar 15, 2021
8928462
work in progress
ant0nsc Mar 15, 2021
2ba11fd
fixed and tested mount_or_download
ant0nsc Mar 15, 2021
7158013
Training works
ant0nsc Mar 15, 2021
3f459b7
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Mar 15, 2021
c2f666c
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Mar 16, 2021
5b1081f
updates to latest main
ant0nsc Mar 16, 2021
4a70b55
fix import problems
ant0nsc Mar 16, 2021
76be9e6
fix import problems
ant0nsc Mar 16, 2021
56fc73b
fix import problems
ant0nsc Mar 16, 2021
def1145
removing generics from loader
ant0nsc Mar 16, 2021
25268e2
DRY
ant0nsc Mar 16, 2021
81e565a
Inference is running
ant0nsc Mar 18, 2021
6ce4fb6
refactoring
ant0nsc Mar 19, 2021
3f97299
Inferenced finished and tested
ant0nsc Mar 19, 2021
e686d6b
fixing broken tests
ant0nsc Mar 19, 2021
2e07e48
param refactoring
ant0nsc Mar 22, 2021
81f8383
Many test fixes
ant0nsc Mar 22, 2021
0bbb342
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Mar 23, 2021
180d3c9
More test fixes
ant0nsc Mar 23, 2021
bb8543d
More test fixes
ant0nsc Mar 23, 2021
0e44e3f
test fixes
ant0nsc Mar 23, 2021
292655b
getting varnet module to start training
ant0nsc Mar 23, 2021
47b4670
docu
ant0nsc Mar 24, 2021
a1e88e0
Refactoring
ant0nsc Mar 25, 2021
a698b4d
test fixes
ant0nsc Mar 25, 2021
ed6351f
test fixes
ant0nsc Mar 26, 2021
f4ec27f
flake8
ant0nsc Mar 26, 2021
9d13e86
exclude
ant0nsc Mar 26, 2021
3e58360
mypy
ant0nsc Mar 26, 2021
016c737
avoid storinglogger
ant0nsc Mar 26, 2021
bddd724
changelog
ant0nsc Mar 29, 2021
9479c97
torch.count_nonzero
ant0nsc Mar 29, 2021
0928133
No longer need ModelTrainingResults
ant0nsc Mar 29, 2021
aa02a76
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Mar 29, 2021
20289f9
flake
ant0nsc Mar 29, 2021
acab458
test fixes
ant0nsc Mar 29, 2021
4407f6e
fix spawn test failure
ant0nsc Mar 29, 2021
bace860
fix for unset local dataset path in Azure
ant0nsc Mar 30, 2021
11382cd
checkout submodules
ant0nsc Mar 30, 2021
a88850c
cleanup
ant0nsc Mar 30, 2021
e8a87a1
cleanup
ant0nsc Mar 30, 2021
fa2dd8d
import error
ant0nsc Mar 30, 2021
7a4f009
remove fastmri
ant0nsc Mar 31, 2021
4f58b1c
fastMRI models runs up to inference
ant0nsc Mar 31, 2021
56d6c87
test fixes
ant0nsc Mar 31, 2021
2349fe8
tests
ant0nsc Mar 31, 2021
7eea299
test fixes
ant0nsc Mar 31, 2021
d521ec2
cleanup
ant0nsc Apr 1, 2021
27da1ca
print argv path
ant0nsc Apr 1, 2021
cfab6b0
fix argv problem
ant0nsc Apr 1, 2021
8b01564
update
ant0nsc Apr 1, 2021
1e26232
update sub
ant0nsc Apr 1, 2021
5a3a385
final model
ant0nsc Apr 1, 2021
8d2bfdb
flake
ant0nsc Apr 1, 2021
0765238
update to fastmri master
ant0nsc Apr 7, 2021
15fe1ab
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 7, 2021
d564119
test fixes
ant0nsc Apr 7, 2021
87267f9
import fixes
ant0nsc Apr 7, 2021
2063383
test fixes
ant0nsc Apr 9, 2021
58a6dc6
fix import problems
ant0nsc Apr 9, 2021
193e9c9
fix checkpoint test
ant0nsc Apr 9, 2021
3231e45
test fix
ant0nsc Apr 12, 2021
88e29fa
Fix test_recover_training_mean_teacher_model
Shruthi42 Apr 12, 2021
448550d
Flake8
Shruthi42 Apr 12, 2021
b6a3575
test fix
ant0nsc Apr 12, 2021
101d42d
test fix
ant0nsc Apr 12, 2021
3cd1b17
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
ant0nsc Apr 12, 2021
ede2829
Remove main() from patch_sampling.py
Shruthi42 Apr 12, 2021
987e83d
mypy
Shruthi42 Apr 12, 2021
9fad817
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
Shruthi42 Apr 12, 2021
76e5f02
docu
ant0nsc Apr 12, 2021
5dae881
ensure correct seeding in tests
ant0nsc Apr 12, 2021
b744abb
fix test_invalid_trainer_args
ant0nsc Apr 13, 2021
60f5e32
fix downloading problems in AzureML
ant0nsc Apr 13, 2021
2ae170a
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 13, 2021
dc01b61
adding reports for container models
ant0nsc Apr 13, 2021
59e2ca1
fixing test loss values
ant0nsc Apr 13, 2021
8758973
fix bug with running the unit tests in AzureML
ant0nsc Apr 13, 2021
c514ca0
docu, flake
ant0nsc Apr 13, 2021
ab9b4c2
mypy and flake
ant0nsc Apr 13, 2021
7988af9
downgrading mypy
ant0nsc Apr 13, 2021
4e5b7f6
test fixes
ant0nsc Apr 14, 2021
f70884f
test fixes, reduce logging noise
ant0nsc Apr 14, 2021
069ca6c
mypy and flake
ant0nsc Apr 14, 2021
731483f
simplify mypy runner
ant0nsc Apr 14, 2021
1f5afc0
remove comments
ant0nsc Apr 14, 2021
a362e70
Update InnerEye/ML/model_training.py
ant0nsc Apr 14, 2021
7f17e83
Update docs/bring_your_own_model.md
ant0nsc Apr 14, 2021
a8024d1
PR comments
ant0nsc Apr 14, 2021
ed94083
Merge branch 'antonsc/byol' of https://github.com/microsoft/InnerEye-…
ant0nsc Apr 14, 2021
b6a7d01
fixes
ant0nsc Apr 15, 2021
de15468
updated docu and design as per PR feedback.
ant0nsc Apr 15, 2021
5656489
docu and mypy
ant0nsc Apr 15, 2021
efd9231
update doc, add report test
ant0nsc Apr 16, 2021
2f4c36a
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 16, 2021
268b265
HelloContainer model
ant0nsc Apr 16, 2021
8649681
HelloContainer data
ant0nsc Apr 16, 2021
422c790
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 16, 2021
7e1e516
HelloWorld running
ant0nsc Apr 16, 2021
be0948b
docu
ant0nsc Apr 16, 2021
79795e0
fixes
ant0nsc Apr 16, 2021
3cde0ee
fixes
ant0nsc Apr 16, 2021
6f9e37c
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 16, 2021
9df5025
updated main, PR comments
ant0nsc Apr 16, 2021
00bb084
PR comments
ant0nsc Apr 16, 2021
181302f
Merge remote-tracking branch 'origin/main' into antonsc/byol
ant0nsc Apr 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
updates to latest main
ant0nsc committed Mar 16, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 5b1081f2cb764b7eb00349a4c8d2b1ece3dcd2de
4 changes: 2 additions & 2 deletions InnerEye/ML/lightning_base.py
Original file line number Diff line number Diff line change
@@ -34,7 +34,7 @@ class TrainingAndValidationDataLightning(LightningDataModule):
A class that wraps training and validation data from an InnerEye model configuration to a Lightning data module.
"""

def _init__(self, config: ModelConfigBase) -> None:
def __init__(self, config: ModelConfigBase) -> None:
super().__init__()
self.config = config
self.data_loaders: Dict[ModelExecutionMode, DataLoader] = {}
@@ -68,7 +68,7 @@ def setup(self) -> None:
# loaded (typically only during tests)
if self.config.dataset_data_frame is None:
assert self.config.local_dataset is not None
validate_dataset_paths(self.config.local_dataset)
validate_dataset_paths(self.config.local_dataset, self.config.dataset_csv)
self.config.read_dataset_if_needed()

def get_training_data_module(self, crossval_index: int, crossval_count: int) -> LightningDataModule:
3 changes: 3 additions & 0 deletions InnerEye/ML/lightning_container.py
Original file line number Diff line number Diff line change
@@ -158,6 +158,9 @@ def val_diagnostics(self) -> Any:
"""
return None

def trainer_hook(self, trainer) -> None:
pass


class LightningContainer:

29 changes: 13 additions & 16 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
import os
import sys
from pathlib import Path
from typing import Optional, Tuple, TypeVar
from typing import Any, Dict, Optional, Tuple, TypeVar

import torch
from pytorch_lightning import Trainer, seed_everything
@@ -65,14 +65,17 @@ def upload_output_file_as_temp(file_path: Path, outputs_folder: Path) -> None:

def create_lightning_trainer(config: DeepLearningConfig,
resume_from_checkpoint: Optional[Path] = None,
num_nodes: int = 1) -> Tuple[Trainer, StoringLogger]:
num_nodes: int = 1,
**kwargs: Dict[str, Any]) -> \
Tuple[Trainer, StoringLogger]:
"""
Creates a Pytorch Lightning Trainer object for the given model configuration. It creates checkpoint handlers
and loggers. That includes a diagnostic logger for use in unit tests, that is also returned as the second
return value.
:param config: The model configuration.
:param resume_from_checkpoint: If provided, training resumes from this checkpoint point.
:param num_nodes: The number of nodes to use in distributed training.
:param kwargs: Any additional keyowrd arguments will be passed to the constructor of Trainer.
:return: A tuple [Trainer object, diagnostic logger]
"""
# For now, stick with the legacy behaviour of always saving only the last epoch checkpoint. For large segmentation
@@ -106,14 +109,6 @@ def create_lightning_trainer(config: DeepLearningConfig,
storing_logger = StoringLogger()
tensorboard_logger = TensorBoardLogger(save_dir=str(config.logs_folder), name="Lightning", version="")
loggers = [storing_logger, tensorboard_logger, AzureMLLogger()]
# This leads to problems with run termination.
# if not is_offline_run_context(RUN_CONTEXT):
# mlflow_logger = MLFlowLogger(experiment_name=RUN_CONTEXT.experiment.name,
# tracking_uri=RUN_CONTEXT.experiment.workspace.get_mlflow_tracking_uri())
# # The MLFlow logger needs to get its ID from the AzureML run context, otherwise there will be two sets of
# # results for each run, one from native AzureML and one from the MLFlow logger.
# mlflow_logger._run_id = RUN_CONTEXT.id
# loggers.append(mlflow_logger)
# Use 32bit precision when running on CPU. Otherwise, make it depend on use_mixed_precision flag.
precision = 32 if num_gpus == 0 else 16 if config.use_mixed_precision else 32
# The next two flags control the settings in torch.backends.cudnn.deterministic and torch.backends.cudnn.benchmark
@@ -142,8 +137,8 @@ def create_lightning_trainer(config: DeepLearningConfig,
precision=precision,
sync_batchnorm=True,
terminate_on_nan=config.detect_anomaly,
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None
)
resume_from_checkpoint=str(resume_from_checkpoint) if resume_from_checkpoint else None,
**kwargs)
return trainer, storing_logger


@@ -183,11 +178,11 @@ def model_train(config: DeepLearningConfig,
assert isinstance(config, ModelConfigBase), "When using a built-in InnerEye model, the configuration should " \
"be an instance of ModelConfigBase"
lightning_container = InnerEyeContainer(config)
# When trying to store the config object in the constructor, it does not appear to get stored at all, later
# reference of the object simply fail. Hence, have to set explicitly here.
lightning_container.setup()
if is_rank_zero():
# Save the dataset files for later use in cross validation analysis
# TODO antonsc: Should we move that into TrainAndValidationDataLightning? The .prepare method
# of a data module is called only on rank zero
config.write_dataset_files()
lightning_model = create_lightning_model(config)
else:
@@ -200,7 +195,10 @@ def model_train(config: DeepLearningConfig,
# training in the unit tests.d
old_environ = dict(os.environ)
seed_everything(config.get_effective_random_seed())
trainer, storing_logger = create_lightning_trainer(config, checkpoint_path, num_nodes=num_nodes)
trainer, storing_logger = create_lightning_trainer(config,
checkpoint_path,
num_nodes=num_nodes,
**lightning_container.get_trainer_arguments())

logging.info(f"GLOBAL_RANK: {os.getenv('GLOBAL_RANK')}, LOCAL_RANK {os.getenv('LOCAL_RANK')}. "
f"trainer.global_rank: {trainer.global_rank}")
@@ -240,7 +238,6 @@ def model_train(config: DeepLearningConfig,
trainer.logger.close() # type: ignore
lightning_model.close_all_loggers()
world_size = getattr(trainer, "world_size", 0)
# TODO antonsc
is_azureml_run = not config.is_offline_run
# Per-subject model outputs for regression models are written per rank, and need to be aggregated here.
# Each thread per rank will come here, and upload its files to the run outputs. Rank 0 will later download them.
44 changes: 21 additions & 23 deletions Tests/ML/configs/lightning_test_containers.py
Original file line number Diff line number Diff line change
@@ -14,39 +14,35 @@
from InnerEye.ML.lightning_container import LightningContainer, LightningWithInference


class DummyContainerWithAzureDataset(LightningContainer):
def __init__(self):
class DummyContainerWithDatasets(LightningContainer):
def __init__(self, has_local_dataset: bool = False, has_azure_dataset: bool = False):
super().__init__()
self.has_local_dataset = has_local_dataset
self.has_azure_dataset = has_azure_dataset

def create_lightning_module(self) -> LightningWithInference:
local_dataset = full_ml_test_data_path("lightning_module_data")
return LightningWithInference(azure_dataset_id="azure_dataset", local_dataset=local_dataset)
local_dataset = full_ml_test_data_path("lightning_module_data") if self.has_local_dataset else None
azure_dataset = "azure_dataset" if self.has_local_dataset else ""
return LightningWithInference(azure_dataset_id=azure_dataset, local_dataset=local_dataset)


class DummyContainerWithoutDataset(LightningContainer):
class DummyContainerWithAzureDataset(DummyContainerWithDatasets):
def __init__(self):
super().__init__()
super().__init__(has_azure_dataset=True)

def create_lightning_module(self) -> LightningWithInference:
return LightningWithInference()

class DummyContainerWithoutDataset(DummyContainerWithDatasets):
pass

class DummyContainerWithLocalDataset(LightningContainer):
def __init__(self):
super().__init__()

def create_lightning_module(self) -> LightningWithInference:
local_dataset = full_ml_test_data_path("lightning_module_data")
return LightningWithInference(local_dataset=local_dataset)
class DummyContainerWithLocalDataset(DummyContainerWithDatasets):
def __init__(self):
super().__init__(has_local_dataset=True)


class DummyContainerWithAzureAndLocalDataset(LightningContainer):
class DummyContainerWithAzureAndLocalDataset(DummyContainerWithDatasets):
def __init__(self):
super().__init__()

def create_lightning_module(self) -> LightningWithInference:
local_dataset = full_ml_test_data_path("lightning_module_data")
return LightningWithInference(azure_dataset_id="azure_dataset", local_dataset=local_dataset)
super().__init__(has_local_dataset=True, has_azure_dataset=True)


class DummyRegression(LightningWithInference):
@@ -102,11 +98,13 @@ def test_dataloader(self, *args, **kwargs) -> DataLoader:


class DummyContainerWithModel(LightningContainer):
def __init__(self):
super().__init__()

def create_lightning_module(self) -> LightningWithInference:
return DummyRegression()

def get_training_data_module(self, crossval_index: int, crossval_count: int) -> LightningDataModule:
return FixedRegressionData()


class DummyContainerWithInvalidTrainerArguments(DummyContainerWithModel):
def get_trainer_arguments(self):
return {"no_such_argument": 1}
19 changes: 16 additions & 3 deletions Tests/ML/models/test_instantiate_models.py
Original file line number Diff line number Diff line change
@@ -11,16 +11,18 @@

from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import logging_to_stdout, namespace_to_path
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.config import SegmentationModelBase
from InnerEye.ML.deep_learning_config import DeepLearningConfig
from InnerEye.ML.lightning_container import LightningContainer
from InnerEye.ML.model_training import generate_and_print_model_summary
from InnerEye.ML.model_training import generate_and_print_model_summary, model_train
from InnerEye.ML.runner import Runner
from InnerEye.ML.utils.config_loader import ModelConfigLoader
from InnerEye.ML.utils.model_util import create_model_with_temperature_scaling
from Tests.ML.configs.DummyModel import DummyModel
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel
from Tests.ML.util import get_model_loader
from Tests.ML.configs.lightning_test_containers import DummyContainerWithInvalidTrainerArguments, \
DummyContainerWithModel
from Tests.ML.util import get_default_checkpoint_handler, get_model_loader


def find_models() -> List[str]:
@@ -159,3 +161,14 @@ def test_run_container_in_situ() -> None:
loaded_config, actual_run = runner.run()
assert actual_run is None
assert isinstance(runner.lightning_container, DummyContainerWithModel)


def test_run_model_with_invalid_trainer_arguments(test_output_dirs: OutputFolderForTests) -> None:
container = DummyContainerWithInvalidTrainerArguments()
config = container.create_lightning_module()
container.lightning_module = config
checkpoint_handler = get_default_checkpoint_handler(model_config=config,
project_root=test_output_dirs.root_dir)
with pytest.raises(Exception) as ex:
model_train(container.lightning_module, checkpoint_handler=checkpoint_handler, lightning_container=container)
assert "no_such_argument" in str(ex)