Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge r1.1 bugfixes into main #2407

Merged
merged 13 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pipeline {
}
}
options {
timeout(time: 1, unit: 'HOURS')
timeout(time: 2, unit: 'HOURS')
disableConcurrentBuilds()
}
stages {
Expand Down
2 changes: 1 addition & 1 deletion examples/asr/asr_webapp/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_IMAGE=nvcr.io/nvidia/nemo:1.0.0rc1
ARG BASE_IMAGE=nvcr.io/nvidia/nemo:1.0.1

# build an image that includes only the nemo dependencies, ensures that dependencies
# are included first for optimal caching, and useful for building a development
Expand Down
2 changes: 1 addition & 1 deletion examples/nlp/machine_translation/enc_dec_nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def main(cfg: MTEncDecConfig) -> None:
# training is managed by PyTorch Lightning
trainer_cfg = OmegaConf.to_container(cfg.trainer)
trainer_cfg.pop('plugins', None)
trainer = Trainer(plugins=[NLPDDPPlugin()], **trainer_cfg)
trainer = Trainer(plugins=[NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)], **trainer_cfg)

# tokenizers will be trained and and tarred training data will be created if needed
# model config is then updated
Expand Down
2 changes: 1 addition & 1 deletion examples/speaker_recognition/extract_speaker_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def main():
labels=None,
batch_size=1,
shuffle=False,
time_length=8,
time_length=20,
embedding_dir=args.embedding_dir,
)
)
Expand Down
2 changes: 1 addition & 1 deletion examples/speaker_recognition/voxceleb_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_acc(trial_file='', emb='', save_kaldi_emb=False):
keys.append(y_speaker)
trial_embs.extend([Y])

score = (X @ Y.T) / (((X @ X.T) * (Y @ Y.T)) ** 0.5)
score = np.dot(X, Y) / ((np.dot(X, X) * np.dot(Y, Y)) ** 0.5)
score = (score + 1) / 2

all_scores.append(score)
Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ def list_available_models(cls) -> Optional[PretrainedModelInfo]:
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="stt_zh_citrinet_1024_gamma_0_25",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:stt_zh_citrinet_1024_gamma_0_25",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/stt_zh_citrinet_1024_gamma_0_25/versions/1.0.0/files/stt_zh_citrinet_1024_gamma_0_25.nemo",
)
results.append(model)

