Skip to content

Commit

Permalink
Text Memmap Parsing Improvements (NVIDIA#5265)
Browse files Browse the repository at this point in the history
* 1. Fixed text-memmap issue when boundary (new-line) is missing from end of file).

Signed-off-by: Micha Livne <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* 1. Fixed style.

Signed-off-by: Micha Livne <[email protected]>

* 1. Added support in paratial sample loading and alternative decoding.

Signed-off-by: Micha Livne <[email protected]>

* 1. Fixed syntax issues.

Signed-off-by: Micha Livne <[email protected]>

* 1. Minor change.

Signed-off-by: Micha Livne <[email protected]>

* 1. Extended flexibility of mapping indices.

Signed-off-by: Micha Livne <[email protected]>

* 1. Added validation ofdtype of indexing function.

Signed-off-by: Micha Livne <[email protected]>

Signed-off-by: Micha Livne <[email protected]>
Co-authored-by: Micha Livne <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: Sandeep Subramanian <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
5 people authored and Hainan Xu committed Nov 29, 2022
1 parent 7363277 commit 40f3ba6
Showing 1 changed file with 74 additions and 27 deletions.
101 changes: 74 additions & 27 deletions nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,55 @@
__idx_suffix__ = 'idx' # index file suffix


def _build_index_from_memdata(fn, newline_int):
"""
Build index of delimiter positions between samples in memmap.
Can be provided externally.
Returns a 1D array of ints.
"""
# use memmap to read file
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
# find newline positions
midx = np.where(mdata == newline_int)[0]
midx_dtype = midx.dtype
# make sure to account for all data
midx = midx.tolist()
# add last item in case there is no new-line at the end of the file
if (len(midx) == 0) or (midx[-1] + 1 != len(mdata)):
midx = midx + [len(mdata) + 1]

# remove empty lines from end of file
while len(midx) > 1 and (midx[-1] - midx[-2]) < 2:
midx.pop(-1)
midx = np.asarray(midx, dtype=midx_dtype)

# free memmap
mdata._mmap.close()
del mdata

return midx


class TextMemMapDataset(Dataset):
"""
Allow per-line lazy access to multiple text files using numpy memmap.
"""

# FIXME: header_lines=0 by default
def __init__(
self, dataset_paths, newline_int=10, header_lines=0, workers=None, tokenizer=None, sort_dataset_paths=True,
self,
dataset_paths,
newline_int=10,
header_lines=0,
workers=None,
tokenizer=None,
sort_dataset_paths=True,
build_index_fn=_build_index_from_memdata,
):
"""
build_index_fn - a callable build_index_fn(fn, newline_int) -> midx [np.array] that returns the index of newlines in a file fn
must be pickleable (to be used in multiprocessing.Pool.map)
"""
super().__init__()
self.mdata_midx_list = []

Expand All @@ -65,7 +105,7 @@ def __init__(
is_ditributed = torch.distributed.is_available() and torch.distributed.is_initialized()

if not is_ditributed or (is_ditributed and torch.distributed.get_rank() == 0):
build_index_files(dataset_paths, newline_int, workers=self._worker)
build_index_files(dataset_paths, newline_int, workers=self._worker, build_index_fn=build_index_fn)

if is_ditributed:
torch.distributed.barrier()
Expand All @@ -83,20 +123,23 @@ def __init__(
self.midx_bins = midx_bins
self.mdata_midx_list = mdata_midx_list

# figure out size of the dataset
self._size = self.midx_bins[-1]

def __del__(self):
if self.mdata_midx_list:
for mdata, midx in self.mdata_midx_list:
mdata._mmap.close()

def __len__(self):
return self.midx_bins[-1]
return self._size

def __getitem__(self, idx):
"""
Return a string from binary memmap
"""
if (idx >= self.midx_bins[-1]) or (idx < 0):
raise IndexError(f"Index {idx} if out of dataset range with {self.midx_bins[-1]} samples")
if (idx >= len(self)) or (idx < 0):
raise IndexError(f"Index {idx} if out of dataset range with {len(self)} samples")

# Identify the file containing the record
file_id = np.digitize(idx, self.midx_bins, right=False)
Expand All @@ -111,13 +154,21 @@ def __getitem__(self, idx):
i = midx[file_idx - 1] + 1 # ignore newline
j = midx[file_idx]

text = mdata[i:j].tobytes().decode("utf-8")
# fetch sample from memmap

sample = self._fetch_sample_from_memmap(mdata, i, j)
# parse raw text (e.g., tokenize)
data = self._build_data_from_text(text)
data = self._build_data_from_text(sample)

return data

def _fetch_sample_from_memmap(self, mdata, i, j):
"""Fetchs the text sample. Can be overriden by child-classes to support loading of partial samples and alternative decode methods"""
# load text sample by slicing memmap data[i:j]
text = mdata[i:j].tobytes().decode("utf-8")

return text

def _build_data_from_text(self, text):
"""Allows child-classes to modify the parsing of raw text, prior to tokenization"""
# tokenize text if tokenizer is given
Expand Down Expand Up @@ -207,39 +258,35 @@ def _build_data_from_text(self, text):
return super()._build_data_from_text(text)


def _build_memmap_index_files(newline_int, fn):
def _build_memmap_index_files(newline_int, build_index_fn, fn):
"""Helper function to build an index file"""
idx_fn = f"{fn}.{__idx_suffix__}"

# create data map
mdata = np.memmap(fn, dtype=np.uint8, mode='r')
if os.path.exists(idx_fn + ".npy"):
return False
else:
logging.info(f"Building idx file = {idx_fn}.npy")
midx = np.where(mdata == newline_int)[0]
midx_dtype = midx.dtype
# add last item in case there is no new-line
if (len(midx) == 0) or (midx[-1] + 1 != len(mdata)):
midx = np.asarray(midx.tolist() + [len(midx) + 1], dtype=midx_dtype)

# remove empty lines from end of file
midx = midx.tolist()
while len(midx) > 1 and (midx[-1] - midx[-2]) < 2:
midx.pop(-1)
midx = np.asarray(midx, dtype=midx_dtype)

logging.info(f"Building indexing for fn = {fn}")
# find all newline positions
midx = build_index_fn(fn, newline_int)
# validate midx
midx = np.asarray(midx)
if not np.issubdtype(midx.dtype, np.integer):
raise TypeError(f"midx must be an integer array, but got type = {midx.dtype}")

# create e metadata file
data = dict(newline_int=newline_int, version=__idx_version__)

# save index as numpy array to enable memmap reading
logging.info(f"Saving idx file = {idx_fn}.npy")
np.save(idx_fn + ".npy", midx, allow_pickle=True)
logging.info(f"Saving metadata file = {idx_fn}.info")
pickle.dump(data, open(idx_fn + ".info", "wb"))
mdata._mmap.close()
del mdata

return True


def build_index_files(dataset_paths, newline_int, workers=None):
def build_index_files(dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata):
"""Auxiliary method to build multiple index files"""
if len(dataset_paths) < 1:
raise ValueError("files_list must contain at leat one file name")
Expand All @@ -251,7 +298,7 @@ def build_index_files(dataset_paths, newline_int, workers=None):
# load all files into memmap
start_time = time.time()
with mp.Pool(workers) as p:
build_status = p.map(partial(_build_memmap_index_files, newline_int), dataset_paths)
build_status = p.map(partial(_build_memmap_index_files, newline_int, build_index_fn), dataset_paths)

logging.info(
f'Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'
Expand Down

0 comments on commit 40f3ba6

Please sign in to comment.