Skip to content

Commit

Permalink
Support configurable extra fields for LazyNeMoTarredIterator (NVIDIA#…
Browse files Browse the repository at this point in the history
…9548)

* Support configurable extra fields for LazyNeMoTarredIterator

Signed-off-by: Piotr Żelasko <[email protected]>

* Add tests and fixes

Signed-off-by: Piotr Żelasko <[email protected]>

* Documentation, more tests

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko authored and XuesongYang committed Jan 18, 2025
1 parent f6fa4b8 commit 909d23d
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 6 deletions.
1 change: 1 addition & 0 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet:
"lang_field": config.lang_field,
"shuffle_shards": config.shuffle,
"shard_seed": config.shard_seed,
"extra_fields": config.get("extra_fields", None),
}
# The option below is to allow a special case of NeMo manifest iteration as Lhotse CutSet
# without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse,
Expand Down
160 changes: 154 additions & 6 deletions nemo/collections/common/data/lhotse/nemo_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import random
import re
import tarfile
from collections.abc import Mapping, Sequence
from io import BytesIO
from pathlib import Path
from typing import Generator, Iterable, List, Literal

import lhotse.serialization
import soundfile
from cytoolz import groupby
from lhotse import AudioSource, Recording, SupervisionSegment
Expand All @@ -28,6 +30,7 @@
from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator
from lhotse.serialization import open_best
from lhotse.utils import compute_num_samples

from nemo.collections.common.parts.preprocessing.manifest import get_full_path


Expand Down Expand Up @@ -56,16 +59,33 @@ class LazyNeMoIterator:
Example::
>>> cuts = lhotse.CutSet(LazyNeMoIterator("nemo_manifests/train.json"))
We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument.
In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line
under ``cut.question`` using the field type ``text_iter``::
>>> cuts = lhotse.CutSet(LazyNeMoIterator(
... "nemo_manifests/train.json",
... extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}],
... ))
We also support random sampling of lines with field type ``text_sample``::
>>> cuts = lhotse.CutSet(LazyNeMoIterator(
... "nemo_manifests/train.json",
... extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}],
... ))
"""

def __init__(
self,
path: str | Path,
path: str | Path | list[str],
text_field: str = "text",
lang_field: str = "lang",
metadata_only: bool = False,
shuffle_shards: bool = False,
shard_seed: int | Literal["randomized", "trng"] = "trng",
extra_fields: list[dict[str, str]] | None = None,
) -> None:
self.path = path
self.shuffle_shards = shuffle_shards
Expand All @@ -80,8 +100,13 @@ def __init__(
self.text_field = text_field
self.lang_field = lang_field
self.metadata_only = metadata_only
self.extra_fields = extra_fields
validate_extra_fields(self.extra_fields)

def __iter__(self) -> Generator[Cut, None, None]:
seed = resolve_seed(self.shard_seed)
# Propagate the random seed
extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()]
for data in self.source:
audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path))
duration = data.pop("duration")
Expand All @@ -104,6 +129,8 @@ def __iter__(self) -> Generator[Cut, None, None]:
)
)
cut.custom = data
for extra_field in extra_fields:
extra_field.attach_to(cut)
yield cut

def __len__(self) -> int:
Expand Down Expand Up @@ -180,20 +207,39 @@ class LazyNeMoTarredIterator:
Example of CutSet with inter-shard shuffling enabled::
>>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
... manifest_path="nemo_manifests/train.json",
... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
... tar_paths=["nemo_manifests/audio_0.tar", ...],
... shuffle_shards=True,
... ))
We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument.
In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line
under ``cut.question`` using the field type ``text_iter``::
>>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
... tar_paths=["nemo_manifests/audio_0.tar", ...],
... extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}],
... ))
We also support random sampling of lines with field type ``text_sample``::
>>> cuts = lhotse.CutSet(LazyNeMoTarredIterator(
... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...],
... tar_paths=["nemo_manifests/audio_0.tar", ...],
... extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}],
... ))
"""

def __init__(
self,
manifest_path: str | Path,
manifest_path: str | Path | list[str],
tar_paths: str | list,
shuffle_shards: bool = False,
shard_seed: int | Literal["trng", "randomized"] = "trng",
text_field: str = "text",
lang_field: str = "lang",
extra_fields: list[dict[str, str]] | None = None,
) -> None:
self.shard_id_to_manifest: dict[int, Iterable[dict]]
self.paths = expand_sharded_filepaths(manifest_path)
Expand Down Expand Up @@ -235,6 +281,7 @@ def __init__(
self.shard_seed = shard_seed
self.text_field = text_field
self.lang_field = lang_field
self.extra_fields = extra_fields
self._validate()

def to_shards(self) -> List["LazyNeMoTarredIterator"]:
Expand Down Expand Up @@ -266,6 +313,7 @@ def _validate(self) -> None:
f"* JSON manifest(s) indicate(s) IDs: {sorted(shard_ids_manifest)}\n"
f"* Tar path(s) indicate(s) IDs: {sorted(shard_ids_tars)}\n"
)
validate_extra_fields(self.extra_fields)

