Skip to content

Commit

Permalink
[NeMo Megatron] use worker processes for data preprocessing (#3665)
Browse files Browse the repository at this point in the history
* nvtx ranges

Signed-off-by: Masaki Kozuki <[email protected]>

* non_blocking

Signed-off-by: Masaki Kozuki <[email protected]>

* more workload on dataset, local global batch batchsampler

Signed-off-by: Masaki Kozuki <[email protected]>

* fix typo

Signed-off-by: Masaki Kozuki <[email protected]>

* remove nvtx ranges

Signed-off-by: Masaki Kozuki <[email protected]>

* cosmetic

Signed-off-by: Masaki Kozuki <[email protected]>

* pass tokenizer

Signed-off-by: Masaki Kozuki <[email protected]>

* fix type

Signed-off-by: Masaki Kozuki <[email protected]>

* style fix

Signed-off-by: Masaki Kozuki <[email protected]>

* remove GlobalBatchDataFetcher and NLPDataConnector

Signed-off-by: Masaki Kozuki <[email protected]>

* `_create_ltor_masks_and_position_ids` docstring

Signed-off-by: Masaki Kozuki <[email protected]>

* fix `BatchSampler`'s dunder len method

Originally (meaning Megatron's MegatronPretrainingSampler and MegatronPretrainingRandomSampler), `__len__` is implemented as follows:

```python
    def __len__(self):
        return (self.total_samples - self.consumed_samples - 1) // self.micro_batch_times_data_parallel_size + 1
```

The counterpart of `self.micro_batch_times_data_parallel_size` in
MegatronPretraining(|Random)BatchSampler is
`num_micro_batch_times_micro_batch_size_times_data_parallel_size`.

Signed-off-by: Masaki Kozuki <[email protected]>

* remove unused imports

Signed-off-by: Masaki Kozuki <[email protected]>

* rename and license

Signed-off-by: Masaki Kozuki <[email protected]>

* isort
  • Loading branch information
crcrpar authored Feb 25, 2022
1 parent 2ce4536 commit bc6215f
Show file tree
Hide file tree
Showing 6 changed files with 331 additions and 144 deletions.
5 changes: 0 additions & 5 deletions examples/nlp/language_modeling/megatron_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDataConnector,
NLPDDPPlugin,
PipelineMixedPrecisionPlugin,
)
Expand Down Expand Up @@ -59,10 +58,6 @@ def main(cfg) -> None:

trainer = Trainer(plugins=plugins, **cfg.trainer)

# NLPDataConnector used to provide global batches which are needed
# for Apex fwd/bwd functions
trainer._data_connector = NLPDataConnector(trainer)

exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
Expand Down
103 changes: 98 additions & 5 deletions nemo/collections/nlp/data/language_modeling/megatron/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,16 @@


def build_train_valid_test_datasets(
cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup
cfg,
trainer,
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
skip_warmup,
tokenizer,
):
"""Build train, valid, and test datasets."""

Expand All @@ -56,6 +65,7 @@ def build_train_valid_test_datasets(
seq_length,
seed,
skip_warmup,
tokenizer,
)

# Blending dataset.
Expand All @@ -78,6 +88,7 @@ def build_train_valid_test_datasets(
seq_length,
seed,
skip_warmup,
tokenizer,
)
if train_ds:
train_datasets.append(train_ds)
Expand All @@ -101,7 +112,16 @@ def build_train_valid_test_datasets(


def _build_train_valid_test_datasets(
cfg, trainer, data_prefix, data_impl, splits_string, train_valid_test_num_samples, seq_length, seed, skip_warmup
cfg,
trainer,
data_prefix,
data_impl,
splits_string,
train_valid_test_num_samples,
seq_length,
seed,
skip_warmup,
tokenizer,
):
"""Build train, valid, and test datasets."""

Expand Down Expand Up @@ -132,6 +152,7 @@ def build_dataset(index, name):
dataset = GPTDataset(
cfg,
trainer,
tokenizer,
name,
data_prefix,
documents,
Expand Down Expand Up @@ -162,7 +183,9 @@ def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):


class GPTDataset(MegatronDataset):
def __init__(self, cfg, trainer, name, data_prefix, documents, indexed_dataset, num_samples, seq_length, seed):
def __init__(
self, cfg, trainer, tokenizer, name, data_prefix, documents, indexed_dataset, num_samples, seq_length, seed,
):
if not HAVE_APEX:
raise ImportError(
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
Expand All @@ -176,6 +199,11 @@ def __init__(self, cfg, trainer, name, data_prefix, documents, indexed_dataset,
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]

self.reset_position_ids = cfg.data.get('reset_position_ids', False)
self.reset_attention_mask = cfg.data.get('reset_attention_mask', False)
self.eod_mask_loss = cfg.data.get('eod_mask_loss', False)
self.eos_id = tokenizer.eos_id

# Build index mappings.
self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings(
self.name, data_prefix, documents, self.indexed_dataset.sizes, num_samples, seq_length, seed
Expand All @@ -186,7 +214,8 @@ def __len__(self):
# sample i --> [sample_idx[i], sample_idx[i+1])
return self.sample_idx.shape[0] - 1

def __getitem__(self, idx):
def _get_text(self, idx: int) -> np.ndarray:

# Get the shuffled index.
idx = self.shuffle_idx[idx]
# Start and end documents and offsets.
Expand All @@ -208,8 +237,72 @@ def __getitem__(self, idx):
# And finally add the relevant portion of last document.
sample_list.append(self.indexed_dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1))
sample = np.concatenate(sample_list)
return sample.astype(np.int64)

def __getitem__(self, idx):
text = torch.from_numpy(self._get_text(idx))
tokens = text[:-1].contiguous()
labels = text[1:].contiguous()
attention_mask, loss_mask, position_ids = _create_ltor_masks_and_position_ids(
tokens, self.eos_id, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss,
)

return {
'tokens': tokens,
'labels': labels,
'attention_mask': attention_mask,
'loss_mask': loss_mask,
'position_ids': position_ids,
}

return {'text': np.array(sample, dtype=np.int64)}

@torch.no_grad()
def _create_ltor_masks_and_position_ids(
tokens: torch.Tensor, eod_token: int, reset_position_ids: bool, reset_attention_mask: bool, eod_mask_loss: bool,
):
"""Create `attention_mask`, `loss_mask`, and `position_ids`.
This function is modified :func:`get_ltor_masks_and_position_ids` in nemo/collections/nlp/modules/common/megatron/utils.py:
`get_ltor_masks_and_position_ids` assumes a microbatch of ``tokens``, i.e. 2D tensor while
this function assumes ``tokens`` to be 1D tensor.
Args:
tokens: A 1D tensor that holds the indices of tokens.
eod_token:
reset_position_ids:
reset_attention_mask:
eod_mask_loss
"""
assert tokens.ndim == 1
seq_length = tokens.numel()
# `attention_mask` has the shape of [1, seq_length, seq_length]
attention_mask = torch.tril(torch.ones((seq_length, seq_length))).unsqueeze(0)
loss_mask = torch.ones(seq_length, dtype=torch.float)
if eod_mask_loss:
loss_mask[tokens == eod_token] = 0.0

position_ids = torch.arange(seq_length, dtype=torch.int64)
if reset_position_ids:
position_ids = position_ids.clone()

if reset_position_ids or reset_attention_mask:
# Find indices where EOD token is.
eod_index = position_ids[tokens[b] == eod_token]
# Detach indices from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
prev_index = 0
for j in range(eod_index.numel()):
i = eod_index[j]
if reset_attention_mask:
attention_mask[0, (i + 1) :, : (i + 1)] = 0
if reset_position_ids:
position_ids[(i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary.
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids


def _build_index_mappings(name, data_prefix, documents, sizes, num_samples, seq_length, seed):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
super().__init__(
cfg,
trainer,
tokenizer,
name,
data_prefix,
documents,
Expand All @@ -42,8 +43,7 @@ def __init__(
)

def __getitem__(self, idx):
text = super().__getitem__(idx)
text = text['text']
text = super()._get_text(idx)

# Split text sequence into encoder and decoder inputs
tokens_enc = text[: self.seq_length_encoder]
Expand Down
Loading

0 comments on commit bc6215f

Please sign in to comment.