Skip to content

Commit

Permalink
Adds a Knob for OnlineSampling by introducing 'global_sample_mapping'…
Browse files Browse the repository at this point in the history
… in the SFT config.yaml (NVIDIA#9913)

* Add 'global_sample_mapping' in config for turn on/off OnlineSampleMapping.

Signed-off-by: conver334 <[email protected]>

* black

Signed-off-by: Alexandros Koumparoulis <[email protected]>

---------

Signed-off-by: conver334 <[email protected]>
Signed-off-by: Alexandros Koumparoulis <[email protected]>
Co-authored-by: Simiao Zhang <[email protected]>
Co-authored-by: Alexandros Koumparoulis <[email protected]>
Signed-off-by: Vivian Chen <[email protected]>
  • Loading branch information
3 people authored and Vivian Chen committed Aug 1, 2024
1 parent 6f925d3 commit 6980146
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ model:
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch
validation_ds:
file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
names: null # Names of the corresponding datasets used to log metrics.
Expand All @@ -181,6 +182,7 @@ model:
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
Expand Down Expand Up @@ -208,6 +210,7 @@ model:
prompt_template: ${model.data.train_ds.prompt_template}
tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def create_sft_dataset(
truncation_method: str = 'right',
memmap_workers: int = 2,
hf_dataset: bool = False,
global_sample_mapping: bool = False,
**kwargs,
) -> "GPTSFTDataset":
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
Expand All @@ -42,6 +43,7 @@ def create_sft_dataset(
max_seq_length=seq_length,
memmap_workers=memmap_workers,
hf_dataset=hf_dataset,
global_sample_mapping=global_sample_mapping,
add_bos=add_bos,
add_eos=add_eos,
add_sep=add_sep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
tokens_to_generate: int = 0,
memmap_workers: Optional[int] = None,
hf_dataset: bool = False,
global_sample_mapping: bool = False,
truncation_method: str = 'right',
special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token}
is_test: bool = False,
Expand All @@ -83,6 +84,7 @@ def __init__(
index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset.
prompt_template: Prompt template to inject via an fstring. Formatted like Q: {context_key}\n\nA: {label_key}
hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.
global_sample_mapping: Whether to shuffle all data together, or shuffle the dataset within each epoch
truncation_method: Truncation from which position. Options: ['left', 'right']
special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
is_test: Whether this dataset is the test split.
Expand All @@ -109,6 +111,7 @@ def __init__(
self.tokens_to_generate = tokens_to_generate
self.memmap_workers = memmap_workers
self.hf_dataset = hf_dataset
self.global_sample_mapping = global_sample_mapping
self.truncation_method = truncation_method
self.is_test = is_test
self.output_original_text = output_original_text
Expand Down Expand Up @@ -176,7 +179,11 @@ def _maybe_validate_prompt_template(self):

def _build_samples_mapping(self):
if self.max_num_samples is not None:
osm = OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples)
osm = (
OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples)
if not self.global_sample_mapping
else None
)
self.samples_mapping = get_samples_mapping(
indexed_dataset=self.indexed_dataset,
data_prefix=self.file_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _build_dataset(self, data_cfg, is_train=True):
prompt_template=data_cfg.get('prompt_template', None),
ceil_to_power_2=data_cfg.get('ceil_to_power_2', False),
get_attention_mask_from_fusion=data_cfg.get('get_attention_mask_from_fusion', False),
global_sample_mapping=data_cfg.get('global_sample_mapping', False),
virtual_tokens=self.virtual_tokens,
tokens_to_generate=data_cfg.get(
'tokens_to_generate', 0
Expand Down
1 change: 1 addition & 0 deletions scripts/nlp_language_modeling/prepare_packed_ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def tokenize_dataset(cfg: 'DictConfig'):
tokens_to_generate=data_cfg.get('tokens_to_generate', 0),
memmap_workers=data_cfg.get('memmap_workers', None),
hf_dataset=data_cfg.get('hf_dataset', False),
global_sample_mapping=data_cfg.get('global_sample_mapping', False),
truncation_method=data_cfg.get('truncation_method', 'right'),
special_tokens=data_cfg.get('chat_prompt_tokens', None),
is_test=True,
Expand Down

0 comments on commit 6980146

Please sign in to comment.