@property
def shard_ids(self) -> List[int]:
Expand All @@ -274,10 +322,13 @@ def shard_ids(self) -> List[int]:
def __iter__(self) -> Generator[Cut, None, None]:
shard_ids = self.shard_ids

seed = resolve_seed(self.shard_seed)
if self.shuffle_shards:
seed = resolve_seed(self.shard_seed)
random.Random(seed).shuffle(shard_ids)

# Propagate the random seed
extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()]

for sid in shard_ids:
manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0]
shard_manifest = {data["audio_filepath"]: data for data in self.shard_id_to_manifest[sid]}
Expand Down Expand Up @@ -314,6 +365,8 @@ def __iter__(self) -> Generator[Cut, None, None]:
)
)
cut.custom = _to_custom_attr_dict(data)
for extra_field in extra_fields:
extra_field.attach_to(cut)
yield cut

def __len__(self) -> int:
Expand All @@ -323,11 +376,106 @@ def __add__(self, other):
return LazyIteratorChain(self, other)


def expand_sharded_filepaths(path: str | Path) -> list[str]:
class ExtraField:
TYPE = None
SUPPORTED_TYPES = {}

def attach_to(self, cut):
raise NotImplementedError()

def __init_subclass__(cls, **kwargs):
if cls.__name__ not in ExtraField.SUPPORTED_TYPES:
ExtraField.SUPPORTED_TYPES[cls.TYPE] = cls
super().__init_subclass__(**kwargs)

@staticmethod
def from_dict(data: dict) -> "ExtraField":
assert data["type"] in ExtraField.SUPPORTED_TYPES, f"Unknown transform type: {data['type']}"
return ExtraField.SUPPORTED_TYPES[data["type"]](**{k: v for k, v in data.items() if k != 'type'})

@classmethod
def is_supported(cls, field_type: str) -> bool:
return field_type in cls.SUPPORTED_TYPES

@classmethod
def supported_types(cls) -> list[str]:
return list(cls.SUPPORTED_TYPES)


class TextIteratorExtraField(ExtraField):
TYPE = "text_iter"

def __init__(self, name: str, path: str, seed=None):
self.name = name
self.path = path
self.iterator = None

def _maybe_init(self):
if self.iterator is None:
self.iterator = iter(map(str.strip, open_best(self.path)))

def attach_to(self, cut):
self._maybe_init()
try:
attached_value = next(self.iterator)
except StopIteration:
raise RuntimeError(f"Not enough lines in file {self.path} to attach to cuts under field {self.name}.")
setattr(cut, self.name, attached_value)
return cut


class TextSampleExtraField(ExtraField):
TYPE = "text_sample"

def __init__(self, name: str, path: str, seed: int | str):
self.name = name
self.path = path
self.seed = seed
self.population = None
self.rng = None

def _maybe_init(self):
if self.population is None:
self.population = list(map(str.strip, open_best(self.path)))
self.rng = random.Random(resolve_seed(self.seed))

def attach_to(self, cut):
self._maybe_init()
attached_value = self.rng.choice(self.population)
setattr(cut, self.name, attached_value)
return cut


def validate_extra_fields(extra_fields):
if extra_fields is None:
return
assert isinstance(
extra_fields, Sequence
), f"The argument provided to 'extra_fields' must be a list of dicts. We received {extra_fields=}"
for field in extra_fields:
assert isinstance(
field, Mapping
), f"Each item in 'extra_fields' must be a dict. We received {field=} in {extra_fields=}"
field_type = field.get("type")
assert ExtraField.is_supported(field_type), (
f"Each item in 'extra_fields' must contain a 'type' field with one of "
f"the supported values ({ExtraField.supported_types()}). "
f"We got {field_type=} in {extra_fields=}"
)
assert "name" in field, (
f"Each item in 'extra_fields' must contain a 'name' field so that the field is available under cut.<name>."
f"We found {field=} in {extra_fields=}"
)


def expand_sharded_filepaths(paths: str | Path | list[str]) -> list[str]:
# local import to avoid circular imports
from nemo.collections.asr.data.audio_to_text import expand_sharded_filepaths as _expand_sharded_filepaths

return _expand_sharded_filepaths(str(path), shard_strategy="replicate", world_size=1, global_rank=0)
if isinstance(paths, Path):
paths = str(paths)

return _expand_sharded_filepaths(paths, shard_strategy="replicate", world_size=1, global_rank=0)


def _to_custom_attr_dict(d: dict, _excluded_fields: set[str] = {"duration", "audio_filepath"}) -> dict:
Expand Down
Loading

0 comments on commit 909d23d

Please sign in to comment.