From b6d81110e7eea70844fa353eab2803ff191b8ff6 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 17 Jun 2021 14:40:12 +0100 Subject: [PATCH 1/5] parsing bug --- InnerEye/Azure/azure_config.py | 14 +++++++++++--- Tests/Azure/test_parsing.py | 19 +++++++++++++------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/InnerEye/Azure/azure_config.py b/InnerEye/Azure/azure_config.py index 210bf259f..724144536 100755 --- a/InnerEye/Azure/azure_config.py +++ b/InnerEye/Azure/azure_config.py @@ -332,9 +332,17 @@ def set_script_params_except_submit_flag(self) -> None: arg = args[i] if arg.startswith(submit_flag): if len(arg) == len(submit_flag): - # The argument list contains something like ["--azureml", "True]: Skip 2 entries - i = i + 1 - elif arg[len(submit_flag)] != "=": + # The commandline argument is "--azureml", with something possibly following: This can either be + # "--azureml True" or "--azureml --some_other_param" + if i < (len(args) - 1): + # If the next argument starts with a "-" then assume that it does not belong to the --azureml + # flag. If there is no "-", assume it belongs to the --azureml flag, and skip both + if not args[i + 1].startswith("-"): + i = i + 1 + elif arg[len(submit_flag)] == "=": + # The commandline argument is "--azureml=True" or "--azureml=False": Continue with next arg + pass + else: # The argument list contains a flag like "--azureml_foo": Keep that. retained_args.append(arg) else: diff --git a/Tests/Azure/test_parsing.py b/Tests/Azure/test_parsing.py index 2cbc30064..7f5562273 100644 --- a/Tests/Azure/test_parsing.py +++ b/Tests/Azure/test_parsing.py @@ -146,16 +146,23 @@ def assert_has_params(expected_args: str) -> None: # Arguments are in the keys of the dictionary only, and should have been added in the right order assert " ".join(s.script_params) == expected_args - with mock.patch("sys.argv", ["", "some", "--param", "1", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}=True", "more"]): + with mock.patch("sys.argv", ["", "some", "--param", "1", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}=True", "--more"]): s.set_script_params_except_submit_flag() - assert_has_params("some --param 1 more") - with mock.patch("sys.argv", ["", "some", "--param", "1", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}", "False", "more"]): + assert_has_params("some --param 1 --more") + with mock.patch("sys.argv", ["", "some", "--param", "1", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}", "False", "--more"]): s.set_script_params_except_submit_flag() - assert_has_params("some --param 1 more") + assert_has_params("some --param 1 --more") + # Using the new syntax for boolean flags + with mock.patch("sys.argv", ["", "some", "--param", "1", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}", "--more"]): + s.set_script_params_except_submit_flag() + assert_has_params("some --param 1 --more") + with mock.patch("sys.argv", ["", "some", "--param", "1", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}"]): + s.set_script_params_except_submit_flag() + assert_has_params("some --param 1") # Arguments where azureml is just the prefix should not be removed. - with mock.patch("sys.argv", ["", "some", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}foo", "False", "more"]): + with mock.patch("sys.argv", ["", "some", f"--{AZURECONFIG_SUBMIT_TO_AZUREML}foo", "False", "--more"]): s.set_script_params_except_submit_flag() - assert_has_params(f"some --{AZURECONFIG_SUBMIT_TO_AZUREML}foo False more") + assert_has_params(f"some --{AZURECONFIG_SUBMIT_TO_AZUREML}foo False --more") @pytest.mark.parametrize(["s", "expected"], From cf84d0aa0aad6f4f06c6bac573f239e3d1420eef Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 17 Jun 2021 15:21:14 +0100 Subject: [PATCH 2/5] doc --- docs/bring_your_own_model.md | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/docs/bring_your_own_model.md b/docs/bring_your_own_model.md index a3c9455f6..b8b56b26c 100644 --- a/docs/bring_your_own_model.md +++ b/docs/bring_your_own_model.md @@ -62,13 +62,17 @@ class MyDataModule(LightningDataModule): # All data should be read from the folder given in self.root_path self.root_path = root_path def train_dataloader(self, *args, **kwargs) -> DataLoader: - ... + # The data should be read off self.root_path + train_dataset = ... + return DataLoader(train_dataset, batch_size=5, num_workers=5) def val_dataloader(self, *args, **kwargs) -> DataLoader: # The data should be read off self.root_path - ... + val_dataset = ... + return DataLoader(val_dataset, batch_size=5, num_workers=5) def test_dataloader(self, *args, **kwargs) -> DataLoader: # The data should be read off self.root_path - ... + test_dataset = ... + return DataLoader(test_dataset, batch_size=5, num_workers=5) class MyContainer(LightningContainer): def __init__(self): @@ -97,6 +101,18 @@ In the above example, training is done for 42 epochs. After the model is trained via PyTorch Lightning's [built-in test functionality](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html?highlight=trainer.test#test). See below for an alternative way of running the evaluation on the test set. +### Data loaders +The example above creates `DataLoader` objects from a dataset. When creating those, you need to specify a batch size +(how many samples from your dataset will go into one minibatch), and a number of worker processes. Note that, by +default, data loading will happen in the main process, meaning that your GPU will sit idle while the CPU reads data +from disk. When specifying a number of workers, it will spawn processes that pre-fetch data from disk, and put them +into a queue, ready for the GPU to pick it up when it is done processing the current minibatch. + +For more details, please see the documentation for +[DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). There is also a +[tutorial describing the foundations of datasets and +data loaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) + ### Outputting files during training The Lightning model returned by `create_model` needs to write its output files to the current working directory. From bba2a7324ed887fe826c698c266c120ac0faba89 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 17 Jun 2021 15:21:52 +0100 Subject: [PATCH 3/5] bug fix --- InnerEye/ML/configs/other/HelloContainer.py | 18 +++++++++++++++--- InnerEye/ML/run_ml.py | 3 +++ InnerEye/ML/runner.py | 9 +++------ 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/InnerEye/ML/configs/other/HelloContainer.py b/InnerEye/ML/configs/other/HelloContainer.py index f4d406352..027d04a45 100644 --- a/InnerEye/ML/configs/other/HelloContainer.py +++ b/InnerEye/ML/configs/other/HelloContainer.py @@ -11,6 +11,7 @@ from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import StepLR, _LRScheduler from torch.utils.data import DataLoader, Dataset +from pytorch_lightning.metrics import MeanAbsoluteError from InnerEye.Common import fixed_paths_for_tests from InnerEye.ML.lightning_container import LightningContainer @@ -75,6 +76,7 @@ def __init__(self) -> None: super().__init__() self.model = torch.nn.Linear(in_features=1, out_features=1, bias=True) self.test_mse: List[torch.Tensor] = [] + self.test_mae = MeanAbsoluteError() def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore """ @@ -142,6 +144,7 @@ def on_test_epoch_start(self) -> None: test set (that is done in the test_step). """ self.test_mse = [] + self.test_mae.reset() def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: # type: ignore """ @@ -153,8 +156,15 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Ten :param batch_idx: The index (0, 1, ...) of the batch when the data loader is enumerated. :return: The loss on the test data. """ - loss = self.shared_step(batch) + input = batch["x"] + target = batch["y"] + prediction = self.forward(input) + # This illustrates two ways of computing metrics: Using standard torch + loss = torch.nn.functional.mse_loss(prediction, target) self.test_mse.append(loss) + # Metrics computed using PyTorch Lightning objects. Note that these will, by default, attempt + # to synchronize across GPUs. + self.test_mae.update(preds=prediction, target=target) return loss def on_test_epoch_end(self) -> None: @@ -166,6 +176,7 @@ def on_test_epoch_end(self) -> None: """ average_mse = torch.mean(torch.stack(self.test_mse)) Path("test_mse.txt").write_text(str(average_mse.item())) + Path("test_mae.txt").write_text(str(self.test_mae.compute())) class HelloContainer(LightningContainer): @@ -196,7 +207,8 @@ def get_data_module(self) -> LightningDataModule: # training, and cook them into a nice looking report. Here, the report is a simple text file. def create_report(self) -> None: # This just prints out the test MSE, but you could also generate a Jupyter notebook here, for example. - test_mse = float(Path("test_mse.txt").read_text()) - report = f"Performance on test set: MSE = {test_mse}" + test_mse = Path("test_mse.txt").read_text().strip() + test_mae = Path("test_mae.txt").read_text().strip() + report = f"Performance on test set: MSE = {test_mse}, MAE = {test_mae}" print(report) Path("report.txt").write_text(report) diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 743c9256b..2d22e5f99 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -447,6 +447,9 @@ def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> No # searching for Horovod if ENV_OMPI_COMM_WORLD_RANK in os.environ: del os.environ[ENV_OMPI_COMM_WORLD_RANK] + # From the training setup, torch still thinks that it should run in a distributed manner, + # and would block on some GPU operations. Hence, clean up distributed training. + torch.distributed.destroy_process_group() trainer, _ = create_lightning_trainer(self.container, num_nodes=1) # When training models that are not built-in InnerEye models, we have no guarantee that they write # files to the right folder. Best guess is to change the current working directory to where files should go. diff --git a/InnerEye/ML/runner.py b/InnerEye/ML/runner.py index 6fb6209db..34f4159c5 100755 --- a/InnerEye/ML/runner.py +++ b/InnerEye/ML/runner.py @@ -288,12 +288,9 @@ def run_in_situ(self) -> None: pytest_failures = f"Not all PyTest tests passed. See {results_file_path}" raise ValueError(pytest_failures) else: - # Set environment variables for multi-node training if needed. - # In particular, the multi-node environment variables should NOT be set in single node - # training, otherwise this might lead to errors with the c10 distributed backend - # (https://github.com/microsoft/InnerEye-DeepLearning/issues/395) - if self.azure_config.num_nodes > 1: - set_environment_variables_for_multi_node() + # Set environment variables for multi-node training if needed. This function will terminate early + # if it detects that it is not in a multi-node environment. + set_environment_variables_for_multi_node() ml_runner = self.create_ml_runner() ml_runner.setup() ml_runner.start_logging_to_file() From 04d9c5d0aed1d2cfba7ab473485379d459efe437 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 17 Jun 2021 15:35:17 +0100 Subject: [PATCH 4/5] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bbe10f2cd..95c36bcd0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ created. - ([#475](https://github.com/microsoft/InnerEye-DeepLearning/pull/475)) Bug in AML SDK meant that we could not train any large models anymore because data loaders ran out of memory. - ([#472](https://github.com/microsoft/InnerEye-DeepLearning/pull/472)) Correct model path for moving ensemble models. +- ([#494](https://github.com/microsoft/InnerEye-DeepLearning/pull/494)) Fix an issue where multi-node jobs for +LightningContainer models can get stuck at test set inference. ### Removed From 927dfa9b86aba59665d318174c3eec16cc87b357 Mon Sep 17 00:00:00 2001 From: Anton Schwaighofer Date: Thu, 17 Jun 2021 15:42:45 +0100 Subject: [PATCH 5/5] fix --- InnerEye/ML/run_ml.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 2d22e5f99..f423dd5f0 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -449,7 +449,8 @@ def run_inference_for_lightning_models(self, checkpoint_paths: List[Path]) -> No del os.environ[ENV_OMPI_COMM_WORLD_RANK] # From the training setup, torch still thinks that it should run in a distributed manner, # and would block on some GPU operations. Hence, clean up distributed training. - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() trainer, _ = create_lightning_trainer(self.container, num_nodes=1) # When training models that are not built-in InnerEye models, we have no guarantee that they write # files to the right folder. Best guess is to change the current working directory to where files should go.