diff --git a/3rdparty/Megatron-LM b/3rdparty/Megatron-LM index 0bda578b7f..bc8c4f3562 160000 --- a/3rdparty/Megatron-LM +++ b/3rdparty/Megatron-LM @@ -1 +1 @@ -Subproject commit 0bda578b7f96bdf0958ca38764441cf220f04105 +Subproject commit bc8c4f356240ea4ccadce426251171e6e430c9d3 diff --git a/3rdparty/NeMo b/3rdparty/NeMo index e35a6592f5..0c8af28c92 160000 --- a/3rdparty/NeMo +++ b/3rdparty/NeMo @@ -1 +1 @@ -Subproject commit e35a6592f53ee34b1ec2fc3f1e009dd1ebc79e65 +Subproject commit 0c8af28c92b1acc21cce534188c6d215702808de diff --git a/Dockerfile b/Dockerfile index 2a74224ddf..d4c55f0d66 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ RUN git clone https://github.com/NVIDIA/apex.git && \ --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam --group_norm" # Transformer Engine pre-1.7.0. 1.7 standardizes the meaning of bits in the attention mask to match -ARG TE_COMMIT=7d576ed25266a17a7b651f2c12e8498f67e0baea +ARG TE_COMMIT=c27ee60ec746210bcea4ec33958dbbff06706506 RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ cd TransformerEngine && \ git fetch origin ${TE_COMMIT} && \ diff --git a/docs/docs/user-guide/examples/bionemo-esm2/pretrain.md b/docs/docs/user-guide/examples/bionemo-esm2/pretrain.md index 6845785966..edc1198cb2 100644 --- a/docs/docs/user-guide/examples/bionemo-esm2/pretrain.md +++ b/docs/docs/user-guide/examples/bionemo-esm2/pretrain.md @@ -253,7 +253,6 @@ checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=True, monitor="val_loss", save_top_k=1, - every_n_train_steps=100, always_save_context=True, ) diff --git a/scripts/gpt-pretrain.py b/scripts/gpt-pretrain.py index 2e1f9fbd7a..42faa09677 100644 --- a/scripts/gpt-pretrain.py +++ b/scripts/gpt-pretrain.py @@ -187,7 +187,7 @@ def main() -> None: devices, seq_length = 1, 2048 strategy = nl.MegatronStrategy( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, pipeline_dtype=torch.float32 + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, pipeline_dtype=torch.float32, ckpt_async_save=False, ) trainer = nl.Trainer( devices=devices, diff --git a/scripts/protein/esm2/esm2_pretrain.py b/scripts/protein/esm2/esm2_pretrain.py index 25741e7bd2..6f0c348936 100644 --- a/scripts/protein/esm2/esm2_pretrain.py +++ b/scripts/protein/esm2/esm2_pretrain.py @@ -143,6 +143,7 @@ def main( ddp="megatron", find_unused_parameters=True, ckpt_include_optimizer=True, + ckpt_async_save=False, ) # for wandb integration @@ -243,10 +244,10 @@ def main( # Configure our custom Checkpointer checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=save_last_checkpoint, - monitor=metric_to_monitor_for_checkpoints, # "val_loss", + monitor=metric_to_monitor_for_checkpoints, save_top_k=save_top_k, - every_n_train_steps=val_check_interval, always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + filename="{epoch}-{step}-{" + metric_to_monitor_for_checkpoints + ":.2f}", ) # Setup the logger and train the model diff --git a/scripts/protein/esm2/test_esm2_pretrain.py b/scripts/protein/esm2/test_esm2_pretrain.py index c2253d46da..353652ba0b 100644 --- a/scripts/protein/esm2/test_esm2_pretrain.py +++ b/scripts/protein/esm2/test_esm2_pretrain.py @@ -30,7 +30,6 @@ from bionemo.testing import megatron_parallel_state_utils -@pytest.mark.skip("duplicate unittest") @pytest.fixture def dummy_protein_dataset(tmp_path): """Create a mock protein dataset.""" @@ -62,7 +61,6 @@ def dummy_protein_dataset(tmp_path): return db_file -@pytest.mark.skip("duplicate unittest") @pytest.fixture def dummy_parquet_train_val_inputs(tmp_path): """Create a mock protein train and val cluster parquet.""" @@ -104,7 +102,7 @@ def test_main_runs(monkeypatch, tmpdir, dummy_protein_dataset, dummy_parquet_tra result_dir=result_dir, wandb_project=None, wandb_offline=True, - num_steps=55, + num_steps=10, warmup_steps=5, limit_val_batches=1, val_check_interval=1, diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py index 3f6688ba53..63d93448e1 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py @@ -30,14 +30,12 @@ from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.transformer_config import TransformerConfig -from pkg_resources import packaging +from megatron.core.utils import get_te_version, is_te_min_version from torch import Tensor __all__: Sequence[str] = ("ESM2DotProductAttention", "ESM2TEDotProductAttention") -from megatron.core.extensions.transformer_engine import _te_version - class ESM2TEDotProductAttention(TEDotProductAttention): """ESM2-Specific transformer engine core attention. @@ -52,6 +50,10 @@ def __init__( attn_mask_type: AttnMaskType, attention_type: str, attention_dropout: float | None = None, + softmax_scale: float = 1.0, + k_channels: int | None = None, + v_channels: int | None = None, + cp_comm_type: str = "p2p", ): """Initialize ESM2TEDotProductAttention.""" self.config = config @@ -67,30 +69,35 @@ def __init__( ) extra_kwargs = {} - if _te_version >= packaging.version.Version("0.11.0"): + if is_te_min_version("0.11.0"): extra_kwargs["num_gqa_groups"] = self.config.num_query_groups elif self.config.num_query_groups != self.config.num_attention_heads: raise ValueError( - f"Transformer Engine v{_te_version} does not support Grouped Query Attention, " + f"Transformer Engine v{get_te_version()} does not support Grouped Query Attention, " f"use a newer version of Transformer Engine. " f"(num_query_groups ({self.config.num_query_groups}) != " f"num_attention_heads ({self.config.num_attention_heads}))" ) - if _te_version >= packaging.version.Version("0.10.0"): + if is_te_min_version("0.10.0"): extra_kwargs["attention_type"] = attention_type # older version don't need attention_type - if _te_version > packaging.version.Version("0.12.0"): + if is_te_min_version("0.12.0", check_equality=False): self.te_forward_mask_type = True # Only Transformer-Engine version >= 1.0.0 supports context parallelism - if _te_version >= packaging.version.Version("1.0.0"): + if is_te_min_version("1.0.0"): if getattr(TEDotProductAttention, "cp_stream") is None: TEDotProductAttention.cp_stream = torch.cuda.Stream() extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False) extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False) extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream + if is_te_min_version("1.10.0"): + if cp_comm_type is None: + extra_kwargs["cp_comm_type"] = "p2p" + else: + extra_kwargs["cp_comm_type"] = cp_comm_type else: assert ( self.config.context_parallel_size == 1 @@ -106,15 +113,26 @@ def __init__( if config.window_size is not None: # Check version - assert _te_version >= packaging.version.Version("1.2.0"), ( - f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support" - "sliding window attention." + assert is_te_min_version("1.2.0"), ( + f"Transformer-Engine v{get_te_version()} must be >= 1.2.0 to support" "sliding window attention." ) extra_kwargs["window_size"] = config.window_size + if is_te_min_version("1.10.0"): + # TE 1.10.0 introduces the ability to set the different k and v channels + kv_channels = ( + (k_channels, v_channels) + if k_channels is not None and v_channels is not None + else self.config.kv_channels + ) + else: + kv_channels = self.config.kv_channels + + extra_kwargs["softmax_scale"] = softmax_scale + super(TEDotProductAttention, self).__init__( num_attention_heads=self.config.num_attention_heads, - kv_channels=self.config.kv_channels, + kv_channels=kv_channels, attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout), attn_mask_type=attn_mask_type.name, sequence_parallel=self.config.sequence_parallel, @@ -122,7 +140,6 @@ def __init__( get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None), tp_group=get_tensor_model_parallel_group(check_initialized=False), layer_number=layer_number, - softmax_scale=1.0, # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing **extra_kwargs, ) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py index e2bce88976..fdaacf7a1d 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/finetune/train.py @@ -70,8 +70,7 @@ def train_model( checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=True, save_on_train_epoch_end=True, - monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss", - every_n_train_steps=n_steps_train // 2, + monitor="val_loss", always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe ) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index b5b53b036c..c6d64b94e8 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -324,7 +324,16 @@ class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]): def __post_init__(self): # TODO, as a validator? - """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.""" + """Check configuration compatibility.""" + # reset moe_token_dispatcher_type when variable_seq_lengths is True. + # must be performed before super().__post_init__() + if self.variable_seq_lengths and self.moe_token_dispatcher_type in ["allgather", "alltoall_seq"]: + logging.warning( + "MoE token dispatcher type 'allgather' and 'alltoall_seq' are not supported with variable sequence lengths. Setting moe_token_dispatcher_type to 'alltoall'." + ) + self.moe_token_dispatcher_type = "alltoall" + + # reset apply_query_key_layer_scaling based on biobert_spec_option super().__post_init__() if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec: self.apply_query_key_layer_scaling = False diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_tokenizer.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_tokenizer.py index cef099f9b1..457771c868 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_tokenizer.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_tokenizer.py @@ -98,7 +98,7 @@ def test_tokenize_with_empty_string(tokenizer): def test_tokenizer_serialization(tokenizer, tmp_path): - tokenizer.io_dump(tmp_path / "tokenizer") + tokenizer.io_dump(tmp_path / "tokenizer", yaml_attrs=[]) # BioNeMoESMTokenizer takes no __init__ arguments deserialized_tokenizer = io.load(tmp_path / "tokenizer", tokenizer.__class__) our_tokens = deserialized_tokenizer.encode("K A I S Q", add_special_tokens=False) diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py index 18be7eccf3..637c1196e3 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_stop_and_go.py @@ -92,7 +92,7 @@ def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataMo adam_beta2=0.98, ), lr_scheduler=WarmupAnnealDecayHoldScheduler( - warmup_steps=50, max_steps=cls.num_steps, max_lr=cls.lr, min_lr=cls.lr / 10.0, anneal_percentage=0.10 + warmup_steps=50, max_steps=cls.num_steps, max_lr=cls.lr, min_lr=0.0, anneal_percentage=0.10 ), ) diff --git a/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning/lightning_basic.py b/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning/lightning_basic.py index 2a09636998..b7795feaa2 100644 --- a/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning/lightning_basic.py +++ b/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning/lightning_basic.py @@ -649,8 +649,7 @@ def loss_reduction_class(self) -> Type[MegatronLossReduction]: checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=True, save_on_train_epoch_end=True, - monitor="reduced_train_loss", - every_n_train_steps=25, + monitor="val_loss", always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe ) diff --git a/sub-packages/bionemo-example_model/src/bionemo/example_model/training_scripts/finetune_mnist.py b/sub-packages/bionemo-example_model/src/bionemo/example_model/training_scripts/finetune_mnist.py index ef423aa4f1..f03b6856f5 100644 --- a/sub-packages/bionemo-example_model/src/bionemo/example_model/training_scripts/finetune_mnist.py +++ b/sub-packages/bionemo-example_model/src/bionemo/example_model/training_scripts/finetune_mnist.py @@ -44,8 +44,7 @@ def run_finetune(checkpoint_dir: str, name: str, directory_name: str): checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=True, save_on_train_epoch_end=True, - monitor="reduced_train_loss", - every_n_train_steps=25, + monitor="val_loss", always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe ) diff --git a/sub-packages/bionemo-example_model/tests/bionemo/example_model/lightning/test_lightning_basic.py b/sub-packages/bionemo-example_model/tests/bionemo/example_model/lightning/test_lightning_basic.py index 5195287ab2..651c87887c 100644 --- a/sub-packages/bionemo-example_model/tests/bionemo/example_model/lightning/test_lightning_basic.py +++ b/sub-packages/bionemo-example_model/tests/bionemo/example_model/lightning/test_lightning_basic.py @@ -52,10 +52,10 @@ def _train_model_get_ckpt( checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=True, save_on_train_epoch_end=True, - monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss", - every_n_train_steps=5, + monitor="val_loss", always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe # async_save=False, # Tries to save asynchronously, previously led to race conditions. + filename="{epoch}-{step}-{val_loss:.2f}", ) save_dir = root_dir / name tb_logger = TensorBoardLogger(save_dir=save_dir, name=name) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py index 2fa9e4a68f..1ea112430b 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/scripts/train_geneformer.py @@ -311,10 +311,10 @@ def main( # Configure our custom Checkpointer checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=save_last_checkpoint, - monitor=metric_to_monitor_for_checkpoints, # "val_loss", + monitor=metric_to_monitor_for_checkpoints, save_top_k=save_top_k, - every_n_train_steps=val_check_interval, always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + filename="{epoch}-{step}-{" + metric_to_monitor_for_checkpoints + ":.2f}", ) # Setup the logger and train the model diff --git a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py index 14ba0b03c0..40cbecf7d8 100644 --- a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py +++ b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/scripts/test_train_geneformer.py @@ -37,7 +37,6 @@ def test_bionemo2_rootdir(): assert data_path.is_dir(), "Test data directory is supposed to be a directory." -@pytest.mark.skip("duplicate unittest") def test_main_runs(tmpdir): result_dir = Path(tmpdir.mkdir("results")) diff --git a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py index a561cb5b87..f5471470cc 100644 --- a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py +++ b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py @@ -821,8 +821,7 @@ def _train_model_get_ckpt( checkpoint_callback = nl_callbacks.ModelCheckpoint( save_last=True, save_on_train_epoch_end=True, - monitor="reduced_train_loss", # TODO find out how to get val_loss logged and use "val_loss", - every_n_train_steps=n_steps_train // 2, + monitor="val_loss", always_save_context=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe ) save_dir = root_dir / name diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index 2ee35bb38b..01b4142a04 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -375,7 +375,11 @@ def forward( rotary_pos_emb = None if self.position_embedding_type == "rope": rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( - inference_params, self.encoder, encoder_input, self.config + inference_params, + self.encoder, + encoder_input, + self.config, + packed_seq_params=None, # TODO @sichu: upstream to Megatron-LM ) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/train.py b/sub-packages/bionemo-llm/src/bionemo/llm/train.py index 18ec5b1b83..f54fc6dea5 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/train.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/train.py @@ -109,6 +109,7 @@ def setup_trainer( ddp="megatron", find_unused_parameters=True, ckpt_include_optimizer=True, + ckpt_async_save=False, ) if callbacks is None: callbacks = [ diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/utils/weight_utils.py b/sub-packages/bionemo-llm/src/bionemo/llm/utils/weight_utils.py index cc201d342c..5ee4218c77 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/utils/weight_utils.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/utils/weight_utils.py @@ -145,7 +145,7 @@ def load_weights_sharded_inplace_nemo2_to_mcore( sharded_state_dict = { _munge_key_megatron_to_nemo2(k): _munge_sharded_tensor_key_megatron_to_nemo2(v) for k, v in model.sharded_state_dict().items() - if not _key_in_filter(k, skip_keys_with_these_prefixes) + if not _key_in_filter(k, skip_keys_with_these_prefixes) and "_extra_state" not in k } dist_checkpointing.load( sharded_state_dict=sharded_state_dict, diff --git a/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py b/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py index 8bf1b242e4..802cbeb94e 100644 --- a/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py +++ b/sub-packages/bionemo-llm/tests/bionemo/llm/test_lightning.py @@ -169,6 +169,8 @@ def forward(self, x): return self.other(x) +# TODO rewrite unittest and potentially LightningPassthroughPredictionMixin +@pytest.mark.xfail(reason="MegatronStrategy no longer has '_get_loss_reduction' attribute") def test_mixin_strategy_contract_get_loss_reduction(): with megatron_parallel_state_utils.clean_parallel_state_context(): strategy = nl.MegatronStrategy( diff --git a/sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_iomixin_utils.py b/sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_iomixin_utils.py index 2513e36d27..b24e820b54 100644 --- a/sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_iomixin_utils.py +++ b/sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_iomixin_utils.py @@ -45,6 +45,12 @@ class OverrideModelDataClass2(BaseDataClass, iom.IOMixinWithGettersSetters): class TestIOMixin: + """TestCase on IOMixin. + + Notes: + IOMixin only captures non-default __init__ arguments into self.__io__ to ensure no compatibility in loading older mcore config in newer versions. + """ + def test_dataclasses_two_versions(self): _ = OverrideModelDataClass1(b=2) v1 = OverrideModelDataClass2(b=4) @@ -80,10 +86,10 @@ def test_dataclass_out_of_sync(self): with pytest.raises(KeyError): v1.get_hparam("q") - # Make sure we can get all hyper-parameters that are not defaultfactory objects - assert v1.get_hparams() == {"b": 7, "c": 3} + # Make sure we can get all hyper-parameters that are **non-default** non-defaultfactory objects + assert v1.get_hparams() == {"b": 7} - # Make sure by default we can change botht he hyper-parameter and the attribute. + # Make sure by default we can change both the hyper-parameter and the attribute. v1_copy.set_hparam("b", 8) assert v1_copy.b == 8 assert v1_copy.get_hparam("b") == 8 @@ -92,8 +98,8 @@ def test_dataclass_hparam_modify_parent_default(self): v1 = OverrideModelDataClass1() v1.set_hparam("a", 7) assert v1.a == 7 - # Make sure we can get all hyper-parameters - assert v1.get_hparams() == {"a": 7, "b": 3, "c": 3} + # Make sure we can get all **non-default** **non-defaultfactory** hyper-parameters + assert v1.get_hparams() == {"a": 7} v1_copy = io.reinit(v1) assert v1_copy.a == 7, "V1 should re-initialize with the updated hyper-parameter." diff --git a/sub-packages/bionemo-testing/src/bionemo/testing/harnesses/stop_and_go.py b/sub-packages/bionemo-testing/src/bionemo/testing/harnesses/stop_and_go.py index d811c997c4..c14b520f78 100644 --- a/sub-packages/bionemo-testing/src/bionemo/testing/harnesses/stop_and_go.py +++ b/sub-packages/bionemo-testing/src/bionemo/testing/harnesses/stop_and_go.py @@ -171,6 +171,7 @@ def setup_trainer( ddp="megatron", find_unused_parameters=True, ckpt_include_optimizer=True, + ckpt_async_save=False, ) trainer = nl.Trainer( @@ -231,13 +232,13 @@ def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]: callbacks[Mode.STOP].update( { - testing_callbacks.RaiseAfterMetadataCallback: testing_callbacks.RaiseAfterMetadataCallback(), + testing_callbacks.StopAfterValidEpochEndCallback: testing_callbacks.StopAfterValidEpochEndCallback(), nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint( save_last=True, - monitor="reduced_train_loss", + monitor="val_loss", save_top_k=2, - every_n_train_steps=cls.val_check_interval, always_save_context=True, + filename="{epoch}-{step}-{val_loss:.2f}", ), } ) @@ -262,20 +263,17 @@ def stop(cls) -> None: model, data, opt = cls.setup_model(mode=Mode.STOP) trainer = cls.setup_trainer(Mode.STOP) with distributed_model_parallel_state(): - try: - llm.train( - model=model, - data=data, - trainer=trainer, - log=cls.nemo_logger, - optim=opt, - resume=resume.AutoResume( - resume_if_exists=False, # Looks for the -last checkpoint to continue training. - resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. - ), - ) - except testing_callbacks.StopAndGoException: - return + llm.train( + model=model, + data=data, + trainer=trainer, + log=cls.nemo_logger, + optim=opt, + resume=resume.AutoResume( + resume_if_exists=False, # Looks for the -last checkpoint to continue training. + resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. + ), + ) @classmethod def resume(cls) -> None: @@ -327,6 +325,9 @@ def run_stop_and_go(cls): testing_callbacks.TrainInputCallback, testing_callbacks.TrainOutputCallback, testing_callbacks.TrainLossCallback, + testing_callbacks.ValidInputCallback, + testing_callbacks.ValidOutputCallback, + testing_callbacks.ValidLossCallback, ], ) def test_stop_and_go_consistency(self, callback_type): @@ -335,12 +336,27 @@ def test_stop_and_go_consistency(self, callback_type): continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type) assert interrupted_callback.data, f"No data found for {callback_type}" - if callback_type == testing_callbacks.TrainOutputCallback: - atol = 1e-3 + if callback_type in {testing_callbacks.TrainOutputCallback, testing_callbacks.ValidOutputCallback}: + atol, rtol = 1e-3, 1e-4 else: - atol = 1e-4 + atol, rtol = 1e-4, 1e-4 - recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol) + if callback_type in ( + testing_callbacks.ValidInputCallback, + testing_callbacks.ValidOutputCallback, + testing_callbacks.ValidLossCallback, + ): + if len(interrupted_callback.data) != len(continuous_callback.data): + pytest.xfail( + "NeMo will run extra validation batch(s) in resumption and NeMo team is working on fixing it." + ) + + recursive_assert_approx_equal( + interrupted_callback.data, + continuous_callback.data, + atol=atol, + rtol=rtol, + ) def test_train_val_init_consumed_samples(self): """Tests the initial consumed samples in stop-and-go scenario.""" @@ -356,10 +372,9 @@ def test_train_val_init_consumed_samples(self): assert train_consumed_stop == 0 assert train_consumed_go > 0 - # TODO: For some reason, validation in NeMo runs an extra batch in the case when the training is stopped and - # resumed. Hopefully we can fix this upstream and remove the indexing based on the length of the continuous - # validation batches. - @pytest.mark.xfail(reason="Validation runs an extra batch in the case when training is stopped and resumed.") + @pytest.mark.xfail( + reason="NeMo will run extra validation batch(s) in resumption and NeMo team is working on fixing it." + ) def test_identical_number_of_validation_batches(self): """Ensures that the input tensors for training are identical for the interrupted and continuous tests.""" callback_type = testing_callbacks.ValidInputCallback @@ -368,28 +383,3 @@ def test_identical_number_of_validation_batches(self): assert interrupted_callback.data, f"No data found for {callback_type}" recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data) assert len(interrupted_callback.data) == len(continuous_callback.data) - - @pytest.mark.parametrize( - "callback_type", - [ - testing_callbacks.ValidInputCallback, - testing_callbacks.ValidOutputCallback, - testing_callbacks.ValidLossCallback, - ], - ) - def test_stop_and_go_consistency_with_uneven_validation_sizes(self, callback_type): - """Ensures that the input tensors for training are identical for the interrupted and continuous tests.""" - interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type) - continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type) - assert interrupted_callback.data, f"No data found for {callback_type}" - - # Hack: Validation seems to run an extra batch in the case when training is stopped and resumed, but we can - # still test the rest of the data to ensure consistency. - interrupted_data = interrupted_callback.data[-len(continuous_callback.data) :] - - if callback_type == testing_callbacks.ValidOutputCallback: - atol = 1e-3 - else: - atol = 1e-4 - - recursive_assert_approx_equal(interrupted_data, continuous_callback.data, atol=atol) diff --git a/sub-packages/bionemo-testing/src/bionemo/testing/testing_callbacks.py b/sub-packages/bionemo-testing/src/bionemo/testing/testing_callbacks.py index 10c278f7cd..c591d723ff 100644 --- a/sub-packages/bionemo-testing/src/bionemo/testing/testing_callbacks.py +++ b/sub-packages/bionemo-testing/src/bionemo/testing/testing_callbacks.py @@ -29,11 +29,7 @@ from bionemo.testing.torch import recursive_detach -class StopAndGoException(Exception): # noqa: D101 - pass - - -class RaiseAfterMetadataCallback(Callback): +class StopAfterValidEpochEndCallback(Callback): """A callback that raises a StopAndGoException after the validation epoch. Use this callback for pytest based Stop and go tests. @@ -42,7 +38,7 @@ class RaiseAfterMetadataCallback(Callback): def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule): # noqa: D102 if trainer.sanity_checking: return - raise StopAndGoException() + trainer.should_stop = True class BaseInterruptedVsContinuousCallback(Callback, CallbackMethods, io.IOMixin):