Skip to content

Commit

Permalink
Add shufflers as utilities (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregtatum authored Feb 27, 2024
1 parent 3706913 commit 19e46e5
Show file tree
Hide file tree
Showing 2 changed files with 336 additions and 0 deletions.
180 changes: 180 additions & 0 deletions pipeline/common/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import os
import tempfile
from collections import deque
from io import TextIOWrapper
from random import Random
from typing import Iterator, Optional


def shuffle_with_max_lines(
line_stream: Iterator[str],
seed: str,
max_lines: int,
max_words_in_sentence,
total_byte_size: int,
) -> Iterator[str]:
"""
Shuffle a line stream, but only retain up to a maximum number of lines in memory.
Note that the final ordering is determined by the seed and the contents of the file. So
running this multiple times on the same dataset will return the same result, but running
it with the same seed and different content will create a different ordering.
Only run for monolingual data or where the parallel sentences are separated by a delimiter.
The distribution should be even unless the initial content is not representative of the
general size of the sentences, in this case the distribution will be slightly biased. See
the test cases for more in-depth examples.
"""
lines = deque()

random = Random(seed) # Make this deterministic based on dataset key.

total_bytes = 0

# Fill up the lines up until the max, and measure the total bytes.
for line in line_stream:
# Encoding returns the underlying byte representation which is then measured.
total_bytes = total_bytes + len(line.encode("utf-8"))

if len(line.split()) > max_words_in_sentence:
# TODO(CJK) - Issue #424
# This sentence is too long.
continue

lines.append(line)

if len(lines) == max_lines:
break

random.shuffle(lines)

# Consume the rest of the line stream, but sample based on the probability that adding
# something to the collection will be representative.

i = 0
for line in line_stream:
i = i + 1
# Continuously adjust this estimation in case the first sampled data is not representative.
total_bytes = total_bytes + len(line.encode("utf-8"))
average_bytes_per_line = total_bytes / (max_lines + i)
estimated_lines = total_byte_size / average_bytes_per_line
line_sampling_probability = max_lines / estimated_lines

if random.random() < line_sampling_probability:
# Shift the deque so the oldest line is shifted out, and this new sample is shifted in.
lines.popleft()
lines.append(line)

# Do a final shuffle to ensure that the newly sampled lines are shuffled with the original
# set of shuffled lines.
random.shuffle(lines)

return lines


def shuffle_in_temp_files(
line_stream: Iterator[str],
output: TextIOWrapper,
seed: str,
chunk_bytes: int,
bucket_bytes: int,
chunk_dir: Optional[str] = tempfile.gettempdir(),
keep_chunks=False,
):
"""
Shuffle large datasets by storing chunks to the file system. The ordering is guaranteed to be
stable across two datasets as long as they are the same length. For instance it could be used
to shuffle `dataset.en.zst` and `dataset.ca.zst` the same if the two are parallel sentences.
Take in a stream of lines (from a download, or stdin) and split it out to chunks.
tmpdir
├── chunk.1
├── chunk.2
├── chunk.3
├── chunk.4
├── ...
└── chunk.100
After the entire dataset is written to chunks, pick random chunks and put them into a
bucket. Only one bucket is fully loaded into memory at a time, and the contents
of the bucket is shuffled in memory.
Bucket:
┌───────────┐
│ chunk.85 │
│ chunk.3 │
│ chunk.52 │
│ chunk.30 │
│ chunk.12 │
│ chunk.18 │
└───────────┘
• shuffle bucket lines
• write to output
At most 1 bucket will be held in memory. At most the dataset + 1 bucket of file space will be
needed when running this algorithm.
"""
random = Random(seed)

chunk_index = 0
chunk_file = open(os.path.join(chunk_dir, f"chunk.{chunk_index}"), "wt")

# Write out the chunks to disk.
bytes_written_to_chunk = 0
for line in line_stream:
line_bytes = len(line.encode("utf-8")) + 1

if bytes_written_to_chunk + line_bytes > chunk_bytes:
# Start a new chunk.
chunk_file.close()
chunk_index += 1
chunk_file = open(os.path.join(chunk_dir, f"chunk.{chunk_index}"), "wt")
bytes_written_to_chunk = 0

chunk_file.write(line + "\n")
bytes_written_to_chunk += line_bytes

chunk_file.close()

# Shuffle the chunk indexes
chunk_count = chunk_index + 1

shuffled_chunk_indexes = [*range(chunk_count)]
random.shuffle(shuffled_chunk_indexes)

# Load a single bucket into memory, discarding the chunks.
bucket_count = 0
bytes_in_bucket = 0
bucket = []

for chunk_index in shuffled_chunk_indexes:
chunk_name = os.path.join(chunk_dir, f"chunk.{chunk_index}")

# Read in the chunk line by line.
with open(chunk_name, "r") as file:
for line in file.readlines():
bucket.append(line)
bytes_in_bucket += len(line.encode("utf-8"))

# If the bucket overflows, shuffle and write it out.
if bytes_in_bucket > bucket_bytes:
random.shuffle(bucket)
for shuffled_line in bucket:
output.write(shuffled_line)

# Create the new bucket.
bucket = []
bytes_in_bucket = 0
bucket_count += 1

if not keep_chunks:
os.remove(chunk_name)

if len(bucket) > 0:
random.shuffle(bucket)
for shuffled_line in bucket:
output.write(shuffled_line)

print(f"Shuffled with {bucket_count} buckets.")
156 changes: 156 additions & 0 deletions tests/test_common_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import io
from typing import Iterator

import pytest
from fixtures import DataDir

from pipeline.common.datasets import shuffle_in_temp_files, shuffle_with_max_lines

ITEMS = 100_000
# ITEMS = 1_000
PERCENTAGE = 0.2
MAX_LINES = int(ITEMS * PERCENTAGE)


def get_total_byte_size(lines: list[str]) -> int:
total_byte_size = 0
for line in lines:
total_byte_size = total_byte_size + len(line.encode())
return total_byte_size


def compute_distribution(lines: Iterator[str], items=ITEMS, max_lines=MAX_LINES) -> list[float]:
"""
Computes a histogram (list of 10 items) with a percentage value of 0.0 - 100.0 for each item.
"""
histogram = [0] * 10
for line in lines:
# This assumes the content will be a tab separated list, with the first item to be the
# initial sorted order in the list.
key = int(int(line.split("\t")[0]) * 10 / items)
histogram[key] = histogram[key] + (1 / max_lines)

# Lower the precision of the ints.
return [round(value * 1000) / 1000 for value in histogram]


# Test the distributions of the different types of datasets. This shuffler estimates the content
# size as it iterates through the line stream.
shuffle_params = [
(
# Each line is the same bytes as the next line. This should create an even distribution.
# [
# "000000000 000000000 000000000 ... 000000000",
# "000000001 000000001 000000001 ... 000000001",
# ...
# ]
"even-distribution",
[f"{line:09d}\t" * 10 for line in range(ITEMS)],
[0.102, 0.101, 0.099, 0.1, 0.1, 0.102, 0.099, 0.097, 0.1, 0.1],
),
(
# The initial lines are low in byte count, and gradually increase. In this case there
# will be a bias to over-sample the the initial items, but it will eventually even out as
# more bytes are read in and the average spreads out.
# [
# "0 0 0 ... 0",
# "1 1 1 ... 1",
# ...
# "99997 99997 99997 ... 99997",
# "99998 99998 99998 ... 99998",
# "99999 99999 99999 ... 99999",
# ]
"small-content-at-start",
[f"{line}\t" * 10 for line in range(ITEMS)],
[0.114, 0.116, 0.092, 0.095, 0.096, 0.099, 0.097, 0.095, 0.098, 0.098],
# | | |
# | | ^ Lower sample rate.
# ^^^^^^^ Higher sampling rate.
),
(
# [
# "99999 99999 99999 ... 99999",
# "99998 99998 99998 ... 99998",
# "99997 99997 99997 ... 99997",
# ...
# "1 1 1 ... 1",
# "0 0 0 ... 0",
# ]
"large-content-at-start",
[f"{line}\t" * 10 for line in range(ITEMS)][::-1],
[0.101, 0.102, 0.099, 0.102, 0.103, 0.102, 0.101, 0.102, 0.102, 0.086],
# lower sample rate ^^^^^
),
]


@pytest.mark.parametrize("params", shuffle_params, ids=[d[0] for d in shuffle_params])
def test_shuffle_with_max_lines(params):
description, line_stream, histograph = params
# [
# "0000 0000 0000 ... 0000",
# "0001 0001 0001 ... 0001",
# "0002 0002 0002 ... 0002",
# ...
# ]

output = shuffle_with_max_lines(
line_stream,
seed="test",
max_lines=MAX_LINES,
max_words_in_sentence=100,
total_byte_size=get_total_byte_size(line_stream),
)

assert compute_distribution(output) == histograph, description


def test_shuffle_in_temp_files():
# [
# "0000 0000 0000 ... 0000",
# "0001 0001 0001 ... 0001",
# "0002 0002 0002 ... 0002",
# ...
# ]
line_stream = [f"{line:09d}\t" * 10 for line in range(ITEMS)]

# Total byte size is ~10_000_000
chunk_bytes = 100_000
bucket_bytes = 2_000_000
data_dir = DataDir("test_common_datasets")

with io.StringIO() as output:
shuffle_in_temp_files(
line_stream,
output=output,
seed="test",
chunk_bytes=chunk_bytes,
bucket_bytes=bucket_bytes,
chunk_dir=data_dir.path,
keep_chunks=True,
)

data_dir.print_tree()

output.seek(0)
text = output.read()
lines = [*text.splitlines()]
sample = lines[:MAX_LINES]

output.seek(0)
with open(data_dir.join("shuffle.txt"), "w") as file:
print(output.getvalue(), file=file)

assert len(lines) == ITEMS
assert compute_distribution(sample) == [
0.149,
0.258,
0.04,
0.1,
0.001, # The distribution is not perfect with this strategy.
0.052,
0.05,
0.1,
0.101,
0.15,
]

0 comments on commit 19e46e5

Please sign in to comment.