-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Alpaca Finetune Datamodule (#11185)
* Add Alpaca Datamodule * Apply isort and black reformatting Signed-off-by: suiyoubi <[email protected]> --------- Signed-off-by: suiyoubi <[email protected]> Co-authored-by: suiyoubi <[email protected]>
- Loading branch information
1 parent
198e78b
commit a24ee77
Showing
3 changed files
with
129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import shutil | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||
|
||
from datasets import load_dataset | ||
|
||
from nemo.collections.llm.gpt.data.core import get_dataset_root | ||
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule | ||
from nemo.lightning.io.mixin import IOMixin | ||
from nemo.utils import logging | ||
|
||
if TYPE_CHECKING: | ||
from nemo.collections.common.tokenizers import TokenizerSpec | ||
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs | ||
|
||
|
||
class AlpacaDataModule(FineTuningDataModule, IOMixin): | ||
"""A data module for fine-tuning on the Alpaca Python dataset. | ||
This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models | ||
on the "iamtarun/python_code_instructions_18k_alpaca" dataset. It handles data download, preprocessing, splitting, | ||
and preparing the data in a format suitable for training, validation, and testing. | ||
Args: | ||
force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally. | ||
Defaults to False. | ||
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing. | ||
Defaults to True. | ||
See FineTuningDataModule for the other args | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seq_length: int = 2048, | ||
tokenizer: Optional["TokenizerSpec"] = None, | ||
micro_batch_size: int = 4, | ||
global_batch_size: int = 8, | ||
rampup_batch_size: Optional[List[int]] = None, | ||
force_redownload: bool = False, | ||
delete_raw: bool = True, | ||
seed: int = 1234, | ||
memmap_workers: int = 1, | ||
num_workers: int = 8, | ||
pin_memory: bool = True, | ||
persistent_workers: bool = False, | ||
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None, | ||
dataset_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
self.force_redownload = force_redownload | ||
self.delete_raw = delete_raw | ||
|
||
super().__init__( | ||
dataset_root=get_dataset_root("alpaca"), | ||
seq_length=seq_length, | ||
tokenizer=tokenizer, | ||
micro_batch_size=micro_batch_size, | ||
global_batch_size=global_batch_size, | ||
rampup_batch_size=rampup_batch_size, | ||
seed=seed, | ||
memmap_workers=memmap_workers, | ||
num_workers=num_workers, | ||
pin_memory=pin_memory, | ||
persistent_workers=persistent_workers, | ||
packed_sequence_specs=packed_sequence_specs, | ||
dataset_kwargs=dataset_kwargs, | ||
) | ||
|
||
def prepare_data(self) -> None: | ||
# if train file is specified, no need to do anything | ||
if not self.train_path.exists() or self.force_redownload: | ||
dset = self._download_data() | ||
self._preprocess_and_split_data(dset) | ||
super().prepare_data() | ||
|
||
def _download_data(self): | ||
logging.info(f"Downloading {self.__class__.__name__}...") | ||
return load_dataset( | ||
"iamtarun/python_code_instructions_18k_alpaca", | ||
cache_dir=str(self.dataset_root), | ||
download_mode="force_redownload" if self.force_redownload else None, | ||
) | ||
|
||
def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15): | ||
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...") | ||
|
||
test_ratio = 1 - train_ratio - val_ratio | ||
save_splits = {} | ||
dataset = dset.get('train') | ||
split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed) | ||
split_dataset2 = split_dataset['test'].train_test_split( | ||
test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed | ||
) | ||
save_splits['training'] = split_dataset['train'] | ||
save_splits['validation'] = split_dataset2['train'] | ||
save_splits['test'] = split_dataset2['test'] | ||
|
||
for split_name, dataset in save_splits.items(): | ||
output_file = self.dataset_root / f"{split_name}.jsonl" | ||
with output_file.open("w", encoding="utf-8") as f: | ||
for o in dataset: | ||
prompt = o['prompt'][: o['prompt'].find('### Output')] | ||
completion = o['output'] | ||
f.write(json.dumps({"input": prompt, "output": completion}) + "\n") | ||
|
||
logging.info(f"{split_name} split saved to {output_file}") | ||
|
||
if self.delete_raw: | ||
for p in self.dataset_root.iterdir(): | ||
if p.is_dir(): | ||
shutil.rmtree(p) | ||
elif '.jsonl' not in str(p.name): | ||
p.unlink() |