Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for initializing lhotse shar dataloader via field: list[path] mapping #11460

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,11 @@ Some other Lhotse related arguments we support:
Specifying this option will result in ``manifest_filepaths`` and ``tarred_audio_filepaths`` being ignored.
* ``shar_path``
Can be provided to read data from a Lhotse Shar manifest instead of a NeMo manifest.
Specifying this option will result in ``manifest_filepaths`` and ``tarred_audio_filepaths`` being ignored.
This argument can be a string (single Shar directory), a list of strings (Shar directories),
or a list of 2-item lists, where the first item is a Shar directory path, and the other is a sampling weight.
Specifying this option will result in ``manifest_filepaths`` and ``tarred_audio_filepaths`` being ignored.
The user can also provide a dict mapping Lhotse Shar fields to a list of shard paths with data for that field.
For details about Lhotse Shar format, see: |tutorial_shar|
* ``bucket_duration_bins``
Duration bins are a list of float values (seconds) that when provided, will skip the initial bucket bin estimation
and save some time. It has to have a length of ``num_buckets - 1``. An optimal value can be obtained by running CLI:
Expand Down
27 changes: 24 additions & 3 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@
from functools import partial
from itertools import repeat
from pathlib import Path
from typing import Sequence, Tuple, Union
from typing import Mapping, Sequence, Tuple, Union

import omegaconf
from lhotse import CutSet, Features, Recording
from lhotse.array import Array, TemporalArray
from lhotse.cut import Cut, MixedCut, PaddingCut
from omegaconf import DictConfig, ListConfig, OmegaConf

from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator, LazyNeMoTarredIterator
from nemo.collections.common.data.lhotse.nemo_adapters import (
LazyNeMoIterator,
LazyNeMoTarredIterator,
expand_sharded_filepaths,
)
from nemo.collections.common.data.lhotse.text_adapters import LhotseTextAdapter, LhotseTextPairAdapter
from nemo.collections.common.parts.preprocessing.manifest import get_full_path

Expand Down Expand Up @@ -281,7 +285,7 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet:
)
if not metadata_only and not force_finite:
cuts = cuts.repeat()
else:
elif isinstance(config.shar_path, Sequence):
# Multiple datasets in Lhotse Shar format: we will dynamically multiplex them
# with probability approximately proportional to their size
logging.info(
Expand Down Expand Up @@ -318,6 +322,23 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet:
seed=config.shard_seed,
force_finite=force_finite,
)
elif isinstance(config.shar_path, Mapping):
fields = {k: expand_sharded_filepaths(v) for k, v in config.shar_path.items()}
assert "cuts" in config.shar_path.keys(), (
f"Invalid value for key 'shar_path': a dict was provided, but didn't specify key 'cuts' pointing "
f"to the manifests. We got the following: {config.shar_path=}"
)
if metadata_only:
fields = {"cuts": fields["cuts"]}
cuts = CutSet.from_shar(fields=fields, shuffle_shards=True, seed=shard_seed)
if not metadata_only and not force_finite:
cuts = cuts.repeat()
else:
raise RuntimeError(
f"Unexpected value for key 'shar_path'. We support string, list of strings, "
f"list of tuples[string,float], and dict[string,list[string]], "
f"but got: {type(config.shar_path)=} {config.shar_path=}"
)
else:
# Regular Lhotse manifest points to individual audio files (like native NeMo manifest).
path = config.cuts_path
Expand Down
59 changes: 59 additions & 0 deletions tests/collections/common/test_lhotse_dataloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from lhotse.audio import AudioLoadingError
from lhotse.cut import Cut, MixedCut
from lhotse.cut.text import TextPairExample
from lhotse.shar import JsonlShardWriter
from lhotse.testing.dummies import dummy_recording
from omegaconf import OmegaConf

Expand Down Expand Up @@ -413,6 +414,64 @@ def test_dataloader_from_lhotse_shar_cuts(cutset_shar_path: Path):
assert b["audio"].shape[0] == b["audio_lens"].shape[0] == 3


def test_dataloader_from_lhotse_shar_cuts_via_fields(cutset_shar_path: Path):
config = OmegaConf.create(
{
"shar_path": {
"cuts": f"{cutset_shar_path}/cuts._OP_000000..000001_CL_.jsonl.gz",
"recording": f"{cutset_shar_path}/recording._OP_000000..000001_CL_.tar",
},
"sample_rate": 16000,
"num_workers": 0,
"shuffle": False,
"batch_size": 4,
"seed": 0,
"shard_seed": 0,
}
)

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity())

batch = next(iter(dl))
assert len(batch) == 4
audio = batch[0].load_audio()
assert isinstance(audio, np.ndarray)


def test_dataloader_from_lhotse_shar_cuts_add_new_field(tmp_path_factory, cutset_shar_path: Path):

# We're creating a new field called "wer" that will be dynamically attached to Lhotse Shar cuts.
# Each "wer" shard is a jsonl manifest that has to match the "cuts" sharded manifest.
# It must have a "cut_id" field used for runtime check that the user provided correct paths.
# "wer" will be attached to each cut under `cut.wer` / cut.custom["wer"].
wer_dir = tmp_path_factory.mktemp("wer_dir")
with JsonlShardWriter(f"{wer_dir}/wer.%06d.jsonl.gz", shard_size=5) as writer:
for i in range(10):
writer.write({"cut_id": "dummy-mono-cut-%04d" % i, "wer": 0.5})

config = OmegaConf.create(
{
"shar_path": {
"cuts": f"{cutset_shar_path}/cuts._OP_000000..000001_CL_.jsonl.gz",
"recording": f"{cutset_shar_path}/recording._OP_000000..000001_CL_.tar",
"wer": f"{wer_dir}/wer._OP_000000..000001_CL_.jsonl.gz",
},
"sample_rate": 16000,
"num_workers": 0,
"shuffle": False,
"batch_size": 4,
"seed": 0,
"shard_seed": 0,
}
)

dl = get_lhotse_dataloader_from_config(config=config, global_rank=0, world_size=1, dataset=Identity())

batch = next(iter(dl))
assert len(batch) == 4
assert batch[0].wer == 0.5


def test_dataloader_from_nemo_manifest(nemo_manifest_path: Path):
config = OmegaConf.create(
{
Expand Down
Loading