Skip to content

Commit

Permalink
[NeMo-UX] Add distributed checkpointing unit tests (#9922)
Browse files Browse the repository at this point in the history
* [NeMo-UX] Add distributed checkpointing unit tests (#9794)

* add dist checkpointing tests

Signed-off-by: ashors1 <[email protected]>

* fix recursion bug

Signed-off-by: ashors1 <[email protected]>

* raise original AttributeError

Signed-off-by: ashors1 <[email protected]>

* dist checkpoint test fixes

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* set async_save from strategy to make checkpoint_io more robust

Signed-off-by: ashors1 <[email protected]>

* update async save test

Signed-off-by: ashors1 <[email protected]>

* fixes

Signed-off-by: ashors1 <[email protected]>

* clean up and address comments

Signed-off-by: ashors1 <[email protected]>

* fix mock datamodule

Signed-off-by: ashors1 <[email protected]>

* fix qk layer scaling setting

Signed-off-by: ashors1 <[email protected]>

* fix microbatch calculator

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Co-authored-by: ashors1 <[email protected]>

* fix test

Signed-off-by: ashors1 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: ashors1 <[email protected]>

* remove unused imports

Signed-off-by: ashors1 <[email protected]>

---------

Signed-off-by: ashors1 <[email protected]>
Signed-off-by: ashors1 <[email protected]>
Co-authored-by: Anna Shors <[email protected]>
Co-authored-by: ashors1 <[email protected]>
Co-authored-by: ashors1 <[email protected]>
  • Loading branch information
4 people authored and monica-sekoyan committed Oct 11, 2024
1 parent a72b35c commit 0ce878c
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 2 deletions.
4 changes: 3 additions & 1 deletion nemo/lightning/megatron_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,9 @@ def wrapped(self, *args):
def getattr_proxy(self, item: Any) -> Any:
try:
return super(self.__class__, self).__getattr__(item)
except AttributeError:
except AttributeError as e:
if item == 'module': ## this is a hacky WAR and may cause misleading error messages
raise e
try:
return getattr(self.module, item)
except AttributeError:
Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/pytorch/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def setup(self, trainer: pl.Trainer) -> None:
if not self.data_sampler and hasattr(datamodule, "data_sampler"):
self.data_sampler = datamodule.data_sampler
self.data_sampler.setup(self.cluster_environment.global_rank())
datamodule.reconfigure_limit_batches()
if hasattr(datamodule, "reconfigure_limit_batches"):
datamodule.reconfigure_limit_batches()

if self.data_sampler:
self.data_sampler.connect(trainer)
Expand Down
167 changes: 167 additions & 0 deletions tests/lightning/test_dist_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
from pathlib import Path

import pytest
import pytorch_lightning as pl
import torch
from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator

import nemo.lightning as nl
from nemo.collections import llm
from nemo.lightning.io.pl import MegatronCheckpointIO
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO, AsyncFinalizerCallback


def _get_strategy():
strategy = nl.MegatronStrategy(
enable_nemo_ckpt_io=False,
)
return strategy


def _get_last_checkpoint_dir(model: pl.LightningModule, suffix: str = '') -> Path:
return f'epoch={model.trainer.current_epoch - 1}-step={model.trainer.max_steps - 1}{suffix}'


def get_model_and_data():
micro_batch_size = 2
global_batch_size = 2
seq_length = 128
data = llm.MockDataModule(
seq_length=seq_length, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size
)

config = llm.GPTConfig(
num_layers=2,
hidden_size=64,
ffn_hidden_size=256,
num_attention_heads=4,
seq_length=seq_length,
apply_query_key_layer_scaling=1,
)
reconfigure_num_microbatches_calculator(
0,
None,
global_batch_size,
micro_batch_size,
data_parallel_size=1,
)
return llm.GPTModel(config, tokenizer=data.tokenizer), data


class TestDistCkptIO:

@pytest.mark.run_only_on('GPU')
def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path):

model, data = get_model_and_data()

strategy = _get_strategy()

trainer = nl.Trainer(
devices=1,
accelerator="gpu",
strategy=strategy,
enable_checkpointing=True,
max_steps=2,
default_root_dir=str(tmp_path),
logger=False,
)

trainer.fit(model, data)

assert isinstance(trainer.strategy.checkpoint_io, MegatronCheckpointIO)
# Ckpt path doesn't contain the .ckpt suffix
ckpts = os.listdir(Path(tmp_path / "checkpoints"))
assert len(ckpts) == 1
ckpt = ckpts[0]
assert str(ckpt) == _get_last_checkpoint_dir(model)

@pytest.mark.run_only_on('GPU')
def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path):

model, data = get_model_and_data()

sync_ckpt_dir = tmp_path / 'sync_checkpoints'
async_ckpt_dir = tmp_path / 'async_checkpoints'

sync_checkpoint_io = MegatronCheckpointIO('torch_dist')
async_checkpoint_io = AsyncFinalizableCheckpointIO(MegatronCheckpointIO('torch_dist', async_save=True))

# dummy_trainer just to initialize NCCL
dummy_trainer = pl.Trainer(
devices=1,
logger=False,
max_steps=2,
strategy=_get_strategy(),
)
dummy_trainer.fit(model, data)
strategy = _get_strategy()
tmp_path = strategy.broadcast(tmp_path)

## reset the model and data and train with sync checkpointing
model, data = get_model_and_data()
sync_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[sync_checkpoint_io],
default_root_dir=str(sync_ckpt_dir),
)
sync_test_trainer.fit(model, data)

## reset the model and data and train with sync checkpointing
model, data = get_model_and_data()
async_test_trainer = pl.Trainer(
devices=1,
enable_checkpointing=True,
logger=False,
max_steps=2,
strategy=_get_strategy(),
plugins=[async_checkpoint_io],
callbacks=AsyncFinalizerCallback(),
default_root_dir=str(async_ckpt_dir),
)
async_test_trainer.fit(model, data)

checkpoint = {'sharded_state_dict': model.sharded_state_dict()}

sync_state_dict = sync_checkpoint_io.load_checkpoint(
Path(f"{sync_ckpt_dir}/checkpoints/{_get_last_checkpoint_dir(model)}"), sharded_state_dict=checkpoint
)

async_state_dict = async_checkpoint_io.load_checkpoint(
Path(f"{async_ckpt_dir}/checkpoints/{_get_last_checkpoint_dir(model)}"), sharded_state_dict=checkpoint
)

## one of the keys is a _io.BytesIO object
for k in sync_state_dict['sharded_state_dict'].keys():
if isinstance(sync_state_dict['sharded_state_dict'][k], torch.Tensor):
assert torch.all(sync_state_dict['sharded_state_dict'][k] == async_state_dict['sharded_state_dict'][k])

def test_sharded_strategies(self):

model_checkpoint = nl.ModelCheckpoint()

strategy = nl.MegatronStrategy(
enable_nemo_ckpt_io=False,
save_ckpt_format='torch_dist',
ckpt_parallel_save=True,
ckpt_load_directly_on_device=False,
ckpt_async_save=True,
)
trainer = nl.Trainer(
callbacks=[model_checkpoint],
strategy=strategy,
)

assert isinstance(strategy.checkpoint_io, AsyncFinalizableCheckpointIO)
assert isinstance(strategy.checkpoint_io._checkpoint_io, MegatronCheckpointIO)

base_checkpoint_io = strategy.checkpoint_io._checkpoint_io

assert base_checkpoint_io.save_ckpt_format == 'torch_dist'
assert base_checkpoint_io.parallel_save
assert base_checkpoint_io.load_directly_on_device == False

0 comments on commit 0ce878c

Please sign in to comment.