Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make tokenize_and_concatenate work with more datasets #473

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch.nn.functional as F
import transformers
from datasets.arrow_dataset import Dataset
from datasets.iterable_dataset import IterableDataset
from datasets.load import load_dataset
from huggingface_hub import hf_hub_download
from jaxtyping import Float, Int
Expand Down Expand Up @@ -209,13 +210,15 @@ def keep_single_column(dataset: Dataset, col_name: str):


def tokenize_and_concatenate(
dataset: Dataset,
dataset: Union[Dataset, IterableDataset],
tokenizer: AutoTokenizer,
streaming: bool = False,
max_length: int = 1024,
column_name: str = "text",
add_bos_token: bool = True,
num_proc: int = 10,
remove_pad_tokens: bool = True,
set_format: bool = True,
) -> Dataset:
"""Helper function to tokenizer and concatenate a dataset of text. This converts the text to tokens, concatenates them (separated by EOS tokens) and then reshapes them into a 2D array of shape (____, sequence_length), dropping the last batch. Tokenizers are much faster if parallelised, so we chop the string into 20, feed it into the tokenizer, in parallel with padding, then remove padding at the end.

Expand All @@ -228,6 +231,10 @@ def tokenize_and_concatenate(
max_length (int, optional): The length of the context window of the sequence. Defaults to 1024.
column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'.
add_bos_token (bool, optional): . Defaults to True.
num_proc (int, optional): The number of processes to use for parallel tokenization. Defaults
to 10.
remove_pad_tokens (bool, optional): Whether to remove the padding tokens. Defaults to True.
set_format (bool, optional): Whether to set the format of the dataset to torch and remove a column. Defaults to True.

Returns:
Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens"
Expand Down Expand Up @@ -259,8 +266,9 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
tokens = tokenizer(chunks, return_tensors="np", padding=True)[
"input_ids"
].flatten()
# Drop padding tokens
tokens = tokens[tokens != tokenizer.pad_token_id]
if remove_pad_tokens:
# Drop padding tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment that padding tokens may be there because the chunks are uneven length, because we split the text into 20 chunks (for tokenization efficiency), in addition to maybe being in the training data.

tokens = tokens[tokens != tokenizer.pad_token_id]
num_tokens = len(tokens)
num_batches = num_tokens // (seq_len)
# Drop the final tokens if not enough to make a full sequence
Expand All @@ -276,10 +284,23 @@ def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, np.ndarray]:
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=(num_proc if not streaming else None),
remove_columns=[column_name],
# Don't even pass the num_proc argument if we're streaming
**({"num_proc": num_proc} if not streaming else {}),
)
tokenized_dataset.set_format(type="torch", columns=["tokens"])

if set_format:
# This cleans up the dataset, removing the column name and setting the format to torch
# Doesn't work for all datasets (eg when streaming)
# Creating a generator which will be lazily loaded, e.g
# ````
# formatted_tokenized_dataset = (torch.LongTensor(example['tokens']) for example in
# tokenized_dataset)
# ````
# May work for your use case

tokenized_dataset.set_format(type="torch", columns=["tokens"])

return tokenized_dataset


Expand Down