From 66238edf5f5507ed5df8d1102c561c2bf62039dd Mon Sep 17 00:00:00 2001 From: Mikhail Kardash Date: Tue, 29 Oct 2024 13:29:56 -0700 Subject: [PATCH] deepspeed docs --- .../apis-howto/deepspeed/deepspeed.rst | 235 ++++++++++++++++++ .../training/api-deepspeed-reference.rst | 13 + examples/deepspeed/dcgan/README.md | 8 +- examples/deepspeed/dcgan/model.py | 18 +- examples/deepspeed/gpt_neox/det_utils.py | 33 +-- examples/deepspeed/gpt_neox/gpt2_trial.py | 61 +++-- examples/deepspeed/gpt_neox/trainer.py | 37 +++ examples/deepspeed/gpt_neox/zero1.yaml | 4 +- .../determined/pytorch/deepspeed/_trainer.py | 2 + 9 files changed, 350 insertions(+), 61 deletions(-) create mode 100644 examples/deepspeed/gpt_neox/trainer.py diff --git a/docs/model-dev-guide/api-guides/apis-howto/deepspeed/deepspeed.rst b/docs/model-dev-guide/api-guides/apis-howto/deepspeed/deepspeed.rst index 6c5f25bc8f9..93afd712c1b 100644 --- a/docs/model-dev-guide/api-guides/apis-howto/deepspeed/deepspeed.rst +++ b/docs/model-dev-guide/api-guides/apis-howto/deepspeed/deepspeed.rst @@ -365,6 +365,241 @@ profiling batches 3 and 4. rendering times for TensorBoard and memory issues. For long-running experiments, it is recommended to configure a profiling schedule. +******************* + DeepSpeed Trainer +******************* + +With the DeepSpeed Trainer API, you can implement and iterate on model training code locally before +running on cluster. When you are satisfied with your model code, you configure and submit the code +on cluster. + +The DeepSpeed Trainer API lets you do the following: + +- Work locally, iterating on your model code. +- Debug models in your favorite debug environment (e.g., directly on your machine, IDE, or Jupyter + notebook). +- Run training scripts without needing to use an experiment configuration file. +- Load previously saved checkpoints directly into your model. + +Initializing the Trainer +======================== + +After defining the PyTorch Trial, initialize the trial and the trainer. +:meth:`~determined.pytorch.deepspeed.init` returns a +:class:`~determined.pytorch.deepspeed.DeepSpeedTrialContext` for instantiating +:class:`~determined.pytorch.deepspeed.DeepSpeedTrial`. Initialize +:class:`~determined.pytorch.deepspeed.Trainer` with the trial and context. + +.. code:: python + + from determined.pytorch import deepspeed as det_ds + + def main(): + with det_ds.init() as train_context: + trial = MyTrial(train_context) + trainer = det_ds.Trainer(trial, train_context) + + if __name__ == "__main__": + # Configure logging + logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + main() + +Training is configured with a call to :meth:`~determined.pytorch.deepspeed.Trainer.fit` with +training loop arguments, such as checkpointing periods, validation periods, and checkpointing +policy. + +.. code:: diff + + from determined import pytorch + from determined.pytorch import deepspeed as det_ds + + def main(): + with det_ds.init() as train_context: + trial = MyTrial(train_context) + trainer = det_ds.Trainer(trial, train_context) + + trainer.fit( + + max_length=pytorch.Epoch(10), + + checkpoint_period=pytorch.Batch(100), + + validation_period=pytorch.Batch(100), + + checkpoint_policy="all" + + ) + + + if __name__ == "__main__": + # Configure logging + logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + main() + +Run Your Training Script Locally +================================ + +Run training scripts locally without submitting to a cluster or defining an experiment configuration +file. + +.. code:: python + + from determined import pytorch + from determined.pytorch import deepspeed as det_ds + + def main(): + with det_ds.init() as train_context: + trial = MyTrial(train_context) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit( + max_length=pytorch.Epoch(10), + checkpoint_period=pytorch.Batch(100), + validation_period=pytorch.Batch(100), + checkpoint_policy="all", + ) + + + if __name__ == "__main__": + # Configure logging + logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + main() + +You can run this Python script directly (``python3 train.py``), or in a Jupyter notebook. This code +will train for ten epochs, and checkpoint and validate every 100 batches. + +Local Distributed Training +========================== + +Local training can utilize multiple GPUs on a single node with a few modifications to the above +code. + +.. code:: diff + + import deepspeed + + def main(): + + # Initialize distributed backend before det_ds.init() + + deepspeed.init_distributed() + + # Set flag used by internal PyTorch training loop + + os.environ["DET_MANUAL_INIT_DISTRIBUTED"] = "true" + + # Initialize DistributedContext + with det_ds.init( + + distributed=core.DistributedContext.from_deepspeed() + ) as train_context: + trial = MyTrial(train_context) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit( + max_length=pytorch.Epoch(10), + checkpoint_period=pytorch.Batch(100), + validation_period=pytorch.Batch(100), + checkpoint_policy="all" + ) + +This code can be directly invoked with your distributed backend's launcher: ``deepspeed --num_gpus=4 +trainer.py --deepspeed --deepspeed_config ds_config.json`` + +Test Mode +========= + +Trainer accepts a test_mode parameter which, if true, trains and validates your training code for +only one batch, checkpoints, then exits. This is helpful for debugging code or writing automated +tests around your model code. + +.. code:: diff + + trainer.fit( + max_length=pytorch.Epoch(10), + checkpoint_period=pytorch.Batch(100), + validation_period=pytorch.Batch(100), + + test_mode=True + ) + +Prepare Your Training Code for Deploying to a Determined Cluster +================================================================ + +Once you are satisfied with the results of training the model locally, you submit the code to a +cluster. This example allows for distributed training locally and on cluster without having to make +code changes. + +Example workflow of frequent iterations between local debugging and cluster deployment: + +.. code:: diff + + def main(): + + local = det.get_cluster_info() is None + + if local: + + # Local: configure local distributed training. + + deepspeed.init_distributed() + + # Set flag used by internal PyTorch training loop + + os.environ["DET_MANUAL_INIT_DISTRIBUTED"] = "true" + + distributed_context = core.DistributedContext.from_deepspeed() + + latest_checkpoint = None + + else: + + # On-cluster: Determined will automatically detect distributed context. + + distributed_context = None + + # On-cluster: configure the latest checkpoint for pause/resume training functionality. + + latest_checkpoint = det.get_cluster_info().latest_checkpoint + + + with det_ds.init( + + distributed=distributed_context + ) as train_context: + trial = DCGANTrial(train_context) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit( + max_length=pytorch.Epoch(11), + checkpoint_period=pytorch.Batch(100), + validation_period=pytorch.Batch(100), + + latest_checkpoint=latest_checkpoint, + ) + +To run Trainer API solely on-cluster, the code is much simpler: + +.. code:: python + + def main(): + with det_ds.init() as train_context: + trial_inst = gan_model.DCGANTrial(train_context) + trainer = det_ds.Trainer(trial_inst, train_context) + trainer.fit( + max_length=pytorch.Epoch(11), + checkpoint_period=pytorch.Batch(100), + validation_period=pytorch.Batch(100), + latest_checkpoint=det.get_cluster_info().latest_checkpoint, + ) + +Submit Your Trial for Training on Cluster +========================================= + +To run your experiment on cluster, you'll need to create an experiment configuration (YAML) file. +Your experiment configuration file must contain searcher configuration and entrypoint. + +.. code:: python + + name: dcgan_deepspeed_mnist + searcher: + name: single + metric: validation_loss + resources: + slots_per_trial: 2 + entrypoint: python3 -m determined.launch.deepspeed python3 train.py + +Submit the trial to the cluster: + +.. code:: bash + + det e create det.yaml . + +If your training code needs to read some values from the experiment configuration, +``pytorch.deepspeed.init()`` accepts an ``exp_conf`` argument which allows calling +``context.get_experiment_config()`` from ``DeepSpeedTrialContext``. + +Profiling +========= + +When training on cluster, you can enable the system metrics profiler by adding a parameter to your +``fit()`` call: + +.. code:: diff + + trainer.fit( + ..., + + profiling_enabled=True + ) + ***************************** Known DeepSpeed Constraints ***************************** diff --git a/docs/reference/training/api-deepspeed-reference.rst b/docs/reference/training/api-deepspeed-reference.rst index 0fa7fbe8f87..6d00a8253ec 100644 --- a/docs/reference/training/api-deepspeed-reference.rst +++ b/docs/reference/training/api-deepspeed-reference.rst @@ -48,3 +48,16 @@ documentation): - :ref:`determined.pytorch.samplers ` - :ref:`determined.pytorch.MetricReducer ` - :ref:`determined.pytorch.PyTorchCallback ` + +****************************************** + ``determined.pytorch.deepspeed.Trainer`` +****************************************** + +.. autoclass:: determined.pytorch.deepspeed.Trainer + :members: + +***************************************** + ``determined.pytorch.deepspeed.init()`` +***************************************** + +.. autofunction:: determined.pytorch.deepspeed.init diff --git a/examples/deepspeed/dcgan/README.md b/examples/deepspeed/dcgan/README.md index f0b9811b9c9..31481d432c3 100644 --- a/examples/deepspeed/dcgan/README.md +++ b/examples/deepspeed/dcgan/README.md @@ -25,10 +25,16 @@ After installing docker and pulling an image, users can launch a container via Install necessary dependencies via `pip install determined mpi4py` -Then, run the following command: +Then, run the following command if running on a single node and GPU: ``` python trainer.py ``` +For multiple nodes GPUs, use the following: +``` +deepspeed --num_nodes= --num_gpus= trainer.py --deepspeed --deepspeed_config ds_config.json +``` +Where `num_nodes` corresponds to the number of nodes on your local cluster and `num_gpus` corresponds to +the number of GPUs per node. Any additional configs can be specified in `mnist.yaml` and `ds_config.json` accordingly. diff --git a/examples/deepspeed/dcgan/model.py b/examples/deepspeed/dcgan/model.py index 8ceab93dc6a..00071e2dbc2 100644 --- a/examples/deepspeed/dcgan/model.py +++ b/examples/deepspeed/dcgan/model.py @@ -46,7 +46,7 @@ def __init__(self, context: det_ds.DeepSpeedTrialContext, self.discriminator = self.context.wrap_model_engine(discriminator) self.fixed_noise = self.context.to_device( torch.randn( - self.context.train_micro_batch_size_per_gpu, self.hparams["noise_length"], 1, 1 + self.context.get_train_micro_batch_size_per_gpu(), self.hparams["noise_length"], 1, 1 ) ) self.criterion = nn.BCELoss() @@ -62,7 +62,7 @@ def _get_noise(self, dtype: torch.dtype) -> torch.Tensor: torch.Tensor, self.context.to_device( torch.randn( - self.context.train_micro_batch_size_per_gpu, + self.context.get_train_micro_batch_size_per_gpu(), self.hparams["noise_length"], 1, 1, @@ -93,7 +93,7 @@ def train_batch( else: dtype = torch.float32 real_label, fake_label = self._get_label_constants( - self.context.train_micro_batch_size_per_gpu, dtype + self.context.get_train_micro_batch_size_per_gpu(), dtype ) ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) @@ -106,7 +106,7 @@ def train_batch( D_x = 0.0 D_G_z1 = 0.0 fake_sample_count = ( - self.context.train_micro_batch_size_per_gpu * self.gradient_accumulation_steps + self.context.get_train_micro_batch_size_per_gpu() * self.gradient_accumulation_steps ) for i in range(self.gradient_accumulation_steps): @@ -132,7 +132,7 @@ def train_batch( output = self.discriminator(fake.detach()) errD_fake = self.criterion(output, fake_label) self.discriminator.backward(errD_fake) - errD_fake_sum += errD_fake * self.context.train_micro_batch_size_per_gpu + errD_fake_sum += errD_fake * self.context.get_train_micro_batch_size_per_gpu() D_G_z1 += output.sum().item() # update self.discriminator.step() @@ -153,7 +153,7 @@ def train_batch( output = self.discriminator(fake) errG = self.criterion(output, real_label) # fake labels are real for generator cost self.generator.backward(errG) - errG_sum += errG * self.context._train_micro_batch_size_per_gpu + errG_sum += errG * self.context.get_train_micro_batch_size_per_gpu() D_G_z2_sum += output.sum().item() self.generator.step() @@ -188,7 +188,7 @@ def build_training_data_loader(self) -> Any: dataset = data.get_dataset(self.data_config) return DataLoader( dataset, - batch_size=self.context.train_micro_batch_size_per_gpu, + batch_size=self.context.get_train_micro_batch_size_per_gpu(), shuffle=True, num_workers=int(self.hparams["data_workers"]), ) @@ -200,9 +200,9 @@ def build_validation_data_loader(self) -> Any: dataset, list( range( - self.context.train_micro_batch_size_per_gpu + self.context.get_train_micro_batch_size_per_gpu() * self.context.distributed.get_size() ) ), ) - return DataLoader(dataset, batch_size=self.context.train_micro_batch_size_per_gpu) + return DataLoader(dataset, batch_size=self.context.get_train_micro_batch_size_per_gpu()) diff --git a/examples/deepspeed/gpt_neox/det_utils.py b/examples/deepspeed/gpt_neox/det_utils.py index 3a6eac44f1c..a6c8d251fd7 100644 --- a/examples/deepspeed/gpt_neox/det_utils.py +++ b/examples/deepspeed/gpt_neox/det_utils.py @@ -1,18 +1,17 @@ import logging import os +import attrdict import numpy as np -from attrdict import AttrMap -from eval_tasks.eval_adapter import run_eval_harness -from megatron.neox_arguments import NeoXArgs -from torch.utils.tensorboard import SummaryWriter +from eval_tasks import eval_adapter +from megatron import neox_arguments +from torch.utils import tensorboard -from determined.pytorch import MetricReducer, PyTorchCallback +from determined import pytorch -def get_neox_args(context): - args = AttrMap(context.get_hparams()) - exp_config = context.get_experiment_config() +def get_neox_args(exp_config: dict, hparams: dict, trial_seed: int): + args = attrdict.AttrMap(hparams) # Gather overrides. overwrite_values = args.pop("overwrite_values", {}) @@ -30,19 +29,21 @@ def get_neox_args(context): "checkpoint_factor": exp_config["min_validation_period"]["batches"], "eval_interval": exp_config["min_validation_period"]["batches"], "hostfile": os.environ.get("DET_DEEPSPEED_HOSTFILE_PATH"), - "seed": context.get_trial_seed(), + "seed": trial_seed, } ) for k, v in overwrite_values.items(): logging.info(f"Setting neox_args.{k} to {v}") # Build neox args. - neox_args = NeoXArgs.process_parsed_deepy_args(args, overwrite_values=overwrite_values) + neox_args = neox_arguments.NeoXArgs.process_parsed_deepy_args( + args, overwrite_values=overwrite_values + ) return neox_args -class TensorboardWriter(PyTorchCallback): - def __init__(self, writer: SummaryWriter): +class TensorboardWriter(pytorch.PyTorchCallback): + def __init__(self, writer: tensorboard.SummaryWriter): self.tb_writer = writer def on_validation_end(self, metrics): @@ -53,7 +54,7 @@ def trial_cleanup(self) -> None: self.tb_writer.close() -class EarlyStoppingCallback(PyTorchCallback): +class EarlyStoppingCallback(pytorch.PyTorchCallback): def __init__(self, trial): self.trial = trial @@ -62,7 +63,7 @@ def on_validation_start(self): self.trial.context.set_stop_requested(True) -class LMReducers(MetricReducer): +class LMReducers(pytorch.MetricReducer): def __init__(self, neox_args): self.char_level_ppl = neox_args.char_level_ppl self.token_count = 0 @@ -95,7 +96,7 @@ def cross_slot_reduce(self, per_slot_metrics): return metrics -class EvalHarness(PyTorchCallback): +class EvalHarness(pytorch.PyTorchCallback): def __init__(self, model, forward_step_fn, neox_args): self.model = model self.forward_step_fn = forward_step_fn @@ -104,7 +105,7 @@ def __init__(self, model, forward_step_fn, neox_args): def on_validation_end(self, metrics): # TODO: This hangs with pipeline parallel. metrics.update( - run_eval_harness( + eval_adapter.run_eval_harness( self.model, self.forward_step_fn, self.neox_args, diff --git a/examples/deepspeed/gpt_neox/gpt2_trial.py b/examples/deepspeed/gpt_neox/gpt2_trial.py index eb80221fb32..17e8f89d3c3 100644 --- a/examples/deepspeed/gpt_neox/gpt2_trial.py +++ b/examples/deepspeed/gpt_neox/gpt2_trial.py @@ -3,41 +3,36 @@ import traceback from datetime import datetime +import attrdict import deepspeed +import det_utils import megatron.training as megatron_train import megatron.utils as megatron_utils import torch -from attrdict import AttrMap -from det_utils import ( - EarlyStoppingCallback, - EvalHarness, - LMReducers, - TensorboardWriter, - get_neox_args, -) -from megatron import mpu -from megatron.checkpointing import load_checkpoint, save_checkpoint -from megatron.data.data_utils import build_datasets_from_neox_args +from megatron import checkpointing, mpu +from megatron.data import data_utils -from determined import LOG_FORMAT, InvalidHP -from determined.pytorch import DataLoader -from determined.pytorch.deepspeed import DeepSpeedTrial, DeepSpeedTrialContext, ModelParallelUnit +import determined as det +from determined import pytorch +from determined.pytorch import deepspeed as det_ds -logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) +logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) -class GPT2Trial(DeepSpeedTrial): - def __init__(self, context: DeepSpeedTrialContext) -> None: +class GPT2Trial(det_ds.DeepSpeedTrial): + def __init__( + self, context: det_ds.DeepSpeedTrialContext, hparams: dict, trial_seed: int + ) -> None: self.context = context - self.exp_config = self.context.get_experiment_config() - self.args = AttrMap(self.context.get_hparams()) + self.exp_config = context.get_experiment_config() + self.args = attrdict.AttrMap(hparams) # Initalize and get arguments, timers, and Tensorboard writer. try: - self.neox_args = get_neox_args(self.context) + self.neox_args = det_utils.get_neox_args(self.exp_config, hparams, trial_seed) except: traceback.print_exc() - raise InvalidHP("Could not parse neox_args.") + raise det.InvalidHP("Could not parse neox_args.") logging.info(self.neox_args) self.writer = self.context.get_tensorboard_writer() self.neox_args.tensorboard_writer = self.writer @@ -60,7 +55,7 @@ def __init__(self, context: DeepSpeedTrialContext) -> None: ) = megatron_train.setup_model_and_optimizer(neox_args=self.neox_args) self.model = self.context.wrap_model_engine(model) self.context.set_mpu( - ModelParallelUnit( + det_ds.ModelParallelUnit( mpu.get_data_parallel_rank(), mpu.get_data_parallel_world_size(), should_report_metrics=True, @@ -77,7 +72,7 @@ def __init__(self, context: DeepSpeedTrialContext) -> None: # For tracking. if not self.args.search_world_size: self.reducer = self.context.wrap_reducer( - LMReducers(self.neox_args), for_training=False, for_validation=True + det_utils.LMReducers(self.neox_args), for_training=False, for_validation=True ) self.report_memory_flag = True self.total_train_loss_dict = {} @@ -98,13 +93,13 @@ def should_build_data_loader(self): return mpu.get_model_parallel_rank() == 0 and pipe_load def build_callbacks(self): - callbacks = {"tb": TensorboardWriter(self.writer)} + callbacks = {"tb": det_utils.TensorboardWriter(self.writer)} if self.neox_args.eval_tasks: - callbacks["eval_tasks"] = EvalHarness( + callbacks["eval_tasks"] = det_utils.EvalHarness( self.model, megatron_train.forward_step, self.neox_args ) if self.args.search_world_size: - callbacks["early_stopping"] = EarlyStoppingCallback(self) + callbacks["early_stopping"] = det_utils.EarlyStoppingCallback(self) return callbacks def train_batch(self, data_iterator, epoch_idx, batch_idx): @@ -241,10 +236,10 @@ def build_training_data_loader(self): self.train_data, self.valid_data, self.test_data, - ) = build_datasets_from_neox_args(self.neox_args) + ) = data_utils.build_datasets_from_neox_args(self.neox_args) self.timers("train/valid/test data dataset").stop() self.timers.log(["train/valid/test data dataset"]) - return DataLoader( + return pytorch.DataLoader( self.train_data, batch_size=self.neox_args.train_micro_batch_size_per_gpu, shuffle=True, @@ -254,7 +249,7 @@ def build_training_data_loader(self): ) def build_validation_data_loader(self): - return DataLoader( + return pytorch.DataLoader( self.valid_data, batch_size=self.neox_args.train_micro_batch_size_per_gpu, num_workers=self.neox_args.num_workers, @@ -262,9 +257,9 @@ def build_validation_data_loader(self): pin_memory=False, ) - def save(self, context: DeepSpeedTrialContext, path: pathlib.Path) -> None: + def save(self, context: det_ds.DeepSpeedTrialContext, path: pathlib.Path) -> None: self.neox_args.save = str(path) - save_checkpoint( + checkpointing.save_checkpoint( neox_args=self.neox_args, iteration=self.neox_args.iteration, model=self.model, @@ -272,9 +267,9 @@ def save(self, context: DeepSpeedTrialContext, path: pathlib.Path) -> None: lr_scheduler=self.lr_scheduler, ) - def load(self, context: DeepSpeedTrialContext, path: pathlib.Path) -> None: + def load(self, context: det_ds.DeepSpeedTrialContext, path: pathlib.Path) -> None: self.neox_args.load = str(path) - self.neox_args.iteration = load_checkpoint( + self.neox_args.iteration = checkpointing.load_checkpoint( neox_args=self.neox_args, model=self.model, optimizer=self.optimizer, diff --git a/examples/deepspeed/gpt_neox/trainer.py b/examples/deepspeed/gpt_neox/trainer.py new file mode 100644 index 00000000000..ee330373860 --- /dev/null +++ b/examples/deepspeed/gpt_neox/trainer.py @@ -0,0 +1,37 @@ +import logging + +import gpt2_trial +import yaml + +import determined as det +from determined import pytorch +from determined.pytorch import deepspeed as det_ds + + +def main(config_file: str, local: bool = True): + info = det.get_cluster_info() + + if local: + # For convenience, use hparams from const.yaml for local mode. + with open(config_file, "r") as f: + experiment_config = yaml.load(f, Loader=yaml.SafeLoader) + hparams = experiment_config["hyperparameters"] + latest_checkpoint = None + else: + hparams = info.trial.hparams + latest_checkpoint = ( + info.latest_checkpoint + ) # (Optional) Configure checkpoint for pause/resume functionality. + + with det_ds.init() as train_context: + trial_seed = train_context.get_trial_seed() + trial = gpt2_trial.GPT2Trial(train_context, hparams, trial_seed) + trainer = det_ds.Trainer(trial, train_context) + trainer.fit(max_length=pytorch.Batch(200), latest_checkpoint=latest_checkpoint) + + +if __name__ == "__main__": + local = det.get_cluster_info() is None + # Configure logging + logging.basicConfig(level=logging.INFO, format=det.LOG_FORMAT) + main(config_file="zero1.yaml", local=local) diff --git a/examples/deepspeed/gpt_neox/zero1.yaml b/examples/deepspeed/gpt_neox/zero1.yaml index 2b8b264c498..3b4a048c47a 100644 --- a/examples/deepspeed/gpt_neox/zero1.yaml +++ b/examples/deepspeed/gpt_neox/zero1.yaml @@ -43,5 +43,5 @@ entrypoint: - python3 - -m - determined.launch.deepspeed - - --trial - - gpt2_trial:GPT2Trial + - python3 + - trainer.py diff --git a/harness/determined/pytorch/deepspeed/_trainer.py b/harness/determined/pytorch/deepspeed/_trainer.py index 8e36f345235..587a1b41999 100644 --- a/harness/determined/pytorch/deepspeed/_trainer.py +++ b/harness/determined/pytorch/deepspeed/_trainer.py @@ -65,7 +65,9 @@ def fit( max_length: The maximum number of steps to train for. This is a ``TrainUnit`` type (``Batch`` or ``Epoch``) which takes an ``int``. For example, ``Epoch(1)`` would train for a maximum length of one epoch. + .. note:: + If using an ASHA searcher, this value should match the searcher config values in the experiment config (i.e. ``Epoch(1)`` = `max_time: 1` and `time_metric: "epochs"`).