diff --git a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py index 67181106bd04..9a7dd12d1d7d 100644 --- a/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/text_memmap_dataset.py @@ -30,15 +30,55 @@ __idx_suffix__ = 'idx' # index file suffix +def _build_index_from_memdata(fn, newline_int): + """ + Build index of delimiter positions between samples in memmap. + Can be provided externally. + + Returns a 1D array of ints. + """ + # use memmap to read file + mdata = np.memmap(fn, dtype=np.uint8, mode='r') + # find newline positions + midx = np.where(mdata == newline_int)[0] + midx_dtype = midx.dtype + # make sure to account for all data + midx = midx.tolist() + # add last item in case there is no new-line at the end of the file + if (len(midx) == 0) or (midx[-1] + 1 != len(mdata)): + midx = midx + [len(mdata) + 1] + + # remove empty lines from end of file + while len(midx) > 1 and (midx[-1] - midx[-2]) < 2: + midx.pop(-1) + midx = np.asarray(midx, dtype=midx_dtype) + + # free memmap + mdata._mmap.close() + del mdata + + return midx + + class TextMemMapDataset(Dataset): """ Allow per-line lazy access to multiple text files using numpy memmap. """ - # FIXME: header_lines=0 by default def __init__( - self, dataset_paths, newline_int=10, header_lines=0, workers=None, tokenizer=None, sort_dataset_paths=True, + self, + dataset_paths, + newline_int=10, + header_lines=0, + workers=None, + tokenizer=None, + sort_dataset_paths=True, + build_index_fn=_build_index_from_memdata, ): + """ + build_index_fn - a callable build_index_fn(fn, newline_int) -> midx [np.array] that returns the index of newlines in a file fn + must be pickleable (to be used in multiprocessing.Pool.map) + """ super().__init__() self.mdata_midx_list = [] @@ -65,7 +105,7 @@ def __init__( is_ditributed = torch.distributed.is_available() and torch.distributed.is_initialized() if not is_ditributed or (is_ditributed and torch.distributed.get_rank() == 0): - build_index_files(dataset_paths, newline_int, workers=self._worker) + build_index_files(dataset_paths, newline_int, workers=self._worker, build_index_fn=build_index_fn) if is_ditributed: torch.distributed.barrier() @@ -83,20 +123,23 @@ def __init__( self.midx_bins = midx_bins self.mdata_midx_list = mdata_midx_list + # figure out size of the dataset + self._size = self.midx_bins[-1] + def __del__(self): if self.mdata_midx_list: for mdata, midx in self.mdata_midx_list: mdata._mmap.close() def __len__(self): - return self.midx_bins[-1] + return self._size def __getitem__(self, idx): """ Return a string from binary memmap """ - if (idx >= self.midx_bins[-1]) or (idx < 0): - raise IndexError(f"Index {idx} if out of dataset range with {self.midx_bins[-1]} samples") + if (idx >= len(self)) or (idx < 0): + raise IndexError(f"Index {idx} if out of dataset range with {len(self)} samples") # Identify the file containing the record file_id = np.digitize(idx, self.midx_bins, right=False) @@ -111,13 +154,21 @@ def __getitem__(self, idx): i = midx[file_idx - 1] + 1 # ignore newline j = midx[file_idx] - text = mdata[i:j].tobytes().decode("utf-8") + # fetch sample from memmap + sample = self._fetch_sample_from_memmap(mdata, i, j) # parse raw text (e.g., tokenize) - data = self._build_data_from_text(text) + data = self._build_data_from_text(sample) return data + def _fetch_sample_from_memmap(self, mdata, i, j): + """Fetchs the text sample. Can be overriden by child-classes to support loading of partial samples and alternative decode methods""" + # load text sample by slicing memmap data[i:j] + text = mdata[i:j].tobytes().decode("utf-8") + + return text + def _build_data_from_text(self, text): """Allows child-classes to modify the parsing of raw text, prior to tokenization""" # tokenize text if tokenizer is given @@ -207,39 +258,35 @@ def _build_data_from_text(self, text): return super()._build_data_from_text(text) -def _build_memmap_index_files(newline_int, fn): +def _build_memmap_index_files(newline_int, build_index_fn, fn): """Helper function to build an index file""" idx_fn = f"{fn}.{__idx_suffix__}" # create data map - mdata = np.memmap(fn, dtype=np.uint8, mode='r') if os.path.exists(idx_fn + ".npy"): return False else: - logging.info(f"Building idx file = {idx_fn}.npy") - midx = np.where(mdata == newline_int)[0] - midx_dtype = midx.dtype - # add last item in case there is no new-line - if (len(midx) == 0) or (midx[-1] + 1 != len(mdata)): - midx = np.asarray(midx.tolist() + [len(midx) + 1], dtype=midx_dtype) - - # remove empty lines from end of file - midx = midx.tolist() - while len(midx) > 1 and (midx[-1] - midx[-2]) < 2: - midx.pop(-1) - midx = np.asarray(midx, dtype=midx_dtype) - + logging.info(f"Building indexing for fn = {fn}") + # find all newline positions + midx = build_index_fn(fn, newline_int) + # validate midx + midx = np.asarray(midx) + if not np.issubdtype(midx.dtype, np.integer): + raise TypeError(f"midx must be an integer array, but got type = {midx.dtype}") + + # create e metadata file data = dict(newline_int=newline_int, version=__idx_version__) + # save index as numpy array to enable memmap reading + logging.info(f"Saving idx file = {idx_fn}.npy") np.save(idx_fn + ".npy", midx, allow_pickle=True) + logging.info(f"Saving metadata file = {idx_fn}.info") pickle.dump(data, open(idx_fn + ".info", "wb")) - mdata._mmap.close() - del mdata return True -def build_index_files(dataset_paths, newline_int, workers=None): +def build_index_files(dataset_paths, newline_int, workers=None, build_index_fn=_build_index_from_memdata): """Auxiliary method to build multiple index files""" if len(dataset_paths) < 1: raise ValueError("files_list must contain at leat one file name") @@ -251,7 +298,7 @@ def build_index_files(dataset_paths, newline_int, workers=None): # load all files into memmap start_time = time.time() with mp.Pool(workers) as p: - build_status = p.map(partial(_build_memmap_index_files, newline_int), dataset_paths) + build_status = p.map(partial(_build_memmap_index_files, newline_int, build_index_fn), dataset_paths) logging.info( f'Time building {sum(build_status)} / {len(build_status)} mem-mapped files: {datetime.timedelta(seconds=time.time() - start_time)}'