Skip to content

Commit

Permalink
Use MegatronDataSampler in HfDatasetDataModule (#11274)
Browse files Browse the repository at this point in the history
* Use MegatronDataSampler in HfDataset

Signed-off-by: Alexandros Koumparoulis <[email protected]>

* Apply isort and black reformatting

Signed-off-by: akoumpa <[email protected]>

---------

Signed-off-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: akoumpa <[email protected]>
Co-authored-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa and akoumpa authored Nov 14, 2024
1 parent 8b0c311 commit bf7cc64
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions nemo/collections/llm/gpt/data/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from nemo.lightning.pytorch.plugins import MegatronDataSampler


class HfDatasetDataModule(pl.LightningDataModule):
Expand All @@ -24,6 +25,7 @@ def __init__(
num_workers=2,
pin_memory=True,
persistent_workers=True,
seq_length=1024,
micro_batch_size=2,
global_batch_size=2,
pad_token_id=0,
Expand All @@ -37,6 +39,7 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.persistent_workers = persistent_workers
self.seq_length = seq_length
self.micro_batch_size = micro_batch_size
self.global_batch_size = global_batch_size
self.pad_token_id = pad_token_id
Expand All @@ -58,6 +61,7 @@ def pad_within_micro(batch, pad_token_id):
max_len = max(map(len, batch))
return [item + [pad_token_id] * (max_len - len(item)) for item in batch]

keys = list(filter(lambda x: x in batch[0], ['tokens', 'labels', 'position_ids', 'loss_mask']))
return {
key: batchify(
torch.LongTensor(
Expand All @@ -67,37 +71,30 @@ def pad_within_micro(batch, pad_token_id):
)
)
)
for key in ['tokens', 'labels']
for key in keys
}

def setup(self, stage: str):
if not self.use_mcore_sampler:
return
self.data_sampler = MegatronDataSampler(
seq_len=self.seq_length,
micro_batch_size=self.micro_batch_size,
global_batch_size=self.global_batch_size,
dataloader_type=self.mcore_dataloader_type,
)

def train_dataloader(self, collate_fn=None):
from nemo.lightning.data import add_megatron_sampler

if collate_fn is None:
collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id)

dataloader = DataLoader(
return DataLoader(
self.dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
persistent_workers=self.persistent_workers,
collate_fn=collate_fn,
batch_size=self.micro_batch_size,
)
if not self.use_mcore_sampler:
return dataloader

rank = 0
world_size = 1
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()

return add_megatron_sampler(
dataloader,
self.micro_batch_size,
self.global_batch_size,
dataloader_type=self.mcore_dataloader_type,
rank=rank,
world_size=world_size,
)

0 comments on commit bf7cc64

Please sign in to comment.