From ae5813190354d23b06aa56042727aae6a3929504 Mon Sep 17 00:00:00 2001 From: Dmytro Pykhtar <37850217+dimapihtar@users.noreply.github.com> Date: Mon, 15 Jul 2024 13:13:57 +0300 Subject: [PATCH] fix legacy ds padding bug (#9716) * fix legacy ds padding bug Signed-off-by: dimapihtar * Apply isort and black reformatting Signed-off-by: dimapihtar * avoid code repetition Signed-off-by: dimapihtar * fix typo Signed-off-by: dimapihtar --------- Signed-off-by: dimapihtar Signed-off-by: dimapihtar Co-authored-by: dimapihtar Signed-off-by: Hainan Xu --- .../data/language_modeling/megatron/data_samplers.py | 10 +++++++++- .../nlp/models/language_modeling/megatron_gpt_model.py | 8 +++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py index 4a8b989a7b6d..622e2d759266 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py @@ -100,13 +100,16 @@ def get_start_end_idx(self): end_idx = start_idx + self.micro_batch_size return start_idx, end_idx + def _get_padding_indices(self, pad_samples_num): + return range(-1, -pad_samples_num - 1, -1) + def __iter__(self): batch = [] # Last batch will be dropped if drop_last is not set False indices = range(self.consumed_samples, self.total_samples) if (not self.drop_last) and self.pad_samples_to_global_batch_size: pad_samples_num = -len(indices) % self.global_batch_size - pad_indices = [None] * pad_samples_num + pad_indices = self._get_padding_indices(pad_samples_num) indices = chain(indices, pad_indices) for idx in indices: @@ -125,6 +128,11 @@ def __iter__(self): yield batch[start_idx:end_idx] +class MegatronCorePretrainingSampler(MegatronPretrainingSampler): + def _get_padding_indices(self, pad_samples_num): + return [None] * pad_samples_num + + class MegatronPretrainingRandomSampler(BaseMegatronSampler): def __init__( self, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 69cd06021f50..e4cab6cec26f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -32,6 +32,7 @@ from nemo.collections.common.parts.utils import extend_instance from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( + MegatronCorePretrainingSampler, MegatronPretrainingRandomSampler, MegatronPretrainingSampler, ) @@ -1605,8 +1606,13 @@ def build_pretraining_data_loader( logging.info(f'Building dataloader with consumed samples: {consumed_samples}') # Megatron sampler if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: + data_sampler = ( + MegatronPretrainingSampler + if self.cfg.data.get('legacy_dataset', False) + else MegatronCorePretrainingSampler + ) if self.cfg.data.dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler( + batch_sampler = data_sampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self.cfg.micro_batch_size,