diff --git a/pipeline/common/datasets.py b/pipeline/common/datasets.py new file mode 100644 index 000000000..750c73549 --- /dev/null +++ b/pipeline/common/datasets.py @@ -0,0 +1,180 @@ +import os +import tempfile +from collections import deque +from io import TextIOWrapper +from random import Random +from typing import Iterator, Optional + + +def shuffle_with_max_lines( + line_stream: Iterator[str], + seed: str, + max_lines: int, + max_words_in_sentence, + total_byte_size: int, +) -> Iterator[str]: + """ + Shuffle a line stream, but only retain up to a maximum number of lines in memory. + Note that the final ordering is determined by the seed and the contents of the file. So + running this multiple times on the same dataset will return the same result, but running + it with the same seed and different content will create a different ordering. + + Only run for monolingual data or where the parallel sentences are separated by a delimiter. + + The distribution should be even unless the initial content is not representative of the + general size of the sentences, in this case the distribution will be slightly biased. See + the test cases for more in-depth examples. + """ + lines = deque() + + random = Random(seed) # Make this deterministic based on dataset key. + + total_bytes = 0 + + # Fill up the lines up until the max, and measure the total bytes. + for line in line_stream: + # Encoding returns the underlying byte representation which is then measured. + total_bytes = total_bytes + len(line.encode("utf-8")) + + if len(line.split()) > max_words_in_sentence: + # TODO(CJK) - Issue #424 + # This sentence is too long. + continue + + lines.append(line) + + if len(lines) == max_lines: + break + + random.shuffle(lines) + + # Consume the rest of the line stream, but sample based on the probability that adding + # something to the collection will be representative. + + i = 0 + for line in line_stream: + i = i + 1 + # Continuously adjust this estimation in case the first sampled data is not representative. + total_bytes = total_bytes + len(line.encode("utf-8")) + average_bytes_per_line = total_bytes / (max_lines + i) + estimated_lines = total_byte_size / average_bytes_per_line + line_sampling_probability = max_lines / estimated_lines + + if random.random() < line_sampling_probability: + # Shift the deque so the oldest line is shifted out, and this new sample is shifted in. + lines.popleft() + lines.append(line) + + # Do a final shuffle to ensure that the newly sampled lines are shuffled with the original + # set of shuffled lines. + random.shuffle(lines) + + return lines + + +def shuffle_in_temp_files( + line_stream: Iterator[str], + output: TextIOWrapper, + seed: str, + chunk_bytes: int, + bucket_bytes: int, + chunk_dir: Optional[str] = tempfile.gettempdir(), + keep_chunks=False, +): + """ + Shuffle large datasets by storing chunks to the file system. The ordering is guaranteed to be + stable across two datasets as long as they are the same length. For instance it could be used + to shuffle `dataset.en.zst` and `dataset.ca.zst` the same if the two are parallel sentences. + + Take in a stream of lines (from a download, or stdin) and split it out to chunks. + + tmpdir + ├── chunk.1 + ├── chunk.2 + ├── chunk.3 + ├── chunk.4 + ├── ... + └── chunk.100 + + After the entire dataset is written to chunks, pick random chunks and put them into a + bucket. Only one bucket is fully loaded into memory at a time, and the contents + of the bucket is shuffled in memory. + + Bucket: + ┌───────────┐ + │ chunk.85 │ + │ chunk.3 │ + │ chunk.52 │ + │ chunk.30 │ + │ chunk.12 │ + │ chunk.18 │ + └───────────┘ + + • shuffle bucket lines + • write to output + + At most 1 bucket will be held in memory. At most the dataset + 1 bucket of file space will be + needed when running this algorithm. + """ + random = Random(seed) + + chunk_index = 0 + chunk_file = open(os.path.join(chunk_dir, f"chunk.{chunk_index}"), "wt") + + # Write out the chunks to disk. + bytes_written_to_chunk = 0 + for line in line_stream: + line_bytes = len(line.encode("utf-8")) + 1 + + if bytes_written_to_chunk + line_bytes > chunk_bytes: + # Start a new chunk. + chunk_file.close() + chunk_index += 1 + chunk_file = open(os.path.join(chunk_dir, f"chunk.{chunk_index}"), "wt") + bytes_written_to_chunk = 0 + + chunk_file.write(line + "\n") + bytes_written_to_chunk += line_bytes + + chunk_file.close() + + # Shuffle the chunk indexes + chunk_count = chunk_index + 1 + + shuffled_chunk_indexes = [*range(chunk_count)] + random.shuffle(shuffled_chunk_indexes) + + # Load a single bucket into memory, discarding the chunks. + bucket_count = 0 + bytes_in_bucket = 0 + bucket = [] + + for chunk_index in shuffled_chunk_indexes: + chunk_name = os.path.join(chunk_dir, f"chunk.{chunk_index}") + + # Read in the chunk line by line. + with open(chunk_name, "r") as file: + for line in file.readlines(): + bucket.append(line) + bytes_in_bucket += len(line.encode("utf-8")) + + # If the bucket overflows, shuffle and write it out. + if bytes_in_bucket > bucket_bytes: + random.shuffle(bucket) + for shuffled_line in bucket: + output.write(shuffled_line) + + # Create the new bucket. + bucket = [] + bytes_in_bucket = 0 + bucket_count += 1 + + if not keep_chunks: + os.remove(chunk_name) + + if len(bucket) > 0: + random.shuffle(bucket) + for shuffled_line in bucket: + output.write(shuffled_line) + + print(f"Shuffled with {bucket_count} buckets.") diff --git a/tests/test_common_datasets.py b/tests/test_common_datasets.py new file mode 100644 index 000000000..7e23283f5 --- /dev/null +++ b/tests/test_common_datasets.py @@ -0,0 +1,156 @@ +import io +from typing import Iterator + +import pytest +from fixtures import DataDir + +from pipeline.common.datasets import shuffle_in_temp_files, shuffle_with_max_lines + +ITEMS = 100_000 +# ITEMS = 1_000 +PERCENTAGE = 0.2 +MAX_LINES = int(ITEMS * PERCENTAGE) + + +def get_total_byte_size(lines: list[str]) -> int: + total_byte_size = 0 + for line in lines: + total_byte_size = total_byte_size + len(line.encode()) + return total_byte_size + + +def compute_distribution(lines: Iterator[str], items=ITEMS, max_lines=MAX_LINES) -> list[float]: + """ + Computes a histogram (list of 10 items) with a percentage value of 0.0 - 100.0 for each item. + """ + histogram = [0] * 10 + for line in lines: + # This assumes the content will be a tab separated list, with the first item to be the + # initial sorted order in the list. + key = int(int(line.split("\t")[0]) * 10 / items) + histogram[key] = histogram[key] + (1 / max_lines) + + # Lower the precision of the ints. + return [round(value * 1000) / 1000 for value in histogram] + + +# Test the distributions of the different types of datasets. This shuffler estimates the content +# size as it iterates through the line stream. +shuffle_params = [ + ( + # Each line is the same bytes as the next line. This should create an even distribution. + # [ + # "000000000 000000000 000000000 ... 000000000", + # "000000001 000000001 000000001 ... 000000001", + # ... + # ] + "even-distribution", + [f"{line:09d}\t" * 10 for line in range(ITEMS)], + [0.102, 0.101, 0.099, 0.1, 0.1, 0.102, 0.099, 0.097, 0.1, 0.1], + ), + ( + # The initial lines are low in byte count, and gradually increase. In this case there + # will be a bias to over-sample the the initial items, but it will eventually even out as + # more bytes are read in and the average spreads out. + # [ + # "0 0 0 ... 0", + # "1 1 1 ... 1", + # ... + # "99997 99997 99997 ... 99997", + # "99998 99998 99998 ... 99998", + # "99999 99999 99999 ... 99999", + # ] + "small-content-at-start", + [f"{line}\t" * 10 for line in range(ITEMS)], + [0.114, 0.116, 0.092, 0.095, 0.096, 0.099, 0.097, 0.095, 0.098, 0.098], + # | | | + # | | ^ Lower sample rate. + # ^^^^^^^ Higher sampling rate. + ), + ( + # [ + # "99999 99999 99999 ... 99999", + # "99998 99998 99998 ... 99998", + # "99997 99997 99997 ... 99997", + # ... + # "1 1 1 ... 1", + # "0 0 0 ... 0", + # ] + "large-content-at-start", + [f"{line}\t" * 10 for line in range(ITEMS)][::-1], + [0.101, 0.102, 0.099, 0.102, 0.103, 0.102, 0.101, 0.102, 0.102, 0.086], + # lower sample rate ^^^^^ + ), +] + + +@pytest.mark.parametrize("params", shuffle_params, ids=[d[0] for d in shuffle_params]) +def test_shuffle_with_max_lines(params): + description, line_stream, histograph = params + # [ + # "0000 0000 0000 ... 0000", + # "0001 0001 0001 ... 0001", + # "0002 0002 0002 ... 0002", + # ... + # ] + + output = shuffle_with_max_lines( + line_stream, + seed="test", + max_lines=MAX_LINES, + max_words_in_sentence=100, + total_byte_size=get_total_byte_size(line_stream), + ) + + assert compute_distribution(output) == histograph, description + + +def test_shuffle_in_temp_files(): + # [ + # "0000 0000 0000 ... 0000", + # "0001 0001 0001 ... 0001", + # "0002 0002 0002 ... 0002", + # ... + # ] + line_stream = [f"{line:09d}\t" * 10 for line in range(ITEMS)] + + # Total byte size is ~10_000_000 + chunk_bytes = 100_000 + bucket_bytes = 2_000_000 + data_dir = DataDir("test_common_datasets") + + with io.StringIO() as output: + shuffle_in_temp_files( + line_stream, + output=output, + seed="test", + chunk_bytes=chunk_bytes, + bucket_bytes=bucket_bytes, + chunk_dir=data_dir.path, + keep_chunks=True, + ) + + data_dir.print_tree() + + output.seek(0) + text = output.read() + lines = [*text.splitlines()] + sample = lines[:MAX_LINES] + + output.seek(0) + with open(data_dir.join("shuffle.txt"), "w") as file: + print(output.getvalue(), file=file) + + assert len(lines) == ITEMS + assert compute_distribution(sample) == [ + 0.149, + 0.258, + 0.04, + 0.1, + 0.001, # The distribution is not perfect with this strategy. + 0.052, + 0.05, + 0.1, + 0.101, + 0.15, + ]