Skip to content

Commit

Permalink
[NeMo-UX] Fix logging of consumed samples in MegatronDataSampler (#10018
Browse files Browse the repository at this point in the history
)

* Fix logging of consumed samples in MegatronDataSampler

Signed-off-by: Hemil Desai <[email protected]>

* Apply isort and black reformatting

Signed-off-by: hemildesai <[email protected]>

* Remove unused import

Signed-off-by: Hemil Desai <[email protected]>

---------

Signed-off-by: Hemil Desai <[email protected]>
Signed-off-by: hemildesai <[email protected]>
Co-authored-by: hemildesai <[email protected]>
  • Loading branch information
2 people authored and monica-sekoyan committed Oct 11, 2024
1 parent f284b08 commit 5ad049f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 26 deletions.
1 change: 0 additions & 1 deletion nemo/lightning/pytorch/callbacks/preemption.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from pytorch_lightning.trainer.trainer import Trainer

from nemo.lightning.io.mixin import IOMixin
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO

Expand Down
36 changes: 11 additions & 25 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,12 @@ def compute_consumed_samples(self, steps_since_resume=0) -> int:
from nemo.lightning.pytorch.strategies import MegatronStrategy
from nemo.utils import AppState

try:
from megatron.core.num_microbatches_calculator import get_current_global_batch_size

except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_current_global_batch_size

if not isinstance(self.trainer.strategy, MegatronStrategy):
return 0

app_state = AppState()

if self.rampup_batch_size is not None:
if get_current_global_batch_size():
current_global_batch_size = get_current_global_batch_size()
else:
current_global_batch_size = 1
consumed_samples = self.prev_consumed_samples + self.if_first_step * current_global_batch_size
consumed_samples = self.prev_consumed_samples + self.if_first_step * self.current_global_batch_size
else:
consumed_samples = (
self.init_consumed_samples
Expand All @@ -96,22 +84,16 @@ def on_megatron_step_start(self, trainer: pl.Trainer) -> None:

def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
try:
from megatron.core.num_microbatches_calculator import (
get_current_global_batch_size,
update_num_microbatches,
)
from megatron.core.num_microbatches_calculator import update_num_microbatches

except (ImportError, ModuleNotFoundError):
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_current_global_batch_size, update_num_microbatches

if self.rampup_batch_size is None:
return
from apex.transformer.pipeline_parallel.utils import update_num_microbatches

self.prev_global_batch_size = self.current_global_batch_size

# TODO: Add consumed samples
consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step)
consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_consumed_samples)

pl_module.log(
'consumed_samples',
Expand All @@ -127,10 +109,9 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul
consumed_samples=consumed_samples,
consistency_check=False,
)
current_global_batch_size = get_current_global_batch_size()
pl_module.log(
"global_batch_size",
current_global_batch_size,
self.current_global_batch_size,
prog_bar=True,
rank_zero_only=True,
batch_size=1,
Expand Down Expand Up @@ -165,4 +146,9 @@ def current_global_batch_size(self) -> int:
logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
from apex.transformer.pipeline_parallel.utils import get_current_global_batch_size

return get_current_global_batch_size()
if get_current_global_batch_size():
current_global_batch_size = get_current_global_batch_size()
else:
current_global_batch_size = 1

return current_global_batch_size

0 comments on commit 5ad049f

Please sign in to comment.