diff --git a/examples/asr/asr_adapters/train_asr_adapter.py b/examples/asr/asr_adapters/train_asr_adapter.py index fb55ac18d24f..a71643760044 100644 --- a/examples/asr/asr_adapters/train_asr_adapter.py +++ b/examples/asr/asr_adapters/train_asr_adapter.py @@ -29,7 +29,7 @@ model.train_ds.batch_size=16 \ model.validation_ds.manifest_filepath= \ model.validation_ds.batch_size=16 \ - model.optim.lr=0.5 \ + model.optim.lr=0.001 \ model.optim.weight_decay=0.0 \ model.optim.sched.warmup_steps=100 \ trainer.max_steps=300 \ @@ -37,6 +37,26 @@ trainer.precision=32 \ exp_manager.exp_dir= +# Hyper Parmaeter Search + +python train_asr_adapter.py \ + --config-path="../conf/asr_adapters" \ + --config-name="asr_adaptation_hp.yaml" \ + -m \ + model.pretrained_model=null \ + model.nemo_model=null \ + model.adapter.adapter_name= \ + model.adapter.adapter_module_name= \ + model.adapter.in_features= \ + model.train_ds.manifest_filepath= \ + model.train_ds.batch_size=16 \ + model.validation_ds.manifest_filepath= \ + model.validation_ds.batch_size=16 \ + exp_manager.exp_dir="" \ + exp_manager.create_wandb_logger=true \ + exp_manager.wandb_logger_kwargs.project="" \ + ++delete_ckpt_after_train=True + # Fine-tune a model While adaptation is very efficient for low-resource datasets, it imposes several restrictions - @@ -59,7 +79,7 @@ https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/results.html """ - +import glob import os from dataclasses import is_dataclass @@ -211,6 +231,16 @@ def main(cfg): # Save the adapter modules in a seperate file model.save_adapters(str(state_path)) + if 'delete_ckpt_after_train' in cfg: + delete_ckpt_after_train = cfg.delete_ckpt_after_train + if delete_ckpt_after_train: + logging.info("Deleting *.ckpt files after training to preserve storage space...") + + ckpt_files = glob.glob(os.path.join(exp_log_dir, "checkpoints", "*.ckpt")) + for filepath in ckpt_files: + logging.info(f"Deleting file : {filepath}") + os.remove(filepath) + if __name__ == '__main__': main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml new file mode 100644 index 000000000000..3762e8ae71d8 --- /dev/null +++ b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml @@ -0,0 +1,226 @@ +# Config to perform ASR adaptation using any pre-trained model (local nemo model or pre-trained checkpoint). +############################################################################################################ +# This config is special in that it is used alongside the scripts in the asr_adapters examples directory, +# but does not directly construct a model itself. Instead it mimics the usual ASR model configs, and initializes +# a pre-trained model (either local or via network), and overrides its many data loaders / optimizer / scheduler +# and other arguments +# +# **Note**: This config does *not* get stored in the adapted model, since this config is merely to setup the +# adapter training / inference script. This file can be considered a config not for the model, but for the +# script that will adapt the model or infer an adapted model. +# +# You can therefore call this script multiple times to add as many adapters as you need in a single model, +# by providing the previous adapted checkpoint as `model.nemo_model`. +# +# **Note**: Any config value in this yaml file *overrides* the equivalent config inside the model ! +# +# There are some important paramters of this config that must be updated by the user : +# - model.pretrained_model or model.nemo_model: str name or path to some pretrained model. Only one of the +# two should be passed. Selects the pre-trained model to be loaded and adapted. +# +# - model.adapter.adapter_name: Globally unique name, assigned to the adapter itself. Every adapter of a +# model must have a unique name. +# +# - model.adapter.in_features: The output dimension of each block of the model. This is model dependent. +# For example, Conformer dimension can be found via `model.encoder.d_model` in its config. +# For Citrinets/ContextNets, the dimension can be found usually in `model.encoder.jasper.0.filters`. +# +# - model.train_ds.manifest_filepath / model.validation_ds.manifest_filepath: Data filepaths to train the +# adapter module. +############################################################################################################ +# The recommendations during training of adapters is significantly different than general ASR training or +# fine-tuning. Below are some recommended configuration values. +# +# - model.adapter.dim: Usually we chose a small bottleneck dim here. 16 to 32 is generally enough. +# +# - model.optim.lr: We generally chose a very small LR, and a very short training schedule of just a few hundred +# steps - depending on the size of the dataset. Usually just a few epochs over the dataset with a low LR is +# sufficient for adaptation. +# +# - model.optim.weight_decay: We find that strong weight decay prevents significant degradation of prior training, +# but also limits the capacity of the model to learn the adapted domain. Usually as a baseline we use 0.0 +# +# - model.optim.sched.warmup_steps: We encourage warmup steps to be modified to suit the smaller training schedule. +# +# - trainer.max_steps: We recommend using trainer.max_steps to limit the training duration to just 10-20 epochs. +# Adapters converge very fast, and prolonged training may cause overfitting to the new domain, consequently, +# leading to catastrophic forgetting of the old domain. You can equivalently use small number of epochs using +# trainer.max_epochs. +# +# - trainer.check_val_every_n_epoch: Since the training run is short, and very fast usually, it is recommended to +# reduce the amount of validation to once every few epochs, rather than after every epoch, to speed up training. + +name: "ASR-Adapter-hp" + +model: + # One of the below two values must be set ! + pretrained_model: null # name of a pretrained model + nemo_model: null # path to a ASR model file (.nemo) + + log_prediction: false # enables logging sample predictions in the output during training + + adapter: + # Config of the adapter training/eval script. + adapter_name: ??? # Name of the adapter, used by the script + adapter_module_name: null # Name of the adapter module. Combine multiple modules with '+' between module names. + adapter_state_dict_name: "adapters.pt" # If the individual adapters must be saved, a file name can be provided here. null disables this. + + # Config of the adapter module itself + _target_: nemo.collections.common.parts.adapter_modules.LinearAdapter + in_features: ??? # User must provide the output dimension of the layers of the model, which is the input dimension of this adapter. + dim: 32 # The hidden dimension of the adapter, as chosen by user, but small values are preferred to reduce param count. + activation: swish + norm_position: 'pre' # Can be `pre` or `post` + dropout: 0.0 # float, dropout for the adapter + + # Adapter strategy config + adapter_strategy: + _target_: nemo.core.classes.mixins.adapter_mixin_strategies.ResidualAddAdapterStrategy + stochastic_depth: 0.0 # float, setting to > 0 will enable stochastic depth for each adapter block. + l2_lambda: 0.0 # float, setting to > 0 will enable l2 norm auxiliary loss for each adapter's output. + + # Optional global config available to all adapters at a global level. + # A global config is shared across every layer of the adapters, defining global properties rather + # than properties local to the adapter (as defined above). + # This can be useful in order to select *which type of adapter* is added, *what adapters to enable*, + # and further global operations that can decide dynamically how to support the requested adapter. + global_cfg: + check_encoder_adapter: True # ASR adapter key, determines whether to check if encoder adapter modules is supported + check_decoder_adapter: True # ASR adapter key, determines whether to check if decoder adapter modules is supported + check_joint_adapter: True # ASR adapter key, determines whether to check if joint adapter modules is supported + + # Overrides the model's internal spec augment configuration + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 0 + time_masks: 0 + freq_width: 27 + time_width: 0.05 + + train_ds: + # train dataset + dataloader config + # sample_rate will be merged with model config + # use_start_end_token will be merged with model config + # trim_silence will be merged with model config + # max_duration will be merged with model config + # min_duration will be merged with model config + manifest_filepath: ??? + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + # sample_rate will be merged with model config + # use_start_end_token will be merged with model config + manifest_filepath: ??? + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + + test_ds: + # sample_rate will be merged with model config + # use_start_end_token will be merged with model config + manifest_filepath: null + batch_size: 16 + shuffle: false + num_workers: 8 + pin_memory: true + + optim: + # optimizer arguments + name: adamw + betas: [0.9, 0.98] + lr: 0.001 # LR depends on the scheduler used by the base model. Noam prefers 0.5, Cosine Annealing prefers 0.02 + weight_decay: 0 # During adaptation, since training run is short, WD is not required. Can be set if needed. + + # scheduler setup + sched: + name: CosineAnnealing + + # scheduler config override + warmup_steps: null # Warmup steps should be set, and smaller than the trainer.max_steps set below. + warmup_ratio: 0.1 # Warmup steps will be 10% of the training steps. + min_lr: 1e-5 + last_epoch: -1 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: null + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + + +exp_manager: + exp_dir: null + + # Add a unique name for all hyperparameter arguments to allow continued training. + # NOTE: It is necessary to add all hyperparameter arguments to the name ! + # This ensures successful restoration of model runs in case HP search crashes. + name: ${name}-lr-${model.optim.lr}-adim-${model.adapter.dim}-sd-${model.adapter.adapter_strategy.stochastic_depth} + + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 1 + always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints + + create_wandb_logger: false + wandb_logger_kwargs: + name: ${exp_manager.name} + project: null + entity: null + save_dir: null + offline: false # If true, wandb logging will be done offline and would require manual syncing. + tags: null # List of tags to assign to the run + + # HP Search may crash due to various reasons, best to attempt continuation in order to + # resume from where the last failure case occured. + resume_if_exists: true + resume_ignore_no_checkpoint: true + +# Required for Hydra launch of hyperparameter search +defaults: + - override hydra/launcher: nemo_launcher + +# Hydra arguments necessary for hyperparameter optimization +hydra: + sweep: + dir: "." + subdir: "." + + sweeper: + params: # place all the parameters you wish to search over here (corresponding to the rest of the config) + model.optim.lr: 0.001,0.0001 + model.adapter.dim: 32,64,96,128 + model.adapter.adapter_strategy.stochastic_depth: 0.0,0.5,0.6,0.7,0.8,0.9 + + # Arguments to the hyperparameter runner + launcher: + num_gpus: -1 # Number of gpus to use. Each run works on a single GPU. + jobs_per_gpu: 1 # If each GPU has large memory, you can run multiple jobs on the same GPU for faster results (until OOM). diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index e4a798dd5fae..6631de153036 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -934,7 +934,7 @@ class JasperEncoderConfig: @dataclass class ConvASREncoderConfig: _target_: str = 'nemo.collections.asr.modules.ConvASREncoder' - jasper: Optional[JasperEncoderConfig] = field(default_factory=list) + jasper: Optional[List[JasperEncoderConfig]] = field(default_factory=list) activation: str = MISSING feat_in: int = MISSING normalization_mode: str = "batch" diff --git a/nemo/collections/common/parts/adapter_modules.py b/nemo/collections/common/parts/adapter_modules.py index 7e985ee33b97..ad80b84c3f02 100644 --- a/nemo/collections/common/parts/adapter_modules.py +++ b/nemo/collections/common/parts/adapter_modules.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, is_dataclass -from typing import Optional +from typing import Any, Optional from hydra.utils import instantiate from omegaconf import OmegaConf @@ -57,6 +57,7 @@ def setup_adapter_strategy(self, adapter_strategy: Optional[adapter_mixin_strate class LinearAdapter(AbstractAdapterModule): + """ Simple Linear Feedforward Adapter module with LayerNorm and singe hidden layer with activation function. Note: The adapter explicitly initializes its final layer with all zeros in order to avoid affecting the @@ -66,7 +67,7 @@ class LinearAdapter(AbstractAdapterModule): in_features: Input dimension of the module. Note that for adapters, input_dim == output_dim. dim: Hidden dimension of the feed forward network. activation: Str name for an activation function. - norm_position: Str, can be `pre` or `post`. Defaults to `post`. Determines whether the normalization + norm_position: Str, can be `pre` or `post`. Defaults to `pre`. Determines whether the normalization will occur in the first layer or the last layer. Certain architectures may prefer one over the other. dropout: float value, whether to perform dropout on the output of the last layer of the adapter. adapter_strategy: By default, ResidualAddAdapterStrategyConfig. An adapter composition function object. @@ -77,7 +78,7 @@ def __init__( in_features: int, dim: int, activation: str = 'swish', - norm_position: str = "post", + norm_position: str = 'pre', dropout: float = 0.0, adapter_strategy: adapter_mixin_strategies.ResidualAddAdapterStrategyConfig = None, ): @@ -142,7 +143,7 @@ class LinearAdapterConfig: in_features: int dim: int activation: str = 'swish' - norm_position: str = 'post' + norm_position: str = 'pre' dropout: float = 0.0 - adapter_strategy: Optional[dict] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() + adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() _target_: str = "{0}.{1}".format(LinearAdapter.__module__, LinearAdapter.__name__) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index f83d12a8667c..b8cbb3ef0c24 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -17,7 +17,7 @@ import enum import logging from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import torch import torch.nn as nn @@ -76,7 +76,7 @@ class MLPInfusedAdapter(InfusedAdapter): @dataclass class InfusedAdapterConfig: in_features: int - adapter_strategy: Optional[dict] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() + adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() _target_: str = "{0}.{1}".format(InfusedAdapter.__module__, InfusedAdapter.__name__) @@ -167,5 +167,5 @@ class ParallelLinearAdapterConfig: column_init_method: str = 'xavier' row_init_method: str = 'zero' dropout: float = 0.0 - adapter_strategy: Optional[dict] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() + adapter_strategy: Optional[Any] = adapter_mixin_strategies.ResidualAddAdapterStrategyConfig() _target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__) diff --git a/nemo/core/config/hydra_runner.py b/nemo/core/config/hydra_runner.py index 382a346eed81..41d4557d6f36 100644 --- a/nemo/core/config/hydra_runner.py +++ b/nemo/core/config/hydra_runner.py @@ -91,18 +91,19 @@ def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any: # Wrap a callable object with name `parse_args` # This is to mimic the ArgParser.parse_args() API. - class _argparse_wrapper: - def __init__(self, arg_parser): - self.arg_parser = arg_parser - self._actions = arg_parser._actions + def parse_args(self, args=None, namespace=None): + return parsed_args - def parse_args(self, args=None, namespace=None): - return parsed_args + parsed_args.parse_args = parse_args # no return value from run_hydra() as it may sometime actually run the task_function # multiple times (--multirun) + # argparse_wrapper = _argparse_wrapper(args) + argparse_wrapper = parsed_args + _run_hydra( - args_parser=_argparse_wrapper(args), + args=argparse_wrapper, + args_parser=args, task_function=task_function, config_path=config_path, config_name=config_name, diff --git a/nemo/core/utils/__init__.py b/nemo/core/utils/__init__.py index aca379342d14..6c68bb720e4a 100644 --- a/nemo/core/utils/__init__.py +++ b/nemo/core/utils/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.core.utils import k2_utils, numba_utils +from nemo.core.utils import k2_utils, numba_utils, process_launcher diff --git a/nemo/core/utils/process_launcher/__init__.py b/nemo/core/utils/process_launcher/__init__.py new file mode 100644 index 000000000000..b189cea498a7 --- /dev/null +++ b/nemo/core/utils/process_launcher/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.core.utils.process_launcher.launcher import ProcessLauncher, ProcessLauncherConfig diff --git a/nemo/core/utils/process_launcher/launcher.py b/nemo/core/utils/process_launcher/launcher.py new file mode 100644 index 000000000000..54d0cbd211ef --- /dev/null +++ b/nemo/core/utils/process_launcher/launcher.py @@ -0,0 +1,323 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import hashlib +import os +import subprocess +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Sequence + +import torch +from hydra.core.config_store import ConfigStore +from hydra.core.hydra_config import HydraConfig +from hydra.core.plugins import Plugins +from hydra.core.singleton import Singleton +from hydra.core.utils import JobReturn, JobStatus, configure_log, filter_overrides, setup_globals +from hydra.plugins.launcher import Launcher +from hydra.types import HydraContext, TaskFunction +from omegaconf import DictConfig, OmegaConf, open_dict + +from nemo.utils import logging + + +# monkey-patch hydra func +def is_in_toplevel_plugins_module(*args, **kwargs) -> bool: + return True + + +# Monkey-patch Hydra +Plugins.instance().is_in_toplevel_plugins_module = is_in_toplevel_plugins_module + + +@dataclass +class ProcessLauncherConfig: + _target_: str = "nemo.core.utils.process_launcher.launcher.ProcessLauncher" + num_gpus: int = -1 + jobs_per_gpu: int = 1 + + +def execute_job( + idx: int, + overrides: Sequence[str], + hydra_context: HydraContext, + config: DictConfig, + singleton_state: Dict[Any, Any], + gpu_idx: int, +): + """ + Creates a process that launches a "single" job that is identical in config + updated with sweep hyperparams. + Since a different process is being used, CUDA can work in non-ddp mode without issue. + Attempting ddp when using this script will not work as ddp cannot be used in shared contexts. + + Args: + idx: Global index of the job. + overrides: List of str overrides that correspond to this job + hydra_context: Hydra Context used to load the sweep params into the global config + config: Global config that will be updated with sweep hyper parameters. + singleton_state: Hydra state. + gpu_idx: The GPU ID on which this process will be run. + + Returns: + - The Process object that corresponds to this sweep + - The JobReturn object holding some metadata about this run + """ + # Required by Hydra (lookup other Hydra Launchers for details) + setup_globals() + Singleton.set_state(singleton_state) + + # Update base config with overrides to create sweep config + sweep_config = hydra_context.config_loader.load_sweep_config(config, list(overrides)) + with open_dict(sweep_config): + sweep_config.hydra.job.id = "{}_{}".format(sweep_config.hydra.job.name, idx) + sweep_config.hydra.job.num = idx + HydraConfig.instance().set_config(sweep_config) + + # Setup a directory where the config will temporarily be stored. + script_path = os.path.join(os.getcwd(), sys.argv[0]) + script_path = os.path.abspath(script_path) + + hash_salt = "|".join([script_path, str(OmegaConf.to_yaml(config))]).encode('utf-8') + hash_val = hashlib.sha256(hash_salt).hexdigest() + + config_dir = os.path.join(os.getcwd(), "hydra_cfg", str(hash_val)) + if not os.path.exists(config_dir): + os.makedirs(config_dir, exist_ok=True) + + task_cfg = copy.deepcopy(sweep_config) + + # Remove hydra from sweep config + # This is done to prevent recursive call to multirun launcher in Hydra. + with open_dict(task_cfg): + task_cfg.pop('hydra', '') + + # Save the current jobs config to directory + temp_config_name = f"config_{idx}.yaml" + temp_config = os.path.join(config_dir, temp_config_name) + OmegaConf.save(task_cfg, temp_config) + + # Compute the overides as a dict + overrides = OmegaConf.to_container(config.hydra.overrides.task) + + # Check and replace trainer.devices in config with gpu_idx + found_devices = False + gpu_override = f'trainer.devices=[{gpu_idx}]' + for oidx, val in enumerate(overrides): + if 'trainer.devices' in val: + overrides[oidx] = gpu_override + found_devices = True + + if not found_devices: + overrides.append(gpu_override) + + # Build launch command + # Note: We depend on PTL doing the right thing since this command has global visibility of all CUDA_VISIBLE_DEVICES + cmd = [ + 'python', + script_path, + "--config-path", + config_dir, + "--config-name", + temp_config_name, + *overrides, + ] + + # Launch the subprocess; pipe the stderr + # NOTE: If this hangs due to some reason after prolonged training, it means that the stderr pipe buffer + # has become full at the OS level and we need to explicitly empty it (either parallel thread or manually + # call proc.communicate(). It should not happen in general case as stderr is filled only in case retcode != 0 + # If it does happen though, implement the code here + # https://stackoverflow.com/questions/39607172/python-subprocess-popen-poll-seems-to-hang-but-communicate-works + proc = subprocess.Popen(cmd, stderr=subprocess.PIPE) + + # Construct JobReturn object for Hydra + res = JobReturn() + res.cfg = task_cfg + res.overrides = overrides + res.hydra_cfg = config + res.working_dir = os.getcwd() + res.return_value = None + + return proc, res + + +def launch(launcher, job_overrides: Sequence[Sequence[str]], initial_job_idx: int,) -> Sequence[JobReturn]: + """ + Args: + launcher: Reference to the Launched subclass + job_overrides: A List of List, where each inner list is the arguments for one job run + initial_job_idx: Initial job idx in batch + + Returns: + A list of JobReturn objects. + """ + # Needed for Hydra, lookup JoblibLauncher in Hydra + setup_globals() + assert launcher.config is not None + assert launcher.task_function is not None + assert launcher.hydra_context is not None + + configure_log(launcher.config.hydra.hydra_logging, launcher.config.hydra.verbose) + sweep_dir = Path(str(launcher.config.hydra.sweep.dir)) + sweep_dir.mkdir(parents=True, exist_ok=True) + + # Extraact the runner's config (its actually a DictConfig, but type is used for autocomplete) + runner_cfg = launcher.runner # type: ProcessLauncherConfig + + logging.info( + "ProcessLauncher({}) is launching {} jobs".format( + ",".join([f"{k}={v}" for k, v in runner_cfg.items()]), len(job_overrides), + ) + ) + logging.info("Launching jobs, sweep output dir : {}".format(sweep_dir)) + for idx, overrides in enumerate(job_overrides): + logging.info("\t#{} : {}".format(idx, " ".join(filter_overrides(overrides)))) + + # Needed by Hydra + singleton_state = Singleton.get_state() + + # Process the runner's config to build up the multiplex config + num_gpus = runner_cfg.get('num_gpus', -1) + jobs_per_gpu = runner_cfg.get('jobs_per_gpu', 1) + + # Only GPUs are supported for now. + if num_gpus <= 0: + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + else: + raise ValueError(f"{launcher.__class__.__name__} only supports GPU operations.") + + # Setup arguments for multiplex runner + overrides = list(job_overrides) + num_overrides = len(overrides) + + job_idx = 0 + batch_size = num_gpus * jobs_per_gpu + gpu_idx = 0 + + ret = [] # List of returned JobResult + subprocess_list = [] # Buffer of subprocess + results = [] # Buffer of JobResult + + # Run over all job combinations + while job_idx < num_overrides: + # Fill up subprocess buffer while its size is smaller than multiplex batch size + while len(subprocess_list) < batch_size: + # If we run out of jobs, stop trying to submit more jobs + if job_idx >= num_overrides: + break + + # Submit a job as a new process + process, res = execute_job( + initial_job_idx + job_idx, + overrides[job_idx], + launcher.hydra_context, + launcher.config, + singleton_state, + gpu_idx % num_gpus, # This will evenly distribute GPU load + ) + + # Store the subprocesses and JobResults + subprocess_list.append(process) + results.append(res) + + job_idx += 1 + gpu_idx += 1 + + # Poll for samples in batch to finish. + if len(subprocess_list) > 0: + finished_processes = [0] * len(subprocess_list) + + # Check if all processes are completed or not + # TODO: This is busy waiting, need to check if its really needed or we can do one time communicate() + while sum(finished_processes) < len(subprocess_list): + # Check all processes to make sure they have a retcode (doesnt matter yet if 0 or not) + for proc_idx, proc in enumerate(subprocess_list): + # poll() is cheaper op than communicate() + retcode = proc.poll() + + if retcode is not None: + # Log that the process with some ID has finished + if finished_processes[proc_idx] == 0: + logging.info(f"Processed job : {len(ret) + proc_idx}") + + finished_processes[proc_idx] = 1 + + time.sleep(1.0) + + # Process all the subprocess results + for proc, res in zip(subprocess_list, results): + # Wait until completion of process + output, error = proc.communicate() + + # 0 is for successful run + if proc.returncode == 0: + res.status = JobStatus.COMPLETED + else: + # > 0 is for error, log the error. + # Note: For the sake of efficiency while we log the error and raise an exception, + # It will only raise the 1st wrong job in all the jobs. + # If multiple jobs fail, it will still try to execute every job first before + # raising the error for the first one. + # This is done so that even if some jobs fail (say OOM or something), + # other jobs can still run. + error_msg = ( + f"\nHyperparameter Arguments : {proc.args}\n" + f"Process Return code : {proc.returncode}\n" + f"Error Trace :\n" + f"{str(error, encoding='utf-8').encode('utf-8').decode('utf-8')}" + ) + res.return_value = Exception(error_msg) + res.status = JobStatus.FAILED + + logging.info(f"Finished executing job : {len(ret)}. Return Code = {proc.returncode}") + ret.append(res) + + # Reset for next batch + subprocess_list.clear() + results.clear() + + return ret + + +class ProcessLauncher(Launcher): + def __init__(self, **kwargs: Any) -> None: + """Process Launcher + Based on the JoblibLauncher, but uses processes to scatter jobs in a multiplexed manner across + some number of GPUs on a single machine. + """ + self.config: Optional[DictConfig] = None + self.task_function: Optional[TaskFunction] = None + self.hydra_context: Optional[HydraContext] = None + + self.runner = kwargs # type: ProcessLauncherConfig + + def setup(self, *, hydra_context: HydraContext, task_function: TaskFunction, config: DictConfig,) -> None: + self.config = config + self.task_function = task_function + self.hydra_context = hydra_context + + def launch(self, job_overrides: Sequence[Sequence[str]], initial_job_idx: int) -> Sequence[JobReturn]: + + return launch(launcher=self, job_overrides=job_overrides, initial_job_idx=initial_job_idx) + + +ConfigStore.instance().store( + group="hydra/launcher", name="nemo_launcher", node=ProcessLauncherConfig, provider="nemo_process_launcher", +) + +Plugins.instance().register(ProcessLauncher) diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index b883f847b51f..9bc7736e13d2 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -2,7 +2,7 @@ pytorch-lightning>=1.7.0,<1.8 torchmetrics>=0.4.1rc0 transformers>=4.0.1,<=4.21.2 webdataset>=0.1.48,<=0.1.62 -omegaconf>=2.1.2,<2.2 -hydra-core>=1.1.0,<1.2 +omegaconf>=2.2,<2.3 +hydra-core>=1.2.0,<1.3 pyyaml<6 # Pinned until omegaconf works with pyyaml>=6 wandb