diff --git a/InnerEye/Azure/azure_config.py b/InnerEye/Azure/azure_config.py index 468defb0f..ae2616937 100755 --- a/InnerEye/Azure/azure_config.py +++ b/InnerEye/Azure/azure_config.py @@ -130,6 +130,8 @@ class AzureConfig(GenericConfig): _workspace: Workspace = param.ClassSelector(class_=Workspace, doc="The cached workspace object that has been created in the first" "call to get_workspace") + workers_per_node: int = param.Integer(1, doc="The number of workers to assign per machine") + node_count: int = param.Integer(1, doc="The number of machines to run distributed training across") def __init__(self, **params: Any) -> None: super().__init__(**params) diff --git a/InnerEye/Azure/azure_runner.py b/InnerEye/Azure/azure_runner.py index faf285a3d..85fea3bdc 100644 --- a/InnerEye/Azure/azure_runner.py +++ b/InnerEye/Azure/azure_runner.py @@ -12,11 +12,13 @@ from pathlib import Path from typing import Any, Dict, List, Optional -from azureml.core import Dataset, Experiment, Run +from azureml.core import Dataset, Experiment, Run, ComputeTarget from azureml.core.conda_dependencies import CondaDependencies from azureml.core.datastore import Datastore from azureml.core.workspace import WORKSPACE_DEFAULT_BLOB_STORE_NAME from azureml.data.dataset_consumption_config import DatasetConsumptionConfig +from azureml.train._distributed_training import Mpi + from azureml.train.dnn import PyTorch from InnerEye.Azure import azure_util @@ -288,15 +290,27 @@ def create_estimator_from_configs(azure_config: AzureConfig, framework_version = pytorch_version_from_conda_dependencies(conda_dependencies) assert framework_version is not None, "The AzureML SDK is behind PyTorch, it does not yet know the version we use." logging.info(f"PyTorch framework version: {framework_version}") + + if azure_config.node_count > 1: + source_config.script_params = source_config.script_params or {} + source_config.script_params.update({'--distributed_training_init_method': 'tcp://' + '$AZ_BATCH_MASTER_NODE'}) + distributed_training_backend = Mpi() + distributed_training_backend.process_count_per_node = azure_config.workers_per_node + else: + distributed_training_backend = None + max_run_duration = None if azure_config.max_run_duration: max_run_duration = run_duration_string_to_seconds(azure_config.max_run_duration) + workspace = azure_config.get_workspace() + compute_target = ComputeTarget(workspace, azure_config.cluster) + estimator = PyTorch( source_directory=str(source_config.root_folder), entry_script=entry_script_relative_path, script_params=source_config.script_params, - compute_target=azure_config.cluster, + compute_target=compute_target, # Use blob storage for storing the source, rather than the FileShares section of the storage account. source_directory_data_store=workspace.datastores.get(WORKSPACE_DEFAULT_BLOB_STORE_NAME), inputs=estimator_inputs, @@ -305,7 +319,10 @@ def create_estimator_from_configs(azure_config: AzureConfig, use_docker=True, use_gpu=True, framework_version=framework_version, - max_run_duration_seconds=max_run_duration + node_count=azure_config.node_count, + distributed_training=distributed_training_backend, + pip_packages=['azureml-dataprep[pandas,fuse]'], + max_run_duration_seconds=max_run_duration, ) estimator.run_config.environment.python.conda_dependencies = conda_dependencies # We'd like to log the estimator config, but conversion to string fails when the Estimator has some inputs. diff --git a/InnerEye/Common/common_util.py b/InnerEye/Common/common_util.py index c2dd9f4b3..94c72cf5c 100644 --- a/InnerEye/Common/common_util.py +++ b/InnerEye/Common/common_util.py @@ -61,7 +61,8 @@ def flush(self, log_info: bool = False) -> None: """ import pandas as pd if not self.csv_path.parent.is_dir(): - self.csv_path.parent.mkdir(parents=True) + # exist_ok necessary when running multiple processes + self.csv_path.parent.mkdir(parents=True, exist_ok=True) # Specifying columns such that the order in which columns appear matches the order in which # columns were added in the code. columns = self.records[0].keys() if len(self.records) > 0 else None diff --git a/InnerEye/Common/generic_parsing.py b/InnerEye/Common/generic_parsing.py index 0b72c8af1..cea2b4eb2 100644 --- a/InnerEye/Common/generic_parsing.py +++ b/InnerEye/Common/generic_parsing.py @@ -75,15 +75,19 @@ def get_cuda_devices(self) -> Optional[List[Any]]: else: return None - def get_gpu_tensor_if_possible(self, data: T) -> Any: + def get_gpu_tensor_if_possible(self, data: T, device: Optional[Any] = None) -> Any: """" Get a cuda tensor if this transform was cuda enabled and a GPU is available, otherwise return the input. + :param data: The data to send to device + :param device: Torch device to allocate to """ import torch if isinstance(data, torch.Tensor): if self.use_gpu and not is_gpu_tensor(data): - return data.cuda() + # use default CUDA device if not specified + device = device or torch.device('cuda') + return data.to(device) else: return data else: diff --git a/InnerEye/ML/config.py b/InnerEye/ML/config.py index 9432caf82..2d19a4657 100644 --- a/InnerEye/ML/config.py +++ b/InnerEye/ML/config.py @@ -670,13 +670,13 @@ def get_output_size(self, execution_mode: ModelExecutionMode = ModelExecutionMod return self._test_output_size raise ValueError("Unknown execution mode '{}' for function 'get_output_size'".format(execution_mode)) - def adjust_after_mixed_precision_and_parallel(self, model: Any) -> None: + def adjust_after_mixed_precision_and_parallel(self, model: Any, device: Any) -> None: """ Updates the model config parameters (e.g. output patch size). If testing patch stride size is unset then its value is set by the output patch size """ - self._train_output_size = model.get_output_shape(input_shape=self.crop_size) - self._test_output_size = model.get_output_shape(input_shape=self.test_crop_size) + self._train_output_size = model.get_output_shape(input_shape=self.crop_size, device=device) + self._test_output_size = model.get_output_shape(input_shape=self.test_crop_size, device=device) if self.inference_stride_size is None: self.inference_stride_size = self._test_output_size else: diff --git a/InnerEye/ML/configs/classification/DummyClassification.py b/InnerEye/ML/configs/classification/DummyClassification.py index a1ef9bd40..02568654b 100644 --- a/InnerEye/ML/configs/classification/DummyClassification.py +++ b/InnerEye/ML/configs/classification/DummyClassification.py @@ -29,7 +29,7 @@ def __init__(self) -> None: num_dataload_workers=0, test_start_epoch=num_epochs, use_mixed_precision=True, - subject_column="subjectID" + subject_column="subjectID", ) self.conv_in_3d = True self.expected_image_size_zyx = (4, 5, 7) diff --git a/InnerEye/ML/dataset/full_image_dataset.py b/InnerEye/ML/dataset/full_image_dataset.py index 949b73485..b707dcd11 100644 --- a/InnerEye/ML/dataset/full_image_dataset.py +++ b/InnerEye/ML/dataset/full_image_dataset.py @@ -11,7 +11,8 @@ import pandas as pd import torch.utils.data from torch._six import container_abcs -from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler, Sampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler, Sampler, SequentialSampler, \ + DistributedSampler from torch.utils.data.dataloader import default_collate # type: ignore from InnerEye.ML.config import SegmentationModelBase @@ -175,10 +176,21 @@ def as_data_loader(self, num_dataload_workers: Optional[int] = None, use_imbalanced_sampler: bool = False, drop_last_batch: bool = False, - max_repeats: Optional[int] = None) -> DataLoader: + max_repeats: Optional[int] = None, + distribute: bool = False + ) -> DataLoader: num_dataload_workers = num_dataload_workers or self.args.num_dataload_workers batch_size = batch_size or self.args.train_batch_size - if self.args.avoid_process_spawn_in_data_loaders: + if distribute: + # distributed data loader + sampler: Optional[Sampler] = DistributedSampler(self) + return DataLoader(self, + batch_size=batch_size, + shuffle=False, + num_workers=num_dataload_workers, + collate_fn=collate_with_metadata, + sampler=sampler) + elif self.args.avoid_process_spawn_in_data_loaders: if max_repeats is None: max_repeats = self.args.get_total_number_of_training_epochs() return RepeatDataLoader( @@ -193,9 +205,10 @@ def as_data_loader(self, use_imbalanced_sampler=use_imbalanced_sampler, drop_last=drop_last_batch ) + else: if use_imbalanced_sampler: - sampler: Optional[Sampler] = ImbalancedSampler(self) + sampler = ImbalancedSampler(self) shuffle = False else: sampler = None diff --git a/InnerEye/ML/deep_learning_config.py b/InnerEye/ML/deep_learning_config.py index abe776294..fb17af82f 100644 --- a/InnerEye/ML/deep_learning_config.py +++ b/InnerEye/ML/deep_learning_config.py @@ -202,7 +202,6 @@ class DeepLearningConfig(GenericConfig, CudaAwareConfig): doc="The high-level model category described by this config.") _model_name: str = param.String(None, doc="The human readable name of the model (for example, Liver). This is " "usually set from the class name.") - random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.") azure_dataset_id: str = param.String(doc="If provided, the ID of the dataset to use. This dataset must exist as a " "folder of the same name in the 'datasets' " @@ -398,6 +397,14 @@ class DeepLearningConfig(GenericConfig, CudaAwareConfig): "initialization, " "when training is running outside Azure.") + distributed_training_backend: str = param.String(default='nccl', + doc="Communication package to use for distributed training") + distributed_training_init_method: str = param.String(default='env://', + doc="URL specifying where to find peer processes") + + num_workers_per_node: int = param.Integer(1, doc="The number of workers to assign per machine") + num_nodes: int = param.Integer(1, doc="The number of machines to run distributed training across") + def __init__(self, **params: Any) -> None: self._model_name = type(self).__name__ # This should be annotated as torch.utils.data.Dataset, but we don't want to import torch here. @@ -694,11 +701,27 @@ def use_gpu(self, value: bool) -> None: @property def use_data_parallel(self) -> bool: """ - Data parallel is used if GPUs are usable and the number of CUDA devices are greater than 1. + Data parallel is used if GPUs are usable and the number of CUDA devices are greater than 1 and + DistributedDataParallel is False (i.e. is_windows is True) + :return: + """ + _devices = self.get_cuda_devices() + return _devices is not None and len(_devices) > 1 and not self.use_distributed_data_parallel + + @property + def use_distributed_data_parallel(self) -> bool: + """ + Distributed Data Parallel may used if GPUs are usable and the number of CUDA devices is greater than 1 + and the package torch.distributed is available. Additionally, the product of num_nodes and workers_per_node + as set in the config must be greater than 1. :return: """ + import torch.distributed as dist _devices = self.get_cuda_devices() - return _devices is not None and len(_devices) > 1 + if _devices is None: + return False + world_size = self.num_nodes * self.num_workers_per_node + return (len(_devices) > 1) & dist.is_available() & (world_size > 1) def write_args_file(self, root: Optional[Path] = None) -> None: """ diff --git a/InnerEye/ML/model_config_base.py b/InnerEye/ML/model_config_base.py index 7beee36d2..83739fd63 100644 --- a/InnerEye/ML/model_config_base.py +++ b/InnerEye/ML/model_config_base.py @@ -68,7 +68,7 @@ def get_model_train_test_dataset_splits(self, dataset_df: pd.DataFrame) -> Datas def create_and_set_torch_datasets(self, for_training: bool = True, for_inference: bool = True) -> None: """ - Creats and sets torch datasets for training and validation, and stores them in the self._datasets_for_training + Creates and sets torch datasets for training and validation, and stores them in the self._datasets_for_training field. Similarly, create torch datasets in the form required for model inference, for all of the 3 splits of the full data, and stored them in the self._datasets_for_training and/or self._datasets_for_inference fields. @@ -120,12 +120,14 @@ def create_data_loaders(self) -> Dict[ModelExecutionMode, Any]: .as_data_loader(shuffle=self.shuffle, use_imbalanced_sampler=self.use_imbalanced_sampler_for_training, drop_last_batch=self.drop_last_batch_in_training, - max_repeats=self.get_total_number_of_training_epochs()) + max_repeats=self.get_total_number_of_training_epochs(), + distribute=self.use_distributed_data_parallel) logging.info("Creating the data loader for the validation set.") val_loader = self._datasets_for_training[ModelExecutionMode.VAL].as_data_loader( shuffle=False, - max_repeats=self.get_total_number_of_validation_epochs() + max_repeats=self.get_total_number_of_validation_epochs(), + distribute=False # validation step is not distributed ) logging.info("Finished creating the data loaders.") return { @@ -234,12 +236,13 @@ def write_dataset_files(self, root: Optional[Path] = None) -> None: dst = root / STORED_CSV_FILE_NAMES[mode] dataframe.to_csv(dst, mode='w', index=False) - def adjust_after_mixed_precision_and_parallel(self, model: Any) -> None: + def adjust_after_mixed_precision_and_parallel(self, model: Any, device: Any) -> None: """ A hook to adjust the model configuration that is stored in the present object to match the torch model given in the argument. This hook is called after adjusting the model for mixed precision and parallel training. :param model: The torch model. + :param device: The Torch device to allocate to. """ pass diff --git a/InnerEye/ML/model_training.py b/InnerEye/ML/model_training.py index 9a4211b9d..6d7914dfa 100644 --- a/InnerEye/ML/model_training.py +++ b/InnerEye/ML/model_training.py @@ -3,10 +3,9 @@ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ import logging +import os from time import time -from typing import Tuple, TypeVar - -from torch.cuda.amp import GradScaler +from typing import Any, Optional, Tuple, TypeVar, Union from InnerEye.Azure.azure_util import RUN_CONTEXT, is_offline_run_context from InnerEye.Common.common_util import logging_section @@ -23,12 +22,14 @@ from InnerEye.ML.scalar_config import ScalarModelBase from InnerEye.ML.sequence_config import SequenceModelBase from InnerEye.ML.utils import ml_util +from InnerEye.ML.utils.aml_distributed_utils import get_global_rank, get_global_size, get_local_rank, get_local_size, \ + is_aml_mpi_run from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler from InnerEye.ML.utils.lr_scheduler import SchedulerWithWarmUp from InnerEye.ML.utils.metrics_util import create_summary_writers from InnerEye.ML.utils.ml_util import RandomStateSnapshot from InnerEye.ML.utils.model_util import ModelAndInfo, generate_and_print_model_summary -from InnerEye.ML.utils.training_util import ModelOutputsAndMetricsForEpoch, ModelTrainingResults +from InnerEye.ML.utils.training_util import ModelOutputsAndMetricsForEpoch, ModelTrainingResults, determine_device from InnerEye.ML.visualizers.patch_sampling import visualize_random_crops_for_dataset MAX_ITEM_LOAD_TIME_SEC = 0.5 @@ -37,7 +38,7 @@ T = TypeVar('T') -def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) -> ModelTrainingResults: +def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) -> Optional[ModelTrainingResults]: """ The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training. @@ -46,7 +47,9 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization :raises TypeError: If the arguments are of the wrong type. :raises ValueError: When there are issues loading a previous checkpoint. + :return: ModelTrainingResult, unless model called with torch.mp.spawn (which must return None) """ + from torch.multiprocessing import spawn # Save the dataset files for later use in cross validation analysis config.write_dataset_files() @@ -58,11 +61,75 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) visualize_random_crops_for_dataset(config) ml_util.set_random_seed(config.get_effective_random_seed(), "Model training") - logging.debug("Creating the PyTorch model.") + if config.use_distributed_data_parallel: + + world_size = get_global_size(config) + print(f"Starting distributed training with {world_size} process(es)") + + if is_aml_mpi_run(config): + # AzureML MPI has been instantiated - configuration handles rank + train(None, config, checkpoint_handler) + model_training_results = None + + else: + # either offline run, or AML but not an MPI job + # set the environment variable for master node address + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + # spawn processes + spawn(train, args=(config, checkpoint_handler), nprocs=world_size) + model_training_results = None + + else: + single_process_rank = 0 + model_training_results = train(single_process_rank, config, checkpoint_handler) + + return model_training_results + + +def train(rank: Optional[int], config: ModelConfigBase, checkpoint_handler: CheckpointHandler + ) -> Optional[ModelTrainingResults]: + """ + :param rank: The global rank of the current process (for DistributedDataParallel). For single process, rank=0 + :param config: The arguments which specify all required information. + :param checkpoint_handler: Checkpoint handler object to find checkpoint paths for model initialization + :return: + """ + from torch.cuda.amp import GradScaler + + global_rank = get_global_rank() if rank is None else rank + + local_rank = get_local_rank() if is_aml_mpi_run(config) else global_rank # For 1 machine, global_rank = local_rank + device = determine_device(local_rank) + + if config.use_distributed_data_parallel: + + from torch.distributed import init_process_group, destroy_process_group + + world_size = get_global_size(config) + print(f"Running distributed training on device with global rank {global_rank} and local rank {local_rank}") + init_process_group( + backend=config.distributed_training_backend, + init_method=config.distributed_training_init_method, + world_size=world_size, + rank=global_rank) + + n_gpus_per_node = get_local_size(config) + config.train_batch_size = int(config.train_batch_size // n_gpus_per_node) + config.num_dataload_workers = int((config.num_dataload_workers + n_gpus_per_node - 1) / n_gpus_per_node) + + print(f'Updated batch size for mutiple GPUs: train_batch_size={config.train_batch_size},' + f' num_dataload_workers={config.num_dataload_workers}') # Create the train loader and validation loader to load images from the dataset data_loaders = config.create_data_loaders() + if config.use_distributed_data_parallel: + train_dataset = data_loaders[ModelExecutionMode.TRAIN].dataset + len_dataset = len(train_dataset) + assert 2 * len_dataset >= world_size, f"2* len(dataset) (={2 * len_dataset}) must be >= num GPUs (=" \ + f"{world_size})" + # Get the path to the checkpoint to recover from checkpoint_path = checkpoint_handler.get_recovery_path_train() @@ -78,10 +145,10 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) .format(config.start_epoch)) # Print out a detailed breakdown of layers, memory consumption and time. - generate_and_print_model_summary(config, models_and_optimizer.model) + generate_and_print_model_summary(config, models_and_optimizer.model, device) # Move model to GPU and adjust for multiple GPUs - models_and_optimizer.adjust_model_for_gpus() + models_and_optimizer.adjust_model_for_gpus(rank=local_rank) # Create the mean teacher model and move to GPU if config.compute_mean_teacher_model: @@ -101,7 +168,8 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) logging.info(f"Models are saved at {config.checkpoint_folder}") # Create the SummaryWriters for Tensorboard - writers = create_summary_writers(config) + writers = create_summary_writers(config, global_rank=global_rank) + config.create_dataframe_loggers() # Create LR scheduler @@ -126,6 +194,10 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) logging.info("Starting epoch {}".format(epoch)) save_epoch = config.should_save_epoch(epoch) and models_and_optimizer.optimizer is not None + if config.use_distributed_data_parallel: + # set epoch for DistributedSampler to make shuffling work properly + data_loaders[ModelExecutionMode.TRAIN].sampler.set_epoch(epoch) + # store the learning rates used for each epoch epoch_lrs = l_rate_scheduler.get_last_lr() learning_rates_per_epoch.append(epoch_lrs) @@ -141,8 +213,9 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) summary_writers=writers, dataframe_loggers=config.metrics_data_frame_loggers, in_training_mode=True) + training_steps = create_model_training_steps(config, train_val_params) - train_epoch_results = train_or_validate_epoch(training_steps) + train_epoch_results = train_or_validate_epoch(training_steps, local_rank, device) train_results_per_epoch.append(train_epoch_results.metrics) metrics.validate_and_store_model_parameters(writers.train, epoch, models_and_optimizer.model) @@ -155,19 +228,19 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) train_val_params.save_metrics = not (save_epoch and config.temperature_scaling_config) training_steps = create_model_training_steps(config, train_val_params) - val_epoch_results = train_or_validate_epoch(training_steps) + + val_epoch_results = train_or_validate_epoch(training_steps, local_rank, device) val_results_per_epoch.append(val_epoch_results.metrics) if config.is_segmentation_model: metrics.store_epoch_stats_for_segmentation(config.outputs_folder, epoch, epoch_lrs, train_epoch_results.metrics, val_epoch_results.metrics) - - if save_epoch: + if save_epoch and global_rank == 0: # perform temperature scaling if required if isinstance(config, SequenceModelBase) and config.temperature_scaling_config: optimal_temperature, scaled_val_results = \ - temperature_scaling_steps(config, train_val_params, val_epoch_results) + temperature_scaling_steps(config, train_val_params, val_epoch_results, local_rank, device) optimal_temperature_scale_values.append(optimal_temperature) # overwrite the metrics for the epoch with the metrics from the temperature scaled model val_results_per_epoch[-1] = scaled_val_results.metrics @@ -207,12 +280,17 @@ def model_train(config: ModelConfigBase, checkpoint_handler: CheckpointHandler) RUN_CONTEXT.log(name, value) resource_monitor.kill() - return model_training_results + if config.use_distributed_data_parallel: + destroy_process_group() + + # return model_training_results + return None if (config.use_distributed_data_parallel and is_aml_mpi_run(config)) else model_training_results def temperature_scaling_steps(config: SequenceModelBase, train_val_params: TrainValidateParameters, - val_results_for_epoch: ModelOutputsAndMetricsForEpoch) -> \ + val_results_for_epoch: ModelOutputsAndMetricsForEpoch, + rank: int, device: Any) -> \ Tuple[float, ModelOutputsAndMetricsForEpoch]: """ Perform the steps required for temperature scaling: @@ -222,6 +300,8 @@ def temperature_scaling_steps(config: SequenceModelBase, :param config: Config for a sequence model. :param train_val_params: Train/Validate parameters to use. :param val_results_for_epoch: results from the validation epoch to use in order to perform temperature scaling. + :param rank: Rank of the current process + :param device: The Torch device to allocate to :return: the optimal temperature value and the validation results after scaling has been performed. """ # re-create the training steps for the repeat pass, but with metrics saving enabled @@ -235,21 +315,59 @@ def temperature_scaling_steps(config: SequenceModelBase, labels = val_results_for_epoch.get_labels() temperature_value = training_steps.learn_temperature_scale_parameter(logits, labels) # recompute the validation set results for the temperature scaled model - val_epoch_results = train_or_validate_epoch(training_steps) + val_epoch_results = train_or_validate_epoch(training_steps, rank, device) return temperature_value, val_epoch_results -def train_or_validate_epoch(training_steps: ModelTrainingStepsBase) -> ModelOutputsAndMetricsForEpoch: +def train_or_validate_epoch(training_steps: ModelTrainingStepsBase, rank: int, + device: Any) -> ModelOutputsAndMetricsForEpoch: """ Trains or validates the model for one epoch. :param training_steps: Training pipeline to use. + :param rank: The rank of the current process + :param device: The Torch device to allocate to :returns: The results for training or validation. Result type depends on the type of model that is trained. """ - epoch_start_time = time() + import torch + FloatOrCudaEvent = Union[float, torch.cuda.streams.Event] + + def record_time() -> FloatOrCudaEvent: + """ + Record current time. For CUDA devices, where operations are asynchronous, return a CUDA event + :return: + """ + if torch.cuda.is_available(): + recorded_time: FloatOrCudaEvent = torch.cuda.Event(enable_timing=True) + assert isinstance(recorded_time, torch.cuda.streams.Event) # for mypy + recorded_time.record() + else: + recorded_time = time() + return recorded_time + + def calculate_time_difference(start_time: FloatOrCudaEvent, end_time: FloatOrCudaEvent) -> float: + """ + Calculate the difference between two timestamps. For CUDA devices, where operations are asynchronous + we call synchronize to ensure all events complete before returning the time + :param start_time: + :param end_time: + :return: + """ + if isinstance(start_time, torch.cuda.streams.Event) & isinstance(end_time, torch.cuda.streams.Event): + torch.cuda.synchronize() + return start_time.elapsed_time(end_time) # type: ignore + elif isinstance(start_time, float) & isinstance(end_time, float): + return end_time - start_time # type: ignore + else: + raise ValueError(f"Incompatible start and end times: {type(start_time)} & {type(end_time)}") + training_random_state = None train_val_params = training_steps.train_val_params config = training_steps.model_config + + item_start_time = record_time() + epoch_start_time = record_time() + if not train_val_params.in_training_mode: # take the snapshot of the existing random state training_random_state = RandomStateSnapshot.snapshot_random_state() @@ -257,7 +375,7 @@ def train_or_validate_epoch(training_steps: ModelTrainingStepsBase) -> ModelOutp ml_util.set_random_seed(config.get_effective_random_seed(), "Model validation") status_string = "training" if train_val_params.in_training_mode else "validation" - item_start_time = time() + num_load_time_warnings = 0 num_load_time_exceeded = 0 num_batches = 0 @@ -265,12 +383,15 @@ def train_or_validate_epoch(training_steps: ModelTrainingStepsBase) -> ModelOutp total_load_time = 0.0 model_outputs_epoch = [] for batch_index, sample in enumerate(train_val_params.data_loader): - item_finish_time = time() - item_load_time = item_finish_time - item_start_time + + item_finish_time = record_time() + item_load_time = calculate_time_difference(item_start_time, item_finish_time) + # Having slow minibatch loading is OK in the very first batch of the every epoch, where processes # are spawned. Later, the load time should be zero. if batch_index == 0: logging.info(f"Loaded the first minibatch of {status_string} data in {item_load_time:0.2f} sec.") + elif item_load_time > MAX_ITEM_LOAD_TIME_SEC: num_load_time_exceeded += 1 total_extra_load_time += item_load_time @@ -282,25 +403,33 @@ def train_or_validate_epoch(training_steps: ModelTrainingStepsBase) -> ModelOutp f"{MAX_LOAD_TIME_WARNINGS} times.") num_load_time_warnings += 1 model_outputs_minibatch = training_steps.forward_and_backward_minibatch( - sample, batch_index, train_val_params.epoch) + sample, batch_index, train_val_params.epoch, rank=rank, device=device) model_outputs_epoch.append(model_outputs_minibatch) - train_finish_time = time() + + train_finish_time = record_time() + status_time = calculate_time_difference(item_finish_time, train_finish_time) + logging.debug(f"Epoch {train_val_params.epoch} {status_string} batch {batch_index}: " f"Loaded in {item_load_time:0.2f}sec, " - f"{status_string} in {(train_finish_time - item_finish_time):0.2f}sec. " + f"{status_string} in {status_time:0.2f}sec. " f"Loss = {model_outputs_minibatch.loss}") - total_load_time += item_finish_time - item_start_time + + total_load_time = calculate_time_difference(item_start_time, item_finish_time) + item_start_time = record_time() + num_batches += 1 - item_start_time = time() # restore the training random state when validation has finished if training_random_state is not None: training_random_state.restore_random_state() - epoch_time_seconds = time() - epoch_start_time + epoch_end_time = record_time() + epoch_time_seconds = calculate_time_difference(epoch_start_time, epoch_end_time) + logging.info(f"Epoch {train_val_params.epoch} {status_string} took {epoch_time_seconds:0.2f} sec, " f"of which waiting for next minibatch took {total_load_time:0.2f} sec total. {num_batches} " "minibatches in total.") + if num_load_time_exceeded > 0: logging.warning("The dataloaders were not fast enough to always supply the next batch in less than " f"{MAX_ITEM_LOAD_TIME_SEC}sec.") diff --git a/InnerEye/ML/model_training_steps.py b/InnerEye/ML/model_training_steps.py index 890e6eee6..8640f0a5d 100644 --- a/InnerEye/ML/model_training_steps.py +++ b/InnerEye/ML/model_training_steps.py @@ -14,7 +14,6 @@ import torch.utils.data from torch import Tensor from torch.cuda import amp -from torch.cuda.amp import GradScaler from torch.nn import MSELoss from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -75,7 +74,7 @@ class TrainValidateParameters(param.Parameterized, Generic[M]): in_training_mode: bool = param.Boolean(default=True) dataframe_loggers: MetricsDataframeLoggers = param.ClassSelector(class_=MetricsDataframeLoggers, instantiate=False) save_metrics: bool = param.Boolean(default=True) - gradient_scaler = param.ClassSelector(class_=GradScaler, instantiate=False) + gradient_scaler = param.ClassSelector(class_=amp.GradScaler, instantiate=False) class ModelTrainingStepsBase(Generic[C, M], ABC): @@ -114,12 +113,15 @@ def in_training_mode(self) -> bool: @abstractmethod def forward_and_backward_minibatch(self, sample: Dict[str, Any], - batch_index: int, epoch: int) -> ModelForwardAndBackwardsOutputs: + batch_index: int, epoch: int, rank: int, device: torch.device + ) -> ModelForwardAndBackwardsOutputs: """ Runs training for a single minibatch of training data, and returns the loss. :param sample: The batched sample on which the model should be trained. :param batch_index: The index of the present batch (supplied only for diagnostics). :param epoch: The number of the present epoch. + :param rank: The global rank of the current process. + :param device: The Torch device to allocate to. """ raise NotImplementedError("forward_minibatch must be implemented by derived class.") @@ -146,6 +148,7 @@ def create_criterion(self) -> torch.nn.Module: if use_data_parallel is enabled or the loss function module otherwise. """ loss_function = self.create_loss_function() + if self.model_config.use_data_parallel: return DataParallelCriterion(module=loss_function, device_ids=self.model_config.get_cuda_devices(), # type:ignore @@ -153,15 +156,16 @@ def create_criterion(self) -> torch.nn.Module: else: return loss_function - def compute_loss(self, model_output: torch.Tensor, labels: NumpyOrTorch) -> torch.Tensor: + def compute_loss(self, model_output: torch.Tensor, labels: NumpyOrTorch, device: torch.device) -> torch.Tensor: """ Provided model outputs (logits) applies the criterion function and returns the loss tensor. If data parallel is used, then the independent loss values are aggregated by averaging. :param model_output: Model output logits (unnormalised) :param labels: A tensor or numpy array of labels. + :param device: The Torch device to allocate to. If None, sets to default GPU if available, else CPU. """ # ensure that the labels are loaded into the GPU - labels = self.model_config.get_gpu_tensor_if_possible(labels) + labels = self.model_config.get_gpu_tensor_if_possible(labels, device=device) loss = self.forward_criterion_with_autocast(model_output, labels) if self.model_config.use_data_parallel: # Aggregate the loss values for each parallelized batch element. @@ -219,6 +223,7 @@ def get_scalar_model_inputs_and_labels(model_config: ScalarModelBase, :param model_config: The configuration object for the model. :param model: The instantiated PyTorch model. :param sample: A training sample, as returned by a PyTorch data loader (dictionary mapping from field name to value) + :param device: The Torch device to allocate to. If None, sets to default GPU if available, else CPU. :return: An instance of ScalarModelInputsAndLabels, containing the list of model input tensors, label tensor, subject IDs, and the data item reconstructed from the data loader output """ @@ -300,16 +305,17 @@ def create_loss_function(self) -> torch.nn.Module: else: raise NotImplementedError("Loss type {} is not implemented".format(self.model_config.loss_type)) - def get_label_tensor(self, labels: torch.Tensor) -> torch.Tensor: + def get_label_tensor(self, labels: torch.Tensor, device: torch.device) -> torch.Tensor: """ Converts the given tensor to the right data format, depending on the chosen loss function. :param labels: The label tensor that should be converted. + :param device: the Torch device to allocate to """ try: - labels = labels.to(dtype=self.label_tensor_dtype) + labels = labels.to(device, dtype=self.label_tensor_dtype) except ValueError as ex: raise ValueError(f"Unable to convert tensor {labels} to data type {self.label_tensor_dtype}: {str(ex)}") - return self.model_config.get_gpu_tensor_if_possible(labels) + return self.model_config.get_gpu_tensor_if_possible(labels, device=device) def get_logits_and_posteriors(self, *model_inputs: torch.Tensor, use_mean_teacher_model: bool = False) \ -> Tuple[torch.Tensor, torch.Tensor]: @@ -329,7 +335,8 @@ def get_logits_and_posteriors(self, *model_inputs: torch.Tensor, use_mean_teache posteriors = self.model_config.get_post_loss_logits_normalization_function()(gather_tensor(logits)) return logits, posteriors - def _compute_model_output_and_loss(self, model_inputs_and_labels: ScalarModelInputsAndLabels) -> \ + def _compute_model_output_and_loss(self, model_inputs_and_labels: ScalarModelInputsAndLabels, rank: int, + device: torch.device) -> \ Tuple[Tensor, Tensor, Tensor]: """ Computes the output of the model for a given set of inputs and labels. @@ -337,7 +344,7 @@ def _compute_model_output_and_loss(self, model_inputs_and_labels: ScalarModelInp as a list. """ model = self.train_val_params.model - label_gpu = self.get_label_tensor(model_inputs_and_labels.labels) + label_gpu = self.get_label_tensor(model_inputs_and_labels.labels, device=device) if self.model_config.use_mixed_precision and self.model_config.use_gpu: label_gpu = label_gpu.to(dtype=torch.float16) @@ -346,30 +353,39 @@ def compute() -> Tuple[Tensor, Tensor, Tensor]: model.train() logits, posteriors = self.get_logits_and_posteriors(*model_inputs_and_labels.model_inputs) else: - model.eval() - with torch.no_grad(): - logits, posteriors = self.get_logits_and_posteriors(*model_inputs_and_labels.model_inputs) - model.train() - loss = self.compute_loss(logits, label_gpu) + if rank == 0: + model.eval() + with torch.no_grad(): + logits, posteriors = self.get_logits_and_posteriors(*model_inputs_and_labels.model_inputs) + model.train() + loss = self.compute_loss(logits, label_gpu, device) return logits, posteriors, loss return execute_within_autocast_if_needed(func=compute, use_autocast=self.model_config.use_mixed_precision) def forward_and_backward_minibatch(self, sample: Dict[str, Any], - batch_index: int, epoch: int) -> ModelForwardAndBackwardsOutputs: + batch_index: int, epoch: int, rank: int, + device: torch.device + ) -> ModelForwardAndBackwardsOutputs: """ Runs training for a single minibatch of training data, and computes all metrics. :param sample: The batched sample on which the model should be trained. :param batch_index: The index of the present batch (supplied only for diagnostics). :param epoch: The number of the present epoch. + :param rank: The global rank of the current process. + :param device: The Torch device to allocate to """ start_time = time.time() model = self.train_val_params.model mean_teacher_model = self.train_val_params.mean_teacher_model + model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model_config, model, sample) - label_gpu = self.get_label_tensor(model_inputs_and_labels.labels) - logits, posteriors, loss = self._compute_model_output_and_loss(model_inputs_and_labels) - gathered_logits = gather_tensor(logits) + label_gpu = self.get_label_tensor(model_inputs_and_labels.labels, device) + logits, posteriors, loss = self._compute_model_output_and_loss(model_inputs_and_labels, rank, device) + if self.model_config.use_distributed_data_parallel: + gathered_logits = logits + else: + gathered_logits = gather_tensor(logits) if self.in_training_mode: single_optimizer_step(loss, self.train_val_params.optimizer, @@ -385,7 +401,10 @@ def forward_and_backward_minibatch(self, sample: Dict[str, Any], logits, posteriors = self.get_logits_and_posteriors( *model_inputs_and_labels.model_inputs, use_mean_teacher_model=True) - gathered_logits = gather_tensor(logits) + if self.model_config.use_distributed_data_parallel: + gathered_logits = logits + else: + gathered_logits = gather_tensor(logits) # Autocast may have returned float16 tensors. Documentation suggests to simply cast back to float32. # If tensor was already float32, no overhead is incurred. @@ -403,7 +422,8 @@ def forward_and_backward_minibatch(self, sample: Dict[str, Any], self.metrics.add_metric(MetricType.LOSS, loss_scalar) self.update_metrics(model_inputs_and_labels.subject_ids, posteriors, label_gpu) logging.debug(f"Batch {batch_index}: {self.metrics.to_string()}") - minibatch_time = time.time() - start_time + minibatch_end_time = time.time() + minibatch_time = minibatch_end_time - start_time self.metrics.add_metric(MetricType.SECONDS_PER_BATCH, minibatch_time) return ModelForwardAndBackwardsOutputs( @@ -638,20 +658,26 @@ def construct_non_mixture_loss_function(cls, raise NotImplementedError("Loss type {} is not implemented".format(loss_type)) def forward_and_backward_minibatch(self, sample: Dict[str, Any], - batch_index: int, epoch: int) -> ModelForwardAndBackwardsOutputs: + batch_index: int, epoch: int, rank: int, device: torch.device + ) -> ModelForwardAndBackwardsOutputs: """ Runs training for a single minibatch of training data, and computes all metrics. :param sample: The batched sample on which the model should be trained. :param batch_index: The index of the present batch (supplied only for diagnostics). :param epoch: The number of the present epoch. + :param rank: The global rank of the current process. + :param device: The Torch device to allocate to """ + cropped_sample: CroppedSample = CroppedSample.from_dict(sample=sample) - labels = self.model_config.get_gpu_tensor_if_possible(cropped_sample.labels_center_crop) + labels = self.model_config.get_gpu_tensor_if_possible(cropped_sample.labels_center_crop, device=device) mask = None if self.train_val_params.in_training_mode else cropped_sample.mask_center_crop forward_pass_result = self.pipeline.forward_pass_patches(patches=cropped_sample.image, labels=labels, - mask=mask) + mask=mask, + rank=rank, + device=device) # Clear the GPU cache between forward and backward passes to avoid possible out-of-memory torch.cuda.empty_cache() dice_for_all_classes = metrics.compute_dice_across_patches( diff --git a/InnerEye/ML/models/architectures/base_model.py b/InnerEye/ML/models/architectures/base_model.py index 9b1a60b6d..75299ca23 100644 --- a/InnerEye/ML/models/architectures/base_model.py +++ b/InnerEye/ML/models/architectures/base_model.py @@ -13,6 +13,7 @@ from InnerEye.Common.common_util import any_pairwise_larger, initialize_instance_variables from InnerEye.Common.type_annotations import IntOrTuple3, TupleInt2, TupleInt3 from InnerEye.ML.utils.device_aware_module import DeviceAwareModule +from InnerEye.ML.utils.training_util import determine_device from InnerEye.ML.visualizers.model_summary import ModelSummary, forward_preserve_state @@ -138,14 +139,18 @@ def __init__(self, crop_size_constraints = CropSizeConstraints(multiple_of=1) self.crop_size_constraints = crop_size_constraints - def get_output_shape(self, input_shape: Union[TupleInt2, TupleInt3]) -> Tuple[int, ...]: + def get_output_shape(self, input_shape: Union[TupleInt2, TupleInt3], device: Optional[torch.device] = None) \ + -> Tuple[int, ...]: """ Computes model's output tensor shape for given input tensor shape. The argument is expected to be either a 2-tuple or a 3-tuple. A batch dimension (1) and the number of channels are added as the first dimensions. The result tuple has batch and channel dimension stripped off. :param input_shape: A tuple (2D or 3D) representing incoming tensor shape. + :param device: The Torch device to allocate to. """ + device = device or determine_device() + # Create a sample tensor for inference batch_size = 1 if len(input_shape) not in [2, 3]: @@ -154,7 +159,7 @@ def get_output_shape(self, input_shape: Union[TupleInt2, TupleInt3]) -> Tuple[in [torch.zeros(batch_size, self.input_channels, *input_shape, dtype=torch.float)] # Perform a forward pass then restore the state of the module - output_shape = forward_preserve_state(module=self, inputs=input_tensors).size() + output_shape = forward_preserve_state(module=self, inputs=input_tensors, device=device).size() return tuple(output_shape[2:]) def partition_model(self, devices: List[torch.device]) -> None: diff --git a/InnerEye/ML/models/parallel/distributed_data_parallel.py b/InnerEye/ML/models/parallel/distributed_data_parallel.py new file mode 100644 index 000000000..ef9f29860 --- /dev/null +++ b/InnerEye/ML/models/parallel/distributed_data_parallel.py @@ -0,0 +1,31 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +from typing import List + +from torch import device +from torch.nn.parallel import DistributedDataParallel + +from InnerEye.Common.type_annotations import T +from InnerEye.ML.utils.device_aware_module import DeviceAwareModule +from InnerEye.ML.utils.device_aware_module import E + + +class DistributedDataParallelModel(DistributedDataParallel, DeviceAwareModule): + def get_input_tensors(self, item: T) -> List[E]: + _module: DeviceAwareModule = self.get_module() + return _module.get_input_tensors(item) + + def get_module(self) -> DeviceAwareModule: + module = self.module + if not isinstance(module, DeviceAwareModule): + raise ValueError(f"Expecting DeviceAwareModule. Instead found {module}") + return module + + def get_devices(self) -> List[device]: + """ + Gets the numeric indices of the CUDA devices that the present object is using. + :return: + """ + return [device(x) if isinstance(x, int) else x for x in self.device_ids] diff --git a/InnerEye/ML/pipelines/forward_pass.py b/InnerEye/ML/pipelines/forward_pass.py index 722443838..52b23ed2c 100644 --- a/InnerEye/ML/pipelines/forward_pass.py +++ b/InnerEye/ML/pipelines/forward_pass.py @@ -62,7 +62,9 @@ def __init__(self, def forward_pass_patches(self, patches: torch.Tensor, labels: Optional[torch.Tensor] = None, - mask: Optional[torch.Tensor] = None) -> \ + mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + rank: Optional[int] = None) -> \ SegmentationForwardPass.Result: """ Wrapper function to handle model forward pass, including updating of the optimizer_type with loss gradients @@ -70,7 +72,10 @@ def forward_pass_patches(self, patches: torch.Tensor, :param patches: Images patches to be passed to the model in format Batches x Channels x Z x Y x X. :param labels: Labels for image patches to be used for loss computation: Batches x Classes x Z x Y x X :param mask: optional mask patches channel in shape Batches x Z x Y x X to be applied to the predictions. + :param rank: The global rank of the current process. + :param device: The Torch device to allocate to. """ + # check that the patches are as expected w.r.t to the configuration if patches is None: raise Exception("Patches for forward pass cannot be None.") @@ -100,28 +105,32 @@ def forward_pass_patches(self, patches: torch.Tensor, # handle model modes if self.in_training_mode: self.model.train() - result = self._forward_pass_with_anomaly_detection(patches=patches, mask=mask, labels=labels) + result = self._forward_pass_with_anomaly_detection(patches=patches, mask=mask, + labels=labels, device=device) else: self.model.eval() # turn off autograd for memory optimizations with torch.no_grad(): - result = self._forward_pass_with_anomaly_detection(patches=patches, mask=mask, labels=labels) + result = self._forward_pass_with_anomaly_detection(patches=patches, mask=mask, + labels=labels, device=device) self.model.train() return result def _forward_pass_with_anomaly_detection(self, patches: torch.Tensor, mask: torch.Tensor = None, - labels: torch.Tensor = None) -> SegmentationForwardPass.Result: + labels: torch.Tensor = None, + device: torch.device = None) -> SegmentationForwardPass.Result: if self.detect_anomaly: with autograd.detect_anomaly(): - result = self._forward_pass(patches, mask, labels) + result = self._forward_pass(patches, mask, labels, device=device) if result.loss is not None and (math.isnan(result.loss) or math.isinf(result.loss)): raise RuntimeError(f"The loss computation returned {result.loss}") return result - return self._forward_pass(patches, mask, labels) + return self._forward_pass(patches, mask, labels, device=device) - def _compute_loss(self, patches: Tensor, labels: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]: + def _compute_loss(self, patches: Tensor, labels: Optional[Tensor], device: Optional[torch.device] = None + ) -> Tuple[Tensor, Optional[Tensor]]: """ Do a forward pass on the model with the patches as input. If labels are provided, compute the loss. Return a tuple of (logits, loss). @@ -133,7 +142,7 @@ def compute() -> Tuple[Any, Optional[Tensor]]: # If labels *is* None, loss will also be None, which will stop the code below working (and # currently correctly triggers mypy errors). if labels is not None and self.criterion_fn is not None: - loss = self.criterion_fn(logits, labels) + loss = self.criterion_fn(logits, labels, device=device) return logits, loss return execute_within_autocast_if_needed(func=compute, use_autocast=True if self.gradient_scaler else False) @@ -141,15 +150,16 @@ def compute() -> Tuple[Any, Optional[Tensor]]: def _forward_pass(self, patches: torch.Tensor, mask: torch.Tensor = None, - labels: torch.Tensor = None) -> SegmentationForwardPass.Result: + labels: torch.Tensor = None, + device: Optional[torch.device] = None) -> SegmentationForwardPass.Result: # ensure that we always have float tensors as the model is defined over floats # and transfer the tensors to the GPU if possible before the forward pass - patches = self.config.get_gpu_tensor_if_possible(patches) + patches = self.config.get_gpu_tensor_if_possible(patches, device=device) if mask is not None: - mask = self.config.get_gpu_tensor_if_possible(mask) + mask = self.config.get_gpu_tensor_if_possible(mask, device=device) - logits, loss = self._compute_loss(patches, labels) + logits, loss = self._compute_loss(patches, labels, device=device) if self.in_training_mode: if loss is None: @@ -158,7 +168,7 @@ def _forward_pass(self, single_optimizer_step(loss, self.optimizer, self.gradient_scaler) # Aggregate data parallel logits if multiple hardware are used in forward pass - if isinstance(logits, list): + if isinstance(logits, list) and not self.config.use_distributed_data_parallel: # When using multiple GPUs, logits is a list of tensors. Gather will concatenate them # across the first dimension, and move them to GPU0. logits = torch.nn.parallel.gather(logits, target_device=0) diff --git a/InnerEye/ML/pipelines/inference.py b/InnerEye/ML/pipelines/inference.py index 97c52fd3f..aa5fa27de 100644 --- a/InnerEye/ML/pipelines/inference.py +++ b/InnerEye/ML/pipelines/inference.py @@ -25,6 +25,7 @@ from InnerEye.ML.utils.image_util import compute_uncertainty_map_from_posteriors, gaussian_smooth_posteriors, \ posteriors_to_segmentation from InnerEye.ML.utils.device_aware_module import DeviceAwareModule +from InnerEye.ML.utils.training_util import determine_device class InferencePipelineBase: @@ -398,6 +399,9 @@ def predict(self) -> InferenceBatch: """ model_config = self.get_configs() + rank = 0 # Assume we only perform inference on rank 0 device + device = determine_device(rank) + # extract patches for each image channel: Num patches x Channels x Z x Y x X patches = self._extract_patches_for_image_channels() @@ -409,7 +413,7 @@ def predict(self) -> InferenceBatch: # slice over the batches to prepare batch batch = patches[batch_idx: batch_idx + batch_size, ...] # perform the forward pass - batch_predictions = self._model_fn(batch) + batch_predictions = self._model_fn(batch, rank, device) image_util.check_array_range(batch_predictions, expected_range=InferencePipeline.MODEL_OUTPUT_POSTERIOR_RANGE, # type: ignore error_prefix="Model predictions for current batch") @@ -506,7 +510,7 @@ def _extract_patches_for_image_channels(self) -> np.ndarray: return np.stack(patches, axis=1) - def _model_fn(self, patches: np.ndarray) -> np.ndarray: + def _model_fn(self, patches: np.ndarray, rank: int, device: torch.device) -> np.ndarray: """ Wrapper function to handle the model forward pass :param patches: Image patches to be passed to the model in format Patches x Channels x Z x Y x X @@ -518,6 +522,8 @@ def _model_fn(self, patches: np.ndarray) -> np.ndarray: # get the model from the pipeline environment model = self.pipeline.get_variable(InferencePipeline.Variables.Model) + model.to(device) + # convert patches to Torch tensor patches = torch.from_numpy(patches).float() @@ -527,4 +533,4 @@ def _model_fn(self, patches: np.ndarray) -> np.ndarray: batch_size=model_config.inference_batch_size, optimizer=None, in_training_mode=False - ).forward_pass_patches(patches=patches).posteriors + ).forward_pass_patches(patches=patches, rank=rank, device=device).posteriors diff --git a/InnerEye/ML/pipelines/scalar_inference.py b/InnerEye/ML/pipelines/scalar_inference.py index 59e4b088d..7fc89c1be 100644 --- a/InnerEye/ML/pipelines/scalar_inference.py +++ b/InnerEye/ML/pipelines/scalar_inference.py @@ -17,6 +17,7 @@ from InnerEye.ML.utils import model_util from InnerEye.ML.utils.device_aware_module import DeviceAwareModule from InnerEye.ML.common import ModelExecutionMode +from InnerEye.ML.utils.training_util import determine_device class ScalarInferencePipelineBase(InferencePipelineBase): @@ -116,11 +117,14 @@ def predict(self, sample: Dict[str, Any]) -> ScalarInferencePipelineBase.Result: :return: Returns ScalarInferencePipelineBase.Result with the subject ids, ground truth labels and predictions. """ assert isinstance(self.model_config, ScalarModelBase) + + device = determine_device() model_inputs_and_labels = get_scalar_model_inputs_and_labels(self.model_config, self.model, sample) subject_ids = model_inputs_and_labels.subject_ids - labels = self.model_config.get_gpu_tensor_if_possible(model_inputs_and_labels.labels) + + labels = self.model_config.get_gpu_tensor_if_possible(model_inputs_and_labels.labels, device=device) model_output: torch.Tensor = self.model.forward(*model_inputs_and_labels.model_inputs) - if isinstance(model_output, list): + if isinstance(model_output, list) and not self.model_config.use_distributed_data_parallel: # Model output is a list if we are using data parallel. Here, this will be a degenerate list with # only 1 element model_output = torch.nn.parallel.gather(model_output, target_device=0) diff --git a/InnerEye/ML/run_ml.py b/InnerEye/ML/run_ml.py index 92ddf4460..2ac7100c7 100644 --- a/InnerEye/ML/run_ml.py +++ b/InnerEye/ML/run_ml.py @@ -37,6 +37,7 @@ from InnerEye.ML.runner import ModelDeploymentHookSignature, Runner, get_all_environment_files from InnerEye.ML.scalar_config import ScalarModelBase from InnerEye.ML.utils import ml_util +from InnerEye.ML.utils.aml_distributed_utils import get_global_rank, is_aml_mpi_run from InnerEye.ML.utils.blobxfer_util import download_blobs from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler from InnerEye.ML.utils.ml_util import make_pytorch_reproducible @@ -312,6 +313,12 @@ def run(self) -> None: # log the number of epochs used for model training RUN_CONTEXT.log(name="Train epochs", value=self.model_config.num_epochs) + # When training with DDP on AML, multiple processes will be running here. We only want to run inference + # once. We don't have this problem with offline DDP since training is spawned within model_training + if is_aml_mpi_run(self.model_config): + if get_global_rank() > 0: + return + # We specify the ModelProcessing as DEFAULT here even if the run_recovery points to an ensemble run, because # the current run is a single one. See the documentation of ModelProcessing for more details. best_epoch = self.run_inference_and_register_model(checkpoint_handler, ModelProcessing.DEFAULT) @@ -703,8 +710,10 @@ def run_model_test(data_split: ModelExecutionMode) -> Optional[InferenceMetrics] model_proc=model_proc) if config.perform_validation_and_test_set_inference: + torch.cuda.empty_cache() # perform inference on test set test_metrics = run_model_test(ModelExecutionMode.TEST) + torch.cuda.empty_cache() # perform inference on validation set val_metrics = run_model_test(ModelExecutionMode.VAL) diff --git a/InnerEye/ML/runner.py b/InnerEye/ML/runner.py index 6e806d44f..a1b0c7880 100755 --- a/InnerEye/ML/runner.py +++ b/InnerEye/ML/runner.py @@ -279,6 +279,8 @@ def parse_and_load_model(self) -> ParserResult: model_config = model_config_loader.create_model_config_from_name( model_name=azure_config.model ) + model_config.num_nodes = azure_config.node_count + model_config.num_workers_per_node = azure_config.workers_per_node # This model will be either a classification model or a segmentation model. Those have different # fields that could be overridden on the command line. Create a parser that understands the fields we need # for the actual model type. We feed this parser will the YAML settings and commandline arguments that the @@ -346,6 +348,7 @@ def submit_to_azureml(self) -> Run: upload_timeout_seconds=86400, ) source_config.set_script_params_except_submit_flag() + assert self.model_config.azure_dataset_id is not None # to stop mypy complaining about next line azure_run = submit_to_azureml(self.azure_config, source_config, model_config_overrides, self.model_config.azure_dataset_id) diff --git a/InnerEye/ML/utils/aml_distributed_utils.py b/InnerEye/ML/utils/aml_distributed_utils.py new file mode 100644 index 000000000..cf6ed3815 --- /dev/null +++ b/InnerEye/ML/utils/aml_distributed_utils.py @@ -0,0 +1,72 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +import os +from typing import Optional + +import torch + +from InnerEye.ML.model_config_base import ModelConfigBase + + +def get_local_rank() -> int: + """Returns the local rank of the current process for AML (online) runs.""" + rank = os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') + assert isinstance(rank, str), "Expected env var 'OMPI_COMM_WORLD_LOCAL_RANK' - perhaps this isn't an MPI run?" + return int(rank) + + +def get_global_rank() -> int: + """Returns the global rank of the current process for AML (online) runs.""" + rank = os.environ.get("OMPI_COMM_WORLD_RANK") + assert isinstance(rank, str), "Expected env var 'OMPI_COMM_WORLD_RANK' - perhaps this isn't an MPI run?" + return int(rank) + + +def get_global_size(config: ModelConfigBase) -> int: + """ + If running in AML, global size is the total number of devices across all machines. Otherwise, + assumes 1 machine only, and will set global size as all devices on current machine. In both cases, + global size is the maximum possible number of devices, but we may use fewer, if specified in the config + :return: + """ + max_world_size_from_config = config.num_workers_per_node * config.num_nodes + if is_aml_mpi_run(config): + global_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + else: + global_size = torch.cuda.device_count() + return min(global_size, max_world_size_from_config) + + +def get_local_size(config: ModelConfigBase) -> int: + """ + Get the number of devices on current machine (whether running in AML or locally).In both cases, + local size is the maximum possible number of devices, but we may use fewer, if specified in the config + """ + if is_aml_mpi_run(config): + local_size = int(os.environ['OMPI_COMM_WORLD_LOCAL_SIZE']) + else: + local_size = torch.cuda.device_count() + return min(local_size, config.num_workers_per_node) + + +def get_az_batch_master_node() -> Optional[str]: + """ + If AML MPI job, environment variable named should exist + :return: + """ + master_node_addr = os.environ['$AZ_BATCH_MASTER_NODE'] + return master_node_addr + + +def is_aml_mpi_run(config: ModelConfigBase) -> bool: + """ + Proxy for whether run is an AML MPI job (in which case init_method will be replaced with + tcp communication address instead of using environment vars) + Another proxy could be whether the environment variable $AZ_BATCH_MASTER_NODE has been set? (see above) + :param config: + :return: + """ + return (config.distributed_training_init_method.startswith('tcp://')) and (not config.is_offline_run) diff --git a/InnerEye/ML/utils/metrics_util.py b/InnerEye/ML/utils/metrics_util.py index ba68b2c98..5e2bd8d35 100644 --- a/InnerEye/ML/utils/metrics_util.py +++ b/InnerEye/ML/utils/metrics_util.py @@ -133,19 +133,27 @@ def to_data_frame(self) -> DataFrame: return df -def create_summary_writers(args: ModelConfigBase) -> SummaryWriters: +def create_summary_writers(args: ModelConfigBase, global_rank: int = -1) -> SummaryWriters: """ Creates two tensorboard writers, one for training and one for validation. Stored in a SummaryWriters objects. - - :param args: config of the model + :param args: config of the model. + :param global_rank: the global rank of the current process. :return: SummaryWriters with tensorboard summary writers. """ # Disable tensorboardX's logs logging.getLogger().disabled = True - writer_train = tensorboardX.SummaryWriter(str(args.logs_folder / "train")) - writer_val = tensorboardX.SummaryWriter(str(args.logs_folder / "val")) + train_summary_path = str(args.logs_folder / "train") + val_summary_path = str(args.logs_folder / "val") + + # create additional logs for distributed training + if global_rank > -1: + train_summary_path += f'_proc{global_rank}' + val_summary_path += f'_proc{global_rank}' + + writer_train = tensorboardX.SummaryWriter(train_summary_path) + writer_val = tensorboardX.SummaryWriter(val_summary_path) # Reset logger logging.getLogger().disabled = False diff --git a/InnerEye/ML/utils/model_util.py b/InnerEye/ML/utils/model_util.py index a028efbb3..0f7f5b94d 100644 --- a/InnerEye/ML/utils/model_util.py +++ b/InnerEye/ML/utils/model_util.py @@ -4,6 +4,7 @@ # ------------------------------------------------------------------------------------------ import logging import os +from collections import OrderedDict from pathlib import Path from typing import Any, Dict, Optional, Union @@ -24,12 +25,14 @@ from InnerEye.ML.models.architectures.unet_3d import UNet3D from InnerEye.ML.models.layers.basic import BasicLayer from InnerEye.ML.models.parallel.data_parallel import DataParallelModel +from InnerEye.ML.models.parallel.distributed_data_parallel import DistributedDataParallelModel from InnerEye.ML.scalar_config import ScalarModelBase from InnerEye.ML.sequence_config import SequenceModelBase from InnerEye.ML.utils.device_aware_module import DeviceAwareModule from InnerEye.ML.utils.metrics_constants import LoggingColumns from InnerEye.ML.utils.ml_util import RandomStateSnapshot from InnerEye.ML.utils.temperature_scaling import ModelWithTemperature +from InnerEye.ML.utils.training_util import determine_device from InnerEye.ML.visualizers.model_summary import ModelSummary @@ -48,7 +51,6 @@ class ModelAndInfo: checkpoint_epoch: the training epoch this model was created, if loaded from disk model_execution_mode: mode this model will be run in """ - MODEL_STATE_DICT_KEY = 'state_dict' OPTIMIZER_STATE_DICT_KEY = 'opt_dict' MEAN_TEACHER_STATE_DICT_KEY = 'mean_teacher_state_dict' @@ -102,7 +104,7 @@ def read_checkpoint(path_to_checkpoint: Path, use_gpu: bool) -> Dict[str, Any]: @classmethod def _load_checkpoint(cls, model: DeviceAwareModule, checkpoint_path: Path, - key_in_state_dict: str, use_gpu: bool) -> int: + key_in_state_dict: str, use_gpu: bool, use_distributed_data_parallel: bool) -> int: """ Loads a checkpoint of a model, may be the model or the mean teacher model. Assumes the model has already been created, and the checkpoint exists. This does not set checkpoint epoch. @@ -123,8 +125,15 @@ def _load_checkpoint(cls, model: DeviceAwareModule, checkpoint_path: Path, logging.error(f"Key {key_in_state_dict} not found in checkpoint") return False - if isinstance(model, torch.nn.DataParallel): - result = model.module.load_state_dict(state_dict, strict=False) + if use_distributed_data_parallel: + # Model was stored with DistributedDataParallel which stores the model in module, now loading without + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + name = k.replace('module.', '') # remove `module.` + new_state_dict[name] = v + result = model.load_state_dict(new_state_dict, strict=False) + elif isinstance(model, torch.nn.DataParallel): + result = model.module.load_state_dict(checkpoint[key_in_state_dict], strict=False) else: result = model.load_state_dict(state_dict, strict=False) @@ -137,7 +146,7 @@ def _load_checkpoint(cls, model: DeviceAwareModule, checkpoint_path: Path, @classmethod def _adjust_for_gpus(cls, model: DeviceAwareModule, config: ModelConfigBase, - model_execution_mode: ModelExecutionMode) -> DeviceAwareModule: + model_execution_mode: ModelExecutionMode, rank: int = 0) -> DeviceAwareModule: """ Updates a torch model so that input mini-batches are parallelized across the batch dimension to utilise multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to @@ -147,6 +156,18 @@ def _adjust_for_gpus(cls, model: DeviceAwareModule, config: ModelConfigBase, or adjust_mean_teacher_model_for_gpus :returns Adjusted model """ + device = determine_device(rank) + + use_distributed_data_parallel = config.use_gpu and config.use_distributed_data_parallel + if use_distributed_data_parallel: + if model_execution_mode == ModelExecutionMode.TRAIN: + model = model.to(device) + model = DistributedDataParallelModel(model, device_ids=[rank]) + else: + config.adjust_after_mixed_precision_and_parallel(model, device) + logging.debug("Skipping model update") + return model + if config.use_gpu: model = model.cuda() logging.info("Adjusting the model to use mixed precision training.") @@ -159,7 +180,7 @@ def _adjust_for_gpus(cls, model: DeviceAwareModule, config: ModelConfigBase, logging.info("Making no adjustments to the model because no GPU was found.") # Update model related config attributes (After Model Parallel Activated) - config.adjust_after_mixed_precision_and_parallel(model) + config.adjust_after_mixed_precision_and_parallel(model, device) # DataParallel enables running the model with multiple gpus by splitting samples across GPUs # If the model is used in training mode, data parallel is activated by default. @@ -199,13 +220,14 @@ def try_load_checkpoint_for_model(self) -> bool: epoch = ModelAndInfo._load_checkpoint(model=self._model, checkpoint_path=self.checkpoint_path, key_in_state_dict=ModelAndInfo.MODEL_STATE_DICT_KEY, - use_gpu=self.config.use_gpu) + use_gpu=self.config.use_gpu, + use_distributed_data_parallel=self.config.use_distributed_data_parallel) logging.info(f"Loaded model from checkpoint (epoch: {epoch})") self.checkpoint_epoch = epoch return True - def adjust_model_for_gpus(self) -> None: + def adjust_model_for_gpus(self, rank: Optional[int] = 0) -> None: """ Updates the torch model so that input mini-batches are parallelized across the batch dimension to utilise multiple gpus. If model parallel is set to True and execution is in test mode, then model is partitioned to @@ -223,7 +245,8 @@ def adjust_model_for_gpus(self) -> None: self._model = ModelAndInfo._adjust_for_gpus(model=self._model, config=self.config, - model_execution_mode=self.model_execution_mode) + model_execution_mode=self.model_execution_mode, + rank=rank) self.is_model_adjusted = True logging.debug("model_and_info.is_model_adjusted set to True") @@ -289,7 +312,8 @@ def try_load_checkpoint_for_mean_teacher_model(self) -> bool: epoch = ModelAndInfo._load_checkpoint(model=self._mean_teacher_model, checkpoint_path=self.checkpoint_path, key_in_state_dict=ModelAndInfo.MEAN_TEACHER_STATE_DICT_KEY, - use_gpu=self.config.use_gpu) + use_gpu=self.config.use_gpu, + use_distributed_data_parallel=self.config.use_distributed_data_parallel) logging.info(f"Loaded mean teacher model from checkpoint (epoch: {epoch})") self.checkpoint_epoch = epoch @@ -527,7 +551,7 @@ def summary_for_segmentation_models(config: ModelConfigBase, model: DeviceAwareM logging.warning(f"summary_for_segmentation_models failed with exception {e}") -def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAwareModule) -> None: +def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAwareModule, device: torch.device) -> None: """ Writes a human readable summary of the present model to logging.info, and logs the number of trainable parameters to AzureML. @@ -541,7 +565,7 @@ def generate_and_print_model_summary(config: ModelConfigBase, model: DeviceAware # https://github.com/NVIDIA/apex/issues/694 # Hence, move the model to the GPU before doing model summary. if config.use_gpu: - model = model.cuda() + model = model.to(device) if isinstance(config, ScalarModelBase): # To generate the model summary, read the first item of the dataset. Then use the model's own # get_model_input function to convert the dataset item to input tensors, and feed them through the model. diff --git a/InnerEye/ML/utils/training_util.py b/InnerEye/ML/utils/training_util.py index 52b41b4a1..7a70f93a7 100644 --- a/InnerEye/ML/utils/training_util.py +++ b/InnerEye/ML/utils/training_util.py @@ -82,3 +82,14 @@ def gather_tensor(tensor: Union[torch.Tensor, List[torch.Tensor]], return torch.nn.parallel.gather(tensor, target_device=target_device) else: return tensor + + +def determine_device(rank: int = 0) -> torch.device: + """ + If CUDA is available, returns a CUDA device (if multiple devices available, select the one + corresponding to the integer "rank". Otherwise, use cpu + :param rank: If multiple CUDA devices available, this param specifies which one to use + :return: + """ + device = torch.device('cuda', rank) if torch.cuda.is_available() else torch.device('cpu') + return device diff --git a/InnerEye/ML/visualizers/model_summary.py b/InnerEye/ML/visualizers/model_summary.py index 6c423348f..28a87e219 100644 --- a/InnerEye/ML/visualizers/model_summary.py +++ b/InnerEye/ML/visualizers/model_summary.py @@ -159,6 +159,7 @@ def _generate_summary(self, input_tensors: List[torch.Tensor]) -> None: and intermediate tensor size. :param input_tensors: A list of tensors which are fed into the torch model. """ + device = self._get_device(self.model) def print_summary() -> None: logging.info("-------------------------------------------------------------------------------") @@ -189,7 +190,7 @@ def print_summary() -> None: # Register the forward-pass hooks, profile the model, and restore its state self.model.apply(self._register_hook) with torchprof.Profile(self.model, use_cuda=self.use_gpu) as prof: - forward_preserve_state(self.model, input_tensors) # type: ignore + forward_preserve_state(self.model, input_tensors, device=device) # type: ignore # Log the model summary: tensor shapes, num of parameters, memory requirement, and forward pass time logging.info(self.model) @@ -201,20 +202,23 @@ def print_summary() -> None: h.remove() -def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor]) -> torch.Tensor: +def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor], device: torch.device + ) -> torch.Tensor: """ Perform forward pass on input module with given list of torch tensors. The function preserves the random state of the backend libraries to avoid reproducibility issues. Additionally, it temporarily sets the model in evaluation mode for inference and then restores its previous state. :param module: Callable torch module :param inputs: List of input torch tensors + :param device: The Torch device to allocate tensors to. :return output: Output torch tensors """ + if not isinstance(inputs, list): raise RuntimeError("Inputs object has to be a list of torch tensors") if module.is_model_on_gpu(): - inputs = [input_tensor.cuda() for input_tensor in inputs] + inputs = [input_tensor.to(device) for input_tensor in inputs] # collect the current state of the model is_train = module.training diff --git a/Tests/ML/configs/ClassificationModelForTesting.py b/Tests/ML/configs/ClassificationModelForTesting.py index 799a72e94..e67d39617 100644 --- a/Tests/ML/configs/ClassificationModelForTesting.py +++ b/Tests/ML/configs/ClassificationModelForTesting.py @@ -30,7 +30,9 @@ def __init__(self, conv_in_3d: bool = True, mean_teacher_model: bool = False) -> num_dataload_workers=0, test_start_epoch=num_epochs, subject_column="subjectID", - mean_teacher_alpha=mean_teacher_alpha + mean_teacher_alpha=mean_teacher_alpha, + num_nodes=1, + num_workers_per_node=1 ) self.expected_image_size_zyx = (4, 5, 7) self.conv_in_3d = conv_in_3d diff --git a/Tests/ML/configs/DummyModel.py b/Tests/ML/configs/DummyModel.py index d4b8b1e47..39b1b620c 100644 --- a/Tests/ML/configs/DummyModel.py +++ b/Tests/ML/configs/DummyModel.py @@ -65,6 +65,8 @@ def __init__(self, **kwargs: Any) -> None: test_start_epoch=1, test_diff_epochs=1, test_step_epochs=1, + num_nodes=1, + num_workers_per_node=1 ) self.add_and_validate(kwargs) diff --git a/Tests/ML/datasets/test_sequence_dataset.py b/Tests/ML/datasets/test_sequence_dataset.py index 0b25cbca6..9c6415442 100644 --- a/Tests/ML/datasets/test_sequence_dataset.py +++ b/Tests/ML/datasets/test_sequence_dataset.py @@ -553,7 +553,7 @@ def test_sequence_dataset_all(test_output_dirs: OutputFolderForTests) -> None: num_dataload_workers=0, train_batch_size=2, should_validate=False, - shuffle=False + shuffle=False, ) config.read_dataset_if_needed() df = config.dataset_data_frame diff --git a/Tests/ML/models/architectures/sequential/test_rnn_classifier.py b/Tests/ML/models/architectures/sequential/test_rnn_classifier.py index 236ddc426..d227e0ee8 100644 --- a/Tests/ML/models/architectures/sequential/test_rnn_classifier.py +++ b/Tests/ML/models/architectures/sequential/test_rnn_classifier.py @@ -214,6 +214,7 @@ def test_rnn_classifier_via_config_1(use_combined_model: bool, with mock.patch('InnerEye.ML.utils.io_util.load_image_in_known_formats', return_value=image_and_seg): results = model_train(config, get_default_checkpoint_handler(model_config=config, project_root=test_output_dirs.root_dir)) + assert results is not None # mypy assert len(results.optimal_temperature_scale_values_per_checkpoint_epoch) \ == config.get_total_number_of_save_epochs() @@ -381,7 +382,7 @@ def test_rnn_classifier_via_config_2(test_output_dirs: OutputFolderForTests) -> config.dataset_data_frame = _get_mock_sequence_dataset(dataset_contents) results = model_train(config, get_default_checkpoint_handler(model_config=config, project_root=test_output_dirs.root_dir)) - + assert results is not None # mypy actual_train_loss = results.train_results_per_epoch[-1].values()[MetricType.LOSS.value][0] actual_val_loss = results.val_results_per_epoch[-1].values()[MetricType.LOSS.value][0] print(f"Training loss after {config.num_epochs} epochs: {actual_train_loss}") diff --git a/Tests/ML/models/architectures/test_image_encoder_with_mlp.py b/Tests/ML/models/architectures/test_image_encoder_with_mlp.py index 07d72db3e..25bac15d7 100644 --- a/Tests/ML/models/architectures/test_image_encoder_with_mlp.py +++ b/Tests/ML/models/architectures/test_image_encoder_with_mlp.py @@ -183,7 +183,7 @@ def test_image_encoder(test_output_dirs: OutputFolderForTests, encode_channels_j categorical_feature_encoder=config_for_dataset.categorical_feature_encoder, encoder_dimensionality_reduction_factor=reduction_factor, aggregation_type=aggregation_type, - scan_size=(6, 64, 60) + scan_size=(6, 64, 60), ) if kernel_size_per_encoding_block: diff --git a/Tests/ML/models/test_instantiate_models.py b/Tests/ML/models/test_instantiate_models.py index afc06dbb7..ad297a8a2 100644 --- a/Tests/ML/models/test_instantiate_models.py +++ b/Tests/ML/models/test_instantiate_models.py @@ -6,6 +6,7 @@ from typing import List import pytest +from torch import device from InnerEye.Common.common_util import logging_to_stdout, namespace_to_path from InnerEye.ML.config import SegmentationModelBase @@ -56,7 +57,7 @@ def test_load_all_configs(model_name: str) -> None: config.feature_channels = [minimal_feature_channels] * len(config.feature_channels) print("Model architecture after restricting to 2 feature channels only:") model = create_model_with_temperature_scaling(config) - generate_and_print_model_summary(config, model) + generate_and_print_model_summary(config, model, device('cuda:0')) else: # For classification models, we can't always print a model summary: The model could require arbitrary # numbers of input tensors, and we'd only know once we load the training data. diff --git a/Tests/ML/pipelines/test_forward_pass.py b/Tests/ML/pipelines/test_forward_pass.py index 22bbd61dc..faee4b0a3 100644 --- a/Tests/ML/pipelines/test_forward_pass.py +++ b/Tests/ML/pipelines/test_forward_pass.py @@ -89,7 +89,7 @@ def test_anomaly_detection(value_to_insert: float, in_training_mode: bool) -> No image_channels=["ct"], ground_truth_ids=ground_truth_ids, should_validate=False, - detect_anomaly=True + detect_anomaly=True, ) model_and_info = ModelAndInfo(config=config, model_execution_mode=ModelExecutionMode.TRAIN, @@ -103,13 +103,15 @@ def test_anomaly_detection(value_to_insert: float, in_training_mode: bool) -> No optimizer = model_and_info.optimizer # Create the loss criterion - criterion = lambda x, y: torch.tensor(value_to_insert, requires_grad=True) + def _criterion(logits: torch.Tensor, labels: torch.Tensor, device: torch.device) -> torch.Tensor: + return torch.tensor(value_to_insert, requires_grad=True) + pipeline = SegmentationForwardPass(model, config, batch_size=1, optimizer=optimizer, in_training_mode=in_training_mode, - criterion=criterion) + criterion=_criterion) image[0, 0, 0, 0, 0] = value_to_insert if np.isnan(value_to_insert) or np.isinf(value_to_insert): with pytest.raises(RuntimeError) as ex: @@ -172,13 +174,16 @@ def test_amp_activated(use_model_parallel: bool, if use_data_parallel: assert isinstance(model, DataParallelModel) gradient_scaler = GradScaler() if use_mixed_precision else None - criterion = lambda x, y: torch.tensor([0.0], requires_grad=True).cuda() + + def _criterion(logits: torch.Tensor, labels: torch.Tensor, device: torch.device) -> torch.Tensor: + return torch.tensor([0.0], requires_grad=True).to(device) + pipeline = SegmentationForwardPass(model, model_config, batch_size=1, optimizer=optimizer, gradient_scaler=gradient_scaler, - criterion=criterion) + criterion=_criterion) logits, _ = pipeline._compute_loss(image, labels) # When using DataParallel, we expect to get a list of tensors back, one per GPU. if use_data_parallel: @@ -335,7 +340,7 @@ def create_model(self) -> Any: training_steps = ModelTrainingStepsForScalarModel(config, train_val_parameters) sample = list(data_loaders[execution_mode])[0] model_input = get_scalar_model_inputs_and_labels(config, model, sample) - logits, posteriors, loss = training_steps._compute_model_output_and_loss(model_input) + logits, posteriors, loss = training_steps._compute_model_output_and_loss(model_input, 0, torch.device('cuda:0')) # When using DataParallel, we expect to get a list of tensors back, one per GPU. if use_data_parallel: assert isinstance(logits, list) @@ -353,4 +358,4 @@ def create_model(self) -> Any: assert loss.dtype == torch.float32 # Verify that forward pass does not throw. It would for example if it fails to gather tensors or not convert # float16 to float32 - _, _, _ = training_steps._compute_model_output_and_loss(model_input) + _, _, _ = training_steps._compute_model_output_and_loss(model_input, 0, torch.device('cpu')) diff --git a/Tests/ML/pipelines/test_inference.py b/Tests/ML/pipelines/test_inference.py index 9e783e093..94aec7a3e 100644 --- a/Tests/ML/pipelines/test_inference.py +++ b/Tests/ML/pipelines/test_inference.py @@ -82,7 +82,7 @@ def test_inference_identity(image_size: Any, # instantiate the model model = PyTorchMockModel(shrink_by) - model_config.adjust_after_mixed_precision_and_parallel(model) + model_config.adjust_after_mixed_precision_and_parallel(model, torch.device('cpu')) # create single or ensemble inference pipeline inference_pipeline = InferencePipeline(model=model, model_config=model_config) diff --git a/Tests/ML/test_config_helpers.py b/Tests/ML/test_config_helpers.py index e2d6a5a07..d8e93a6ba 100644 --- a/Tests/ML/test_config_helpers.py +++ b/Tests/ML/test_config_helpers.py @@ -26,17 +26,18 @@ def test_inference_stride_size_setter() -> None: test_output_size = (7, 3, 5) test_stride_size = (3, 3, 3) test_fail_stride_size = (1, 1, 9) + device = torch.device('cpu') model = IdentityModel() model_config = SegmentationModelBase(test_crop_size=test_output_size, should_validate=False) model_config.inference_stride_size = test_stride_size assert model_config.inference_stride_size == test_stride_size - model_config.adjust_after_mixed_precision_and_parallel(model) + model_config.adjust_after_mixed_precision_and_parallel(model, device) assert model_config.inference_stride_size == test_stride_size model_config.inference_stride_size = None - model_config.adjust_after_mixed_precision_and_parallel(model) + model_config.adjust_after_mixed_precision_and_parallel(model, device) assert model_config.inference_stride_size == test_output_size with pytest.raises(ValueError): @@ -58,7 +59,7 @@ def test_set_model_config_attributes() -> None: test_crop_size=test_output_size, should_validate=False) - model_config.adjust_after_mixed_precision_and_parallel(model) + model_config.adjust_after_mixed_precision_and_parallel(model, torch.device('cpu')) assert model_config.inference_stride_size == test_output_size @@ -75,7 +76,7 @@ def test_get_output_size() -> None: assert model_config.get_output_size(execution_mode=ModelExecutionMode.TEST) is None model = IdentityModel() - model_config.adjust_after_mixed_precision_and_parallel(model) + model_config.adjust_after_mixed_precision_and_parallel(model, torch.device('cpu')) assert model_config.get_output_size(execution_mode=ModelExecutionMode.TRAIN) == train_output_size assert model_config.get_output_size(execution_mode=ModelExecutionMode.TEST) == test_output_size diff --git a/Tests/ML/test_model_train_test_and_recovery.py b/Tests/ML/test_model_train_test_and_recovery.py index d95c8a5d4..1569066c7 100644 --- a/Tests/ML/test_model_train_test_and_recovery.py +++ b/Tests/ML/test_model_train_test_and_recovery.py @@ -38,6 +38,7 @@ def test_recover_testing_from_run_recovery(mean_teacher_model: bool, checkpoint_handler = get_default_checkpoint_handler(model_config=config, project_root=test_output_dirs.root_dir) train_results = model_train(config, checkpoint_handler=checkpoint_handler) + assert train_results is not None # mypy assert len(train_results.learning_rates_per_epoch) == config.num_epochs # Run inference on this diff --git a/Tests/ML/test_model_training.py b/Tests/ML/test_model_training.py index 00d6d658d..13440e46e 100644 --- a/Tests/ML/test_model_training.py +++ b/Tests/ML/test_model_training.py @@ -190,10 +190,10 @@ def _check_patch_centers(epoch_results: List[MetricsDict], should_equal: bool) - assert train_config.logs_folder.is_dir() # The train and val folder should contain Tensorflow event files - assert (train_config.logs_folder / "train").is_dir() - assert (train_config.logs_folder / "val").is_dir() - assert len([(train_config.logs_folder / "train").glob("*")]) == 1 - assert len([(train_config.logs_folder / "val").glob("*")]) == 1 + assert (train_config.logs_folder / "train_proc0").is_dir() + assert (train_config.logs_folder / "val_proc0").is_dir() + assert len([(train_config.logs_folder / "train_proc0").glob("*")]) == 1 + assert len([(train_config.logs_folder / "val_proc0").glob("*")]) == 1 # Checkpoint folder # With these settings, we should see a checkpoint only at epoch 2: diff --git a/Tests/ML/utils/test_aml_distributed_utils.py b/Tests/ML/utils/test_aml_distributed_utils.py new file mode 100644 index 000000000..02524ebc3 --- /dev/null +++ b/Tests/ML/utils/test_aml_distributed_utils.py @@ -0,0 +1,150 @@ +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ +import os +import pytest +from typing import Optional +from unittest import mock + +from torch.cuda import device_count + +from InnerEye.ML.utils.aml_distributed_utils import get_local_rank, get_global_rank, get_global_size, get_local_size, \ + is_aml_mpi_run + + +@pytest.mark.parametrize("local_rank_env_var", [None, 1, 10]) +def test_get_local_rank(local_rank_env_var: Optional[int]) -> None: + """ + Test that get_local_rank returns the correct environment variable value if it exists + (e.g. for an AML MPI run) and otherwise raises a TypeError + :param local_rank_env_var: + :return: + """ + if local_rank_env_var is None: + with pytest.raises(AssertionError): + get_local_rank() + else: + with mock.patch.dict(os.environ, {'OMPI_COMM_WORLD_LOCAL_RANK': str(local_rank_env_var)}): + rank = get_local_rank() + assert rank == local_rank_env_var + + +@pytest.mark.parametrize("global_rank_env_var", [None, 5, 10]) +def test_get_global_rank(global_rank_env_var: Optional[int]) -> None: + """ + Test that get_global_rank returns the correct environment variable value if it exists + (e.g. for an AML MPI run) and otherwise raises a TypeError + :param global_rank_env_var: + :return: + """ + if global_rank_env_var is None: + with pytest.raises(AssertionError): + get_global_rank() + else: + with mock.patch.dict(os.environ, {'OMPI_COMM_WORLD_RANK': str(global_rank_env_var)}): + rank = get_global_rank() + assert rank == global_rank_env_var + + +@pytest.mark.parametrize(["num_nodes", "num_workers_per_node"], [(1, 1), (5, 4), (2, 3)]) +def test_get_global_size_offline(num_nodes: int, num_workers_per_node: int) -> None: + """ + Assert that, for an offline run, get_global_size returns the number of cuda devices + on the current machine + :param num_nodes: + :param num_workers_per_node: + :return: + """ + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.num_nodes = num_nodes + MockConfig.return_value.num_workers_per_node = num_workers_per_node + mock_config = MockConfig() + available_local_workers = min(device_count(), num_workers_per_node) + # for offline run, we assume 1 node regardless of config + expected_global_size = available_local_workers + global_size = get_global_size(mock_config) + assert global_size == expected_global_size + + +@pytest.mark.parametrize("expected_global_size", [1, 5, 10]) +def test_get_global_size_aml(expected_global_size: int) -> None: + """ + Assert that, for an AML run, get_global_size returns the value of the appropriate + environment variable + :param expected_global_size: + :return: + """ + with mock.patch("os.environ", {'OMPI_COMM_WORLD_SIZE': expected_global_size}): + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.is_offline_run = False + MockConfig.return_value.num_workers_per_node = expected_global_size + MockConfig.return_value.num_nodes = 1 + config = MockConfig() + global_size = get_global_size(config) + assert global_size == expected_global_size + + +@pytest.mark.parametrize("requested_num_workers_per_node", [1, 2, 5, 10]) +def test_get_local_size_offline(requested_num_workers_per_node: int) -> None: + """ + Assert that, for an offline run, get_local_size returns the number of cuda devices + on the current machine + :param requested_num_workers_per_node: + :return: + """ + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.num_workers_per_node = requested_num_workers_per_node + mock_config = MockConfig() + available_local_workers = device_count() + local_size = get_local_size(mock_config) + assert local_size == min(requested_num_workers_per_node, available_local_workers) + + +@pytest.mark.parametrize("requested_num_workers_per_node", [1, 2, 3]) +def test_get_local_size_aml(requested_num_workers_per_node: int) -> None: + """ + Assert that, for an AML run, get_local_size returns the value of the appropriate + environment variable + :param requested_num_workers_per_node: + :return: + """ + with mock.patch.dict(os.environ, {'OMPI_COMM_WORLD_LOCAL_SIZE': str(requested_num_workers_per_node)}): + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.is_offline_run = False + MockConfig.return_value.num_workers_per_node = requested_num_workers_per_node + config = MockConfig() + local_size = get_local_size(config) + assert local_size == requested_num_workers_per_node + + +def test_is_aml_mpi_run() -> None: + """ + Assert that is_aml_mpi_run returns False, unless we have an AML run where the init_method uses TCP + (by default it would use environment variables, but Azure's MPI job alters this). + :return: + """ + # By default, is_offline = True and init_method = "env://", so expect is_aml_mpi_run = False + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.is_offline_run = True + MockConfig.return_value.distributed_training_init_method = "env://" + config1 = MockConfig() + assert is_aml_mpi_run(config1) is False + # if is_offline = True, still expect is_aml_mpi_run = False, due to init_method + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.is_offline_run = False + MockConfig.return_value.distributed_training_init_method = "env://" + config2 = MockConfig() + assert is_aml_mpi_run(config2) is False + # if init_method starts with "tcp://" but is_offline = True, still expect is_aml_mpi_run = False + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.is_offline_run = True + MockConfig.return_value.distributed_training_init_method = "tcp://" + config = MockConfig() + assert is_aml_mpi_run(config) is False + # if init_method starts with "tcp://" and is_offline = False, expect is_aml_mpi_run = True + with mock.patch("Tests.ML.configs.DummyModel") as MockConfig: + MockConfig.return_value.is_offline_run = False + MockConfig.return_value.distributed_training_init_method = "tcp://" + config = MockConfig() + assert is_aml_mpi_run(config) is True