Skip to content

Commit

Permalink
Add Alpaca Finetune Datamodule (NVIDIA#11185)
Browse files Browse the repository at this point in the history
* Add Alpaca Datamodule

* Apply isort and black reformatting

Signed-off-by: suiyoubi <[email protected]>

---------

Signed-off-by: suiyoubi <[email protected]>
Co-authored-by: suiyoubi <[email protected]>
  • Loading branch information
2 people authored and HuiyingLi committed Nov 15, 2024
1 parent 51808f4 commit 75bbd5e
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 0 deletions.
1 change: 1 addition & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from nemo.collections.llm import peft
from nemo.collections.llm.gpt.data import (
AlpacaDataModule,
DollyDataModule,
FineTuningDataModule,
HfDatasetDataModule,
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.gpt.data.alpaca import AlpacaDataModule
from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
Expand All @@ -21,6 +22,7 @@

__all__ = [
"FineTuningDataModule",
"AlpacaDataModule",
"SquadDataModule",
"DollyDataModule",
"MockDataModule",
Expand Down
126 changes: 126 additions & 0 deletions nemo/collections/llm/gpt/data/alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from datasets import load_dataset

from nemo.collections.llm.gpt.data.core import get_dataset_root
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


class AlpacaDataModule(FineTuningDataModule, IOMixin):
"""A data module for fine-tuning on the Alpaca Python dataset.
This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models
on the "iamtarun/python_code_instructions_18k_alpaca" dataset. It handles data download, preprocessing, splitting,
and preparing the data in a format suitable for training, validation, and testing.
Args:
force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally.
Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing.
Defaults to True.
See FineTuningDataModule for the other args
"""

def __init__(
self,
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
dataset_kwargs: Optional[Dict[str, Any]] = None,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=get_dataset_root("alpaca"),
seq_length=seq_length,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
packed_sequence_specs=packed_sequence_specs,
dataset_kwargs=dataset_kwargs,
)

def prepare_data(self) -> None:
# if train file is specified, no need to do anything
if not self.train_path.exists() or self.force_redownload:
dset = self._download_data()
self._preprocess_and_split_data(dset)
super().prepare_data()

def _download_data(self):
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
"iamtarun/python_code_instructions_18k_alpaca",
cache_dir=str(self.dataset_root),
download_mode="force_redownload" if self.force_redownload else None,
)

def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15):
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")

test_ratio = 1 - train_ratio - val_ratio
save_splits = {}
dataset = dset.get('train')
split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed)
split_dataset2 = split_dataset['test'].train_test_split(
test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed
)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset2['train']
save_splits['test'] = split_dataset2['test']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"
with output_file.open("w", encoding="utf-8") as f:
for o in dataset:
prompt = o['prompt'][: o['prompt'].find('### Output')]
completion = o['output']
f.write(json.dumps({"input": prompt, "output": completion}) + "\n")

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()

0 comments on commit 75bbd5e

Please sign in to comment.