Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds a Knob for OnlineSampling by introducing 'global_sample_mapping' in the SFT config.yaml #9913

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ model:
index_mapping_dir: null # Path to a directory to write index mapping files.
prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch
validation_ds:
file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
names: null # Names of the corresponding datasets used to log metrics.
Expand All @@ -181,6 +182,7 @@ model:
prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}"
tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
Expand Down Expand Up @@ -208,6 +210,7 @@ model:
prompt_template: ${model.data.train_ds.prompt_template}
tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def create_sft_dataset(
truncation_method: str = 'right',
memmap_workers: int = 2,
hf_dataset: bool = False,
global_sample_mapping: bool = False,
**kwargs,
) -> "GPTSFTDataset":
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
Expand All @@ -42,6 +43,7 @@ def create_sft_dataset(
max_seq_length=seq_length,
memmap_workers=memmap_workers,
hf_dataset=hf_dataset,
global_sample_mapping=global_sample_mapping,
add_bos=add_bos,
add_eos=add_eos,
add_sep=add_sep,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
tokens_to_generate: int = 0,
memmap_workers: Optional[int] = None,
hf_dataset: bool = False,
global_sample_mapping: bool = False,
truncation_method: str = 'right',
special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token}
is_test: bool = False,
Expand All @@ -83,6 +84,7 @@ def __init__(
index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset.
prompt_template: Prompt template to inject via an fstring. Formatted like Q: {context_key}\n\nA: {label_key}
hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset.
global_sample_mapping: Whether to shuffle all data together, or shuffle the dataset within each epoch
truncation_method: Truncation from which position. Options: ['left', 'right']
special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
is_test: Whether this dataset is the test split.
Expand All @@ -109,6 +111,7 @@ def __init__(
self.tokens_to_generate = tokens_to_generate
self.memmap_workers = memmap_workers
self.hf_dataset = hf_dataset
self.global_sample_mapping = global_sample_mapping
self.truncation_method = truncation_method
self.is_test = is_test
self.output_original_text = output_original_text
Expand Down Expand Up @@ -176,7 +179,11 @@ def _maybe_validate_prompt_template(self):

def _build_samples_mapping(self):
if self.max_num_samples is not None:
osm = OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples)
osm = (
OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples)
if not self.global_sample_mapping
else None
)
self.samples_mapping = get_samples_mapping(
indexed_dataset=self.indexed_dataset,
data_prefix=self.file_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _build_dataset(self, data_cfg, is_train=True):
prompt_template=data_cfg.get('prompt_template', None),
ceil_to_power_2=data_cfg.get('ceil_to_power_2', False),
get_attention_mask_from_fusion=data_cfg.get('get_attention_mask_from_fusion', False),
global_sample_mapping=data_cfg.get('global_sample_mapping', False),
virtual_tokens=self.virtual_tokens,
tokens_to_generate=data_cfg.get(
'tokens_to_generate', 0
Expand Down
1 change: 1 addition & 0 deletions scripts/nlp_language_modeling/prepare_packed_ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def tokenize_dataset(cfg: 'DictConfig'):
tokens_to_generate=data_cfg.get('tokens_to_generate', 0),
memmap_workers=data_cfg.get('memmap_workers', None),
hf_dataset=data_cfg.get('hf_dataset', False),
global_sample_mapping=data_cfg.get('global_sample_mapping', False),
truncation_method=data_cfg.get('truncation_method', 'right'),
special_tokens=data_cfg.get('chat_prompt_tokens', None),
is_test=True,
Expand Down
Loading