Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Alpaca Finetune Datamodule #11185

Merged
merged 3 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo/collections/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from nemo.collections.llm import peft
from nemo.collections.llm.gpt.data import (
AlpacaDataModule,
DollyDataModule,
FineTuningDataModule,
HfDatasetDataModule,
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/llm/gpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.llm.gpt.data.alpaca import AlpacaDataModule
from nemo.collections.llm.gpt.data.dolly import DollyDataModule
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule
Expand All @@ -21,6 +22,7 @@

__all__ = [
"FineTuningDataModule",
"AlpacaDataModule",
"SquadDataModule",
"DollyDataModule",
"MockDataModule",
Expand Down
126 changes: 126 additions & 0 deletions nemo/collections/llm/gpt/data/alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import shutil
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from datasets import load_dataset

from nemo.collections.llm.gpt.data.core import get_dataset_root
from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule
from nemo.lightning.io.mixin import IOMixin
from nemo.utils import logging

if TYPE_CHECKING:
from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs


class AlpacaDataModule(FineTuningDataModule, IOMixin):
"""A data module for fine-tuning on the Alpaca Python dataset.
This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models
on the "iamtarun/python_code_instructions_18k_alpaca" dataset. It handles data download, preprocessing, splitting,
and preparing the data in a format suitable for training, validation, and testing.
Args:
force_redownload (bool, optional): Whether to force re-download the dataset even if it exists locally.
Defaults to False.
delete_raw (bool, optional): Whether to delete the raw downloaded dataset after preprocessing.
Defaults to True.
See FineTuningDataModule for the other args
"""

def __init__(
self,
seq_length: int = 2048,
tokenizer: Optional["TokenizerSpec"] = None,
micro_batch_size: int = 4,
global_batch_size: int = 8,
rampup_batch_size: Optional[List[int]] = None,
force_redownload: bool = False,
delete_raw: bool = True,
seed: int = 1234,
memmap_workers: int = 1,
num_workers: int = 8,
pin_memory: bool = True,
persistent_workers: bool = False,
packed_sequence_specs: Optional["PackedSequenceSpecs"] = None,
dataset_kwargs: Optional[Dict[str, Any]] = None,
):
self.force_redownload = force_redownload
self.delete_raw = delete_raw

super().__init__(
dataset_root=get_dataset_root("alpaca"),
seq_length=seq_length,
tokenizer=tokenizer,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
seed=seed,
memmap_workers=memmap_workers,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
packed_sequence_specs=packed_sequence_specs,
dataset_kwargs=dataset_kwargs,
)

def prepare_data(self) -> None:
# if train file is specified, no need to do anything
if not self.train_path.exists() or self.force_redownload:
dset = self._download_data()
self._preprocess_and_split_data(dset)
super().prepare_data()

def _download_data(self):
logging.info(f"Downloading {self.__class__.__name__}...")
return load_dataset(
"iamtarun/python_code_instructions_18k_alpaca",
cache_dir=str(self.dataset_root),
download_mode="force_redownload" if self.force_redownload else None,
)

def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15):
logging.info(f"Preprocessing {self.__class__.__name__} to jsonl format and splitting...")

test_ratio = 1 - train_ratio - val_ratio
save_splits = {}
dataset = dset.get('train')
split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed)
split_dataset2 = split_dataset['test'].train_test_split(
test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed
)
save_splits['training'] = split_dataset['train']
save_splits['validation'] = split_dataset2['train']
save_splits['test'] = split_dataset2['test']

for split_name, dataset in save_splits.items():
output_file = self.dataset_root / f"{split_name}.jsonl"
with output_file.open("w", encoding="utf-8") as f:
for o in dataset:
prompt = o['prompt'][: o['prompt'].find('### Output')]
completion = o['output']
f.write(json.dumps({"input": prompt, "output": completion}) + "\n")

logging.info(f"{split_name} split saved to {output_file}")

if self.delete_raw:
for p in self.dataset_root.iterdir():
if p.is_dir():
shutil.rmtree(p)
elif '.jsonl' not in str(p.name):
p.unlink()
Loading