Skip to content

Commit

Permalink
Fix mono downloader (#576)
Browse files Browse the repository at this point in the history
* Make shuffling efficient

* Optimize memory usage, don't run redundant shuffling

* Fix tests
  • Loading branch information
eu9ene authored May 9, 2024
1 parent 04e9e9c commit 90ba11d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
26 changes: 18 additions & 8 deletions pipeline/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import deque
from io import TextIOWrapper
from random import Random
from typing import Iterator, Optional
from typing import Iterable, Iterator, Optional


class Dataset:
Expand Down Expand Up @@ -56,7 +56,7 @@ def shuffle_with_max_lines(
max_lines: int,
max_words_in_sentence,
total_byte_size: int,
) -> Iterator[str]:
) -> Iterable[str]:
"""
Shuffle a line stream, but only retain up to a maximum number of lines in memory.
Note that the final ordering is determined by the seed and the contents of the file. So
Expand Down Expand Up @@ -90,14 +90,21 @@ def shuffle_with_max_lines(
if len(lines) == max_lines:
break

random.shuffle(lines)
# random.shuffle requires random access via indexing
# deque supports fast adding/removing from its ends with O(1)
# but indexing is O(N) which is too slow for shuffling large arrays
lines_list = list(lines)
lines = None
random.shuffle(lines_list)

# Consume the rest of the line stream, but sample based on the probability that adding
# something to the collection will be representative.

i = 0
for line in line_stream:
i = i + 1
if lines is None:
lines = deque(lines_list)
lines_list = None
# Continuously adjust this estimation in case the first sampled data is not representative.
total_bytes = total_bytes + len(line.encode("utf-8"))
average_bytes_per_line = total_bytes / (max_lines + i)
Expand All @@ -109,11 +116,14 @@ def shuffle_with_max_lines(
lines.popleft()
lines.append(line)

# Do a final shuffle to ensure that the newly sampled lines are shuffled with the original
# set of shuffled lines.
random.shuffle(lines)
if i != 0:
# Do a final shuffle to ensure that the newly sampled lines are shuffled with the original
# set of shuffled lines.
lines_list = list(lines)
del lines
random.shuffle(lines_list)

return lines
return lines_list


def shuffle_in_temp_files(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_data_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ def make_url_dataset(lang: str):


mono_params = [
("news-crawl", "en", "news_2021", [0, 1, 4, 6, 3, 7, 5, 2]),
("news-crawl", "ru", "news_2021", [0, 1, 4, 6, 3, 7, 5, 2]),
("url", "en", make_url_dataset("en"), [2, 1, 5, 4, 0, 7, 6, 3]),
("url", "ru", make_url_dataset("ru"), [5, 4, 2, 0, 7, 1, 3, 6]),
("news-crawl", "en", "news_2021", [2, 5, 3, 7, 0, 6, 4, 1]),
("news-crawl", "ru", "news_2021", [2, 5, 3, 7, 0, 6, 4, 1]),
("url", "en", make_url_dataset("en"), [3, 4, 5, 0, 1, 6, 2, 7]),
("url", "ru", make_url_dataset("ru"), [5, 6, 2, 4, 7, 1, 3, 0]),
]


Expand Down

0 comments on commit 90ba11d

Please sign in to comment.