-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
336 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
import os | ||
import tempfile | ||
from collections import deque | ||
from io import TextIOWrapper | ||
from random import Random | ||
from typing import Iterator, Optional | ||
|
||
|
||
def shuffle_with_max_lines( | ||
line_stream: Iterator[str], | ||
seed: str, | ||
max_lines: int, | ||
max_words_in_sentence, | ||
total_byte_size: int, | ||
) -> Iterator[str]: | ||
""" | ||
Shuffle a line stream, but only retain up to a maximum number of lines in memory. | ||
Note that the final ordering is determined by the seed and the contents of the file. So | ||
running this multiple times on the same dataset will return the same result, but running | ||
it with the same seed and different content will create a different ordering. | ||
Only run for monolingual data or where the parallel sentences are separated by a delimiter. | ||
The distribution should be even unless the initial content is not representative of the | ||
general size of the sentences, in this case the distribution will be slightly biased. See | ||
the test cases for more in-depth examples. | ||
""" | ||
lines = deque() | ||
|
||
random = Random(seed) # Make this deterministic based on dataset key. | ||
|
||
total_bytes = 0 | ||
|
||
# Fill up the lines up until the max, and measure the total bytes. | ||
for line in line_stream: | ||
# Encoding returns the underlying byte representation which is then measured. | ||
total_bytes = total_bytes + len(line.encode("utf-8")) | ||
|
||
if len(line.split()) > max_words_in_sentence: | ||
# TODO(CJK) - Issue #424 | ||
# This sentence is too long. | ||
continue | ||
|
||
lines.append(line) | ||
|
||
if len(lines) == max_lines: | ||
break | ||
|
||
random.shuffle(lines) | ||
|
||
# Consume the rest of the line stream, but sample based on the probability that adding | ||
# something to the collection will be representative. | ||
|
||
i = 0 | ||
for line in line_stream: | ||
i = i + 1 | ||
# Continuously adjust this estimation in case the first sampled data is not representative. | ||
total_bytes = total_bytes + len(line.encode("utf-8")) | ||
average_bytes_per_line = total_bytes / (max_lines + i) | ||
estimated_lines = total_byte_size / average_bytes_per_line | ||
line_sampling_probability = max_lines / estimated_lines | ||
|
||
if random.random() < line_sampling_probability: | ||
# Shift the deque so the oldest line is shifted out, and this new sample is shifted in. | ||
lines.popleft() | ||
lines.append(line) | ||
|
||
# Do a final shuffle to ensure that the newly sampled lines are shuffled with the original | ||
# set of shuffled lines. | ||
random.shuffle(lines) | ||
|
||
return lines | ||
|
||
|
||
def shuffle_in_temp_files( | ||
line_stream: Iterator[str], | ||
output: TextIOWrapper, | ||
seed: str, | ||
chunk_bytes: int, | ||
bucket_bytes: int, | ||
chunk_dir: Optional[str] = tempfile.gettempdir(), | ||
keep_chunks=False, | ||
): | ||
""" | ||
Shuffle large datasets by storing chunks to the file system. The ordering is guaranteed to be | ||
stable across two datasets as long as they are the same length. For instance it could be used | ||
to shuffle `dataset.en.zst` and `dataset.ca.zst` the same if the two are parallel sentences. | ||
Take in a stream of lines (from a download, or stdin) and split it out to chunks. | ||
tmpdir | ||
├── chunk.1 | ||
├── chunk.2 | ||
├── chunk.3 | ||
├── chunk.4 | ||
├── ... | ||
└── chunk.100 | ||
After the entire dataset is written to chunks, pick random chunks and put them into a | ||
bucket. Only one bucket is fully loaded into memory at a time, and the contents | ||
of the bucket is shuffled in memory. | ||
Bucket: | ||
┌───────────┐ | ||
│ chunk.85 │ | ||
│ chunk.3 │ | ||
│ chunk.52 │ | ||
│ chunk.30 │ | ||
│ chunk.12 │ | ||
│ chunk.18 │ | ||
└───────────┘ | ||
• shuffle bucket lines | ||
• write to output | ||
At most 1 bucket will be held in memory. At most the dataset + 1 bucket of file space will be | ||
needed when running this algorithm. | ||
""" | ||
random = Random(seed) | ||
|
||
chunk_index = 0 | ||
chunk_file = open(os.path.join(chunk_dir, f"chunk.{chunk_index}"), "wt") | ||
|
||
# Write out the chunks to disk. | ||
bytes_written_to_chunk = 0 | ||
for line in line_stream: | ||
line_bytes = len(line.encode("utf-8")) + 1 | ||
|
||
if bytes_written_to_chunk + line_bytes > chunk_bytes: | ||
# Start a new chunk. | ||
chunk_file.close() | ||
chunk_index += 1 | ||
chunk_file = open(os.path.join(chunk_dir, f"chunk.{chunk_index}"), "wt") | ||
bytes_written_to_chunk = 0 | ||
|
||
chunk_file.write(line + "\n") | ||
bytes_written_to_chunk += line_bytes | ||
|
||
chunk_file.close() | ||
|
||
# Shuffle the chunk indexes | ||
chunk_count = chunk_index + 1 | ||
|
||
shuffled_chunk_indexes = [*range(chunk_count)] | ||
random.shuffle(shuffled_chunk_indexes) | ||
|
||
# Load a single bucket into memory, discarding the chunks. | ||
bucket_count = 0 | ||
bytes_in_bucket = 0 | ||
bucket = [] | ||
|
||
for chunk_index in shuffled_chunk_indexes: | ||
chunk_name = os.path.join(chunk_dir, f"chunk.{chunk_index}") | ||
|
||
# Read in the chunk line by line. | ||
with open(chunk_name, "r") as file: | ||
for line in file.readlines(): | ||
bucket.append(line) | ||
bytes_in_bucket += len(line.encode("utf-8")) | ||
|
||
# If the bucket overflows, shuffle and write it out. | ||
if bytes_in_bucket > bucket_bytes: | ||
random.shuffle(bucket) | ||
for shuffled_line in bucket: | ||
output.write(shuffled_line) | ||
|
||
# Create the new bucket. | ||
bucket = [] | ||
bytes_in_bucket = 0 | ||
bucket_count += 1 | ||
|
||
if not keep_chunks: | ||
os.remove(chunk_name) | ||
|
||
if len(bucket) > 0: | ||
random.shuffle(bucket) | ||
for shuffled_line in bucket: | ||
output.write(shuffled_line) | ||
|
||
print(f"Shuffled with {bucket_count} buckets.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
import io | ||
from typing import Iterator | ||
|
||
import pytest | ||
from fixtures import DataDir | ||
|
||
from pipeline.common.datasets import shuffle_in_temp_files, shuffle_with_max_lines | ||
|
||
ITEMS = 100_000 | ||
# ITEMS = 1_000 | ||
PERCENTAGE = 0.2 | ||
MAX_LINES = int(ITEMS * PERCENTAGE) | ||
|
||
|
||
def get_total_byte_size(lines: list[str]) -> int: | ||
total_byte_size = 0 | ||
for line in lines: | ||
total_byte_size = total_byte_size + len(line.encode()) | ||
return total_byte_size | ||
|
||
|
||
def compute_distribution(lines: Iterator[str], items=ITEMS, max_lines=MAX_LINES) -> list[float]: | ||
""" | ||
Computes a histogram (list of 10 items) with a percentage value of 0.0 - 100.0 for each item. | ||
""" | ||
histogram = [0] * 10 | ||
for line in lines: | ||
# This assumes the content will be a tab separated list, with the first item to be the | ||
# initial sorted order in the list. | ||
key = int(int(line.split("\t")[0]) * 10 / items) | ||
histogram[key] = histogram[key] + (1 / max_lines) | ||
|
||
# Lower the precision of the ints. | ||
return [round(value * 1000) / 1000 for value in histogram] | ||
|
||
|
||
# Test the distributions of the different types of datasets. This shuffler estimates the content | ||
# size as it iterates through the line stream. | ||
shuffle_params = [ | ||
( | ||
# Each line is the same bytes as the next line. This should create an even distribution. | ||
# [ | ||
# "000000000 000000000 000000000 ... 000000000", | ||
# "000000001 000000001 000000001 ... 000000001", | ||
# ... | ||
# ] | ||
"even-distribution", | ||
[f"{line:09d}\t" * 10 for line in range(ITEMS)], | ||
[0.102, 0.101, 0.099, 0.1, 0.1, 0.102, 0.099, 0.097, 0.1, 0.1], | ||
), | ||
( | ||
# The initial lines are low in byte count, and gradually increase. In this case there | ||
# will be a bias to over-sample the the initial items, but it will eventually even out as | ||
# more bytes are read in and the average spreads out. | ||
# [ | ||
# "0 0 0 ... 0", | ||
# "1 1 1 ... 1", | ||
# ... | ||
# "99997 99997 99997 ... 99997", | ||
# "99998 99998 99998 ... 99998", | ||
# "99999 99999 99999 ... 99999", | ||
# ] | ||
"small-content-at-start", | ||
[f"{line}\t" * 10 for line in range(ITEMS)], | ||
[0.114, 0.116, 0.092, 0.095, 0.096, 0.099, 0.097, 0.095, 0.098, 0.098], | ||
# | | | | ||
# | | ^ Lower sample rate. | ||
# ^^^^^^^ Higher sampling rate. | ||
), | ||
( | ||
# [ | ||
# "99999 99999 99999 ... 99999", | ||
# "99998 99998 99998 ... 99998", | ||
# "99997 99997 99997 ... 99997", | ||
# ... | ||
# "1 1 1 ... 1", | ||
# "0 0 0 ... 0", | ||
# ] | ||
"large-content-at-start", | ||
[f"{line}\t" * 10 for line in range(ITEMS)][::-1], | ||
[0.101, 0.102, 0.099, 0.102, 0.103, 0.102, 0.101, 0.102, 0.102, 0.086], | ||
# lower sample rate ^^^^^ | ||
), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("params", shuffle_params, ids=[d[0] for d in shuffle_params]) | ||
def test_shuffle_with_max_lines(params): | ||
description, line_stream, histograph = params | ||
# [ | ||
# "0000 0000 0000 ... 0000", | ||
# "0001 0001 0001 ... 0001", | ||
# "0002 0002 0002 ... 0002", | ||
# ... | ||
# ] | ||
|
||
output = shuffle_with_max_lines( | ||
line_stream, | ||
seed="test", | ||
max_lines=MAX_LINES, | ||
max_words_in_sentence=100, | ||
total_byte_size=get_total_byte_size(line_stream), | ||
) | ||
|
||
assert compute_distribution(output) == histograph, description | ||
|
||
|
||
def test_shuffle_in_temp_files(): | ||
# [ | ||
# "0000 0000 0000 ... 0000", | ||
# "0001 0001 0001 ... 0001", | ||
# "0002 0002 0002 ... 0002", | ||
# ... | ||
# ] | ||
line_stream = [f"{line:09d}\t" * 10 for line in range(ITEMS)] | ||
|
||
# Total byte size is ~10_000_000 | ||
chunk_bytes = 100_000 | ||
bucket_bytes = 2_000_000 | ||
data_dir = DataDir("test_common_datasets") | ||
|
||
with io.StringIO() as output: | ||
shuffle_in_temp_files( | ||
line_stream, | ||
output=output, | ||
seed="test", | ||
chunk_bytes=chunk_bytes, | ||
bucket_bytes=bucket_bytes, | ||
chunk_dir=data_dir.path, | ||
keep_chunks=True, | ||
) | ||
|
||
data_dir.print_tree() | ||
|
||
output.seek(0) | ||
text = output.read() | ||
lines = [*text.splitlines()] | ||
sample = lines[:MAX_LINES] | ||
|
||
output.seek(0) | ||
with open(data_dir.join("shuffle.txt"), "w") as file: | ||
print(output.getvalue(), file=file) | ||
|
||
assert len(lines) == ITEMS | ||
assert compute_distribution(sample) == [ | ||
0.149, | ||
0.258, | ||
0.04, | ||
0.1, | ||
0.001, # The distribution is not perfect with this strategy. | ||
0.052, | ||
0.05, | ||
0.1, | ||
0.101, | ||
0.15, | ||
] |