Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Adding Distributed Data Parallel #261

Closed
wants to merge 52 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
cb0ee83
added distributeddataparallel
mebristo Sep 30, 2020
90488d6
fix bug introduced in code cleanup
mebristo Sep 30, 2020
c1f5191
merge changes from master
mebristo Sep 30, 2020
1e24971
Get DDP working with model parallel on single machine
mebristo Oct 4, 2020
411d4b9
run validation and testing on single device
mebristo Oct 5, 2020
2e30ba4
add pytorch to azure_runner yaml
mebristo Oct 12, 2020
1755e69
update azure runner
mebristo Oct 12, 2020
cb47c4d
Merge recent changes from master
mebristo Oct 13, 2020
464f701
fix bugs to run on AML
mebristo Oct 13, 2020
e6f7744
remove config
mebristo Oct 13, 2020
983afce
Merge branch 'master' into mebristo/ddp
javier-alvarez Oct 16, 2020
94c4fde
switch global rank for local
mebristo Oct 16, 2020
50403dc
Merge branch 'mebristo/ddp' of https://github.com/microsoft/InnerEye-…
mebristo Oct 16, 2020
55b1297
undo changes to rank
mebristo Oct 16, 2020
fdc67cd
debug error with rank
mebristo Oct 19, 2020
0d1b8a5
checkpoint only saves for 1 rank and distributed timing is different
mebristo Oct 19, 2020
c2cebf6
fix sync bug
mebristo Oct 19, 2020
9ac02b0
fix bug in output_size
mebristo Oct 20, 2020
32c8597
Refactor
mebristo Oct 20, 2020
4f8efbd
bug fix
mebristo Oct 20, 2020
6417c42
debugging mem loss in inference on val set
mebristo Oct 21, 2020
82ca448
debug mem error in inference for val set
mebristo Oct 21, 2020
b326ee3
debug cuda memory error in inference
mebristo Oct 21, 2020
f7c58a3
temporarily make val set smaller for debugging
mebristo Oct 21, 2020
00134a3
debug memory error in inference on val set
mebristo Oct 21, 2020
a42d9a2
debug cuda mem error in inference
mebristo Oct 21, 2020
c0c86a7
debug slow inference
mebristo Oct 22, 2020
dd5d904
save epoch only one device
mebristo Oct 22, 2020
2dad14d
compare time doing inference on gpu 0
mebristo Oct 22, 2020
ceed0ef
tidy up
mebristo Oct 22, 2020
32b4d0e
tidy up
mebristo Oct 22, 2020
ba218e6
tidy up and fix tests
mebristo Oct 22, 2020
9c69809
restore model config after debugging finished
mebristo Oct 22, 2020
78f51c1
tidy up
mebristo Oct 23, 2020
a3027b2
merge recent changes from master
mebristo Oct 23, 2020
f4c4e65
tidy up
mebristo Oct 23, 2020
2f6904f
work on 1 device
mebristo Oct 27, 2020
a5980e4
Address PR comments
mebristo Oct 28, 2020
1bc6653
address PR comments
mebristo Oct 28, 2020
179f1bf
bug fix in inference
mebristo Oct 29, 2020
972c794
debug inference mem error: try clearing cache
mebristo Oct 29, 2020
7aef7d2
Destroy process group after trainingcomplete
mebristo Oct 29, 2020
726ad5f
address PR comments
mebristo Oct 29, 2020
7c4bcd8
attempt to fix bug in import
mebristo Oct 29, 2020
42d6cba
fix problem with importing torch
mebristo Oct 29, 2020
393547a
override global and local size with command line args
mebristo Oct 30, 2020
7acdc98
address PR comments
mebristo Oct 30, 2020
d7bf3ba
fix tests
mebristo Nov 12, 2020
1130c9b
merge recent changes from master
mebristo Nov 12, 2020
f7508de
fix test
mebristo Nov 12, 2020
ea195fd
Merge remote-tracking branch 'origin/master' into mebristo/ddp
ant0nsc Nov 16, 2020
0435b6c
Merge branch 'master' into mebristo/ddp
javier-alvarez Nov 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions InnerEye/Azure/azure_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class AzureConfig(GenericConfig):
is_train: bool = param.Boolean(True,
doc="If True, train a new model. If False, run inference on an existing model.")
model: str = param.String(doc="The name of the model to train/test.")
use_distributed_data_parallel: bool = param.Boolean(default=False)
register_model_only_for_epoch: Optional[int] = param.Integer(None,
doc="If set, and run_recovery_id is also set, "
"register the model for this epoch and do no "
Expand Down
18 changes: 15 additions & 3 deletions InnerEye/Azure/azure_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from azureml.core import Dataset, Experiment, Run, Workspace
from azureml.core import Dataset, Experiment, Run, Workspace, ComputeTarget
from azureml.core.conda_dependencies import CondaDependencies
from azureml.core.datastore import Datastore
from azureml.core.workspace import WORKSPACE_DEFAULT_BLOB_STORE_NAME
from azureml.data.dataset_consumption_config import DatasetConsumptionConfig
from azureml.exceptions import WorkspaceException
from azureml.train._distributed_training import Mpi
from azureml.train.dnn import PyTorch

from InnerEye.Azure import azure_util
Expand Down Expand Up @@ -286,19 +287,30 @@ def create_estimator_from_configs(workspace: Workspace, azure_config: AzureConfi
# create Estimator environment
framework_version = pytorch_version_from_conda_dependencies(conda_dependencies)
logging.info(f"PyTorch framework version: {framework_version}")

if azure_config.use_distributed_data_parallel:
distributed_training_backend = Mpi(azure_config.workers_per_node)
mebristo marked this conversation as resolved.
Show resolved Hide resolved
else:
distributed_training_backend = None

compute_target = ComputeTarget(workspace, azure_config.gpu_cluster_name)

estimator = PyTorch(
source_directory=source_config.root_folder,
entry_script=entry_script_relative_path,
script_params=source_config.script_params,
compute_target=azure_config.gpu_cluster_name,
compute_target=compute_target,
# Use blob storage for storing the source, rather than the FileShares section of the storage account.
source_directory_data_store=workspace.datastores.get(WORKSPACE_DEFAULT_BLOB_STORE_NAME),
inputs=estimator_inputs,
environment_variables=environment_variables,
shm_size=azure_config.docker_shm_size,
use_docker=True,
use_gpu=True,
framework_version=framework_version
framework_version=framework_version,
node_count=azure_config.node_count,
distributed_training=distributed_training_backend,
pip_packages=['azureml-dataprep[pandas,fuse]']
)
estimator.run_config.environment.python.conda_dependencies = conda_dependencies
# We'd like to log the estimator config, but conversion to string fails when the Estimator has some inputs.
Expand Down
2 changes: 1 addition & 1 deletion InnerEye/Common/common_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def flush(self, log_info: bool = False) -> None:
"""
import pandas as pd
if not self.csv_path.parent.is_dir():
self.csv_path.parent.mkdir(parents=True)
self.csv_path.parent.mkdir(parents=True, exist_ok=True)
# Specifying columns such that the order in which columns appear matches the order in which
# columns were added in the code.
columns = self.records[0].keys() if len(self.records) > 0 else None
Expand Down
7 changes: 5 additions & 2 deletions InnerEye/Common/generic_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import param
from azureml.core import Run
from azureml.core.run import _OfflineRun
from torch import device as torch_device

from InnerEye.Common.common_util import is_gpu_tensor, is_private_field_name
from InnerEye.Common.type_annotations import T
Expand Down Expand Up @@ -75,15 +76,17 @@ def get_cuda_devices(self) -> Optional[List[Any]]:
else:
return None

def get_gpu_tensor_if_possible(self, data: T) -> Any:
def get_gpu_tensor_if_possible(self, data: T, device: Optional[torch_device] = None) -> Any:
""""
Get a cuda tensor if this transform was cuda enabled and a GPU is available, otherwise
return the input.
"""
import torch
if isinstance(data, torch.Tensor):
if self.use_gpu and not is_gpu_tensor(data):
return data.cuda()
# use default CUDA device if not specified
device = torch.device('cuda') if device is None else device
mebristo marked this conversation as resolved.
Show resolved Hide resolved
return data.to(device)
else:
return data
else:
Expand Down
7 changes: 4 additions & 3 deletions InnerEye/ML/configs/classification/DummyClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class DummyClassification(ScalarModelBase):
"A config file for dummy image classification model for debugging purposes"

def __init__(self) -> None:
num_epochs = 4
num_epochs = 10
super().__init__(
local_dataset=full_ml_test_data_path("classification_data"),
image_channels=["image"],
Expand All @@ -28,9 +28,10 @@ def __init__(self) -> None:
num_epochs=num_epochs,
num_dataload_workers=0,
test_start_epoch=num_epochs,
use_mixed_precision=True,
use_mixed_precision=False,
subject_column="subjectID",
conv_in_3d=True
conv_in_3d=True,
use_distributed_data_parallel=True
)
self.expected_image_size_zyx = (4, 5, 7)

Expand Down
16 changes: 14 additions & 2 deletions InnerEye/ML/dataset/full_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import pandas as pd
import torch.utils.data
from torch._six import container_abcs
from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler, Sampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler, Sampler, SequentialSampler, \
DistributedSampler
from torch.utils.data.dataloader import default_collate # type: ignore

from InnerEye.Common.type_annotations import IntOrString, TupleFloat3
Expand Down Expand Up @@ -177,7 +178,9 @@ def as_data_loader(self,
num_dataload_workers: Optional[int] = None,
use_imbalanced_sampler: bool = False,
drop_last_batch: bool = False,
max_repeats: Optional[int] = None) -> DataLoader:
max_repeats: Optional[int] = None,
distribute=False
) -> DataLoader:
num_dataload_workers = num_dataload_workers or self.args.num_dataload_workers
batch_size = batch_size or self.args.train_batch_size
if self.args.avoid_process_spawn_in_data_loaders:
Expand All @@ -195,6 +198,15 @@ def as_data_loader(self,
use_imbalanced_sampler=use_imbalanced_sampler,
drop_last=drop_last_batch
)
elif distribute:
# distributed data loader
sampler: DistributedSampler = DistributedSampler(self)
return DataLoader(self,
batch_size=batch_size,
shuffle=False,
num_workers=0,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to match node_count?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each device will call this so setting num_workers to 0 will create 1 process on each device (preventing too many processes being spawned on each device , which was leading to CUDA memory errors). However this really slows down the data loading so another way to do it is is to use int((config.num_dataload_workers + n_gpus_per_node - 1) / n_gpus_per_node)

collate_fn=collate_with_metadata,
sampler=sampler)
else:
if use_imbalanced_sampler:
sampler: Optional[Sampler] = ImbalancedSampler(self)
Expand Down
19 changes: 17 additions & 2 deletions InnerEye/ML/deep_learning_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ class DeepLearningConfig(GenericConfig, CudaAwareConfig):
doc="The high-level model category described by this config.")
_model_name: str = param.String(None, doc="The human readable name of the model (for example, Liver). This is "
"usually set from the class name.")

use_distributed_data_parallel: bool = param.Boolean(False,
mebristo marked this conversation as resolved.
Show resolved Hide resolved
doc="If True, will attempt to train with "
"DistributedDataParallel")
random_seed: int = param.Integer(42, doc="The seed to use for all random number generators.")
azure_dataset_id: Optional[str] = param.String(None, allow_None=True,
doc="The ID of the dataset to use. This dataset must exist as a "
Expand Down Expand Up @@ -356,6 +358,9 @@ class DeepLearningConfig(GenericConfig, CudaAwareConfig):
"weight = alpha * (mean_teacher_weight) "
" + (1-alpha) * (current_student_weights). ")

dist_backend: str = param.String(default='nccl', doc="Communication package to use for distributed training")
mebristo marked this conversation as resolved.
Show resolved Hide resolved
init_method: str = param.String(default='env://', doc="URL specifying where to find peer processes")
mebristo marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, **params: Any) -> None:
self._model_name = type(self).__name__
# This should be annotated as torch.utils.data.Dataset, but we don't want to import torch here.
Expand Down Expand Up @@ -629,7 +634,17 @@ def use_data_parallel(self) -> bool:
:return:
"""
_devices = self.get_cuda_devices()
return _devices is not None and len(_devices) > 1
return _devices is not None and len(_devices) > 1 and not self.use_distributed_data_parallel

@property
def use_ddp(self) -> bool:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd spell out distributed_data_parallel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just from the names it is not clear what the difference between use_data_parallel and use_ddp is.

"""
Data parallel may used if GPUs are usable and the number of CUDA devices are greater than 1
and the OS is not windows
:return:
"""
_devices = self.get_cuda_devices()
return (_devices is not None) & (len(_devices) > 1) & (not is_windows()) & self.use_distributed_data_parallel
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is almost the same code as in use_data_parallel, DRY?


def write_args_file(self, root: Optional[Path] = None) -> None:
"""
Expand Down
9 changes: 6 additions & 3 deletions InnerEye/ML/model_config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from azureml.train.estimator import Estimator
from azureml.train.hyperdrive import GridParameterSampling, HyperDriveConfig, PrimaryMetricGoal, choice
from pandas import DataFrame
from torch import device as torch_device

from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY
from InnerEye.ML.common import DATASET_CSV_FILE_NAME, ModelExecutionMode, STORED_CSV_FILE_NAMES, TrackedMetrics
Expand Down Expand Up @@ -76,7 +77,7 @@ def create_torch_datasets(self, dataset_splits: DatasetSplits) -> Dict[ModelExec

def create_and_set_torch_datasets(self, for_training: bool = True, for_inference: bool = True) -> None:
"""
Creats and sets torch datasets for training and validation, and stores them in the self._datasets_for_training
Creates and sets torch datasets for training and validation, and stores them in the self._datasets_for_training
field. Similarly, create torch datasets in the form required for model inference, for all of the
3 splits of the full data, and stored them in the self._datasets_for_training and/or
self._datasets_for_inference fields.
Expand Down Expand Up @@ -128,12 +129,14 @@ def create_data_loaders(self, max_repeats: Optional[int] = None) -> Dict[ModelEx
.as_data_loader(shuffle=self.shuffle,
use_imbalanced_sampler=self.use_imbalanced_sampler_for_training,
drop_last_batch=self.drop_last_batch_in_training,
max_repeats=self.get_total_number_of_training_epochs())
max_repeats=self.get_total_number_of_training_epochs(),
distribute=self.use_ddp)
logging.info("Creating the data loader for the validation set.")

val_loader = self._datasets_for_training[ModelExecutionMode.VAL].as_data_loader(
shuffle=False,
max_repeats=self.get_total_number_of_validation_epochs()
max_repeats=self.get_total_number_of_validation_epochs(),
distribute=self.use_ddp
)
logging.info("Finished creating the data loaders.")
return {
Expand Down
77 changes: 65 additions & 12 deletions InnerEye/ML/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from time import time
from typing import Optional, Tuple, TypeVar

import torch

from InnerEye.Azure.azure_util import RUN_CONTEXT
from InnerEye.Common.common_util import empty_string_to_none
from InnerEye.Common.metrics_dict import MetricsDict
Expand All @@ -23,6 +25,7 @@
from InnerEye.ML.scalar_config import ScalarModelBase
from InnerEye.ML.sequence_config import SequenceModelBase
from InnerEye.ML.utils import ml_util, model_util
from InnerEye.ML.utils.aml_distributed_utils import get_local_rank, get_global_rank, get_global_size
from InnerEye.ML.utils.config_util import ModelConfigLoader
from InnerEye.ML.utils.lr_scheduler import LRScheduler
from InnerEye.ML.utils.metrics_util import create_summary_writers
Expand Down Expand Up @@ -50,7 +53,8 @@ def load_checkpoint_from_model_and_info(run_recovery: Optional[RunRecovery], con
return result


def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = None) -> ModelTrainingResults:
def model_train(config: ModelConfigBase,
run_recovery: Optional[RunRecovery] = None) -> ModelTrainingResults:
"""
The main training loop. It creates the model, dataset, optimizer_type, and criterion, then proceeds
to train the model. If a checkpoint was specified, then it loads the checkpoint before resuming training.
Expand All @@ -68,11 +72,55 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N

logging.debug("Creating the pytorch model.")

# create model
model = create_model_with_temperature_scaling(config)

if config.use_ddp:

world_size = get_global_size(config.is_offline_run)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is little logic in get_global_size. Normally, I'd always go for "put things into functions", but here it could be clearer to handle world size for offline/AzureML runs right here (expand the get_global_size here, and pass as argument into train()


if config.is_offline_run:
# set the environment variable for master node address
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# spawn processes
torch.multiprocessing.spawn(train,
args=(model, config),
nprocs=world_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused. Shouldn't offline runs (outside AzureML) be going through the same codepath as "do not use ddp"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is for cases such as a machine with multiple GPUs on it


else:
# AzureML MPI configuration handles spawn
rank = get_global_rank()
train(None, model, config)
else:
single_process_rank = 0
train(single_process_rank, model, config)


def train(rank: Optional[int], model, config, run_recovery: Optional[RunRecovery] = None):
"""

:param rank: The global rank of the current process (for DistributedDataParallel). For single process, rank=0
:param model:
:param config:
:param run_recovery:
:return:
"""
rank = get_global_rank() if rank is None else rank # get rank for AML run
device = torch.device('cuda', rank) if torch.cuda.is_available() else torch.device('cpu')
mebristo marked this conversation as resolved.
Show resolved Hide resolved

if config.use_ddp:
print(f"Running distributed training on device with global rank {rank}")
torch.distributed.init_process_group( # type: ignore
backend=config.dist_backend,
init_method=config.init_method,
world_size=get_global_size(config.is_offline_run),
rank=rank)

# Create the train loader and validation loader to load images from the dataset
data_loaders = config.create_data_loaders()

# Create models, optimizers, and whether is_mean_teacher
model = create_model_with_temperature_scaling(config)
models_and_optimizers = [ModelAndInfo(model, model_util.create_optimizer(config, model),
model_execution_mode=ModelExecutionMode.TRAIN)]
if config.compute_mean_teacher_model:
Expand All @@ -88,21 +136,22 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
else:
logging.info("Models are saved at {}".format(config.checkpoint_folder))
if not os.path.isdir(config.checkpoint_folder):
os.makedirs(config.checkpoint_folder)
os.makedirs(config.checkpoint_folder, exist_ok=True)

# Print out a detailed breakdown of layers, memory consumption and time.
generate_and_print_model_summary(config, model)

# Enable mixed precision training and data parallelization (no-op if already done).
# This relies on the information generated in the model summary.

# We only want to do this if we didn't call load_checkpoint above, because attempting updating twice
# causes an error.
models_and_optimizers = [model_util.update_model_for_mixed_precision_and_parallel(model_and_info, config)
models_and_optimizers = [model_util.update_model_for_mixed_precision_and_parallel(model_and_info, config,
rank=rank)
mebristo marked this conversation as resolved.
Show resolved Hide resolved
for model_and_info in models_and_optimizers]

# Create the SummaryWriters for Tensorboard
writers = create_summary_writers(config)
writers = create_summary_writers(config, rank=rank)

config.create_dataframe_loggers()

model = models_and_optimizers[0].model
Expand Down Expand Up @@ -144,7 +193,7 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
dataframe_loggers=config.metrics_data_frame_loggers,
in_training_mode=True)
training_steps = create_model_training_steps(config, train_val_params)
train_epoch_results = train_or_validate_epoch(training_steps)
train_epoch_results = train_or_validate_epoch(training_steps, device)
train_results_per_epoch.append(train_epoch_results)

metrics.validate_and_store_model_parameters(writers.train, epoch, model)
Expand All @@ -157,7 +206,7 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
train_val_params.save_metrics = not (save_epoch and config.temperature_scaling_config)

training_steps = create_model_training_steps(config, train_val_params)
val_epoch_results = train_or_validate_epoch(training_steps)
val_epoch_results = train_or_validate_epoch(training_steps, device)
if train_val_params.save_metrics:
val_results_per_epoch.append(val_epoch_results)

Expand All @@ -166,7 +215,7 @@ def model_train(config: ModelConfigBase, run_recovery: Optional[RunRecovery] = N
train_epoch_results.metrics,
val_epoch_results.metrics)

if save_epoch:
if save_epoch and rank==0:
# perform temperature scaling if required
if isinstance(config, SequenceModelBase) and config.temperature_scaling_config:
optimal_temperature, scaled_val_results = \
Expand Down Expand Up @@ -238,7 +287,7 @@ def temperature_scaling_steps(config: SequenceModelBase,
return temperature_value, val_epoch_results


def train_or_validate_epoch(training_steps: ModelTrainingStepsBase) -> ModelOutputsAndMetricsForEpoch:
def train_or_validate_epoch(training_steps: ModelTrainingStepsBase, device) -> ModelOutputsAndMetricsForEpoch:
"""
Trains or validates the model for one epoch.
:param training_steps: Training pipeline to use.
Expand Down Expand Up @@ -280,7 +329,7 @@ def train_or_validate_epoch(training_steps: ModelTrainingStepsBase) -> ModelOutp
f"{MAX_LOAD_TIME_WARNINGS} times.")
num_load_time_warnings += 1
model_outputs_minibatch = training_steps.forward_and_backward_minibatch(
sample, batch_index, train_val_params.epoch)
sample, batch_index, train_val_params.epoch, device)
model_outputs_epoch.append(model_outputs_minibatch)
train_finish_time = time()
logging.debug(f"Epoch {train_val_params.epoch} {status_string} batch {batch_index}: "
Expand Down Expand Up @@ -337,7 +386,11 @@ def main() -> None:
parser.add_argument("--model", help="The name of the model to train", type=empty_string_to_none,
required=True)
args = parser.parse_args()
model_train(ModelConfigLoader().create_model_config_from_name(args.model))

model_config = ModelConfigLoader().create_model_config_from_name(args.model)

not_distributed_rank = 0
mebristo marked this conversation as resolved.
Show resolved Hide resolved
model_train(not_distributed_rank, model_config)


if __name__ == '__main__':
Expand Down
Loading