diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode_2nd.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode_2nd.py new file mode 100755 index 00000000..dee1d63a --- /dev/null +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode_2nd.py @@ -0,0 +1,579 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu, Fangjun Kuang) +# 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import argparse +import k2 +import logging +import numpy as np +import os +import sys +import torch +from k2 import Fsa, SymbolTable +from pathlib import Path +from typing import List +from typing import Optional +from typing import Union + +from snowfall.common2 import average_checkpoint, average_checkpoint_2nd, store_transcripts +from snowfall.common2 import find_first_disambig_symbol +from snowfall.common2 import get_texts +from snowfall.common2 import write_error_stats +from snowfall.common2 import load_checkpoint +from snowfall.common2 import setup_logger +from snowfall.common2 import str2bool +from snowfall.data import LibriSpeechAsrDataModule +from snowfall.decoding.graph import compile_HLG +from snowfall.decoding.lm_rescore2 import decode_with_lm_rescoring +from snowfall.models import AcousticModel +from snowfall.models.transformer import Transformer +from snowfall.models.conformer import Conformer +from snowfall.models.second_pass_model import SecondPassModel +from snowfall.training.ctc_graph import build_ctc_topo +from snowfall.training.mmi_graph2 import create_bigram_phone_lm +from snowfall.training.mmi_graph2 import get_phone_symbols + + +# TODO(fangjun): Replace it with +# https://github.com/k2-fsa/snowfall/issues/232 +@torch.no_grad() +def second_pass_decoding(second_pass: AcousticModel, + lats: k2.Fsa, + supervision_segments: torch.Tensor, + encoder_memory: torch.Tensor, + num_paths: int = 10): + ''' + The fundamental idea is to get an n-best list from the first pass + decoding lattice and use the second pass model to do rescoring. + The path with the highest score is used as the final decoding output. + + Args: + lats: + It's the decoding lattice from the first pass. + ''' + device = lats.device + assert len(lats.shape) == 3 + + # First, extract `num_paths` paths for each sequence. + # paths is a k2.RaggedInt with axes [seq][path][arc_pos] + paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + + # phone_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains phone IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + phone_seqs = k2.index(lats.labels.clone(), paths) + + num_seqs = phone_seqs.dim0() + + indexes = torch.arange(start=0, + end=(num_seqs - 1) * num_paths + 1, + step=num_paths, + dtype=torch.int32, + device=device) + # From indexes, we can select + # + # (1) path 0 of seq0, seq1, seq2, ... + # (2) path 1 of seq0, seq1, seq2, ... + # (3) path 2 of seq0, seq1, seq2, ... + + paths_offset = phone_seqs.row_splits(1) + phone_seqs = k2.ragged.remove_axis(phone_seqs, 0) + # Note, from now on phone_seqs has only two axes + + phone_seqs = k2.ragged.remove_values_leq(phone_seqs, -1) + log_probs = torch.empty(num_paths, num_seqs) + for i in range(num_paths): + # path is a k2.RaggedInt with axes [path][phones], excluding -1 + p, _ = k2.ragged.index(phone_seqs, indexes + i, axis=0) + phone_fsa = k2.linear_fsa(p) + nnet_output_2nd = second_pass(encoder_memory, phone_fsa, supervision_segments) + + nnet_output_2nd = nnet_output_2nd.cpu() + row_splits1 = p.row_splits(1).cpu() + value = p.values().cpu() + + for k in range(num_seqs): + nnet_output_this_seq = nnet_output_2nd[k] + start = row_splits1[k] + end = row_splits1[k+1] + num_frames = end - start + + this_value = value[start:end].tolist() + log_p = 0 + for idx, v in enumerate(this_value): + log_p += nnet_output_this_seq[idx, v] + log_p /= num_frames + log_probs[i, k] = log_p + + # Now get the best score of the path within n-best list + # of each seq + best_indexes = torch.argmax(log_probs, dim=0).to(torch.int32) + best_paths_indexes = best_indexes.to(device) + paths_offset[:-1] + + # best_paths has three axes [seq][path][arc_pos] + # each seq has only one path + best_paths, _ = k2.ragged.index(paths, best_paths_indexes, axis=1) + + best_phone_seqs = k2.index(lats.labels.clone(), best_paths) + + best_phone_seqs = k2.ragged.remove_values_leq(best_phone_seqs, -1) + best_phone_seqs = k2.ragged.remove_axis(best_phone_seqs, 0) + + ans = k2.linear_fsa(best_phone_seqs) + ans.aux_labels = k2.index(lats.aux_labels, best_paths.values()) + + return ans + + + +@torch.no_grad() +def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, + second_pass: Optional[AcousticModel], + device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable, + num_paths: int, G: k2.Fsa, use_whole_lattice: bool, output_beam_size: float): + tot_num_cuts = len(dataloader.dataset.cuts) + num_cuts = 0 + results = [] # a list of pair (ref_words, hyp_words) + for batch_idx, batch in enumerate(dataloader): + feature = batch['inputs'] + supervisions = batch['supervisions'] + supervision_segments = torch.stack( + (supervisions['sequence_idx'], + (((supervisions['start_frame'] - 1) // 2 - 1) // 2), + (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32) + supervision_segments = torch.clamp(supervision_segments, min=0) + indices = torch.argsort(supervision_segments[:, 2], descending=True) + supervision_segments = supervision_segments[indices] + texts = supervisions['text'] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + with torch.no_grad(): + nnet_output, encoder_memory, _ = model(feature, supervisions) + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, + 1) # now nnet_output is [N, T, C] + + # blank_bias = -3.0 + # nnet_output[:, :, 0] += blank_bias + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + # assert HLG.is_cuda() + assert HLG.device == nnet_output.device, \ + f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})" + # TODO(haowen): with a small `beam`, we may get empty `target_graph`, + # thus `tot_scores` will be `inf`. Definitely we need to handle this later. + beam_size = output_beam_size + while True: + try: + if beam_size < 2: + logging.error(f'beam size {beam_size} is too small') + sys.exit(1) + lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, beam_size, 30, + 10000) + if second_pass is not None: + best_paths = second_pass_decoding( + second_pass=second_pass, + lats=lattices, + supervision_segments=supervision_segments, + encoder_memory=encoder_memory, + num_paths=num_paths) + else: + if G is None: + best_paths = k2.shortest_path(lattices, use_double_scores=True) + else: + best_paths = decode_with_lm_rescoring( + lattices, + G, + num_paths=num_paths, + use_whole_lattice=use_whole_lattice) + break + except RuntimeError as e: + logging.info(f'Caught exception:\n{e}\n') + new_beam_size = beam_size * 0.95 + logging.info(f'Change beam_size from {beam_size} to {new_beam_size}') + beam_size = new_beam_size + + # if second_pass is not None: + # # Now for the second pass model + # best_paths = k2.shortest_path(lattices, use_double_scores=True) + # + # nnet_output_2nd = second_pass(encoder_memory, best_paths, supervision_segments) + # # nnet_output_2nd is [N, T, C] + # + # assert nnet_output_2nd.shape[0] == supervision_segments.shape[0] + # + # supervision_segments_2nd = supervision_segments.clone() + # + # # [0, 1, 2, 3, ...] + # supervision_segments_2nd[:, 0] = torch.arange(supervision_segments_2nd.shape[0], dtype=torch.int32) + # + # # start offset + # supervision_segments_2nd[:, 1] = 0 + # + # # duration of supervision_segments_2nd is kept unchanged + # + # dense_fsa_vec_2nd = k2.DenseFsaVec(nnet_output_2nd, supervision_segments_2nd) + # + # lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec_2nd, 20.0, output_beam_size, 30, + # 10000) + + # if G is None: + # best_paths = k2.shortest_path(lattices, use_double_scores=True) + # else: + # best_paths = decode_with_lm_rescoring( + # lattices, + # G, + # num_paths=num_paths, + # use_whole_lattice=use_whole_lattice) + + assert best_paths.shape[0] == len(texts) + hyps = get_texts(best_paths, indices) + assert len(hyps) == len(texts) + + for i in range(len(texts)): + hyp_words = [symbols.get(x) for x in hyps[i]] + ref_words = texts[i].split(' ') + results.append((ref_words, hyp_words)) + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( + batch_idx, num_cuts, tot_num_cuts, + float(num_cuts) / tot_num_cuts * 100)) + + num_cuts += len(texts) + + return results + + +def print_transition_probabilities(P: k2.Fsa, phone_symbol_table: SymbolTable, + phone_ids: List[int], filename: str): + '''Print the transition probabilities of a phone LM. + + Args: + P: + A bigram phone LM. + phone_symbol_table: + The phone symbol table. + phone_ids: + A list of phone ids + filename: + Filename to save the printed result. + ''' + num_phones = len(phone_ids) + table = np.zeros((num_phones + 1, num_phones + 2)) + table[:, 0] = 0 + table[0, -1] = 0 # the start state has no arcs to the final state + assert P.arcs.dim0() == num_phones + 2 + arcs = P.arcs.values()[:, :3] + probability = P.scores.exp().tolist() + + assert arcs.shape[0] - num_phones == num_phones * (num_phones + 1) + for i, arc in enumerate(arcs.tolist()): + src_state, dest_state, label = arc[0], arc[1], arc[2] + prob = probability[i] + if label != -1: + assert label == dest_state + else: + assert dest_state == num_phones + 1 + table[src_state][dest_state] = prob + + try: + from prettytable import PrettyTable + except ImportError: + print('Please run `pip install prettytable`. Skip printing') + return + + x = PrettyTable() + + field_names = ['source'] + field_names.append('sum') + for i in phone_ids: + field_names.append(phone_symbol_table[i]) + field_names.append('final') + + x.field_names = field_names + + for row in range(num_phones + 1): + this_row = [] + if row == 0: + this_row.append('start') + else: + this_row.append(phone_symbol_table[row]) + this_row.append('{:.6f}'.format(table[row, 1:].sum())) + for col in range(1, num_phones + 2): + this_row.append('{:.6f}'.format(table[row, col])) + x.add_row(this_row) + with open(filename, 'w') as f: + f.write(str(x)) + + +def get_parser(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--model-type', + type=str, + default="conformer", + choices=["transformer", "conformer"], + help="Model type.") + parser.add_argument( + '--epoch', + type=int, + default=10, + help="Decoding epoch.") + parser.add_argument( + '--avg', + type=int, + default=5, + help="Number of checkpionts to average. Automaticly select " + "consecutive checkpoints before checkpoint specified by'--epoch'. ") + parser.add_argument( + '--att-rate', + type=float, + default=0.0, + help="Attention loss rate.") + parser.add_argument( + '--nhead', + type=int, + default=4, + help="Number of attention heads in transformer.") + parser.add_argument( + '--attention-dim', + type=int, + default=256, + help="Number of units in transformer attention layers.") + parser.add_argument( + '--output-beam-size', + type=float, + default=8, + help='Output beam size. Used in k2.intersect_dense_pruned.'\ + 'Choose a large value (e.g., 20), for 1-best decoding '\ + 'and n-best rescoring. Choose a small value (e.g., 8) for ' \ + 'rescoring with the whole lattice') + parser.add_argument( + '--use-lm-rescoring', + type=str2bool, + default=True, + help='When enabled, it uses LM for rescoring') + parser.add_argument( + '--use-second-pass', + type=str2bool, + default=True, + help='When enabled, it uses the second pass model') + parser.add_argument( + '--num-paths', + type=int, + default=-1, + help='Number of paths for rescoring using n-best list.' \ + 'If it is negative, then rescore with the whole lattice.'\ + 'CAUTION: You have to reduce max_duration in case of CUDA OOM' + ) + return parser + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + model_type = args.model_type + epoch = args.epoch + avg = args.avg + att_rate = args.att_rate + num_paths = args.num_paths + use_lm_rescoring = args.use_lm_rescoring + use_whole_lattice = False + if use_lm_rescoring and num_paths < 1: + # It doesn't make sense to use n-best list for rescoring + # when n is less than 1 + use_whole_lattice = True + + output_beam_size = args.output_beam_size + use_second_pass = args.use_second_pass + + exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-new') + setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') + + logging.info(f'output_beam_size: {output_beam_size}') + + # load L, G, symbol_table + lang_dir = Path('data/lang_nosp') + symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt') + phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt') + + phone_ids = get_phone_symbols(phone_symbol_table) + P = create_bigram_phone_lm(phone_ids) + + phone_ids_with_blank = [0] + phone_ids + ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank)) + + logging.debug("About to load model") + # Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N + # device = torch.device('cuda', 1) + device = torch.device('cuda') + + if att_rate != 0.0: + num_decoder_layers = 6 + else: + num_decoder_layers = 0 + + if model_type == "transformer": + model = Transformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers) + else: + model = Conformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers) + + model.P_scores = torch.nn.Parameter(P.scores.clone(), requires_grad=False) + + if use_second_pass: + second_pass = SecondPassModel(max_phone_id=max(phone_ids)).to(device) + + if avg == 1: + checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt') + load_checkpoint(checkpoint, model) + + if use_second_pass: + checkpoint_2nd = os.path.join(exp_dir, '2nd-epoch-' + str(epoch - 1) + '.pt') + logging.info(f'loading {checkpoint_2nd}') + second_pass.load_state_dict(torch.load(checkpoint_2nd, map_location='cpu')) + else: + checkpoints = [os.path.join(exp_dir, 'epoch-' + str(avg_epoch) + '.pt') for avg_epoch in + range(epoch - avg, epoch)] + average_checkpoint(checkpoints, model) + + if use_second_pass: + checkpoints_2nd = [os.path.join(exp_dir, '2nd-epoch-' + str(avg_epoch) + '.pt') for avg_epoch in + range(epoch - avg, epoch)] + average_checkpoint_2nd(checkpoints_2nd, second_pass) + + model.to(device) + model.eval() + + if use_second_pass: + second_pass.to(device) + second_pass.eval() + else: + second_pass = None + logging.info('Not using the second pass model') + + assert P.requires_grad is False + P.scores = model.P_scores.cpu() + print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='model_P_scores.txt') + + P.set_scores_stochastic_(model.P_scores) + print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='P_scores.txt') + + if not os.path.exists(lang_dir / 'HLG.pt'): + logging.debug("Loading L_disambig.fst.txt") + with open(lang_dir / 'L_disambig.fst.txt') as f: + L = k2.Fsa.from_openfst(f.read(), acceptor=False) + logging.debug("Loading G.fst.txt") + with open(lang_dir / 'G.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table) + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + HLG = compile_HLG(L=L, + G=G, + H=ctc_topo, + labels_disambig_id_start=first_phone_disambig_id, + aux_labels_disambig_id_start=first_word_disambig_id) + torch.save(HLG.as_dict(), lang_dir / 'HLG.pt') + else: + logging.debug("Loading pre-compiled HLG") + d = torch.load(lang_dir / 'HLG.pt') + HLG = k2.Fsa.from_dict(d) + + if use_lm_rescoring: + if use_whole_lattice: + logging.info('Rescoring with the whole lattice') + else: + logging.info(f'Rescoring with n-best list, n is {num_paths}') + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + if not os.path.exists(lang_dir / 'G_4_gram.pt'): + logging.debug('Loading G_4_gram.fst.txt') + with open(lang_dir / 'G_4_gram.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION(fangjun): The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.create_fsa_vec([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt') + else: + logging.debug('Loading pre-compiled G_4_gram.pt') + d = torch.load(lang_dir / 'G_4_gram.pt') + G = k2.Fsa.from_dict(d).to(device) + + if use_whole_lattice: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + G.lm_scores = G.scores.clone() + else: + logging.debug('Decoding without LM rescoring') + G = None + + logging.debug("convert HLG to device") + HLG = HLG.to(device) + HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) + HLG.requires_grad_(False) + + if not hasattr(HLG, 'lm_scores'): + HLG.lm_scores = HLG.scores.clone() + + # load dataset + librispeech = LibriSpeechAsrDataModule(args) + test_sets = ['test-clean', 'test-other'] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + logging.info(f'* DECODING: {test_set}') + + results = decode(dataloader=test_dl, + model=model, + second_pass=second_pass, + device=device, + HLG=HLG, + symbols=symbol_table, + num_paths=num_paths, + G=G, + use_whole_lattice=use_whole_lattice, + output_beam_size=output_beam_size) + + recog_path = exp_dir / f'recogs-{test_set}.txt' + store_transcripts(path=recog_path, texts=results) + logging.info(f'The transcripts are stored in {recog_path}') + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = exp_dir / f'errs-{test_set}.txt' + with open(errs_filename, 'w') as f: + write_error_stats(f, test_set, results) + logging.info('Wrote detailed error stats to {}'.format(errs_filename)) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_train_2nd.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train_2nd.py new file mode 100755 index 00000000..e769a299 --- /dev/null +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_train_2nd.py @@ -0,0 +1,735 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey +# Haowen Qiu +# Fangjun Kuang) +# 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Apache 2.0 + +import argparse +import logging +import os +import sys +from datetime import datetime +from pathlib import Path +from typing import Dict, Optional + +import k2 +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch import nn +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_value_ +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from lhotse.utils import fix_random_seed, nullcontext +from snowfall.common2 import describe, str2bool +from snowfall.common2 import load_checkpoint, save_checkpoint +from snowfall.common2 import save_training_info +from snowfall.common2 import setup_logger +from snowfall.data.librispeech import LibriSpeechAsrDataModule +from snowfall.dist import cleanup_dist +from snowfall.dist import setup_dist +from snowfall.lexicon import Lexicon +from snowfall.models import AcousticModel +from snowfall.models.conformer import Conformer +from snowfall.models.tdnn_lstm import TdnnLstm1b # alignment model +from snowfall.models.transformer import Noam, Transformer +from snowfall.models.second_pass_model import SecondPassModel +from snowfall.objectives import encode_supervisions +from snowfall.objectives.mmi2 import LFMMILoss +from snowfall.objectives.common import get_tot_objf_and_num_frames +from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change +from snowfall.training.mmi_graph2 import MmiTrainingGraphCompiler +from snowfall.training.mmi_graph2 import create_bigram_phone_lm + + +def get_objf(batch: Dict, + model: AcousticModel, + second_pass: AcousticModel, + ali_model: Optional[AcousticModel], + P: k2.Fsa, + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + is_training: bool, + is_update: bool, + accum_grad: int = 1, + den_scale: float = 1.0, + att_rate: float = 0.0, + tb_writer: Optional[SummaryWriter] = None, + global_batch_idx_train: Optional[int] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scaler: GradScaler = None + ): + feature = batch['inputs'] + # at entry, feature is [N, T, C] + feature = feature.permute(0, 2, 1) # now feature is [N, C, T] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch['supervisions'] + supervision_segments, texts = encode_supervisions(supervisions) + + loss_fn = LFMMILoss( + graph_compiler=graph_compiler, + P=P, + den_scale=den_scale + ) + + grad_context = nullcontext if is_training else torch.no_grad + + with autocast(enabled=scaler.is_enabled()), grad_context(): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + if att_rate != 0.0: + if hasattr(model, 'module'): + att_loss = model.module.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) + else: + att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler) + + if (ali_model is not None and global_batch_idx_train is not None and + global_batch_idx_train * accum_grad < 4000): + with torch.no_grad(): + ali_model_output = ali_model(feature) + # subsampling is done slightly differently, may be small length + # differences. + min_len = min(ali_model_output.shape[2], nnet_output.shape[2]) + # scale less than one so it will be encouraged + # to mimic ali_model's output + ali_model_scale = 500.0 / (global_batch_idx_train*accum_grad + 500) + nnet_output = nnet_output.clone() # or log-softmax backprop will fail. + nnet_output[:, :,:min_len] += ali_model_scale * ali_model_output[:, :,:min_len] + + # nnet_output is [N, C, T] + nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C] + + mmi_loss, tot_frames, all_frames, den_lats, num_den_graphs, a_to_b_map = \ + loss_fn(nnet_output, texts, supervision_segments, True) + + # Get the best path of each sequence. Only the labels of each path + # are used in the following code. 0s in labels are not removed. + best_paths = k2.shortest_path(den_lats, use_double_scores=True) + + nnet_output_2nd = second_pass(encoder_memory, best_paths, supervision_segments) + assert nnet_output_2nd.shape[0] == supervision_segments.shape[0] + + supervision_segments_2nd = supervision_segments.clone() + + # [0, 1, 2, 3, ...] + supervision_segments_2nd[:, 0] = torch.arange(supervision_segments_2nd.shape[0], dtype=torch.int32) + + # start offset + supervision_segments_2nd[:, 1] = 0 + + # duration of supervision_segments_2nd is kept unchanged + + dense_fsa_vec_2nd = k2.DenseFsaVec(nnet_output_2nd, supervision_segments_2nd) + + num_den_lats_2nd = k2.intersect_dense(num_den_graphs, + dense_fsa_vec_2nd, + output_beam=10.0, + a_to_b_map=a_to_b_map) + + num_den_tot_scores_2nd = num_den_lats_2nd.get_tot_scores( + log_semiring=True, use_double_scores=True) + + num_tot_scores_2nd = num_den_tot_scores_2nd[::2] + den_tot_scores_2nd = num_den_tot_scores_2nd[1::2] + + tot_scores_2nd = num_tot_scores_2nd - den_tot_scores_2nd + tot_score_2nd, tot_frames_2nd, all_frames_2nd = get_tot_objf_and_num_frames( + tot_scores_2nd, supervision_segments_2nd[:, 2]) + # print(mmi_loss.item()/tot_frames, tot_score_2nd.item()/tot_frames_2nd) + + mmi_loss = mmi_loss + tot_score_2nd + tot_frames = tot_frames + tot_frames_2nd + all_frames = all_frames + all_frames_2nd + + if is_training: + def maybe_log_gradients(tag: str): + if tb_writer is not None and global_batch_idx_train is not None and global_batch_idx_train % 200 == 0: + tb_writer.add_scalars( + tag, + measure_gradient_norms(model, norm='l1'), + global_step=global_batch_idx_train + ) + + if att_rate != 0.0: + loss = (- (1.0 - att_rate) * mmi_loss + att_rate * att_loss) / (len(texts) * accum_grad) + else: + loss = (-mmi_loss) / (len(texts) * accum_grad) + scaler.scale(loss).backward() + if is_update: + maybe_log_gradients('train/grad_norms') + scaler.unscale_(optimizer) + clip_grad_value_(model.parameters(), 5.0) + clip_grad_value_(second_pass.parameters(), 5.0) + maybe_log_gradients('train/clipped_grad_norms') + if tb_writer is not None and (global_batch_idx_train // accum_grad) % 200 == 0: + # Once in a time we will perform a more costly diagnostic + # to check the relative parameter change per minibatch. + deltas = optim_step_and_measure_param_change(model, optimizer, scaler) + tb_writer.add_scalars( + 'train/relative_param_change_per_minibatch', + deltas, + global_step=global_batch_idx_train + ) + else: + scaler.step(optimizer) + optimizer.zero_grad() + scaler.update() + + ans = -mmi_loss.detach().cpu().item(), tot_frames, all_frames + return ans + + +def get_validation_objf(dataloader: torch.utils.data.DataLoader, + model: AcousticModel, + second_pass: AcousticModel, + ali_model: Optional[AcousticModel], + P: k2.Fsa, + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + scaler: GradScaler, + den_scale: float = 1, + ): + total_objf = 0. + total_frames = 0. # for display only + total_all_frames = 0. # all frames including those seqs that failed. + + model.eval() + second_pass.eval() + + from torchaudio.datasets.utils import bg_iterator + for batch_idx, batch in enumerate(bg_iterator(dataloader, 2)): + objf, frames, all_frames = get_objf( + batch=batch, + model=model, + second_pass=second_pass, + ali_model=ali_model, + P=P, + device=device, + graph_compiler=graph_compiler, + is_training=False, + is_update=False, + den_scale=den_scale, + scaler=scaler + ) + total_objf += objf + total_frames += frames + total_all_frames += all_frames + + return total_objf, total_frames, total_all_frames + + +def train_one_epoch(dataloader: torch.utils.data.DataLoader, + valid_dataloader: torch.utils.data.DataLoader, + model: AcousticModel, + second_pass: AcousticModel, + ali_model: Optional[AcousticModel], + P: k2.Fsa, + device: torch.device, + graph_compiler: MmiTrainingGraphCompiler, + optimizer: torch.optim.Optimizer, + accum_grad: int, + den_scale: float, + att_rate: float, + current_epoch: int, + tb_writer: SummaryWriter, + num_epochs: int, + global_batch_idx_train: int, + world_size: int, + scaler: GradScaler + ): + """One epoch training and validation. + + Args: + dataloader: Training dataloader + valid_dataloader: Validation dataloader + model: Acoustic model to be trained + P: An FSA representing the bigram phone LM + device: Training device, torch.device("cpu") or torch.device("cuda", device_id) + graph_compiler: MMI training graph compiler + optimizer: Training optimizer + accum_grad: Number of gradient accumulation + den_scale: Denominator scale in mmi loss + att_rate: Attention loss rate, final loss is att_rate * att_loss + (1-att_rate) * other_loss + current_epoch: current training epoch, for logging only + tb_writer: tensorboard SummaryWriter + num_epochs: total number of training epochs, for logging only + global_batch_idx_train: global training batch index before this epoch, for logging only + + Returns: + A tuple of 3 scalar: (total_objf / total_frames, valid_average_objf, global_batch_idx_train) + - `total_objf / total_frames` is the average training loss + - `valid_average_objf` is the average validation loss + - `global_batch_idx_train` is the global training batch index after this epoch + """ + total_objf, total_frames, total_all_frames = 0., 0., 0. + valid_average_objf = float('inf') + time_waiting_for_batch = 0 + forward_count = 0 + prev_timestamp = datetime.now() + + model.train() + second_pass.train() + for batch_idx, batch in enumerate(dataloader): + forward_count += 1 + if forward_count == accum_grad: + is_update = True + forward_count = 0 + else: + is_update = False + + global_batch_idx_train += 1 + timestamp = datetime.now() + time_waiting_for_batch += (timestamp - prev_timestamp).total_seconds() + + if forward_count == 1 or accum_grad == 1: + if hasattr(model, 'module'): + P.set_scores_stochastic_(model.module.P_scores) + else: + P.set_scores_stochastic_(model.P_scores) + assert P.requires_grad is True + + curr_batch_objf, curr_batch_frames, curr_batch_all_frames = get_objf( + batch=batch, + model=model, + second_pass=second_pass, + ali_model=ali_model, + P=P, + device=device, + graph_compiler=graph_compiler, + is_training=True, + is_update=is_update, + accum_grad=accum_grad, + den_scale=den_scale, + att_rate=att_rate, + tb_writer=tb_writer, + global_batch_idx_train=global_batch_idx_train, + optimizer=optimizer, + scaler=scaler + ) + + total_objf += curr_batch_objf + total_frames += curr_batch_frames + total_all_frames += curr_batch_all_frames + + if batch_idx % 10 == 0: + logging.info( + 'batch {}, epoch {}/{} ' + 'global average objf: {:.6f} over {} ' + 'frames ({:.1f}% kept), current batch average objf: {:.6f} over {} frames ({:.1f}% kept) ' + 'avg time waiting for batch {:.3f}s'.format( + batch_idx, current_epoch, num_epochs, + total_objf / total_frames, total_frames, + 100.0 * total_frames / total_all_frames, + curr_batch_objf / (curr_batch_frames + 0.001), + curr_batch_frames, + 100.0 * curr_batch_frames / curr_batch_all_frames, + time_waiting_for_batch / max(1, batch_idx))) + + if tb_writer is not None: + tb_writer.add_scalar('train/global_average_objf', + total_objf / total_frames, global_batch_idx_train) + + tb_writer.add_scalar('train/current_batch_average_objf', + curr_batch_objf / (curr_batch_frames + 0.001), + global_batch_idx_train) + # if batch_idx >= 10: + # print("Exiting early to get profile info") + # sys.exit(0) + + if batch_idx > 0 and batch_idx % 200 == 0: + total_valid_objf, total_valid_frames, total_valid_all_frames = get_validation_objf( + dataloader=valid_dataloader, + model=model, + second_pass=second_pass, + ali_model=ali_model, + P=P, + device=device, + graph_compiler=graph_compiler, + scaler=scaler) + if world_size > 1: + s = torch.tensor([ + total_valid_objf, total_valid_frames, + total_valid_all_frames + ]).to(device) + + dist.all_reduce(s, op=dist.ReduceOp.SUM) + total_valid_objf, total_valid_frames, total_valid_all_frames = s.cpu().tolist() + + valid_average_objf = total_valid_objf / total_valid_frames + model.train() + second_pass.train() + logging.info( + 'Validation average objf: {:.6f} over {} frames ({:.1f}% kept)' + .format(valid_average_objf, + total_valid_frames, + 100.0 * total_valid_frames / total_valid_all_frames)) + + if tb_writer is not None: + tb_writer.add_scalar('train/global_valid_average_objf', + valid_average_objf, + global_batch_idx_train) + if hasattr(model, 'module'): + model.module.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) + else: + model.write_tensorboard_diagnostics(tb_writer, global_step=global_batch_idx_train) + prev_timestamp = datetime.now() + return total_objf / total_frames, valid_average_objf, global_batch_idx_train + + +def get_parser(): + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + '--world-size', + type=int, + default=1, + help='Number of GPUs for DDP training.') + parser.add_argument( + '--master-port', + type=int, + default=12354, + help='Master port to use for DDP training.') + parser.add_argument( + '--model-type', + type=str, + default="conformer", + choices=["transformer", "conformer"], + help="Model type.") + parser.add_argument( + '--num-epochs', + type=int, + default=10, + help="Number of training epochs.") + parser.add_argument( + '--start-epoch', + type=int, + default=0, + help="Number of start epoch.") + parser.add_argument( + '--warm-step', + type=int, + default=5000, + help='The number of warm-up steps for Noam optimizer.' + ) + parser.add_argument( + '--accum-grad', + type=int, + default=1, + help="Number of gradient accumulation.") + parser.add_argument( + '--den-scale', + type=float, + default=1.0, + help="denominator scale in mmi loss.") + parser.add_argument( + '--att-rate', + type=float, + default=0.0, + help="Attention loss rate.") + parser.add_argument( + '--nhead', + type=int, + default=4, + help="Number of attention heads in transformer.") + parser.add_argument( + '--attention-dim', + type=int, + default=256, + help="Number of units in transformer attention layers.") + parser.add_argument( + '--tensorboard', + type=str2bool, + default=True, + help='Should various information be logged in tensorboard.' + ) + parser.add_argument( + '--amp', + type=str2bool, + default=True, + help='Should we use automatic mixed precision (AMP) training.' + ) + parser.add_argument( + '--use-ali-model', + type=str2bool, + default=True, + help='If true, we assume that you have run ./ctc_train.py ' + 'and you have some checkpoints inside the directory ' + 'exp-lstm-adam-ctc-musan/ .' + 'It will use exp-lstm-adam-ctc-musan/epoch-{ali-model-epoch}.pt ' + 'as the pre-trained alignment model' + ) + parser.add_argument( + '--ali-model-epoch', + type=int, + default=7, + help='If --use-ali-model is True, load ' + 'exp-lstm-adam-ctc-musan/epoch-{ali-model-epoch}.pt as the alignment model.' + 'Used only if --use-ali-model is True.' + ) + return parser + + +def run(rank, world_size, args): + ''' + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + ''' + model_type = args.model_type + start_epoch = args.start_epoch + num_epochs = args.num_epochs + accum_grad = args.accum_grad + den_scale = args.den_scale + att_rate = args.att_rate + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, args.master_port) + + exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-new') + setup_logger(f'{exp_dir}/log/log-train-{rank}') + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') + else: + tb_writer = None + # tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None + + logging.info("Loading lexicon and symbol tables") + lang_dir = Path('data/lang_nosp') + lexicon = Lexicon(lang_dir) + + device_id = rank + device = torch.device('cuda', device_id) + + graph_compiler = MmiTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + phone_ids = lexicon.phone_symbols() + P = create_bigram_phone_lm(phone_ids) + P.scores = torch.zeros_like(P.scores) + P = P.to(device) + + librispeech = LibriSpeechAsrDataModule(args) + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + + if not torch.cuda.is_available(): + logging.error('No GPU detected!') + sys.exit(-1) + + logging.info("About to create model") + + if att_rate != 0.0: + num_decoder_layers = 6 + else: + num_decoder_layers = 0 + + if model_type == "transformer": + model = Transformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers) + else: + model = Conformer( + num_features=80, + nhead=args.nhead, + d_model=args.attention_dim, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4, + num_decoder_layers=num_decoder_layers) + + model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True) + + model.to(device) + describe(model) + + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + + second_pass = SecondPassModel(max_phone_id=graph_compiler.max_phone_id).to(device) + logging.info('second pass model') + describe(second_pass) + + # Now for the alignment model, if any + if args.use_ali_model: + ali_model = TdnnLstm1b( + num_features=80, + num_classes=len(phone_ids) + 1, # +1 for the blank symbol + subsampling_factor=4) + + ali_model_fname = Path(f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt') + assert ali_model_fname.is_file(), \ + f'ali model filename {ali_model_fname} does not exist!' + ali_model.load_state_dict(torch.load(ali_model_fname, map_location='cpu')['state_dict']) + ali_model.to(device) + + ali_model.eval() + ali_model.requires_grad_(False) + logging.info(f'Use ali_model: {ali_model_fname}') + else: + ali_model = None + logging.info('No ali_model') + + params = [ + { + 'params': model.parameters() + }, + { + 'params': second_pass.parameters() + }, + ] + optimizer = Noam(params, + model_size=args.attention_dim, + factor=1.0, + warm_step=args.warm_step) + + scaler = GradScaler(enabled=args.amp) + + best_objf = np.inf + best_valid_objf = np.inf + best_epoch = start_epoch + best_model_path = os.path.join(exp_dir, 'best_model.pt') + best_model_path_2nd = os.path.join(exp_dir, 'best_model_2nd.pt') + best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info') + global_batch_idx_train = 0 # for logging only + + # TODO(fangjun): support saving/loading the second pass model + if start_epoch > 0: + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1)) + ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler) + best_objf = ckpt['objf'] + best_valid_objf = ckpt['valid_objf'] + global_batch_idx_train = ckpt['global_batch_idx_train'] + logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}") + + second_pass_model_path = os.path.join(exp_dir, '2nd-epoch-{}.pt'.format(start_epoch - 1)) + logging.info(f'loading {second_pass_model_path}') + second_pass.load_state_dict(torch.load(second_pass_model_path, map_location='cpu')) + + for epoch in range(start_epoch, num_epochs): + train_dl.sampler.set_epoch(epoch) + curr_learning_rate = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train) + tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train) + + logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate)) + objf, valid_objf, global_batch_idx_train = train_one_epoch( + dataloader=train_dl, + valid_dataloader=valid_dl, + model=model, + second_pass=second_pass, + ali_model=ali_model, + P=P, + device=device, + graph_compiler=graph_compiler, + optimizer=optimizer, + accum_grad=accum_grad, + den_scale=den_scale, + att_rate=att_rate, + current_epoch=epoch, + tb_writer=tb_writer, + num_epochs=num_epochs, + global_batch_idx_train=global_batch_idx_train, + world_size=world_size, + scaler=scaler + ) + # the lower, the better + if valid_objf < best_valid_objf: + best_valid_objf = valid_objf + best_objf = objf + best_epoch = epoch + save_checkpoint(filename=best_model_path, + optimizer=None, + scheduler=None, + scaler=None, + model=model, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train, + local_rank=rank) + torch.save(second_pass.state_dict(), best_model_path_2nd) + + save_training_info(filename=best_epoch_info_filename, + model_path=best_model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch, + local_rank=rank) + + # we always save the model for every epoch + model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch)) + save_checkpoint(filename=model_path, + optimizer=optimizer, + scheduler=None, + scaler=scaler, + model=model, + epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + valid_objf=valid_objf, + global_batch_idx_train=global_batch_idx_train, + local_rank=rank) + model_path_2nd = os.path.join(exp_dir, '2nd-epoch-{}.pt'.format(epoch)) + torch.save(second_pass.state_dict(), model_path_2nd) + + + epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch)) + save_training_info(filename=epoch_info_filename, + model_path=model_path, + current_epoch=epoch, + learning_rate=curr_learning_rate, + objf=objf, + best_objf=best_objf, + valid_objf=valid_objf, + best_valid_objf=best_valid_objf, + best_epoch=best_epoch, + local_rank=rank) + + logging.warning('Done') + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=world_size, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == '__main__': + main() diff --git a/snowfall/common2.py b/snowfall/common2.py new file mode 100755 index 00000000..85c5b0e8 --- /dev/null +++ b/snowfall/common2.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python3 + +# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) +# Apache 2.0 +import argparse +import logging +import os +import re +from collections import defaultdict +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import k2 +import kaldialign +import torch +import torch.distributed as dist +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel + +from snowfall.models import AcousticModel + +Pathlike = Union[str, Path] + + +def setup_logger(log_filename: Pathlike, log_level: str = 'info', use_console: bool = True) -> None: + now = datetime.now() + date_time = now.strftime('%Y-%m-%d-%H-%M-%S') + log_filename = '{}-{}'.format(log_filename, date_time) + os.makedirs(os.path.dirname(log_filename), exist_ok=True) + + if dist.is_available() and dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + formatter = f'%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s' + else: + formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s' + + level = logging.ERROR + if log_level == 'debug': + level = logging.DEBUG + elif log_level == 'info': + level = logging.INFO + elif log_level == 'warning': + level = logging.WARNING + logging.basicConfig(filename=log_filename, + format=formatter, + level=level, + filemode='w') + if use_console: + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter(formatter)) + logging.getLogger('').addHandler(console) + + +def load_checkpoint( + filename: Pathlike, + model: AcousticModel, + optimizer: Optional[object] = None, + scheduler: Optional[object] = None, + scaler: Optional[GradScaler] = None, +) -> Dict[str, Any]: + logging.info('load checkpoint from {}'.format(filename)) + + checkpoint = torch.load(filename, map_location='cpu') + + keys = [ + 'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate', 'objf', 'valid_objf', + 'num_features', 'num_classes', 'subsampling_factor', + 'global_batch_idx_train' + ] + missing_keys = set(keys) - set(checkpoint.keys()) + if missing_keys: + raise ValueError(f"Missing keys in checkpoint: {missing_keys}") + + if isinstance(model, DistributedDataParallel): + model = model.module + + if not list(model.state_dict().keys())[0].startswith('module.') \ + and list(checkpoint['state_dict'])[0].startswith('module.'): + # the checkpoint was saved by DDP + logging.info('load checkpoint from DDP') + dst_state_dict = model.state_dict() + src_state_dict = checkpoint['state_dict'] + for key in dst_state_dict.keys(): + src_key = '{}.{}'.format('module', key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict) + else: + model.load_state_dict(checkpoint['state_dict']) + + model.num_features = checkpoint['num_features'] + model.num_classes = checkpoint['num_classes'] + model.subsampling_factor = checkpoint['subsampling_factor'] + + if optimizer is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + + if scheduler is not None: + scheduler.load_state_dict(checkpoint['scheduler']) + + if scaler is not None: + scaler.load_state_dict(checkpoint['grad_scaler']) + + return checkpoint + + +def average_checkpoint_2nd(filenames: List[Pathlike], model: AcousticModel) -> Dict[str, Any]: + '''Average checkpoints for the 2nd pass model. + ''' + logging.info('average over checkpoints {}'.format(filenames)) + + avg_model = None + # sum + for filename in filenames: + checkpoint_model = torch.load(filename, map_location='cpu') + if avg_model is None: + avg_model = checkpoint_model + else: + for k in avg_model.keys(): + avg_model[k] += checkpoint_model[k] + + # average + for k in avg_model.keys(): + if avg_model[k] is not None: + if avg_model[k].is_floating_point(): + avg_model[k] /= len(filenames) + else: + avg_model[k] //= len(filenames) + + if not next(iter(model.state_dict().keys())).startswith('module.') \ + and next(iter(avg_model.keys())).startswith('module.'): + # the checkpoint was saved by DDP + dst_state_dict = model.state_dict() + src_state_dict = avg_model + for key in dst_state_dict.keys(): + src_key = '{}.{}'.format('module', key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict) + else: + model.load_state_dict(avg_model) + + +def average_checkpoint(filenames: List[Pathlike], model: AcousticModel) -> Dict[str, Any]: + logging.info('average over checkpoints {}'.format(filenames)) + + avg_model = None + + # sum + for filename in filenames: + checkpoint = torch.load(filename, map_location='cpu') + checkpoint_model = checkpoint['state_dict'] + if avg_model is None: + avg_model = checkpoint_model + else: + for k in avg_model.keys(): + avg_model[k] += checkpoint_model[k] + # average + for k in avg_model.keys(): + if avg_model[k] is not None: + if avg_model[k].is_floating_point(): + avg_model[k] /= len(filenames) + else: + avg_model[k] //= len(filenames) + + checkpoint['state_dict'] = avg_model + + keys = [ + 'state_dict', 'optimizer', 'scheduler', 'epoch', 'learning_rate', 'objf', 'valid_objf', + 'num_features', 'num_classes', 'subsampling_factor', + 'global_batch_idx_train' + ] + missing_keys = set(keys) - set(checkpoint.keys()) + if missing_keys: + raise ValueError(f"Missing keys in checkpoint: {missing_keys}") + + if not list(model.state_dict().keys())[0].startswith('module.') \ + and list(checkpoint['state_dict'])[0].startswith('module.'): + # the checkpoint was saved by DDP + logging.info('load checkpoint from DDP') + dst_state_dict = model.state_dict() + src_state_dict = checkpoint['state_dict'] + for key in dst_state_dict.keys(): + src_key = '{}.{}'.format('module', key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict) + else: + model.load_state_dict(checkpoint['state_dict']) + + model.num_features = checkpoint['num_features'] + model.num_classes = checkpoint['num_classes'] + model.subsampling_factor = checkpoint['subsampling_factor'] + + return checkpoint + + +def save_checkpoint( + filename: Pathlike, + model: Union[AcousticModel, DistributedDataParallel], + optimizer: object, + scheduler: object, + scaler: Optional[GradScaler], + epoch: int, + learning_rate: float, + objf: float, + valid_objf: float, + global_batch_idx_train: int, + local_rank: int = 0 +) -> None: + if local_rank is not None and local_rank != 0: + return + if isinstance(model, DistributedDataParallel): + model = model.module + logging.info(f'Save checkpoint to {filename}: epoch={epoch}, ' + f'learning_rate={learning_rate}, objf={objf}, valid_objf={valid_objf}') + checkpoint = { + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict() if optimizer is not None else None, + 'scheduler': scheduler.state_dict() if scheduler is not None else None, + 'grad_scaler': scaler.state_dict() if scaler is not None else None, + 'epoch': epoch, + 'learning_rate': learning_rate, + 'objf': objf, + 'valid_objf': valid_objf, + 'global_batch_idx_train': global_batch_idx_train, + 'num_features': model.num_features, + 'num_classes': model.num_classes, + 'subsampling_factor': model.subsampling_factor, + } + torch.save(checkpoint, filename) + + +def save_training_info( + filename: Pathlike, + model_path: Pathlike, + current_epoch: int, + learning_rate: float, + objf: float, + best_objf: float, + valid_objf: float, + best_valid_objf: float, + best_epoch: int, + local_rank: int = 0 +): + if local_rank is not None and local_rank != 0: + return + + with open(filename, 'w') as f: + f.write('model_path: {}\n'.format(model_path)) + f.write('epoch: {}\n'.format(current_epoch)) + f.write('learning rate: {}\n'.format(learning_rate)) + f.write('objf: {}\n'.format(objf)) + f.write('best objf: {}\n'.format(best_objf)) + f.write('valid objf: {}\n'.format(valid_objf)) + f.write('best valid objf: {}\n'.format(best_valid_objf)) + f.write('best epoch: {}\n'.format(best_epoch)) + + logging.info('write training info to {}'.format(filename)) + + +def get_phone_symbols(symbol_table: k2.SymbolTable, + pattern: str = r'^#\d+$') -> List[int]: + '''Return a list of phone IDs containing no disambiguation symbols. + + Caution: + 0 is not a phone ID so it is excluded from the return value. + + Args: + symbol_table: + A symbol table in k2. + pattern: + Symbols containing this pattern are disambiguation symbols. + Returns: + Return a list of symbol IDs excluding those from disambiguation symbols. + ''' + regex = re.compile(pattern) + symbols = symbol_table.symbols + ans = [] + for s in symbols: + if not regex.match(s): + ans.append(symbol_table[s]) + if 0 in ans: + ans.remove(0) + ans.sort() + return ans + + +def cut_id_dumper(dataloader, path: Path): + """ + Debugging utility. Writes processed cut IDs to a file. + Expects ``return_cuts=True`` to be passed to the Dataset class. + + Example:: + + >>> for batch in cut_id_dumper(dataloader): + ... pass + """ + if not dataloader.dataset.return_cuts: + return dataloader # do nothing, "return_cuts=True" was not set + with path.open('w') as f: + for batch in dataloader: + for cut in batch['supervisions']['cut']: + print(cut.id, file=f) + yield batch + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def describe(model: torch.nn.Module): + logging.info('=' * 80) + logging.info('Model parameters summary:') + logging.info('=' * 80) + total = 0 + for name, param in model.named_parameters(): + num_params = param.numel() + total += num_params + logging.info(f'* {name}: {num_params:>{80 - len(name) - 4}}') + logging.info('=' * 80) + logging.info(f'Total: {total}') + logging.info('=' * 80) + + +def get_texts(best_paths: k2.Fsa, indices: Optional[torch.Tensor] = None) -> List[List[int]]: + '''Extract the texts from the best-path FSAs, in the original order (before + the permutation given by `indices`). + Args: + best_paths: a k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Must have the 'aux_labels' attribute, as + a ragged tensor. + indices: possibly a torch.Tensor giving the permutation that we used + on the supervisions of this minibatch to put them in decreasing + order of num-frames. We'll apply the inverse permutation. + Doesn't have to be on the same device as `best_paths` + Return: + Returns a list of lists of int, containing the label sequences we + decoded. + ''' + # remove any 0's or -1's (there should be no 0's left but may be -1's.) + if isinstance(best_paths.aux_labels, k2.RaggedInt): + aux_labels = k2.ragged.remove_values_leq(best_paths.aux_labels, 0) + aux_shape = k2.ragged.compose_ragged_shapes(best_paths.arcs.shape(), + aux_labels.shape()) + # remove the states and arcs axes. + aux_shape = k2.ragged.remove_axis(aux_shape, 1) + aux_shape = k2.ragged.remove_axis(aux_shape, 1) + aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) + else: + # remove axis corresponding to states. + aux_shape = k2.ragged.remove_axis(best_paths.arcs.shape(), 1) + aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) + # remove 0's and -1's. + aux_labels = k2.ragged.remove_values_leq(aux_labels, 0) + + assert (aux_labels.num_axes() == 2) + aux_labels, _ = k2.ragged.index(aux_labels, + invert_permutation(indices).to(dtype=torch.int32, + device=best_paths.device)) + return k2.ragged.to_list(aux_labels) + + +def invert_permutation(indices: torch.Tensor) -> torch.Tensor: + ans = torch.zeros(indices.shape, device=indices.device, dtype=torch.long) + ans[indices] = torch.arange(0, indices.shape[0], device=indices.device) + return ans + + +def find_first_disambig_symbol(symbols: k2.SymbolTable) -> int: + return min(v for k, v in symbols._sym2id.items() if k.startswith('#')) + + +def store_transcripts(path: Pathlike, texts: Iterable[Tuple[str, str]]): + with open(path, 'w') as f: + for ref, hyp in texts: + print(f'ref={ref}', file=f) + print(f'hyp={hyp}', file=f) + +def write_error_stats(f: TextIO, test_set_name: str, results: List[Tuple[str,str]]) -> None: + subs: Dict[Tuple[str,str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0,0,0,0,0]) + num_corr = 0 + ERR = '*' + for ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word,hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word,hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for r,_ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = '%.2f' % (100.0 * tot_errs / ref_len) + + logging.info( + f'[{test_set_name}] %WER {tot_errs / ref_len:.2%} ' + f'[{tot_errs} / {ref_len}, {ins_errs} ins, {del_errs} del, {sub_errs} sub ]' + ) + + print(f"%WER = {tot_err_rate}", file=f) + print(f"Errors: {ins_errs} insertions, {del_errs} deletions, {sub_errs} substitutions, over {ref_len} reference words ({num_corr} correct)", + file=f) + print("Search below for sections starting with PER-UTT DETAILS:, SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [ [[x],[y]] for x,y in ali ] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i+1][0] != ali[i+1][1]: + ali[i+1][0] = ali[i][0] + ali[i+1][0] + ali[i+1][1] = ali[i][1] + ali[i+1][1] + ali[i] = [[],[]] + ali = [ [list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y))] + for x,y in ali ] + ali = list(filter(lambda x: x != [[],[]], ali)) + ali = [ [ERR if x == [] else ' '.join(x), + ERR if y == [] else ' '.join(y)] + for x,y in ali ] + + print(' '.join((ref_word if ref_word == hyp_word else f'({ref_word}->{hyp_word})' + for ref_word,hyp_word in ali)), file=f) + + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count,(ref,hyp) in sorted([(v,k) for k,v in subs.items()], reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count,ref in sorted([(v,k) for k,v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count,hyp in sorted([(v,k) for k,v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + for _,word,counts in sorted([(sum(v[1:]),k,v) for k,v in words.items()], reverse=True): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) diff --git a/snowfall/decoding/lm_rescore2.py b/snowfall/decoding/lm_rescore2.py new file mode 100644 index 00000000..3d0c1b51 --- /dev/null +++ b/snowfall/decoding/lm_rescore2.py @@ -0,0 +1,328 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from typing import Optional + +import logging +import math + +import k2 +import torch + + +def _intersect_device(a_fsas: k2.Fsa, b_fsas: k2.Fsa, b_to_a_map: torch.Tensor, + sorted_match_a: bool): + '''This is a wrapper of k2.intersect_device and its purpose is to split + b_fsas into several batches and process each batch separately to avoid + CUDA OOM error. + + The arguments and return value of this function are the same as + k2.intersect_device. + ''' + # NOTE: You can decrease batch_size in case of CUDA out of memory error. + batch_size = 500 + num_fsas = b_fsas.shape[0] + if num_fsas <= batch_size: + return k2.intersect_device(a_fsas, + b_fsas, + b_to_a_map=b_to_a_map, + sorted_match_a=sorted_match_a) + + num_batches = int(math.ceil(float(num_fsas) / batch_size)) + splits = [] + for i in range(num_batches): + start = i * batch_size + end = min(start + batch_size, num_fsas) + splits.append((start, end)) + + ans = [] + for start, end in splits: + indexes = torch.arange(start, end).to(b_to_a_map) + + fsas = k2.index(b_fsas, indexes) + b_to_a = k2.index(b_to_a_map, indexes) + path_lats = k2.intersect_device(a_fsas, + fsas, + b_to_a_map=b_to_a, + sorted_match_a=sorted_match_a) + ans.append(path_lats) + + return k2.cat(ans) + + +def compute_am_scores(lats: k2.Fsa, word_fsas_with_epsilon_loops: k2.Fsa, + path_to_seq_map: torch.Tensor) -> torch.Tensor: + '''Compute AM scores of n-best lists (represented as word_fsas). + + Args: + lats: + An FsaVec, which is the output of `k2.intersect_dense_pruned`. + It must have the attribute `lm_scores`. + word_fsas_with_epsilon_loops: + An FsaVec representing a n-best list. Note that it has been processed + by `k2.add_epsilon_self_loops`. + path_to_seq_map: + A 1-D torch.Tensor with dtype torch.int32. path_to_seq_map[i] indicates + which sequence the i-th Fsa in word_fsas_with_epsilon_loops belongs to. + path_to_seq_map.numel() == word_fsas_with_epsilon_loops.arcs.dim0(). + Returns: + Return a 1-D torch.Tensor containing the AM scores of each path. + `ans.numel() == word_fsas_with_epsilon_loops.shape[0]` + ''' + device = lats.device + assert len(lats.shape) == 3 + assert hasattr(lats, 'lm_scores') + + # k2.compose() currently does not support b_to_a_map. To void + # replicating `lats`, we use k2.intersect_device here. + # + # lats has phone IDs as `labels` and word IDs as aux_labels, so we + # need to invert it here. + inverted_lats = k2.invert(lats) + + # Now the `labels` of inverted_lats are word IDs (a 1-D torch.Tensor) + # and its `aux_labels` are phone IDs ( a k2.RaggedInt with 2 axes) + + # Remove its `aux_labels` since it is not needed in the + # following computation + del inverted_lats.aux_labels + inverted_lats = k2.arc_sort(inverted_lats) + + am_path_lats = _intersect_device(inverted_lats, + word_fsas_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True) + + am_path_lats = k2.top_sort(k2.connect(am_path_lats)) + + # The `scores` of every arc consists of `am_scores` and `lm_scores` + am_path_lats.scores = am_path_lats.scores - am_path_lats.lm_scores + + am_scores = am_path_lats.get_tot_scores(True, True) + + return am_scores + + +@torch.no_grad() +def rescore_with_n_best_list(lats: k2.Fsa, G: k2.Fsa, + num_paths: int) -> k2.Fsa: + '''Decode using n-best list with LM rescoring. + + `lats` is a decoding lattice, which has 3 axes. This function first + extracts `num_paths` paths from `lats` for each sequence using + `k2.random_paths`. The `am_scores` of these paths are computed. + For each path, its `lm_scores` is computed using `G` (which is an LM). + The final `tot_scores` is the sum of `am_scores` and `lm_scores`. + The path with the greatest `tot_scores` within a sequence is used + as the decoding output. + + Args: + lats: + An FsaVec. It can be the output of `k2.intersect_dense_pruned`. + G: + An FsaVec representing the language model (LM). Note that it + is an FsaVec, but it contains only one Fsa. + num_paths: + It is the size `n` in `n-best` list. + Returns: + An FsaVec representing the best decoding path for each sequence + in the lattice. + ''' + device = lats.device + + assert len(lats.shape) == 3 + assert hasattr(lats, 'aux_labels') + assert hasattr(lats, 'lm_scores') + + assert G.shape == (1, None, None) + assert G.device == device + assert hasattr(G, 'aux_labels') is False + + # First, extract `num_paths` paths for each sequence. + # paths is a k2.RaggedInt with axes [seq][path][arc_pos] + paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + + # word_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains word IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + word_seqs = k2.index(lats.aux_labels, paths) + + # Remove epsilons and -1 from word_seqs + word_seqs = k2.ragged.remove_values_leq(word_seqs, 0) + + # Remove repeated sequences to avoid redundant computation later. + # + # unique_word_seqs is still a k2.RaggedInt with 3 axes [seq][path][word] + # except that there are no repeated paths with the same word_seq + # within a seq. + # + # num_repeats is also a k2.RaggedInt with 2 axes containing the + # multiplicities of each path. + # num_repeats.num_elements() == unique_word_seqs.num_elements() + # + # Since k2.ragged.unique_sequences will reorder paths within a seq, + # `new2old` is a 1-D torch.Tensor mapping from the output path index + # to the input path index. + # new2old.numel() == unique_word_seqs.num_elements() + unique_word_seqs, num_repeats, new2old = k2.ragged.unique_sequences( + word_seqs, need_num_repeats=True, need_new2old_indexes=True) + + seq_to_path_shape = k2.ragged.get_layer(unique_word_seqs.shape(), 0) + + # path_to_seq_map is a 1-D torch.Tensor. + # path_to_seq_map[i] is the seq to which the i-th path + # belongs. + path_to_seq_map = seq_to_path_shape.row_ids(1) + + # Remove the seq axis. + # Now unique_word_seqs has only two axes [path][word] + unique_word_seqs = k2.ragged.remove_axis(unique_word_seqs, 0) + + # word_fsas is an FsaVec with axes [path][state][arc] + word_fsas = k2.linear_fsa(unique_word_seqs) + + word_fsas_with_epsilon_loops = k2.add_epsilon_self_loops(word_fsas) + + am_scores = compute_am_scores(lats, word_fsas_with_epsilon_loops, + path_to_seq_map) + + # Now compute lm_scores + b_to_a_map = torch.zeros_like(path_to_seq_map) + lm_path_lats = _intersect_device(G, + word_fsas_with_epsilon_loops, + b_to_a_map=b_to_a_map, + sorted_match_a=True) + lm_path_lats = k2.top_sort(k2.connect(lm_path_lats)) + lm_scores = lm_path_lats.get_tot_scores(True, False) + + tot_scores = am_scores + lm_scores + + # Remember that we used `k2.ragged.unique_sequences` to remove repeated + # paths to avoid redundant computation in `k2.intersect_device`. + # Now we use `num_repeats` to correct the scores for each path. + # + # NOTE(fangjun): It is commented out as it leads to a worse WER + # tot_scores = tot_scores * num_repeats.values() + + # TODO(fangjun): We may need to add `k2.RaggedDouble` + ragged_tot_scores = k2.RaggedFloat(seq_to_path_shape, + tot_scores.to(torch.float32)) + argmax_indexes = k2.ragged.argmax_per_sublist(ragged_tot_scores) + + # Use k2.index here since argmax_indexes' dtype is torch.int32 + best_path_indexes = k2.index(new2old, argmax_indexes) + + paths = k2.ragged.remove_axis(paths, 0) + + # best_path is a k2.RaggedInt with 2 axes [path][arc_pos] + best_paths = k2.index(paths, best_path_indexes) + + # labels is a k2.RaggedInt with 2 axes [path][phone_id] + # Note that it contains -1s. + labels = k2.index(lats.labels.contiguous(), best_paths) + + labels = k2.ragged.remove_values_eq(labels, -1) + + # lats.aux_labels is a k2.RaggedInt tensor with 2 axes, so + # aux_labels is also a k2.RaggedInt with 2 axes + aux_labels = k2.index(lats.aux_labels, best_paths.values()) + + best_path_fsas = k2.linear_fsa(labels) + best_path_fsas.aux_labels = aux_labels + + return best_path_fsas + + +@torch.no_grad() +def rescore_with_whole_lattice(lats: k2.Fsa, + G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: + '''Use whole lattice to rescore. + + Args: + lats: + An FsaVec It can be the output of `k2.intersect_dense_pruned`. + G_with_epsilon_loops: + An FsaVec representing the language model (LM). Note that it + is an FsaVec, but it contains only one Fsa. + ''' + assert len(lats.shape) == 3 + assert hasattr(lats, 'lm_scores') + assert G_with_epsilon_loops.shape == (1, None, None) + + device = lats.device + lats.scores = lats.scores - lats.lm_scores + del lats.lm_scores + # Now, lats.scores contains only am_scores + + # inverted_lats has word IDs as labels. + # Its aux_labels are phone IDs, which is a ragged tensor k2.RaggedInt + inverted_lats = k2.invert(lats) + num_seqs = lats.shape[0] + + b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) + try: + rescoring_lats = k2.intersect_device(G_with_epsilon_loops, + inverted_lats, + b_to_a_map, + sorted_match_a=True) + except RuntimeError as e: + logging.info(f'Caught exception:\n{e}\n') + logging.info(f'Number of FSAs: {inverted_lats.shape[0]}') + logging.info(f'num_arcs before pruning: {inverted_lats.arcs.num_elements()}') + + # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here + # to avoid OOM. We may need to fine tune it. + inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True) + logging.info(f'num_arcs after pruning: {inverted_lats.arcs.num_elements()}') + + rescoring_lats = k2.intersect_device(G_with_epsilon_loops, + inverted_lats, + b_to_a_map, + sorted_match_a=True) + + rescoring_lats = k2.top_sort(k2.connect(rescoring_lats)) + + if rescoring_lats.num_arcs == 0: + return rescoring_lats + + inverted_rescoring_lats = k2.invert(rescoring_lats) + # inverted rescoring_lats has phone IDs as labels + # and word IDs as aux_labels. + + best_paths = k2.shortest_path(inverted_rescoring_lats, + use_double_scores=True) + return best_paths + + +@torch.no_grad() +def decode_with_lm_rescoring(lats: k2.Fsa, G: k2.Fsa, num_paths: int, + use_whole_lattice: bool) -> k2.Fsa: + '''Decode using n-best list with LM rescoring. + + `lats` is a decoding lattice, which has 3 axes. This function first + extracts `num_paths` paths from `lats` for each sequence using + `k2.random_paths`. The `am_scores` of these paths are computed. + For each path, its `lm_scores` is computed using `G` (which is an LM). + The final `tot_scores` is the sum of `am_scores` and `lm_scores`. + The path with the greatest `tot_scores` within a sequence is used + as the decoding output. + + Args: + lats: + An FsaVec It can be the output of `k2.intersect_dense_pruned`. + G: + An FsaVec representing the language model (LM). Note that it + is an FsaVec, but it contains only one Fsa. + num_paths: + It is the size `n` in `n-best` list. + Used only if use_whole_lattice is False. + use_whole_lattice: + True to use whole lattice for rescoring. False to use n-best list + for rescoring. + Returns: + An FsaVec representing the best decoding path for each sequence + in the lattice. + ''' + if use_whole_lattice: + return rescore_with_whole_lattice(lats, G) + else: + return rescore_with_n_best_list(lats, G, num_paths) diff --git a/snowfall/decoding/second_pass.py b/snowfall/decoding/second_pass.py new file mode 100644 index 00000000..81d782aa --- /dev/null +++ b/snowfall/decoding/second_pass.py @@ -0,0 +1,258 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +# This file implements the ideas proposed by Daniel Povey. +# +# See https://github.com/k2-fsa/snowfall/issues/232 for more details +# +from typing import List + +import k2 +import torch + +# Note: We use `utterance` and `sequence` interchangeably in the comment + + +class Nbest(object): + ''' + An Nbest object contains two fields: + + (1) fsa, its type is k2.Fsa + (2) shape, its type is k2.RaggedShape (alias to _k2.RaggedShape) + + The field `fsa` is an FsaVec containing a vector of **linear** FSAs. + + The field `shape` has two axes [utt][path]. `shape.dim0()` contains + the number of utterances, which is also the number of rows in the + supervision_segments. `shape.tot_size(1)` contains the number + of paths, which is also the number of FSAs in `fsa`. + ''' + + def __init__(self, fsa: k2.Fsa, shape: k2.RaggedShape) -> None: + assert len(fsa.shape) == 3, f'fsa.shape: {fsa.shape}' + assert shape.num_axes() == 2, f'num_axes: {shape.num_axes()}' + + assert fsa.shape[0] == shape.tot_size(1), \ + f'{fsa.shape[0]} vs {shape.tot_size(1)}' + + self.fsa = fsa + self.shape = shape + + def __str__(self): + s = 'Nbest(' + s += f'num_seqs:{self.shape.dim0()}, ' + s += f'num_fsas:{self.fsa.shape[0]})' + return s + + def intersect(self, lats: k2.Fsa) -> 'Nbest': + '''Intersect this Nbest object with a lattice and get 1-best + path from the resulting FsaVec. + + Caution: + We assume FSAs in `self.fsa` don't have epsilon self-loops. + We also assume `self.fsa.labels` and `lats.labels` are token IDs. + + Args: + lats: + An FsaVec. It can be the return value of + :func:`whole_lattice_rescoring`. + Returns: + Return a new Nbest. This new Nbest shares the same shape with `self`, + while its `fsa` is the 1-best path from intersecting `self.fsa` and `lats. + ''' + assert self.fsa.device == lats.device, \ + f'{self.fsa.device} vs {lats.device}' + assert len(lats.shape) == 3, f'{lats.shape}' + assert lats.arcs.dim0() == self.shape.dim0(), \ + f'{lats.arcs.dim0()} vs {self.shape.dim0()}' + + lats = k2.arc_sort(lats) # no-op if lats is already arc sorted + + fsas_with_epsilon_loops = k2.add_epsilon_self_loops(self.fsa) + + path_to_seq_map = self.shape.row_ids(1) + + ans_lats = k2.intersect_device(a_fsas=lats, + b_fsas=fsas_with_epsilon_loops, + b_to_a_map=path_to_seq_map, + sorted_match_a=True) + + one_best = k2.shortest_path(ans_lats, use_double_scores=True) + + one_best = k2.remove_epsilon(one_best) + + return Nbest(fsa=one_best, shape=self.shape) + + def total_scores(self) -> k2.RaggedFloat: + '''Get total scores of the FSAs in this Nbest. + + Note: + Since FSAs in Nbest are just linear FSAs, log-semirng and tropical + semiring produce the same total scores. + + Returns: + Return a ragged tensor with two axes [utt][path_scores]. + ''' + scores = self.fsa.get_tot_scores(use_double_scores=True, + log_semiring=False) + # We use single precision here since we only wrap k2.RaggedFloat. + # If k2.RaggedDouble is wrapped, we can use double precision here. + return k2.RaggedFloat(self.shape, scores.float()) + + def top_k(self, k: int) -> 'Nbest': + '''Get a subset of paths in the Nbest. The resulting Nbest is regular + in that each sequence (i.e., utterance) has the same number of paths (k). + + We select the top-k paths according to the total_scores of each path. + If a utterance has less than k paths, then its last path, after sorting + by tot_scores in descending order, is repeated so that each utterance + has exactly k paths. + + Args: + k: + Number of paths in each utterance. + Returns: + Return a new Nbest with a regular shape. + ''' + ragged_scores = self.total_scores() + + # indexes contains idx01's for self.shape + # ragged_scores.values()[indexes] is sorted + indexes = k2.ragged.sort_sublist(ragged_scores, + descending=True, + need_new2old_indexes=True) + + ragged_indexes = k2.RaggedInt(self.shape, indexes) + + padded_indexes = k2.ragged.pad(ragged_indexes, + mode='replicate', + value=-1) + assert torch.ge(padded_indexes, 0).all(), \ + f'Some utterances contain empty n-best: {self.shape.row_splits(1)}' + + # Select the idx01's of top-k paths of each utterance + top_k_indexes = padded_indexes[:, :k].flatten().contiguous() + + top_k_fsas = k2.index_fsa(self.fsa, top_k_indexes) + + top_k_shape = k2.ragged.regular_ragged_shape(dim0=self.shape.dim0(), + dim1=k) + return Nbest(top_k_fsas, top_k_shape) + + +def whole_lattice_rescoring(lats: k2.Fsa, + G_with_epsilon_loops: k2.Fsa) -> k2.Fsa: + '''Rescore the 1st pass lattice with an LM. + + In general, the G in HLG used to obtain `lats` is a 3-gram LM. + This function replaces the 3-gram LM in `lats` with a 4-gram LM. + + Args: + lats: + The decoding lattice from the 1st pass. We assume it is the result + of intersecting HLG with the network output. + G_with_epsilon_loops: + An LM. It is usually a 4-gram LM with epsilon self-loops. + It should be arc sorted. + Returns: + Return a new lattice rescored with a given G. + ''' + assert len(lats.shape) == 3, f'{lats.shape}' + assert hasattr(lats, 'lm_scores') + assert G_with_epsilon_loops.shape == (1, None, None), \ + f'{G_with_epsilon_loops.shape}' + + device = lats.device + lats.scores = lats.scores - lats.lm_scores + # Now lats contains only acoustic scores + + # We will use lm_scores from the given G, so remove lats.lm_scores here + del lats.lm_scores + assert hasattr(lats, 'lm_scores') is False + + # inverted_lats has word IDs as labels. + # Its aux_labels are token IDs, which is a ragged tensor k2.RaggedInt + # if lats.aux_labels is a ragged tensor + inverted_lats = k2.invert(lats) + num_seqs = lats.shape[0] + + b_to_a_map = torch.zeros(num_seqs, device=device, dtype=torch.int32) + + while True: + try: + rescoring_lats = k2.intersect_device(G_with_epsilon_loops, + inverted_lats, + b_to_a_map, + sorted_match_a=True) + break + except RuntimeError as e: + logging.info(f'Caught exception:\n{e}\n') + # Usually, this is an OOM exception. We reduce + # the size of the lattice and redo k2.intersect_device() + + # NOTE(fangjun): The choice of the threshold 1e-5 is arbitrary here + # to avoid OOM. We may need to fine tune it. + logging.info(f'num_arcs before: {inverted_lats.num_arcs}') + inverted_lats = k2.prune_on_arc_post(inverted_lats, 1e-5, True) + logging.info(f'num_arcs after: {inverted_lats.num_arcs}') + + rescoring_lats = k2.top_sort(k2.connect(rescoring_lats)) + + # inv_rescoring_lats has token IDs as labels + # and word IDs as aux_labels. + inv_rescoring_lats = k2.invert(rescoring_lats) + return inv_rescoring_lats + + +def generate_nbest_list(lats: k2.Fsa, num_paths: int) -> Nbest: + '''Generate an n-best list from a lattice. + + Args: + lats: + The decoding lattice from the first pass after LM rescoring. + lats is an FsaVec. It can be the return value of + :func:`whole_lattice_rescoring` + num_paths: + Size of n for n-best list. CAUTION: After removing paths + that represent the same token sequences, the number of paths + in different sequences may not be equal. + Return: + Return an Nbest object. Note the returned FSAs don't have epsilon + self-loops. + ''' + assert len(lats.shape) == 3 + + # CAUTION: We use `phones` instead of `tokens` here because + # :func:`compile_HLG` uses `phones` + assert hasattr(lats, 'phones') + + assert not hasattr(lats, 'tokens') + lats.tokens = lats.phones + # we use tokens instead of phones in the following code + + # First, extract `num_paths` paths for each sequence. + # paths is a k2.RaggedInt with axes [seq][path][arc_pos] + paths = k2.random_paths(lats, num_paths=num_paths, use_double_scores=True) + + # token_seqs is a k2.RaggedInt sharing the same shape as `paths` + # but it contains token IDs. Note that it also contains 0s and -1s. + # The last entry in each sublist is -1. + # Its axes are [seq][path][token_id] + token_seqs = k2.index(lats.tokens, paths) + + # Remove epsilons (0s) and -1 from token_seqs + token_seqs = k2.ragged.remove_values_leq(token_seqs, 0) + + # unique_token_seqs is still a k2.RaggedInt with axes [seq][path]token_id]. + # But then number of pathsin each sequence may be different. + unique_token_seqs, _, _ = k2.ragged.unique_sequences( + word_seqs, need_num_repeats=False, need_new2old_indexes=False) + + seq_to_path_shape = k2.ragged.get_layer(unique_token_seqs.shape(), 0) + + # Remove the seq axis. + # Now unique_token_seqs has only two axes [path][token_id] + unique_token_seqs = k2.ragged.remove_axis(unique_token_seqs, 0) + + token_fsas = k2.linear_fsa(unique_token_seqs) + + return Nbest(fsa=token_fsas, shape=seq_to_path_shape) diff --git a/snowfall/decoding/second_pass_test.py b/snowfall/decoding/second_pass_test.py new file mode 100755 index 00000000..b918f098 --- /dev/null +++ b/snowfall/decoding/second_pass_test.py @@ -0,0 +1,77 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +from snowfall.decoding.second_pass import Nbest + +import k2 +import torch + + +def test_nbest_constructor(): + fsa = k2.Fsa.from_str(''' + 0 1 -1 0.1 + 1 + ''') + + fsa_vec = k2.create_fsa_vec([fsa, fsa, fsa]) + shape = k2.RaggedShape('[[x x] [x]]') + print(shape.num_axes()) + + nbest = Nbest(fsa_vec, shape) + print(nbest) + + +def test_top_k(): + fsa0 = k2.Fsa.from_str(''' + 0 1 -1 0 + 1 + ''') + fsas = [fsa0.clone() for i in range(10)] + fsa_vec = k2.create_fsa_vec(fsas) + fsa_vec.scores = torch.tensor([3, 0, 1, 5, 4, 2, 8, 1, 9, 6], + dtype=torch.float) + # 0 1 2 3 4 5 6 7 8 9 + # [ [3 0] [1 5 4] [2 8 1 9 6] + shape = k2.RaggedShape('[ [x x] [x x x] [x x x x x] ]') + nbest = Nbest(fsa_vec, shape) + + # top_k: k is 1 + nbest1 = nbest.top_k(1) + expected_fsa = k2.create_fsa_vec([fsa_vec[0], fsa_vec[3], fsa_vec[8]]) + assert str(nbest1.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x] [x] [x] ]') + assert nbest1.shape == expected_shape + + # top_k: k is 2 + nbest2 = nbest.top_k(2) + expected_fsa = k2.create_fsa_vec([ + fsa_vec[0], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[8], fsa_vec[6] + ]) + assert str(nbest2.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x x] [x x] [x x] ]') + assert nbest2.shape == expected_shape + + # top_k: k is 3 + nbest3 = nbest.top_k(3) + expected_fsa = k2.create_fsa_vec([ + fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4], fsa_vec[2], + fsa_vec[8], fsa_vec[6], fsa_vec[9] + ]) + assert str(nbest3.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x x x] [x x x] [x x x] ]') + assert nbest3.shape == expected_shape + + # top_k: k is 4 + nbest4 = nbest.top_k(4) + expected_fsa = k2.create_fsa_vec([ + fsa_vec[0], fsa_vec[1], fsa_vec[1], fsa_vec[1], fsa_vec[3], fsa_vec[4], + fsa_vec[2], fsa_vec[2], fsa_vec[8], fsa_vec[6], fsa_vec[9], fsa_vec[5] + ]) + assert str(nbest4.fsa) == str(expected_fsa) + + expected_shape = k2.RaggedShape('[ [x x x x] [x x x x] [x x x x] ]') + assert nbest4.shape == expected_shape diff --git a/snowfall/models/second_pass_model.py b/snowfall/models/second_pass_model.py new file mode 100755 index 00000000..5372d971 --- /dev/null +++ b/snowfall/models/second_pass_model.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +import logging + +import torch +import torch.nn as nn +from snowfall.models.transformer import PositionalEncoding +from snowfall.models.transformer import TransformerDecoderLayer +from snowfall.models.transformer import generate_square_subsequent_mask + +import k2 + + +def _compute_padding_mask(len_per_seq: torch.Tensor): + '''Sequences are of different lengths in number of frames + and they are padded to the longest length. This function + returns a mask to exclude the padded positions for attention. + + The returned mask is called `key_padding_mask` in PyTorch's + implementation of multihead attention. + + Args: + len_per_seq: + A 1-D tensor of dtype torch.int32 containing the number + of entries per seq before padding. + Returns: + Return a bool tensor of shape (len_per_seq.shape[0], max(len_per_seq)). + The masked positions contain True, while non-masked positions contain + False. + ''' + assert len_per_seq.ndim == 1 + num_seqs = len_per_seq.shape[0] + + device = len_per_seq.device + + max_len = len_per_seq.max().item() + + # [0, 1, 2, ..., max_len - 1] + seq_range = torch.arange(0, max_len, dtype=torch.int64, device=device) + + # [ + # [0, 1, 2, ..., max_len - 1] + # [0, 1, 2, ..., max_len - 1] + # [0, 1, 2, ..., max_len - 1] + # ... + # ] + seq_range_expanded = seq_range.unsqueeze(0).expand(num_seqs, max_len) + + # [ + # [x] + # [x] + # [x] + # ... + # ] + len_per_seq = len_per_seq.unsqueeze(-1) + + # It counts from zero, so >= instead of > is used. + # + # Padding positions are set to True + mask = seq_range_expanded >= len_per_seq + return mask + + +class SecondPassModel(nn.Module): + ''' + The second pass model accepts two inputs: + + - The encoder memory output of the first pass model + - The decoding denominator lattice of the first pass model + + For each sequence in the lattice, it computes the best path of it. + Then the labels of the best path are extracted, which are phone IDs. + Therefore, for each input frame, we can get its corresponding phone + ID, i.e., its alignment. + + The phone IDs of each best path is used as a query to an decoder + model. The encoder memory output from the first pass model is used + as input memory for the decoder model. + + At the inference stage, the second pass model is used for rescoring. + ''' + + def __init__(self, + max_phone_id: int, + d_model: int = 256, + dropout: float = 0.1, + nhead: int = 4, + dim_feedforward: int = 2048, + num_decoder_layers: int = 6): + super().__init__() + normalize_before = True # True to use pre_LayerNorm + + num_classes = max_phone_id + 1 # +1 for the blank symbol + + self.decoder_embed = nn.Embedding(num_classes, d_model) + + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder(decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm) + + self.output_linear = nn.Linear(d_model, num_classes) + + def forward(self, encoder_memory: torch.Tensor, best_paths: k2.Fsa, + supervision_segments: torch.Tensor): + ''' + Args: + encoder_memory: + The output of the first network before applying log-softmax. + If you're using Transformer/Conformer, it is the encoder output. + Its shape is (T, batch_size, d_model) + best_paths: + The 1-best results from the 1st pass decoding lattice. + ''' + device = encoder_memory.device + + # offset indicates the arc start index of each seq + offset = k2.index(best_paths.arcs.row_splits(2), + best_paths.arcs.row_splits(1)) + + # Note that len_per_seq does not count -1 + # + # number of phones per seq. + # minus 1 to exclude the label -1 + len_per_seq = offset[1:] - offset[:-1] - 1 + # We use clamp(0) here since it may happen + # that the best_path is empty when pruned_intersect + # is used. This happens rarely in the decoding script. + len_per_seq = len_per_seq.clamp(0) + + # Note: `phones` also contains -1, for the arcs entering the final state + phones = best_paths.labels.clone() + + # remove label -1 + phones = phones[phones != -1] + + # torch.split requires a tuple/list for sizes, so we use tolist() here. + phones_per_seq = torch.split(phones, len_per_seq.tolist()) + + # default padding value is 0 + padded_phones = nn.utils.rnn.pad_sequence(phones_per_seq, + batch_first=True) + # padded_phones is of shape (num_seqs, T) + # encoder_memory is of shape (T, num_batches, F) + + # Number of frames T should be equal + assert padded_phones.shape[1] == encoder_memory.shape[0] + # Caution: number of seqs is not necessarily equal to number of batches + # assert padded_phones.shape[0] == encoder_memory.shape[1] + + encoder_memory = encoder_memory.permute(1, 0, 2) + # Now encoder_memory is (num_batches, T, F) + + acoustic_out = [] + for segment in supervision_segments.tolist(): + batch_idx, start, duration = segment + end = start + duration + acoustic_tensor = encoder_memory[batch_idx, start:end] + acoustic_out.append(acoustic_tensor) + + padded_acoustics = nn.utils.rnn.pad_sequence(acoustic_out, + batch_first=True) + # padded_acoustics is of shape (num_seqs, T, F) + + x2 = self.decoder_embed(padded_phones.long()) + # x2 is (num_seqs, T, F) + x2 = self.decoder_pos(x2) + + assert x2.shape == padded_acoustics.shape + + # (B, T, F) -> (T, B, F) + x2 = x2.permute(1, 0, 2) + padded_acoustics = padded_acoustics.permute(1, 0, 2) + + # compute two masks + # (1) padding_mask + # (2) attn_mask for masked self-attention + + key_padding_mask = _compute_padding_mask(len_per_seq) + # key_padding_mask is of shape (B, T) + + attn_mask = generate_square_subsequent_mask(x2.shape[0]).to(device) + # attn_mask is of shape (T, T) + + x2 = self.decoder(tgt=x2, + memory=padded_acoustics, + tgt_mask=attn_mask, + tgt_key_padding_mask=key_padding_mask, + memory_key_padding_mask=key_padding_mask) + + # x2 is (T, B, F) + + x2 = x2.permute(1, 0, 2) + + # x2 is (B, T, F) + + out = self.output_linear(x2) + + out = nn.functional.log_softmax(out, dim=2) # (B, T, F) + + return out + + +def _test_compute_padding_mask(): + len_per_seq = torch.tensor([3, 5, 1, 2, 4]) + mask = _compute_padding_mask(len_per_seq) + expected_mask = torch.tensor([ + [False, False, False, True, True], + [False, False, False, False, False], + [False, True, True, True, True], + [False, False, True, True, True], + [False, False, False, False, True], + ]) + assert torch.all(torch.eq(mask, expected_mask)) + + +if __name__ == '__main__': + _test_compute_padding_mask() diff --git a/snowfall/objectives/common.py b/snowfall/objectives/common.py index e5a82676..4e018156 100644 --- a/snowfall/objectives/common.py +++ b/snowfall/objectives/common.py @@ -3,7 +3,6 @@ from torch import Tensor from typing import Dict, List, Tuple - def encode_supervisions(supervisions: Dict[str, Tensor]) -> Tuple[Tensor, List[str]]: """ Encodes Lhotse's ``batch["supervisions"]`` dict into a pair of torch Tensor, @@ -24,11 +23,13 @@ def encode_supervisions(supervisions: Dict[str, Tensor]) -> Tuple[Tensor, List[s (((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1 ).to(torch.int32) + supervision_segments = torch.clamp(supervision_segments, min=0) indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] texts = supervisions['text'] texts = [texts[idx] for idx in indices] + return supervision_segments, texts @@ -55,4 +56,4 @@ def get_tot_objf_and_num_frames( finite_indexes = torch.nonzero(mask).squeeze(1) ok_frames = frames_per_seq[finite_indexes].sum() all_frames = frames_per_seq.sum() - return tot_scores[finite_indexes].sum(), ok_frames, all_frames + return tot_scores[finite_indexes].sum(), ok_frames.item(), all_frames.item() diff --git a/snowfall/objectives/mmi2.py b/snowfall/objectives/mmi2.py new file mode 100644 index 00000000..486943a3 --- /dev/null +++ b/snowfall/objectives/mmi2.py @@ -0,0 +1,152 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn + +import k2 + +from snowfall.objectives.common2 import get_tot_objf_and_num_frames +from snowfall.training.mmi_graph2 import MmiTrainingGraphCompiler + + +class LFMMILoss(nn.Module): + """ + Computes Lattice-Free Maximum Mutual Information (LFMMI) loss. + + TODO: more detailed description + """ + + def __init__( + self, + graph_compiler: MmiTrainingGraphCompiler, + P: k2.Fsa, + den_scale: float = 1.0, + ): + super().__init__() + self.graph_compiler = graph_compiler + self.P = P + self.den_scale = den_scale + + def forward(self, + nnet_output: torch.Tensor, + texts: List[str], + supervision_segments: torch.Tensor, + ret_den_lats: bool = False + ) -> Tuple[torch.Tensor, int, int, Optional[k2.Fsa]]: + ''' + Args: + nnet_output: + A 3-D tensor of shape (N, T, F). It is passed to + :func:`k2.DenseFsaVec`, so it represents log_probs, + from `log_softmax()`. + texts: + A list of str. Each list item contains a transcript. + A transcript consists of space(s) separated words. + An example transcript looks like 'hello snowfall'. + An example texts is given below: + + ['hello k2', 'hello snowfall'] + + supervision_segments: + A 2-D tensor that will be passed to :func:`k2.DenseFsaVec`. + See :func:`k2.DenseFsaVec` for its format. + ret_den_lats: + True to also return the resulting denominator lattice. + Returns: + Return a tuple containing 6 entries: + + - A tensor with only one element containing the loss + + - Number of frames that contributes to the returned loss. + Note that frames of sequences that result in an infinity + loss are not counted. + + - Number of frames used in the computation. + + - The denominator lattice if ret_den_lats is True. + Otherwise, it is None. + - + ''' + num_graphs, den_graphs = self.graph_compiler.compile( + texts, self.P, replicate_den=False) + + dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments) + + device = num_graphs.device + + num_fsas = num_graphs.shape[0] + assert dense_fsa_vec.dim0() == num_fsas + + assert den_graphs.shape[0] == 1 + + # the aux_labels of num_graphs is k2.RaggedInt + # but it is torch.Tensor for den_graphs. + # + # The following converts den_graphs.aux_labels + # from torch.Tensor to k2.RaggedInt so that + # we can use k2.cat() later + den_graphs.convert_attr_to_ragged_(name='aux_labels') + + num_den_graphs = k2.cat([num_graphs, den_graphs]) + + # NOTE: The a_to_b_map in k2.intersect_dense must be sorted + # so the following reorders num_den_graphs. It also replicates + # den_graphs. + + # [0, 1, 2, ... ] + num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32) + + # [num_fsas, num_fsas, num_fsas, ... ] + den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, + dtype=torch.int32) + + # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] + num_den_graphs_indexes = torch.stack( + [num_graphs_indexes, + den_graphs_indexes]).t().reshape(-1).to(device) + + num_den_reordered_graphs = k2.index(num_den_graphs, + num_den_graphs_indexes) + # Now num_den_reordered_graphs contains + # [num_graph0, den_graph0, num_graph1, den_graph1, ... ] + + # [[0, 1, 2, ...]] + a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1) + + # [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ] + a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device) + + num_den_lats = k2.intersect_dense(num_den_reordered_graphs, + dense_fsa_vec, + output_beam=10.0, + a_to_b_map=a_to_b_map) + # num_den_lats contains + # [num_lats0, den_lats0, num_lats1, den_lats1, ... ] + + num_den_tot_scores = num_den_lats.get_tot_scores( + log_semiring=True, use_double_scores=True) + + num_tot_scores = num_den_tot_scores[::2] + den_tot_scores = num_den_tot_scores[1::2] + + tot_scores = num_tot_scores - self.den_scale * den_tot_scores + tot_score, tot_frames, all_frames = get_tot_objf_and_num_frames( + tot_scores, supervision_segments[:, 2]) + if ret_den_lats: + # [1, 3, 5, ... ] + den_lats_indexes = torch.arange(start=1, + end=(2 * num_fsas), + step=2, + dtype=torch.int32, + device=device) + with torch.no_grad(): + den_lats = k2.index(num_den_lats, den_lats_indexes) + + assert den_lats.requires_grad is False + else: + den_lats = None + num_den_reordered_graphs = None + a_to_b_map = None + + # TODO(fangjun): return a dict + return tot_score, tot_frames, all_frames, den_lats, num_den_reordered_graphs, a_to_b_map diff --git a/snowfall/training/mmi_graph2.py b/snowfall/training/mmi_graph2.py new file mode 100644 index 00000000..a32be382 --- /dev/null +++ b/snowfall/training/mmi_graph2.py @@ -0,0 +1,177 @@ +# Copyright (c) 2020 Xiaomi Corp. (author: Fangjun Kuang) + +from typing import Iterable +from typing import List +from typing import Tuple + +import k2 +import torch + +from .ctc_graph import build_ctc_topo +from snowfall.common2 import get_phone_symbols +from ..lexicon import Lexicon + + +def create_bigram_phone_lm(phones: List[int]) -> k2.Fsa: + '''Create a bigram phone LM. + The resulting FSA (P) has a start-state and a state for + each phone 1, 2, ....; and each of the above-mentioned states + has a transition to the state for each phone and also to the final-state. + + Caution: + blank is not a phone. + + Args: + A list of phone IDs. + + Returns: + An FSA representing the bigram phone LM. + ''' + assert 0 not in phones + final_state = len(phones) + 1 + rules = '' + for i in range(1, final_state): + rules += f'0 {i} {phones[i-1]} 0.0\n' + + for i in range(1, final_state): + for j in range(1, final_state): + rules += f'{i} {j} {phones[j-1]} 0.0\n' + rules += f'{i} {final_state} -1 0.0\n' + rules += f'{final_state}' + return k2.Fsa.from_str(rules) + + +class MmiTrainingGraphCompiler(object): + + def __init__( + self, + lexicon: Lexicon, + device: torch.device, + oov: str = '' + ): + ''' + Args: + L_inv: + Its labels are words, while its aux_labels are phones. + phones: + The phone symbol table. + words: + The word symbol table. + oov: + Out of vocabulary word. + ''' + self.lexicon = lexicon + L_inv = self.lexicon.L_inv.to(device) + + if L_inv.properties & k2.fsa_properties.ARC_SORTED != 0: + L_inv = k2.arc_sort(L_inv) + + assert L_inv.requires_grad is False + + assert oov in self.lexicon.words + + self.L_inv = L_inv + self.oov_id = self.lexicon.words[oov] + self.oov = oov + self.device = device + + phone_symbols = get_phone_symbols(self.lexicon.phones) + phone_symbols_with_blank = [0] + phone_symbols + + ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device) + assert ctc_topo.requires_grad is False + + self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_()) + + self.max_phone_id = max(phone_symbols) + + def compile(self, + texts: Iterable[str], + P: k2.Fsa, + replicate_den: bool = True) -> Tuple[k2.Fsa, k2.Fsa]: + '''Create numerator and denominator graphs from transcripts + and the bigram phone LM. + + Args: + texts: + A list of transcripts. Within a transcript, words are + separated by spaces. + P: + The bigram phone LM created by :func:`create_bigram_phone_lm`. + replicate_den: + If True, the returned den_graph is replicated to match the number + of FSAs in the returned num_graph; if False, the returned den_graph + contains only a single FSA + Returns: + A tuple (num_graph, den_graph), where + + - `num_graph` is the numerator graph. It is an FsaVec with + shape `(len(texts), None, None)`. + + - `den_graph` is the denominator graph. It is an FsaVec with the same + shape of the `num_graph` if replicate_den is True; otherwise, it + is an FsaVec containing only a single FSA. + ''' + assert P.device == self.device + P_with_self_loops = k2.add_epsilon_self_loops(P) + + ctc_topo_P = k2.intersect(self.ctc_topo_inv, + P_with_self_loops, + treat_epsilons_specially=False).invert() + + ctc_topo_P = k2.arc_sort(ctc_topo_P) + + num_graphs = self.build_num_graphs(texts) + num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops( + num_graphs) + + num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops) + + num = k2.compose(ctc_topo_P, + num_graphs_with_self_loops, + treat_epsilons_specially=False) + num = k2.arc_sort(num) + + ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()]) + if replicate_den: + indexes = torch.zeros(len(texts), + dtype=torch.int32, + device=self.device) + den = k2.index_fsa(ctc_topo_P_vec, indexes) + else: + den = ctc_topo_P_vec + + return num, den + + def build_num_graphs(self, texts: List[str]) -> k2.Fsa: + '''Convert transcript to an Fsa with the help of lexicon + and word symbol table. + + Args: + texts: + Each element is a transcript containing words separated by spaces. + For instance, it may be 'HELLO SNOWFALL', which contains + two words. + + Returns: + Return an FST (FsaVec) corresponding to the transcript. Its `labels` are + phone IDs and `aux_labels` are word IDs. + ''' + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(' '): + if word in self.lexicon.words: + word_ids.append(self.lexicon.words[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + + fsa = k2.linear_fsa(word_ids_list, self.device) + fsa = k2.add_epsilon_self_loops(fsa) + assert fsa.device == self.device + num_graphs = k2.intersect(self.L_inv, + fsa, + treat_epsilons_specially=False).invert_() + num_graphs = k2.arc_sort(num_graphs) + return num_graphs