model = PretrainedModelInfo(
pretrained_model_name="asr_talknet_aligner",
description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:asr_talknet_aligner",
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nemo.collections.asr.models.asr_model import ExportableEncDecModel
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations
from nemo.collections.asr.parts.utils.speaker_utils import embedding_normalize
from nemo.collections.common.losses import CrossEntropyLoss as CELoss
from nemo.collections.common.metrics import TopKClassificationAccuracy
from nemo.core.classes import ModelPT
Expand Down Expand Up @@ -381,6 +382,7 @@ def test_epoch_end(self, outputs):
slices = torch.cat([x['slices'] for x in outputs])
emb_shape = embs.shape[-1]
embs = embs.view(-1, emb_shape).cpu().numpy()
embs = embedding_normalize(embs)
out_embeddings = {}
start_idx = 0
with open(self.test_manifest, 'r') as manifest:
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def input_types(self):
"""
return {
"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType(), optional=True),
"length": NeuralType(tuple('B'), LengthsType()),
}

@property
Expand Down Expand Up @@ -474,7 +474,7 @@ def __init__(
mask_value=mask_value,
)
else:
self.spec_augment = lambda input_spec: input_spec
self.spec_augment = lambda input_spec, length: input_spec

# Check if numba is supported, and use a Numba kernel if it is
if use_numba_spec_augment and numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__):
Expand All @@ -490,15 +490,15 @@ def __init__(
self.spec_augment_numba = None

@typecheck()
def forward(self, input_spec, length=None):
def forward(self, input_spec, length):
augmented_spec = self.spec_cutout(input_spec=input_spec)

# To run the Numba kernel, correct numba version is required as well as
# tensor must be on GPU and length must be provided
if self.spec_augment_numba is not None and spec_augment_launch_heuristics(augmented_spec, length):
augmented_spec = self.spec_augment_numba(input_spec=augmented_spec, length=length)
else:
augmented_spec = self.spec_augment(input_spec=augmented_spec)
augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
return augmented_spec


Expand Down
33 changes: 24 additions & 9 deletions nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,6 @@ def forward(self, input_spec, length):
sh = input_spec.shape
bs = sh[0]

if self.adaptive_temporal_width:
time_width = max(1, int(sh[2] * self.time_width))
else:
time_width = self.time_width

# Construct the freq and time masks as well as start positions
if self.freq_masks > 0:
freq_starts = torch.randint(
Expand All @@ -267,10 +262,30 @@ def forward(self, input_spec, length):
freq_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device)

if self.time_masks > 0:
time_starts = torch.randint(
0, sh[2] - time_width + 1, size=[bs, self.time_masks], device=input_spec.device
)
time_lengths = torch.randint(0, time_width + 1, size=[bs, self.time_masks], device=input_spec.device)
if self.adaptive_temporal_width:
time_width = (length * self.time_width).int().clamp(min=1)
else:
time_width = (
torch.tensor(self.time_width, dtype=torch.int32, device=input_spec.device)
.unsqueeze(0)
.repeat(sh[0])
)

time_starts = []
time_lengths = []
for idx in range(sh[0]):
time_starts.append(
torch.randint(
0, max(1, length[idx] - time_width[idx]), size=[1, self.time_masks], device=input_spec.device
)
)
time_lengths.append(
torch.randint(0, time_width[idx] + 1, size=[1, self.time_masks], device=input_spec.device)
)

time_starts = torch.cat(time_lengths, 0)
time_lengths = torch.cat(time_lengths, 0)

else:
time_starts = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device)
time_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device)
Expand Down
23 changes: 13 additions & 10 deletions nemo/collections/asr/parts/submodules/spectr_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn

from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import NeuralType, SpectrogramType
from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType


class SpecAugment(nn.Module, Typing):
Expand All @@ -43,7 +43,10 @@ class SpecAugment(nn.Module, Typing):
def input_types(self):
"""Returns definitions of module input types
"""
return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
return {
"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
Expand All @@ -54,7 +57,7 @@ def output_types(self):
def __init__(
self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, rng=None, mask_value=0.0,
):
super(SpecAugment, self).__init__()
super().__init__()

self._rng = random.Random() if rng is None else rng

Expand All @@ -76,14 +79,9 @@ def __init__(

@typecheck()
@torch.no_grad()
def forward(self, input_spec):
def forward(self, input_spec, length):
sh = input_spec.shape

if self.adaptive_temporal_width:
time_width = max(1, int(sh[2] * self.time_width))
else:
time_width = self.time_width

for idx in range(sh[0]):
for i in range(self.freq_masks):
x_left = self._rng.randint(0, sh[1] - self.freq_width)
Expand All @@ -93,7 +91,12 @@ def forward(self, input_spec):
input_spec[idx, x_left : x_left + w, :] = self.mask_value

for i in range(self.time_masks):
y_left = self._rng.randint(0, sh[2] - time_width)
if self.adaptive_temporal_width:
time_width = max(1, int(length[idx] * self.time_width))
else:
time_width = self.time_width

y_left = self._rng.randint(0, max(1, length[idx] - time_width))

w = self._rng.randint(0, time_width)

Expand Down
17 changes: 17 additions & 0 deletions nemo/collections/asr/parts/utils/speaker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,20 @@ def write_rttm2manifest(paths2audio_files, paths2rttm_files, manifest_file):
outfile.write("\n")
f.close()
return manifest_file


def embedding_normalize(embs, use_std=False, eps=1e-10):
"""
mean and l2 length normalize the input speaker embeddings
input:
embs: embeddings of shape (Batch,emb_size)
output:
embs: normalized embeddings of shape (Batch,emb_size)
"""
embs = embs - embs.mean(axis=0)
if use_std:
embs = embs / (embs.std(axis=0) + eps)
embs_l2_norm = np.expand_dims(np.linalg.norm(embs, ord=2, axis=-1), axis=1)
embs = embs / embs_l2_norm

return embs
1 change: 1 addition & 0 deletions nemo/collections/common/tokenizers/bytelevel_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import re
from pathlib import Path
from typing import List

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

__all__ = ['ByteLevelProcessor', 'ByteLevelTokenizer']
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def distributed_sampler_kwargs(self):
return distributed_sampler_kwargs

else:
return super().distributed_sampler_kwargs
return super(NLPDDPPlugin, self).distributed_sampler_kwargs


class NLPCheckpointConnector(CheckpointConnector):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,9 @@
parser.add_argument(
'--max_duration',
default=None,
required=True,
type=float,
help='Maximum duration of audio clip in the dataset. By default, it is None and will not filter files.',
help='Maximum duration of audio clip in the dataset. By default, it is None and is required to be set.',
)
parser.add_argument(
'--min_duration',
Expand Down
16 changes: 16 additions & 0 deletions scripts/tokenizers/process_asr_text_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@
help="Character coverage percentage for SentencePiece tokenization. For languages "
"with large vocabulary, should be close to 0.9995, otherwise kept as 1.0",
)
parser.add_argument('--spe_bos', action='store_true', help='Add <s> token to SentencePiece Tokenizer.')
parser.add_argument('--spe_eos', action='store_true', help='Add </s> token to SentencePiece Tokenizer.')
parser.add_argument('--spe_pad', action='store_true', help='Add <pad> token to SentencePiece Tokenizer.')
parser.add_argument(
'--spe_sample_size',
type=int,
Expand Down Expand Up @@ -173,6 +176,9 @@ def __process_data(
spe_train_extremely_large_corpus: bool,
spe_sample_size: int,
spe_max_sentencepiece_length: int,
spe_bos: bool,
spe_eos: bool,
spe_pad: bool,
lower_case: bool,
):
"""
Expand All @@ -191,6 +197,9 @@ def __process_data(
this flag can be set to try to trained the tokenizer. Will silently fail if it runs out of RAM.
spe_max_sentencepiece_length: Limits the maximum length of the SentencePiece subword that can be constructed.
By default, no limit is placed.
spe_bos: Bool flag, whether to add <s> to SentencePiece tokenizer vocabulary.
spe_eos: Bool flag, whether to add </s> to SentencePiece tokenizer vocabulary.
spe_pad: Bool flag, whether to add <pad> to SentencePiece tokenizer vocabulary.
lower_case: whether to tokenize with lower case character set only (for english)

Returns:
Expand Down Expand Up @@ -222,6 +231,9 @@ def __process_data(
character_coverage=spe_character_coverage,
train_extremely_large_corpus=spe_train_extremely_large_corpus,
max_sentencepiece_length=spe_max_sentencepiece_length,
bos=spe_bos,
eos=spe_eos,
pad=spe_pad,
)

else:
Expand Down Expand Up @@ -249,6 +261,7 @@ def main():
spe_sample_size = args.spe_sample_size
spe_train_extremely_large_corpus = args.spe_train_extremely_large_corpus
spe_max_sentencepiece_length = args.spe_max_sentencepiece_length
spe_bos, spe_eos, spe_pad = args.spe_bos, args.spe_eos, args.spe_pad
lower_case = args.lower_case

if not os.path.exists(data_root):
Expand All @@ -272,6 +285,9 @@ def main():
spe_sample_size=spe_sample_size,
spe_train_extremely_large_corpus=spe_train_extremely_large_corpus,
spe_max_sentencepiece_length=spe_max_sentencepiece_length,
spe_bos=spe_bos,
spe_eos=spe_eos,
spe_pad=spe_pad,
)

print("Serialized tokenizer at location :", tokenizer_path)
Expand Down
30 changes: 24 additions & 6 deletions tests/collections/asr/numba/spec_augment/test_spec_aug_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def prepare_data(b, f, t, device='cuda', freq_masks=0, time_masks=0, freq_width=

adaptive_temporal_width = True

if adaptive_temporal_width:
time_width = max(1, int(sh[2] * time_width))
else:
time_width = time_width
orginal_time_width = time_width

# Construct the freq and time masks as well as start positions
if freq_masks > 0:
Expand All @@ -71,8 +68,29 @@ def prepare_data(b, f, t, device='cuda', freq_masks=0, time_masks=0, freq_width=
freq_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=x.device)

if time_masks > 0:
time_starts = torch.randint(0, sh[2] - time_width + 1, size=[bs, time_masks], device=x.device)
time_lengths = torch.randint(0, time_width + 1, size=[bs, time_masks], device=x.device)
if adaptive_temporal_width:
time_width = (x_len * orginal_time_width).int().clamp(min=1)
else:
time_width = (
torch.tensor(orginal_time_width, dtype=torch.int32, device=x.device)
.unsqueeze(0)
.repeat(sh[0])
)

time_starts = []
time_lengths = []
for idx in range(sh[0]):
time_starts.append(
torch.randint(
0, max(1, x_len[idx] - time_width[idx]), size=[1, time_masks], device=x.device
)
)
time_lengths.append(
torch.randint(0, time_width[idx] + 1, size=[1, time_masks], device=x.device)
)

time_starts = torch.cat(time_lengths, 0)
time_lengths = torch.cat(time_lengths, 0)
else:
time_starts = torch.zeros([bs, 1], dtype=torch.int64, device=x.device)
time_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=x.device)
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/asr/test_asr_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_SpectrogramAugmentationr(self):
input_signal = torch.randn(size=(4, 512))
length = torch.randint(low=161, high=500, size=[4])
res0 = instance0(input_signal=input_signal, length=length)
res = instance1(input_spec=res0[0])
res = instance1(input_spec=res0[0], length=length)

assert res.shape == res0[0].shape

Expand Down
2 changes: 1 addition & 1 deletion tutorials/00_NeMo_Primer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"!pip install unidecode\n",
"\n",
"# ## Install NeMo\n",
"BRANCH = 'v1.0.2'\n",
"BRANCH = 'main'\n",
"!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
"\n",
"## Install TorchAudio\n",
Expand Down
Loading