diff --git a/README.md b/README.md index 2058bfb007..14c9e55b16 100755 --- a/README.md +++ b/README.md @@ -87,12 +87,14 @@ For tensorflow 2.4.x, run `pip3 install -U 'TensorFlowASR[tf2.4]'` or `pip3 inst For tensorflow 2.5.x, run `pip3 install -U 'TensorFlowASR[tf2.5]'` or `pip3 install -U 'TensorFlowASR[tf2.5-gpu]'` +For tensorflow 2.6.x, run `pip3 install -U 'TensorFlowASR[tf2.6]'` or `pip3 install -U 'TensorFlowASR[tf2.6-gpu]'` + ### Installing from source ```bash git clone https://github.com/TensorSpeech/TensorFlowASR.git cd TensorFlowASR -pip3 install '.[tf2.3]' # or '.[tf2.3-gpu]' or '.[tf2.4]' or '.[tf2.4-gpu]' or '.[tf2.5]' or '.[tf2.5-gpu]' +pip3 install -e '.[tf2.6]' # see other options in setup.py file ``` For anaconda3: diff --git a/examples/conformer/saved_model.py b/examples/conformer/saved_model.py new file mode 100644 index 0000000000..4a007aefe5 --- /dev/null +++ b/examples/conformer/saved_model.py @@ -0,0 +1,116 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from tensorflow_asr.utils import env_util + +logger = env_util.setup_environment() +import tensorflow as tf + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Testing") + +parser.add_argument( + "--config", + type=str, + default=DEFAULT_YAML, + help="The file path of model configuration file", +) + +parser.add_argument( + "--h5", + type=str, + default=None, + help="Path to saved h5 weights", +) + +parser.add_argument( + "--sentence_piece", + default=False, + action="store_true", + help="Whether to use `SentencePiece` model", +) + +parser.add_argument( + "--subwords", + default=False, + action="store_true", + help="Use subwords", +) + +parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Output directory for saved model", +) + +args = parser.parse_args() + +assert args.h5 +assert args.output_dir + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SentencePieceFeaturizer, SubwordFeaturizer +from tensorflow_asr.models.transducer.conformer import Conformer + +config = Config(args.config) +speech_featurizer = TFSpeechFeaturizer(config.speech_config) + +if args.sentence_piece: + logger.info("Use SentencePiece ...") + text_featurizer = SentencePieceFeaturizer(config.decoder_config) +elif args.subwords: + logger.info("Use subwords ...") + text_featurizer = SubwordFeaturizer(config.decoder_config) +else: + logger.info("Use characters ...") + text_featurizer = CharFeaturizer(config.decoder_config) + +tf.random.set_seed(0) + +# build model +conformer = Conformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) +conformer.make(speech_featurizer.shape) +conformer.load_weights(args.h5, by_name=True) +conformer.summary(line_length=100) +conformer.add_featurizers(speech_featurizer, text_featurizer) + + +class aModule(tf.Module): + def __init__(self, model): + super().__init__() + self.model = model + + @tf.function( + input_signature=[ + { + "inputs": tf.TensorSpec(shape=[None, None, 80, 1], dtype=tf.float32, name="inputs"), + "inputs_length": tf.TensorSpec(shape=[None], dtype=tf.int32, name="inputs_length"), + } + ] + ) + def pred(self, input_batch): + result = self.model.recognize(input_batch) + return {"ASR": result} + + +module = aModule(conformer) +tf.saved_model.save(module, args.output_dir, signatures={"serving_default": module.pred}) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..dcff8a778b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 127 diff --git a/requirements.txt b/requirements.txt index afc3f53fe8..3a54699881 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,18 +1,67 @@ -cython -numpy -scipy -sklearn -pandas -tensorflow-datasets>=4.2.0 -tensorflow-addons>=0.11.1 -setuptools>=47.1.1 -librosa>=0.8.0 -soundfile>=0.10.3 -PyYAML>=5.3.1 -matplotlib>=3.2.1 -sox>=1.4.1 -tqdm>=4.54.1 -colorama>=0.4.4 -nlpaug>=1.1.1 -nltk>=3.5 -sentencepiece>=0.1.94 \ No newline at end of file +absl-py==0.12.0 +appdirs==1.4.4 +astroid==2.6.6 +attrs==21.2.0 +audioread==2.1.9 +black==21.7b0 +certifi==2021.5.30 +cffi==1.14.6 +charset-normalizer==2.0.4 +click==8.0.1 +colorama==0.4.4 +cycler==0.10.0 +Cython==0.29.24 +decorator==5.0.9 +dill==0.3.4 +flake8==3.9.2 +future==0.18.2 +googleapis-common-protos==1.53.0 +idna==3.2 +isort==5.9.3 +joblib==1.0.1 +kiwisolver==1.3.1 +lazy-object-proxy==1.6.0 +librosa==0.8.1 +llvmlite==0.36.0 +matplotlib==3.4.3 +mccabe==0.6.1 +mypy-extensions==0.4.3 +nlpaug==1.1.7 +nltk==3.6.2 +numba==0.53.1 +numpy==1.19.5 +packaging==21.0 +pandas==1.3.1 +pathspec==0.9.0 +Pillow==8.3.1 +pooch==1.4.0 +promise==2.3 +protobuf==3.17.3 +pycodestyle==2.7.0 +pycparser==2.20 +pyflakes==2.3.1 +pyparsing==2.4.7 +python-dateutil==2.8.2 +pytz==2021.1 +PyYAML==5.4.1 +regex==2021.8.3 +requests==2.26.0 +resampy==0.2.2 +scikit-learn==0.24.2 +scipy==1.7.1 +sentencepiece==0.1.96 +six==1.15.0 +sklearn==0.0 +SoundFile==0.10.3.post1 +sox==1.4.1 +tensorflow-addons==0.13.0 +tensorflow-datasets==4.4.0 +tensorflow-metadata==1.2.0 +termcolor==1.1.0 +threadpoolctl==2.2.0 +toml==0.10.2 +tomli==1.2.1 +tqdm==4.62.1 +typeguard==2.12.1 +urllib3==1.26.6 +wrapt==1.12.1 diff --git a/setup.cfg b/setup.cfg index 3db07c49e8..246351288b 100755 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,8 @@ [flake8] -ignore = E402,E701,E702,E704,E251,W503,W504,C901 +ignore = E402,E701,E702,E704,E251,E203,W503,W504,C901 max-line-length = 127 [pep8] -ignore = E402,E701,E702,E704,E251,W503,W504,C901 +ignore = E402,E701,E702,E704,E251,E203,W503,W504,C901 max-line-length = 127 indent-size = 4 diff --git a/setup.py b/setup.py index bba454bee8..cd4984e85b 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ setuptools.setup( name="TensorFlowASR", - version="1.0.2", + version="1.0.3", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", @@ -32,12 +32,14 @@ packages=setuptools.find_packages(include=["tensorflow_asr*"]), install_requires=requirements, extras_require={ - "tf2.3": ["tensorflow>=2.3.0,<2.4", "tensorflow-text>2.3.0,<2.4", "tensorflow-io>=0.16.0,<0.17"], - "tf2.3-gpu": ["tensorflow-gpu>=2.3.0,<2.4", "tensorflow-text>=2.3.0,<2.4", "tensorflow-io>=0.16.0,<0.17"], - "tf2.4": ["tensorflow>=2.4.0,<2.5", "tensorflow-text>=2.4.0,<2.5", "tensorflow-io>=0.17.0,<0.18"], - "tf2.4-gpu": ["tensorflow-gpu>=2.4.0,<2.5", "tensorflow-text>=2.4.0,<2.5", "tensorflow-io>=0.17.0,<0.18"], - "tf2.5": ["tensorflow>=2.5.0,<2.6", "tensorflow-text>=2.5.0,<2.6", "tensorflow-io>=0.18.0,<0.19"], - "tf2.5-gpu": ["tensorflow-gpu>=2.5.0,<2.6", "tensorflow-text>=2.5.0,<2.6", "tensorflow-io>=0.18.0,<0.19"] + "tf2.3": ["tensorflow~=2.3.0", "tensorflow-text~=2.3.0", "tensorflow-io~=0.16.0"], + "tf2.3-gpu": ["tensorflow-gpu~=2.3.0", "tensorflow-text~=2.3.0", "tensorflow-io~=0.16.0"], + "tf2.4": ["tensorflow~=2.4.0", "tensorflow-text~=2.4.0", "tensorflow-io~=0.17.0"], + "tf2.4-gpu": ["tensorflow-gpu~=2.4.0", "tensorflow-text~=2.4.0", "tensorflow-io~=0.17.0"], + "tf2.5": ["tensorflow~=2.5.0", "tensorflow-text~=2.5.0", "tensorflow-io~=0.18.0"], + "tf2.5-gpu": ["tensorflow-gpu~=2.5.0", "tensorflow-text~=2.5.0", "tensorflow-io~=0.18.0"], + "tf2.6": ["tensorflow~=2.6.0", "tensorflow-text~=2.6.0rc0", "tensorflow-io~=0.20.0"], + "tf2.6-gpu": ["tensorflow-gpu~=2.6.0", "tensorflow-text~=2.6.0rc0", "tensorflow-io~=0.20.0"], }, classifiers=[ "Programming Language :: Python :: 3.6", @@ -46,7 +48,7 @@ "Intended Audience :: Science/Research", "Operating System :: POSIX :: Linux", "License :: OSI Approved :: Apache Software License", - "Topic :: Software Development :: Libraries :: Python Modules" + "Topic :: Software Development :: Libraries :: Python Modules", ], - python_requires='>=3.6', + python_requires=">=3.6", ) diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 525d8bdc45..ef4e78e814 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -12,21 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import io import abc +import io import math +import os from typing import Union -import numpy as np + import librosa +import numpy as np import soundfile as sf import tensorflow as tf import tensorflow_io as tfio -from ..utils import math_util, env_util +from ..utils import env_util, math_util from .methods import gammatone - # def tf_resample(signal, rate_in, rate_out): # if rate_in == rate_out: return signal # rate_in = tf.cast(rate_in, dtype=tf.float32) @@ -35,28 +35,39 @@ # nsamples = tf.math.ceil(tf.shape(signal)[0] * ratio) -def load_and_convert_to_wav(path: str) -> tf.Tensor: +def load_and_convert_to_wav( + path: str, +) -> tf.Tensor: wave, rate = librosa.load(os.path.expanduser(path), sr=None, mono=True) return tf.audio.encode_wav(tf.expand_dims(wave, axis=-1), sample_rate=rate) -def read_raw_audio(audio: Union[str, bytes, np.ndarray], sample_rate=16000) -> np.ndarray: +def read_raw_audio( + audio: Union[str, bytes, np.ndarray], + sample_rate=16000, +) -> np.ndarray: if isinstance(audio, str): wave, _ = librosa.load(os.path.expanduser(audio), sr=sample_rate, mono=True) elif isinstance(audio, bytes): wave, sr = sf.read(io.BytesIO(audio)) - if wave.ndim > 1: wave = np.mean(wave, axis=-1) + if wave.ndim > 1: + wave = np.mean(wave, axis=-1) wave = np.asfortranarray(wave) - if sr != sample_rate: wave = librosa.resample(wave, sr, sample_rate) + if sr != sample_rate: + wave = librosa.resample(wave, sr, sample_rate) elif isinstance(audio, np.ndarray): - if audio.ndim > 1: ValueError("input audio must be single channel") + if audio.ndim > 1: + ValueError("input audio must be single channel") return audio else: raise ValueError("input audio must be either a path or bytes") return wave -def tf_read_raw_audio(audio: tf.Tensor, sample_rate=16000) -> tf.Tensor: +def tf_read_raw_audio( + audio: tf.Tensor, + sample_rate=16000, +) -> tf.Tensor: wave, rate = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1) if not env_util.has_devices("TPU"): resampled = tfio.audio.resample(wave, rate_in=tf.cast(rate, dtype=tf.int64), rate_out=sample_rate) @@ -64,36 +75,44 @@ def tf_read_raw_audio(audio: tf.Tensor, sample_rate=16000) -> tf.Tensor: return tf.reshape(wave, shape=[-1]) # reshape for using tf.signal -def slice_signal(signal, window_size, stride=0.5) -> np.ndarray: - """ Return windows of the given signal by sweeping in stride fractions of window """ +def slice_signal( + signal, + window_size, + stride=0.5, +) -> np.ndarray: + """Return windows of the given signal by sweeping in stride fractions of window""" assert signal.ndim == 1, signal.ndim n_samples = signal.shape[0] offset = int(window_size * stride) slices = [] - for beg_i, end_i in zip(range(0, n_samples, offset), - range(window_size, n_samples + offset, - offset)): + for beg_i, end_i in zip(range(0, n_samples, offset), range(window_size, n_samples + offset, offset)): slice_ = signal[beg_i:end_i] if slice_.shape[0] < window_size: - slice_ = np.pad( - slice_, (0, window_size - slice_.shape[0]), 'constant', constant_values=0.0) + slice_ = np.pad(slice_, (0, window_size - slice_.shape[0]), "constant", constant_values=0.0) if slice_.shape[0] == window_size: slices.append(slice_) return np.array(slices, dtype=np.float32) -def tf_merge_slices(slices: tf.Tensor) -> tf.Tensor: +def tf_merge_slices( + slices: tf.Tensor, +) -> tf.Tensor: # slices shape = [batch, window_size] return tf.keras.backend.flatten(slices) # return shape = [-1, ] -def merge_slices(slices: np.ndarray) -> np.ndarray: +def merge_slices( + slices: np.ndarray, +) -> np.ndarray: # slices shape = [batch, window_size] return np.reshape(slices, [-1]) -def normalize_audio_feature(audio_feature: np.ndarray, per_frame=False) -> np.ndarray: - """ Mean and variance normalization """ +def normalize_audio_feature( + audio_feature: np.ndarray, + per_frame=False, +) -> np.ndarray: + """Mean and variance normalization""" axis = 1 if per_frame else None mean = np.mean(audio_feature, axis=axis) std_dev = np.sqrt(np.var(audio_feature, axis=axis) + 1e-9) @@ -101,7 +120,10 @@ def normalize_audio_feature(audio_feature: np.ndarray, per_frame=False) -> np.nd return normalized -def tf_normalize_audio_features(audio_feature: tf.Tensor, per_frame=False) -> tf.Tensor: +def tf_normalize_audio_features( + audio_feature: tf.Tensor, + per_frame=False, +) -> tf.Tensor: """ TF Mean and variance features normalization Args: @@ -116,13 +138,17 @@ def tf_normalize_audio_features(audio_feature: tf.Tensor, per_frame=False) -> tf return (audio_feature - mean) / std_dev -def normalize_signal(signal: np.ndarray) -> np.ndarray: - """ Normailize signal to [-1, 1] range """ +def normalize_signal( + signal: np.ndarray, +) -> np.ndarray: + """Normailize signal to [-1, 1] range""" gain = 1.0 / (np.max(np.abs(signal)) + 1e-9) return signal * gain -def tf_normalize_signal(signal: tf.Tensor) -> tf.Tensor: +def tf_normalize_signal( + signal: tf.Tensor, +) -> tf.Tensor: """ TF Normailize signal to [-1, 1] range Args: @@ -135,13 +161,19 @@ def tf_normalize_signal(signal: tf.Tensor) -> tf.Tensor: return signal * gain -def preemphasis(signal: np.ndarray, coeff=0.97) -> np.ndarray: +def preemphasis( + signal: np.ndarray, + coeff=0.97, +) -> np.ndarray: if not coeff or coeff <= 0.0: return signal return np.append(signal[0], signal[1:] - coeff * signal[:-1]) -def tf_preemphasis(signal: tf.Tensor, coeff=0.97): +def tf_preemphasis( + signal: tf.Tensor, + coeff=0.97, +): """ TF Pre-emphasis Args: @@ -151,14 +183,19 @@ def tf_preemphasis(signal: tf.Tensor, coeff=0.97): Returns: pre-emphasized signal with shape [None] """ - if not coeff or coeff <= 0.0: return signal + if not coeff or coeff <= 0.0: + return signal s0 = tf.expand_dims(signal[0], axis=-1) s1 = signal[1:] - coeff * signal[:-1] return tf.concat([s0, s1], axis=-1) -def depreemphasis(signal: np.ndarray, coeff=0.97) -> np.ndarray: - if not coeff or coeff <= 0.0: return signal +def depreemphasis( + signal: np.ndarray, + coeff=0.97, +) -> np.ndarray: + if not coeff or coeff <= 0.0: + return signal x = np.zeros(signal.shape[0], dtype=np.float32) x[0] = signal[0] for n in range(1, signal.shape[0], 1): @@ -166,7 +203,10 @@ def depreemphasis(signal: np.ndarray, coeff=0.97) -> np.ndarray: return x -def tf_depreemphasis(signal: tf.Tensor, coeff=0.97) -> tf.Tensor: +def tf_depreemphasis( + signal: tf.Tensor, + coeff=0.97, +) -> tf.Tensor: """ TF Depreemphasis Args: @@ -176,7 +216,8 @@ def tf_depreemphasis(signal: tf.Tensor, coeff=0.97) -> tf.Tensor: Returns: depre-emphasized signal with shape [B, None] """ - if not coeff or coeff <= 0.0: return signal + if not coeff or coeff <= 0.0: + return signal def map_fn(elem): x = tf.expand_dims(elem[0], axis=-1) @@ -189,7 +230,10 @@ def map_fn(elem): class SpeechFeaturizer(metaclass=abc.ABCMeta): - def __init__(self, speech_config: dict): + def __init__( + self, + speech_config: dict, + ): """ We should use TFSpeechFeaturizer for training to avoid differences between tf and librosa when converting to tflite in post-training stage @@ -226,20 +270,27 @@ def __init__(self, speech_config: dict): @property def nfft(self) -> int: - """ Number of FFT """ + """Number of FFT""" return 2 ** (self.frame_length - 1).bit_length() @property def shape(self) -> list: - """ The shape of extracted features """ + """The shape of extracted features""" raise NotImplementedError() - def get_length_from_duration(self, duration): + def get_length_from_duration( + self, + duration, + ): nsamples = math.ceil(float(duration) * self.sample_rate) - if self.center: nsamples += self.nfft + if self.center: + nsamples += self.nfft return 1 + (nsamples - self.nfft) // self.frame_step # https://www.tensorflow.org/api_docs/python/tf/signal/frame - def update_length(self, length: int): + def update_length( + self, + length: int, + ): self.max_length = max(self.max_length, length) def reset_length(self): @@ -255,7 +306,7 @@ def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0): @abc.abstractmethod def extract(self, signal): - """ Function to perform feature extraction """ + """Function to perform feature extraction""" raise NotImplementedError() @@ -284,15 +335,36 @@ def shape(self) -> list: return [length, self.num_feature_bins, channel_dim] - def stft(self, signal): + def stft( + self, + signal, + ): return np.square( - np.abs(librosa.core.stft(signal, n_fft=self.nfft, hop_length=self.frame_step, - win_length=self.frame_length, center=self.center, window="hann"))) + np.abs( + librosa.core.stft( + signal, + n_fft=self.nfft, + hop_length=self.frame_step, + win_length=self.frame_length, + center=self.center, + window="hann", + ) + ) + ) - def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0): + def power_to_db( + self, + S, + ref=1.0, + amin=1e-10, + top_db=80.0, + ): return librosa.power_to_db(S, ref=ref, amin=amin, top_db=top_db) - def extract(self, signal: np.ndarray) -> np.ndarray: + def extract( + self, + signal: np.ndarray, + ) -> np.ndarray: signal = np.asfortranarray(signal) if self.normalize_signal: signal = normalize_signal(signal) @@ -307,9 +379,9 @@ def extract(self, signal: np.ndarray) -> np.ndarray: elif self.feature_type == "log_gammatone_spectrogram": features = self.compute_log_gammatone_spectrogram(signal) else: - raise ValueError("feature_type must be either 'mfcc', " - "'log_mel_spectrogram', 'log_gammatone_spectrogram' " - "or 'spectrogram'") + raise ValueError( + "feature_type must be either 'mfcc', " "'log_mel_spectrogram', 'log_gammatone_spectrogram' " "or 'spectrogram'" + ) original_features = features.copy() @@ -327,80 +399,103 @@ def extract(self, signal: np.ndarray) -> np.ndarray: if self.delta_delta: delta_delta = librosa.feature.delta(original_features.T, order=2).T if self.normalize_feature: - delta_delta = normalize_audio_feature( - delta_delta, per_frame=self.normalize_per_frame) + delta_delta = normalize_audio_feature(delta_delta, per_frame=self.normalize_per_frame) features = np.concatenate([features, np.expand_dims(delta_delta, axis=-1)], axis=-1) if self.pitch: pitches = self.compute_pitch(signal) if self.normalize_feature: - pitches = normalize_audio_feature( - pitches, per_frame=self.normalize_per_frame) + pitches = normalize_audio_feature(pitches, per_frame=self.normalize_per_frame) features = np.concatenate([features, np.expand_dims(pitches, axis=-1)], axis=-1) return features - def compute_pitch(self, signal: np.ndarray) -> np.ndarray: + def compute_pitch( + self, + signal: np.ndarray, + ) -> np.ndarray: pitches, _ = librosa.core.piptrack( - y=signal, sr=self.sample_rate, - n_fft=self.nfft, hop_length=self.frame_step, - fmin=0.0, fmax=int(self.sample_rate / 2), win_length=self.frame_length, center=False + y=signal, + sr=self.sample_rate, + n_fft=self.nfft, + hop_length=self.frame_step, + fmin=0.0, + fmax=int(self.sample_rate / 2), + win_length=self.frame_length, + center=False, ) pitches = pitches.T - assert self.num_feature_bins <= self.frame_length // 2 + 1, \ - "num_features for spectrogram should \ + assert ( + self.num_feature_bins <= self.frame_length // 2 + 1 + ), "num_features for spectrogram should \ be <= (sample_rate * window_size // 2 + 1)" - return pitches[:, :self.num_feature_bins] + return pitches[:, : self.num_feature_bins] - def compute_spectrogram(self, signal: np.ndarray) -> np.ndarray: + def compute_spectrogram( + self, + signal: np.ndarray, + ) -> np.ndarray: powspec = self.stft(signal) features = self.power_to_db(powspec.T) - assert self.num_feature_bins <= self.frame_length // 2 + 1, \ - "num_features for spectrogram should \ + assert ( + self.num_feature_bins <= self.frame_length // 2 + 1 + ), "num_features for spectrogram should \ be <= (sample_rate * window_size // 2 + 1)" # cut high frequency part, keep num_feature_bins features - features = features[:, :self.num_feature_bins] + features = features[:, : self.num_feature_bins] return features - def compute_mfcc(self, signal: np.ndarray) -> np.ndarray: + def compute_mfcc( + self, + signal: np.ndarray, + ) -> np.ndarray: S = self.stft(signal) - mel = librosa.filters.mel(self.sample_rate, self.nfft, - n_mels=self.num_feature_bins, - fmin=0.0, fmax=int(self.sample_rate / 2)) + mel = librosa.filters.mel( + self.sample_rate, self.nfft, n_mels=self.num_feature_bins, fmin=0.0, fmax=int(self.sample_rate / 2) + ) mel_spectrogram = np.dot(S.T, mel.T) - mfcc = librosa.feature.mfcc(sr=self.sample_rate, - S=self.power_to_db(mel_spectrogram).T, - n_mfcc=self.num_feature_bins) + mfcc = librosa.feature.mfcc(sr=self.sample_rate, S=self.power_to_db(mel_spectrogram).T, n_mfcc=self.num_feature_bins) return mfcc.T - def compute_log_mel_spectrogram(self, signal: np.ndarray) -> np.ndarray: + def compute_log_mel_spectrogram( + self, + signal: np.ndarray, + ) -> np.ndarray: S = self.stft(signal) - mel = librosa.filters.mel(self.sample_rate, self.nfft, - n_mels=self.num_feature_bins, - fmin=0.0, fmax=int(self.sample_rate / 2)) + mel = librosa.filters.mel( + self.sample_rate, self.nfft, n_mels=self.num_feature_bins, fmin=0.0, fmax=int(self.sample_rate / 2) + ) mel_spectrogram = np.dot(S.T, mel.T) return self.power_to_db(mel_spectrogram) - def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray: + def compute_log_gammatone_spectrogram( + self, + signal: np.ndarray, + ) -> np.ndarray: S = self.stft(signal) - gtone = gammatone.fft_weights(self.nfft, self.sample_rate, - self.num_feature_bins, width=1.0, - fmin=0, fmax=int(self.sample_rate / 2), - maxlen=(self.nfft / 2 + 1)) + gtone = gammatone.fft_weights( + self.nfft, + self.sample_rate, + self.num_feature_bins, + width=1.0, + fmin=0, + fmax=int(self.sample_rate / 2), + maxlen=(self.nfft / 2 + 1), + ) gtone = gtone.numpy().astype(np.float32) @@ -415,8 +510,12 @@ def shape(self) -> list: length = self.max_length if self.max_length > 0 else None return [length, self.num_feature_bins, 1] - def stft(self, signal): - if self.center: signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT") + def stft( + self, + signal, + ): + if self.center: + signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT") window = tf.signal.hann_window(self.frame_length, periodic=True) left_pad = (self.nfft - self.frame_length) // 2 right_pad = self.nfft - self.frame_length - left_pad @@ -425,23 +524,33 @@ def stft(self, signal): framed_signals *= window return tf.square(tf.abs(tf.signal.rfft(framed_signals, [self.nfft]))) - def power_to_db(self, S, amin=1e-10): + def power_to_db( + self, + S, + amin=1e-10, + ): log_spec = 10.0 * math_util.log10(tf.maximum(amin, S)) log_spec -= 10.0 * math_util.log10(tf.maximum(amin, 1.0)) if self.top_db is not None: if self.top_db < 0: - raise ValueError('top_db must be non-negative') + raise ValueError("top_db must be non-negative") log_spec = tf.maximum(log_spec, tf.reduce_max(log_spec) - self.top_db) return log_spec - def extract(self, signal: np.ndarray) -> np.ndarray: + def extract( + self, + signal: np.ndarray, + ) -> np.ndarray: signal = np.asfortranarray(signal) features = self.tf_extract(tf.convert_to_tensor(signal, dtype=tf.float32)) return features.numpy() - def tf_extract(self, signal: tf.Tensor) -> tf.Tensor: + def tf_extract( + self, + signal: tf.Tensor, + ) -> tf.Tensor: """ Extract speech features from signals (for using in tflite) Args: @@ -472,33 +581,51 @@ def tf_extract(self, signal: tf.Tensor) -> tf.Tensor: return features - def compute_log_mel_spectrogram(self, signal): + def compute_log_mel_spectrogram( + self, + signal, + ): spectrogram = self.stft(signal) linear_to_weight_matrix = tf.signal.linear_to_mel_weight_matrix( num_mel_bins=self.num_feature_bins, num_spectrogram_bins=spectrogram.shape[-1], sample_rate=self.sample_rate, - lower_edge_hertz=0.0, upper_edge_hertz=(self.sample_rate / 2) + lower_edge_hertz=0.0, + upper_edge_hertz=(self.sample_rate / 2), ) mel_spectrogram = tf.tensordot(spectrogram, linear_to_weight_matrix, 1) return self.power_to_db(mel_spectrogram) - def compute_spectrogram(self, signal): + def compute_spectrogram( + self, + signal, + ): S = self.stft(signal) spectrogram = self.power_to_db(S) - return spectrogram[:, :self.num_feature_bins] + return spectrogram[:, : self.num_feature_bins] - def compute_mfcc(self, signal): + def compute_mfcc( + self, + signal, + ): log_mel_spectrogram = self.compute_log_mel_spectrogram(signal) return tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrogram) - def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray: + def compute_log_gammatone_spectrogram( + self, + signal: np.ndarray, + ) -> np.ndarray: S = self.stft(signal) - gtone = gammatone.fft_weights(self.nfft, self.sample_rate, - self.num_feature_bins, width=1.0, - fmin=0, fmax=int(self.sample_rate / 2), - maxlen=(self.nfft / 2 + 1)) + gtone = gammatone.fft_weights( + self.nfft, + self.sample_rate, + self.num_feature_bins, + width=1.0, + fmin=0, + fmax=int(self.sample_rate / 2), + maxlen=(self.nfft / 2 + 1), + ) gtone_spectrogram = tf.tensordot(S, gtone, 1) diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py index 9936e666b8..bd993fb791 100755 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ b/tensorflow_asr/featurizers/text_featurizers.py @@ -12,25 +12,57 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import abc import codecs +import os import unicodedata from multiprocessing import cpu_count -import sentencepiece as sp + import numpy as np +import sentencepiece as sp import tensorflow as tf import tensorflow_datasets as tds from ..configs.config import DecoderConfig from ..utils import file_util -ENGLISH_CHARACTERS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", - "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"] +ENGLISH_CHARACTERS = [ + " ", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "'", +] class TextFeaturizer(metaclass=abc.ABCMeta): - def __init__(self, decoder_config: dict): + def __init__( + self, + decoder_config: dict, + ): self.scorer = None self.decoder_config = DecoderConfig(decoder_config) self.blank = None @@ -47,7 +79,10 @@ def shape(self) -> list: def prepand_shape(self) -> list: return [self.max_length + 1 if self.max_length > 0 else None] - def update_length(self, length: int): + def update_length( + self, + length: int, + ): self.max_length = max(self.max_length, length) def reset_length(self): @@ -57,11 +92,17 @@ def preprocess_text(self, text): text = unicodedata.normalize("NFC", text.lower()) return text.strip("\n") # remove trailing newline - def add_scorer(self, scorer: any = None): - """ Add scorer to this instance """ + def add_scorer( + self, + scorer: any = None, + ): + """Add scorer to this instance""" self.scorer = scorer - def normalize_indices(self, indices: tf.Tensor) -> tf.Tensor: + def normalize_indices( + self, + indices: tf.Tensor, + ) -> tf.Tensor: """ Remove -1 in indices by replacing them with blanks Args: @@ -75,8 +116,11 @@ def normalize_indices(self, indices: tf.Tensor) -> tf.Tensor: blank_like = self.blank * tf.ones_like(indices, dtype=tf.int32) return tf.where(indices == minus_one, blank_like, indices) - def prepand_blank(self, text: tf.Tensor) -> tf.Tensor: - """ Prepand blank index for transducer models """ + def prepand_blank( + self, + text: tf.Tensor, + ) -> tf.Tensor: + """Prepand blank index for transducer models""" return tf.concat([[self.blank], text], axis=0) @abc.abstractclassmethod @@ -99,7 +143,10 @@ class CharFeaturizer(TextFeaturizer): converted to a sequence of integer indexes. """ - def __init__(self, decoder_config: dict): + def __init__( + self, + decoder_config: dict, + ): """ decoder_config = { "vocabulary": str, @@ -126,18 +173,23 @@ def __init_vocabulary(self): index = 1 if self.blank == 0 else 0 for line in lines: line = self.preprocess_text(line) - if line.startswith("#") or not line: continue + if line.startswith("#") or not line: + continue self.tokens2indices[line[0]] = index self.tokens.append(line[0]) index += 1 - if self.blank is None: self.blank = len(self.tokens) # blank not at zero + if self.blank is None: + self.blank = len(self.tokens) # blank not at zero self.non_blank_tokens = self.tokens.copy() self.tokens.insert(self.blank, "") # add blank token to tokens self.num_classes = len(self.tokens) self.tokens = tf.convert_to_tensor(self.tokens, dtype=tf.string) self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8").to_tensor(shape=[None, 1]) - def extract(self, text: str) -> tf.Tensor: + def extract( + self, + text: str, + ) -> tf.Tensor: """ Convert string to a list of integers Args: @@ -151,7 +203,10 @@ def extract(self, text: str) -> tf.Tensor: indices = [self.tokens2indices[token] for token in text] return tf.convert_to_tensor(indices, dtype=tf.int32) - def iextract(self, indices: tf.Tensor) -> tf.Tensor: + def iextract( + self, + indices: tf.Tensor, + ) -> tf.Tensor: """ Convert list of indices to string Args: @@ -166,12 +221,11 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor: tokens = tf.strings.reduce_join(tokens, axis=-1) return tokens - @tf.function( - input_signature=[ - tf.TensorSpec([None], dtype=tf.int32) - ] - ) - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: + @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) + def indices2upoints( + self, + indices: tf.Tensor, + ) -> tf.Tensor: """ Transform Predicted Indices to Unicode Code Points (for using tflite) Args: @@ -193,7 +247,11 @@ class SubwordFeaturizer(TextFeaturizer): converted to a sequence of integer indexes. """ - def __init__(self, decoder_config: dict, subwords=None): + def __init__( + self, + decoder_config: dict, + subwords=None, + ): """ decoder_config = { "target_vocab_size": int, @@ -227,7 +285,11 @@ def __load_subwords(self): return tds.deprecated.text.SubwordTextEncoder.load_from_file(filename_prefix) @classmethod - def build_from_corpus(cls, decoder_config: dict, corpus_files: list = None): + def build_from_corpus( + cls, + decoder_config: dict, + corpus_files: list = None, + ): dconf = DecoderConfig(decoder_config.copy()) corpus_files = dconf.corpus_files if corpus_files is None or len(corpus_files) == 0 else corpus_files @@ -245,24 +307,34 @@ def corpus_generator(): dconf.target_vocab_size, dconf.max_subword_length, dconf.max_corpus_chars, - dconf.reserved_tokens + dconf.reserved_tokens, ) return cls(decoder_config, subwords) @classmethod - def load_from_file(cls, decoder_config: dict, filename: str = None): + def load_from_file( + cls, + decoder_config: dict, + filename: str = None, + ): dconf = DecoderConfig(decoder_config.copy()) filename = dconf.vocabulary if filename is None else file_util.preprocess_paths(filename) filename_prefix = os.path.splitext(filename)[0] subwords = tds.deprecated.text.SubwordTextEncoder.load_from_file(filename_prefix) return cls(decoder_config, subwords) - def save_to_file(self, filename: str = None): + def save_to_file( + self, + filename: str = None, + ): filename = self.decoder_config.vocabulary if filename is None else file_util.preprocess_paths(filename) filename_prefix = os.path.splitext(filename)[0] return self.subwords.save_to_file(filename_prefix) - def extract(self, text: str) -> tf.Tensor: + def extract( + self, + text: str, + ) -> tf.Tensor: """ Convert string to a list of integers Args: @@ -276,7 +348,10 @@ def extract(self, text: str) -> tf.Tensor: indices = self.subwords.encode(text) return tf.convert_to_tensor(indices, dtype=tf.int32) - def iextract(self, indices: tf.Tensor) -> tf.Tensor: + def iextract( + self, + indices: tf.Tensor, + ) -> tf.Tensor: """ Convert list of indices to string Args: @@ -289,11 +364,16 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor: total = tf.shape(indices)[0] batch = tf.constant(0, dtype=tf.int32) transcripts = tf.TensorArray( - dtype=tf.string, size=total, dynamic_size=False, infer_shape=False, - clear_after_read=False, element_shape=tf.TensorShape([]) + dtype=tf.string, + size=total, + dynamic_size=False, + infer_shape=False, + clear_after_read=False, + element_shape=tf.TensorShape([]), ) - def cond(batch, total, _): return tf.less(batch, total) + def cond(batch, total, _): + return tf.less(batch, total) def body(batch, total, transcripts): norm_indices = self.normalize_indices(indices[batch]) @@ -306,12 +386,11 @@ def body(batch, total, transcripts): return transcripts.stack() - @tf.function( - input_signature=[ - tf.TensorSpec([None], dtype=tf.int32) - ] - ) - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: + @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) + def indices2upoints( + self, + indices: tf.Tensor, + ) -> tf.Tensor: """ Transform Predicted Indices to Unicode Code Points (for using tflite) Args: @@ -330,12 +409,17 @@ class SentencePieceFeaturizer(TextFeaturizer): """ Extract text feature based on sentence piece package. """ + UNK_TOKEN, UNK_TOKEN_ID = "", 1 BOS_TOKEN, BOS_TOKEN_ID = "", 2 EOS_TOKEN, EOS_TOKEN_ID = "", 3 PAD_TOKEN, PAD_TOKEN_ID = "", 0 # unused, by default - def __init__(self, decoder_config: dict, model=None): + def __init__( + self, + decoder_config: dict, + model=None, + ): super(SentencePieceFeaturizer, self).__init__(decoder_config) self.model = self.__load_model() if model is None else model self.blank = 0 # treats blank as 0 (pad) @@ -359,7 +443,10 @@ def __init_vocabulary(self): self.upoints = self.upoints.to_tensor() # [num_classes, max_subword_length] @classmethod - def build_from_corpus(cls, decoder_config: dict): + def build_from_corpus( + cls, + decoder_config: dict, + ): """ --model_prefix: output model name prefix. .model and .vocab are generated. --vocab_size: vocabulary size, e.g., 8000, 16000, or 32000 @@ -387,7 +474,7 @@ def corpus_iterator(): bos_id=cls.BOS_TOKEN_ID, eos_id=cls.EOS_TOKEN_ID, pad_id=cls.PAD_TOKEN_ID, - unk_surface='__UNKNOWN__' # change default unk surface U+2047("⁇") by "__UNKNOWN__" + unk_surface="__UNKNOWN__", # change default unk surface U+2047("⁇") by "__UNKNOWN__" ) # Export fairseq dictionary processor = sp.SentencePieceProcessor() @@ -398,11 +485,7 @@ def corpus_iterator(): and vocab.get(cls.BOS_TOKEN_ID) == cls.BOS_TOKEN and vocab.get(cls.EOS_TOKEN_ID) == cls.EOS_TOKEN ) - vocab = { - i: s - for i, s in vocab.items() - if s not in {cls.UNK_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.PAD_TOKEN} - } + vocab = {i: s for i, s in vocab.items() if s not in {cls.UNK_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.PAD_TOKEN}} with open(decoder_cfg.output_path_prefix + ".txt", "w") as f_out: for _, s in sorted(vocab.items(), key=lambda x: x[0]): f_out.write(f"{s} 1\n") @@ -410,7 +493,11 @@ def corpus_iterator(): return cls(decoder_config, processor) @classmethod - def load_from_file(cls, decoder_config: dict, filename: str = None): + def load_from_file( + cls, + decoder_config: dict, + filename: str = None, + ): if filename is not None: filename_prefix = os.path.splitext(file_util.preprocess_paths(filename))[0] else: @@ -419,7 +506,10 @@ def load_from_file(cls, decoder_config: dict, filename: str = None): processor.load(filename_prefix + ".model") return cls(decoder_config, processor) - def extract(self, text: str) -> tf.Tensor: + def extract( + self, + text: str, + ) -> tf.Tensor: """ Convert string to a list of integers # encode: text => id @@ -436,7 +526,10 @@ def extract(self, text: str) -> tf.Tensor: indices = self.model.encode_as_ids(text) return tf.convert_to_tensor(indices, dtype=tf.int32) - def iextract(self, indices: tf.Tensor) -> tf.Tensor: + def iextract( + self, + indices: tf.Tensor, + ) -> tf.Tensor: """ Convert list of indices to string # decode: id => text @@ -451,23 +544,24 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor: """ indices = self.normalize_indices(indices) with tf.device("/CPU:0"): # string data is not supported on GPU + def decode(x): - if x[0] == self.blank: x = x[1:] + if x[0] == self.blank: + x = x[1:] return self.model.decode_ids(x.tolist()) text = tf.map_fn( lambda x: tf.numpy_function(decode, inp=[x], Tout=tf.string), indices, - fn_output_signature=tf.TensorSpec([], dtype=tf.string) + fn_output_signature=tf.TensorSpec([], dtype=tf.string), ) return text - @tf.function( - input_signature=[ - tf.TensorSpec([None], dtype=tf.int32) - ] - ) - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: + @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) + def indices2upoints( + self, + indices: tf.Tensor, + ) -> tf.Tensor: """ Transform Predicted Indices to Unicode Code Points (for using tflite) Args: diff --git a/tensorflow_asr/losses/ctc_loss.py b/tensorflow_asr/losses/ctc_loss.py index 89519a4e60..2774080526 100644 --- a/tensorflow_asr/losses/ctc_loss.py +++ b/tensorflow_asr/losses/ctc_loss.py @@ -15,25 +15,41 @@ class CtcLoss(tf.keras.losses.Loss): - def __init__(self, blank=0, global_batch_size=None, name=None): + def __init__( + self, + blank=0, + global_batch_size=None, + name=None, + ): super(CtcLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name) self.blank = blank self.global_batch_size = global_batch_size - def call(self, y_true, y_pred): + def call( + self, + y_true, + y_pred, + ): loss = ctc_loss( y_pred=y_pred["logits"], input_length=y_pred["logits_length"], y_true=y_true["labels"], label_length=y_true["labels_length"], blank=self.blank, - name=self.name + name=self.name, ) return tf.nn.compute_average_loss(loss, global_batch_size=self.global_batch_size) @tf.function -def ctc_loss(y_true, y_pred, input_length, label_length, blank, name=None): +def ctc_loss( + y_true, + y_pred, + input_length, + label_length, + blank, + name=None, +): return tf.nn.ctc_loss( labels=tf.cast(y_true, tf.int32), logit_length=tf.cast(input_length, tf.int32), @@ -41,5 +57,5 @@ def ctc_loss(y_true, y_pred, input_length, label_length, blank, name=None): label_length=tf.cast(label_length, tf.int32), logits_time_major=False, blank_index=blank, - name=name + name=name, ) diff --git a/tensorflow_asr/losses/rnnt_loss.py b/tensorflow_asr/losses/rnnt_loss.py index 035f7aadae..0b1305b2d9 100644 --- a/tensorflow_asr/losses/rnnt_loss.py +++ b/tensorflow_asr/losses/rnnt_loss.py @@ -15,6 +15,7 @@ import tensorflow as tf from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2 + from ..utils import env_util logger = tf.get_logger() @@ -23,6 +24,7 @@ try: from warprnnt_tensorflow import rnnt_loss as warp_rnnt_loss + use_warprnnt = True logger.info("Use RNNT loss in WarpRnnt") except ImportError: @@ -31,7 +33,12 @@ class RnntLoss(tf.keras.losses.Loss): - def __init__(self, blank=0, global_batch_size=None, name=None): + def __init__( + self, + blank=0, + global_batch_size=None, + name=None, + ): super(RnntLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name) self.blank = blank self.global_batch_size = global_batch_size @@ -43,21 +50,41 @@ def call(self, y_true, y_pred): labels=y_true["labels"], label_length=y_true["labels_length"], blank=self.blank, - name=self.name + name=self.name, ) return tf.nn.compute_average_loss(loss, global_batch_size=self.global_batch_size) @tf.function -def rnnt_loss(logits, labels, label_length, logit_length, blank=0, name=None): +def rnnt_loss( + logits, + labels, + label_length, + logit_length, + blank=0, + name=None, +): if use_warprnnt: - return rnnt_loss_warprnnt(logits=logits, labels=labels, - label_length=label_length, logit_length=logit_length, blank=blank) + return rnnt_loss_warprnnt( + logits=logits, labels=labels, label_length=label_length, logit_length=logit_length, blank=blank + ) else: - return rnnt_loss_tf(logits=logits, labels=labels, label_length=label_length, logit_length=logit_length, name=name) + return rnnt_loss_tf( + logits=logits, + labels=labels, + label_length=label_length, + logit_length=logit_length, + name=name, + ) -def rnnt_loss_warprnnt(logits, labels, label_length, logit_length, blank=0): +def rnnt_loss_warprnnt( + logits, + labels, + label_length, + logit_length, + blank=0, +): if not env_util.has_devices(["GPU", "TPU"]): logits = tf.nn.log_softmax(logits) loss = warp_rnnt_loss( @@ -65,35 +92,47 @@ def rnnt_loss_warprnnt(logits, labels, label_length, logit_length, blank=0): label_lengths=tf.cast(label_length, tf.int32), labels=tf.cast(labels, tf.int32), input_lengths=tf.cast(logit_length, tf.int32), - blank_label=blank + blank_label=blank, ) return loss -def nan_to_zero(input_tensor): +def nan_to_zero( + input_tensor, +): return tf.where(tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor) -def reduce_logsumexp(input_tensor, axis): +def reduce_logsumexp( + input_tensor, + axis, +): maximum = tf.reduce_max(input_tensor, axis=axis) input_tensor = nan_to_zero(input_tensor - maximum) return tf.math.log(tf.reduce_sum(tf.exp(input_tensor), axis=axis)) + maximum -def extract_diagonals(log_probs): +def extract_diagonals( + log_probs, +): time_steps = tf.shape(log_probs)[1] # T output_steps = tf.shape(log_probs)[2] # U + 1 reverse_log_probs = tf.reverse(log_probs, axis=[-1]) paddings = [[0, 0], [0, 0], [time_steps - 1, 0]] - padded_reverse_log_probs = tf.pad(reverse_log_probs, paddings, - 'CONSTANT', constant_values=LOG_0) - diagonals = matrix_diag_part_v2(padded_reverse_log_probs, k=(0, time_steps + output_steps - 2), - padding_value=LOG_0) + padded_reverse_log_probs = tf.pad(reverse_log_probs, paddings, "CONSTANT", constant_values=LOG_0) + diagonals = matrix_diag_part_v2( + padded_reverse_log_probs, + k=(0, time_steps + output_steps - 2), + padding_value=LOG_0, + ) return tf.transpose(diagonals, perm=[1, 0, 2]) -def transition_probs(one_hot_labels, log_probs): +def transition_probs( + one_hot_labels, + log_probs, +): """ :return: blank_probs with shape batch_size x input_max_len x target_max_len truth_probs with shape batch_size x input_max_len x (target_max_len-1) @@ -104,7 +143,13 @@ def transition_probs(one_hot_labels, log_probs): return blank_probs, truth_probs -def forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len): +def forward_dp( + bp_diags, + tp_diags, + batch_size, + input_max_len, + target_max_len, +): """ :return: forward variable alpha with shape batch_size x input_max_len x target_max_len """ @@ -120,7 +165,12 @@ def next_state(x, trans_probs): return x initial_alpha = tf.concat( - [tf.zeros(shape=[batch_size, 1]), tf.ones(shape=[batch_size, input_max_len - 1]) * LOG_0], axis=1) + [ + tf.zeros(shape=[batch_size, 1]), + tf.ones(shape=[batch_size, input_max_len - 1]) * LOG_0, + ], + axis=1, + ) fwd = tf.scan(next_state, (bp_diags[:-1, :, :-1], tp_diags), initializer=initial_alpha) @@ -131,9 +181,18 @@ def next_state(x, trans_probs): return alpha -def backward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len, label_length, logit_length, blank_sl): +def backward_dp( + bp_diags, + tp_diags, + batch_size, + input_max_len, + target_max_len, + label_length, + logit_length, + blank_sl, +): """ - :return: backward variable beta with shape batch_size x input_max_len x target_max_len + :return: backward variable beta with shape batch_size x input_max_len x target_max_len """ def next_state(x, mask_and_trans_probs): @@ -143,8 +202,9 @@ def next_state(x, mask_and_trans_probs): beta_t = tf.concat([x[:, :-1] + truth_probs, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1) beta_next = reduce_logsumexp(tf.stack([beta_b, beta_t], axis=0), axis=0) - masked_beta_next = \ - nan_to_zero(beta_next * tf.expand_dims(mask_s, axis=1)) + nan_to_zero(x * tf.expand_dims((1.0 - mask_s), axis=1)) + masked_beta_next = nan_to_zero(beta_next * tf.expand_dims(mask_s, axis=1)) + nan_to_zero( + x * tf.expand_dims((1.0 - mask_s), axis=1) + ) return tf.reshape(masked_beta_next, shape=tf.shape(x)) # Initial beta for batches. @@ -152,10 +212,19 @@ def next_state(x, mask_and_trans_probs): initial_beta = tf.expand_dims(blank_sl, axis=1) * initial_beta_mask + nan_to_zero(LOG_0 * (1.0 - initial_beta_mask)) # Mask for scan iterations. - mask = tf.sequence_mask(logit_length + label_length - 1, input_max_len + target_max_len - 2, dtype=tf.dtypes.float32) + mask = tf.sequence_mask( + logit_length + label_length - 1, + input_max_len + target_max_len - 2, + dtype=tf.dtypes.float32, + ) mask = tf.transpose(mask, perm=[1, 0]) - bwd = tf.scan(next_state, (mask, bp_diags[:-1, :, :], tp_diags), initializer=initial_beta, reverse=True) + bwd = tf.scan( + next_state, + (mask, bp_diags[:-1, :, :], tp_diags), + initializer=initial_beta, + reverse=True, + ) beta = tf.transpose(tf.concat([bwd, tf.expand_dims(initial_beta, axis=0)], axis=0), perm=[1, 2, 0])[:, :-1, :] beta = matrix_diag_part_v2(beta, k=(0, target_max_len - 1), padding_value=LOG_0) @@ -170,18 +239,26 @@ def compute_rnnt_loss_and_grad_helper(logits, labels, label_length, logit_length target_max_len = tf.shape(logits)[2] vocab_size = tf.shape(logits)[3] - one_hot_labels = tf.one_hot(tf.tile(tf.expand_dims(labels, axis=1), - multiples=[1, input_max_len, 1]), depth=vocab_size) + one_hot_labels = tf.one_hot( + tf.tile(tf.expand_dims(labels, axis=1), multiples=[1, input_max_len, 1]), + depth=vocab_size, + ) log_probs = tf.nn.log_softmax(logits) blank_probs, truth_probs = transition_probs(one_hot_labels, log_probs) bp_diags = extract_diagonals(blank_probs) tp_diags = extract_diagonals(truth_probs) - label_mask = tf.expand_dims(tf.sequence_mask(label_length + 1, maxlen=target_max_len, dtype=tf.float32), axis=1) + label_mask = tf.expand_dims( + tf.sequence_mask(label_length + 1, maxlen=target_max_len, dtype=tf.float32), + axis=1, + ) small_label_mask = tf.expand_dims(tf.sequence_mask(label_length, maxlen=target_max_len, dtype=tf.float32), axis=1) input_mask = tf.expand_dims(tf.sequence_mask(logit_length, maxlen=input_max_len, dtype=tf.float32), axis=2) - small_input_mask = tf.expand_dims(tf.sequence_mask(logit_length - 1, maxlen=input_max_len, dtype=tf.float32), axis=2) + small_input_mask = tf.expand_dims( + tf.sequence_mask(logit_length - 1, maxlen=input_max_len, dtype=tf.float32), + axis=2, + ) mask = label_mask * input_mask grad_blank_mask = (label_mask * small_input_mask)[:, :-1, :] grad_truth_mask = (small_label_mask * input_mask)[:, :, :-1] @@ -191,65 +268,111 @@ def compute_rnnt_loss_and_grad_helper(logits, labels, label_length, logit_length indices = tf.stack([logit_length - 1, label_length], axis=1) blank_sl = tf.gather_nd(blank_probs, indices, batch_dims=1) - beta = backward_dp(bp_diags, tp_diags, batch_size, input_max_len, - target_max_len, label_length, logit_length, blank_sl) * mask + beta = ( + backward_dp( + bp_diags, + tp_diags, + batch_size, + input_max_len, + target_max_len, + label_length, + logit_length, + blank_sl, + ) + * mask + ) beta = tf.where(tf.math.is_nan(beta), tf.zeros_like(beta), beta) final_state_probs = beta[:, 0, 0] # Compute gradients of loss w.r.t. blank log-probabilities. - grads_blank = -tf.exp( - ( - alpha[:, :-1, :] + beta[:, 1:, :] - - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) - + blank_probs[:, :-1, :] - ) * grad_blank_mask - ) * grad_blank_mask + grads_blank = ( + -tf.exp( + ( + alpha[:, :-1, :] + + beta[:, 1:, :] + - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) + + blank_probs[:, :-1, :] + ) + * grad_blank_mask + ) + * grad_blank_mask + ) grads_blank = tf.concat([grads_blank, tf.zeros(shape=(batch_size, 1, target_max_len))], axis=1) last_grads_blank = -1 * tf.scatter_nd( - tf.concat([tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=[batch_size, 1]), - tf.cast(indices, dtype=tf.int64)], axis=1), + tf.concat( + [ + tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=[batch_size, 1]), + tf.cast(indices, dtype=tf.int64), + ], + axis=1, + ), tf.ones(batch_size, dtype=tf.float32), - [batch_size, input_max_len, target_max_len] + [batch_size, input_max_len, target_max_len], ) grads_blank = grads_blank + last_grads_blank # Compute gradients of loss w.r.t. truth log-probabilities. - grads_truth = -tf.exp( - ( - alpha[:, :, :-1] + beta[:, :, 1:] - - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) - + truth_probs + grads_truth = ( + -tf.exp( + (alpha[:, :, :-1] + beta[:, :, 1:] - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) + truth_probs) + * grad_truth_mask ) * grad_truth_mask - ) * grad_truth_mask + ) # Compute gradients of loss w.r.t. activations. - a = tf.tile(tf.reshape(tf.range(target_max_len - 1, dtype=tf.int64), shape=(1, 1, target_max_len - 1, 1)), - multiples=[batch_size, 1, 1, 1]) - b = tf.cast(tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1)), dtype=tf.int64) + a = tf.tile( + tf.reshape( + tf.range(target_max_len - 1, dtype=tf.int64), + shape=(1, 1, target_max_len - 1, 1), + ), + multiples=[batch_size, 1, 1, 1], + ) + b = tf.cast( + tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1)), + dtype=tf.int64, + ) if not env_util.has_devices(["GPU", "TPU"]): b = tf.where(tf.equal(b, -1), tf.zeros_like(b), b) # for cpu testing (index -1 on cpu will raise errors) c = tf.concat([a, b], axis=3) d = tf.tile(c, multiples=(1, input_max_len, 1, 1)) - e = tf.tile(tf.reshape(tf.range(input_max_len, dtype=tf.int64), shape=(1, input_max_len, 1, 1)), - multiples=(batch_size, 1, target_max_len - 1, 1)) + e = tf.tile( + tf.reshape(tf.range(input_max_len, dtype=tf.int64), shape=(1, input_max_len, 1, 1)), + multiples=(batch_size, 1, target_max_len - 1, 1), + ) f = tf.concat([e, d], axis=3) - g = tf.tile(tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=(batch_size, 1, 1, 1)), - multiples=[1, input_max_len, target_max_len - 1, 1]) + g = tf.tile( + tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=(batch_size, 1, 1, 1)), + multiples=[1, input_max_len, target_max_len - 1, 1], + ) scatter_idx = tf.concat([g, f], axis=3) # TODO - improve the part of code for scatter_idx computation. probs = tf.exp(log_probs) - grads_truth_scatter = tf.scatter_nd(scatter_idx, grads_truth, - [batch_size, input_max_len, target_max_len, vocab_size - 1]) - grads = tf.concat([tf.reshape(grads_blank, shape=(batch_size, input_max_len, target_max_len, -1)), - grads_truth_scatter], axis=3) + grads_truth_scatter = tf.scatter_nd( + scatter_idx, + grads_truth, + [batch_size, input_max_len, target_max_len, vocab_size - 1], + ) + grads = tf.concat( + [ + tf.reshape(grads_blank, shape=(batch_size, input_max_len, target_max_len, -1)), + grads_truth_scatter, + ], + axis=3, + ) grads_logits = grads - probs * (tf.reduce_sum(grads, axis=3, keepdims=True)) loss = -final_state_probs return loss, grads_logits -def rnnt_loss_tf(logits, labels, label_length, logit_length, name=None): +def rnnt_loss_tf( + logits, + labels, + label_length, + logit_length, + name=None, +): name = "rnnt_loss" if name is None else name with tf.name_scope(name): logits = tf.convert_to_tensor(logits, name="logits") @@ -266,7 +389,12 @@ def compute_rnnt_loss_and_grad(logits_t, labels_t, label_length_t, logit_length_ labels_t.set_shape(labels.shape) label_length_t.set_shape(label_length.shape) logit_length_t.set_shape(logit_length.shape) - kwargs = dict(logits=logits_t, labels=labels_t, label_length=label_length_t, logit_length=logit_length_t) + kwargs = dict( + logits=logits_t, + labels=labels_t, + label_length=label_length_t, + logit_length=logit_length_t, + ) result = compute_rnnt_loss_and_grad_helper(**kwargs) def grad(grad_loss): diff --git a/tensorflow_asr/metrics/error_rates.py b/tensorflow_asr/metrics/error_rates.py index 2d6880e35e..fd06385e18 100644 --- a/tensorflow_asr/metrics/error_rates.py +++ b/tensorflow_asr/metrics/error_rates.py @@ -16,15 +16,24 @@ class ErrorRate(tf.keras.metrics.Metric): - """ Metric for WER or CER """ + """Metric for WER or CER""" - def __init__(self, func, name="error_rate", **kwargs): + def __init__( + self, + func, + name="error_rate", + **kwargs, + ): super(ErrorRate, self).__init__(name=name, **kwargs) self.numerator = self.add_weight(name=f"{name}_numerator", initializer="zeros") self.denominator = self.add_weight(name=f"{name}_denominator", initializer="zeros") self.func = func - def update_state(self, decode: tf.Tensor, target: tf.Tensor): + def update_state( + self, + decode: tf.Tensor, + target: tf.Tensor, + ): n, d = self.func(decode, target) self.numerator.assign_add(n) self.denominator.assign_add(d) diff --git a/tensorflow_asr/models/activations/glu.py b/tensorflow_asr/models/activations/glu.py index 579aecee06..bc2fec47c0 100644 --- a/tensorflow_asr/models/activations/glu.py +++ b/tensorflow_asr/models/activations/glu.py @@ -16,14 +16,20 @@ class GLU(tf.keras.layers.Layer): - def __init__(self, - axis=-1, - name="glu_activation", - **kwargs): + def __init__( + self, + axis=-1, + name="glu_activation", + **kwargs, + ): super(GLU, self).__init__(name=name, **kwargs) self.axis = axis - def call(self, inputs, **kwargs): + def call( + self, + inputs, + **kwargs, + ): a, b = tf.split(inputs, 2, axis=self.axis) b = tf.nn.sigmoid(b) return tf.multiply(a, b) diff --git a/tensorflow_asr/models/base_model.py b/tensorflow_asr/models/base_model.py index ccc6da39bb..c9a0c07f54 100644 --- a/tensorflow_asr/models/base_model.py +++ b/tensorflow_asr/models/base_model.py @@ -15,23 +15,20 @@ import tensorflow as tf from tensorflow.keras import mixed_precision as mxp -from ..utils import file_util, env_util +from ..utils import env_util, file_util class BaseModel(tf.keras.Model): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._metrics = {} - self.use_loss_scale = False - - def save(self, - filepath, - overwrite=True, - include_optimizer=True, - save_format=None, - signatures=None, - options=None, - save_traces=True): + def save( + self, + filepath, + overwrite=True, + include_optimizer=True, + save_format=None, + signatures=None, + options=None, + save_traces=True, + ): with file_util.save_file(filepath) as path: super().save( filepath=path, @@ -40,52 +37,54 @@ def save(self, save_format=save_format, signatures=signatures, options=options, - save_traces=save_traces + save_traces=save_traces, ) - def save_weights(self, - filepath, - overwrite=True, - save_format=None, - options=None): + def save_weights( + self, + filepath, + overwrite=True, + save_format=None, + options=None, + ): with file_util.save_file(filepath) as path: - super().save_weights( - filepath=path, - overwrite=overwrite, - save_format=save_format, - options=options - ) - - def load_weights(self, - filepath, - by_name=False, - skip_mismatch=False, - options=None): + super().save_weights(filepath=path, overwrite=overwrite, save_format=save_format, options=options) + + def load_weights( + self, + filepath, + by_name=False, + skip_mismatch=False, + options=None, + ): with file_util.read_file(filepath) as path: - super().load_weights( - filepath=path, - by_name=by_name, - skip_mismatch=skip_mismatch, - options=options - ) + super().load_weights(filepath=path, by_name=by_name, skip_mismatch=skip_mismatch, options=options) - @property - def metrics(self): - return self._metrics.values() - - def add_metric(self, metric: tf.keras.metrics.Metric): - self._metrics.append({metric.name: metric}) + def add_metric( + self, + metric: tf.keras.metrics.Metric, + ): + if not hasattr(self, "_metrics"): + self._metrics = {} + self._metrics[metric.name] = metric def make(self, *args, **kwargs): - """ Custom function for building model (uses self.build so cannot overwrite that function) """ + """Custom function for building model (uses self.build so cannot overwrite that function)""" raise NotImplementedError() - def compile(self, loss, optimizer, run_eagerly=None, **kwargs): + def compile( + self, + loss, + optimizer, + run_eagerly=None, + **kwargs, + ): + self.use_loss_scale = False if not env_util.has_devices("TPU"): optimizer = mxp.experimental.LossScaleOptimizer(tf.keras.optimizers.get(optimizer), "dynamic") self.use_loss_scale = True loss_metric = tf.keras.metrics.Mean(name="loss", dtype=tf.float32) - self._metrics = {loss_metric.name: loss_metric} + self.add_metric(loss_metric) super().compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs) # -------------------------------- STEP FUNCTIONS ------------------------------------- @@ -112,7 +111,7 @@ def train_step(self, batch): gradients = tape.gradient(loss, self.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) self._metrics["loss"].update_state(loss) - return {m.name: m.result() for m in self.metrics} + return {m.name: m.result() for m in self._metrics.values()} def test_step(self, batch): """ @@ -127,7 +126,7 @@ def test_step(self, batch): y_pred = self(inputs, training=False) loss = self.loss(y_true, y_pred) self._metrics["loss"].update_state(loss) - return {m.name: m.result() for m in self.metrics} + return {m.name: m.result() for m in self._metrics.values()} def predict_step(self, batch): """ @@ -149,9 +148,9 @@ def predict_step(self, batch): # -------------------------------- INFERENCE FUNCTIONS ------------------------------------- def recognize(self, *args, **kwargs): - """ Greedy decoding function that used in self.predict_step """ + """Greedy decoding function that used in self.predict_step""" raise NotImplementedError() def recognize_beam(self, *args, **kwargs): - """ Beam search decoding function that used in self.predict_step """ + """Beam search decoding function that used in self.predict_step""" raise NotImplementedError() diff --git a/tensorflow_asr/models/ctc/ctc.py b/tensorflow_asr/models/ctc/ctc.py index d7dcf5dd5e..f1c6cdacaf 100644 --- a/tensorflow_asr/models/ctc/ctc.py +++ b/tensorflow_asr/models/ctc/ctc.py @@ -13,82 +13,112 @@ # limitations under the License. from typing import Dict, Union + import numpy as np import tensorflow as tf -from ..base_model import BaseModel from ...featurizers.speech_featurizers import TFSpeechFeaturizer from ...featurizers.text_featurizers import TextFeaturizer -from ...utils import math_util, shape_util, data_util from ...losses.ctc_loss import CtcLoss +from ...utils import data_util, math_util, shape_util +from ..base_model import BaseModel class CtcModel(BaseModel): - def __init__(self, - encoder: tf.keras.Model, - decoder: Union[tf.keras.Model, tf.keras.layers.Layer] = None, - vocabulary_size: int = None, - **kwargs): + def __init__( + self, + encoder: tf.keras.Model, + decoder: Union[tf.keras.Model, tf.keras.layers.Layer] = None, + vocabulary_size: int = None, + **kwargs, + ): super().__init__(**kwargs) self.encoder = encoder if decoder is None: assert vocabulary_size is not None, "vocabulary_size must be set" - self.decoder = tf.keras.layers.Dense(units=vocabulary_size, name=f"{self.name}_logits") + self.decoder = tf.keras.layers.Dense( + units=vocabulary_size, + name=f"{self.name}_logits", + ) else: self.decoder = decoder self.time_reduction_factor = 1 - def make(self, input_shape, batch_size=None): + def make( + self, + input_shape, + batch_size=None, + ): inputs = tf.keras.Input(input_shape, batch_size=batch_size, dtype=tf.float32) inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) self( data_util.create_inputs( inputs=inputs, - inputs_length=inputs_length + inputs_length=inputs_length, ), - training=False + training=False, ) - def compile(self, - optimizer, - global_batch_size, - blank=0, - run_eagerly=None, - **kwargs): + def compile( + self, + optimizer, + global_batch_size, + blank=0, + run_eagerly=None, + **kwargs, + ): loss = CtcLoss(blank=blank, global_batch_size=global_batch_size) super().compile(loss=loss, optimizer=optimizer, run_eagerly=run_eagerly, **kwargs) - def add_featurizers(self, - speech_featurizer: TFSpeechFeaturizer, - text_featurizer: TextFeaturizer): + def add_featurizers( + self, + speech_featurizer: TFSpeechFeaturizer, + text_featurizer: TextFeaturizer, + ): self.speech_featurizer = speech_featurizer self.text_featurizer = text_featurizer - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): logits = self.encoder(inputs["inputs"], training=training, **kwargs) logits = self.decoder(logits, training=training, **kwargs) return data_util.create_logits( logits=logits, - logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) + logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor), ) # -------------------------------- GREEDY ------------------------------------- @tf.function - def recognize(self, inputs: Dict[str, tf.Tensor]): + def recognize( + self, + inputs: Dict[str, tf.Tensor], + ): logits = self(inputs, training=False) probs = tf.nn.softmax(logits["logits"]) - def map_fn(prob): return tf.numpy_function(self._perform_greedy, inp=[prob], Tout=tf.string) + def map_fn(prob): + return tf.numpy_function(self._perform_greedy, inp=[prob], Tout=tf.string) return tf.map_fn(map_fn, probs, fn_output_signature=tf.TensorSpec([], dtype=tf.string)) - def _perform_greedy(self, probs: np.ndarray): + def _perform_greedy( + self, + probs: np.ndarray, + ): from ctc_decoders import ctc_greedy_decoder + decoded = ctc_greedy_decoder(probs, vocabulary=self.text_featurizer.non_blank_tokens) return tf.convert_to_tensor(decoded, dtype=tf.string) - def recognize_tflite(self, signal): + def recognize_tflite( + self, + signal, + ): """ Function to convert to tflite using greedy decoding Args: @@ -105,9 +135,7 @@ def recognize_tflite(self, signal): logits = self.encoder(features, training=False) logits = self.decoder(logits, training=False) probs = tf.nn.softmax(logits) - decoded = tf.keras.backend.ctc_decode( - y_pred=probs, input_length=input_length, greedy=True - ) + decoded = tf.keras.backend.ctc_decode(y_pred=probs, input_length=input_length, greedy=True) decoded = tf.cast(decoded[0][0][0], dtype=tf.int32) transcript = self.text_featurizer.indices2upoints(decoded) return transcript @@ -115,27 +143,40 @@ def recognize_tflite(self, signal): # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function - def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): + def recognize_beam( + self, + inputs: Dict[str, tf.Tensor], + lm: bool = False, + ): logits = self(inputs, training=False) probs = tf.nn.softmax(logits["logits"]) - def map_fn(prob): return tf.numpy_function(self._perform_beam_search, inp=[prob, lm], Tout=tf.string) + def map_fn(prob): + return tf.numpy_function(self._perform_beam_search, inp=[prob, lm], Tout=tf.string) return tf.map_fn(map_fn, probs, dtype=tf.string) - def _perform_beam_search(self, probs: np.ndarray, lm: bool = False): + def _perform_beam_search( + self, + probs: np.ndarray, + lm: bool = False, + ): from ctc_decoders import ctc_beam_search_decoder + decoded = ctc_beam_search_decoder( probs_seq=probs, vocabulary=self.text_featurizer.non_blank_tokens, beam_size=self.text_featurizer.decoder_config.beam_width, - ext_scoring_func=self.text_featurizer.scorer if lm else None + ext_scoring_func=self.text_featurizer.scorer if lm else None, ) decoded = decoded[0][-1] return tf.convert_to_tensor(decoded, dtype=tf.string) - def recognize_beam_tflite(self, signal): + def recognize_beam_tflite( + self, + signal, + ): """ Function to convert to tflite using beam search decoding Args: @@ -153,8 +194,10 @@ def recognize_beam_tflite(self, signal): logits = self.decoder(logits, training=False) probs = tf.nn.softmax(logits) decoded = tf.keras.backend.ctc_decode( - y_pred=probs, input_length=input_length, greedy=False, - beam_width=self.text_featurizer.decoder_config.beam_width + y_pred=probs, + input_length=input_length, + greedy=False, + beam_width=self.text_featurizer.decoder_config.beam_width, ) decoded = tf.cast(decoded[0][0][0], dtype=tf.int32) transcript = self.text_featurizer.indices2upoints(decoded) @@ -162,17 +205,16 @@ def recognize_beam_tflite(self, signal): # -------------------------------- TFLITE ------------------------------------- - def make_tflite_function(self, greedy: bool = False): + def make_tflite_function( + self, + greedy: bool = False, + ): if greedy: return tf.function( self.recognize_tflite, - input_signature=[ - tf.TensorSpec([None], dtype=tf.float32) - ] + input_signature=[tf.TensorSpec([None], dtype=tf.float32)], ) return tf.function( self.recognize_beam_tflite, - input_signature=[ - tf.TensorSpec([None], dtype=tf.float32) - ] + input_signature=[tf.TensorSpec([None], dtype=tf.float32)], ) diff --git a/tensorflow_asr/models/ctc/deepspeech2.py b/tensorflow_asr/models/ctc/deepspeech2.py index c8788cbf05..de03a09e07 100644 --- a/tensorflow_asr/models/ctc/deepspeech2.py +++ b/tensorflow_asr/models/ctc/deepspeech2.py @@ -21,28 +21,44 @@ class Reshape(tf.keras.layers.Layer): - def call(self, inputs): return math_util.merge_two_last_dims(inputs) + def call( + self, + inputs, + ): + return math_util.merge_two_last_dims(inputs) class ConvBlock(tf.keras.layers.Layer): - def __init__(self, - conv_type: str = "conv2d", - kernels: list = [11, 41], - strides: list = [2, 2], - filters: int = 32, - dropout: float = 0.1, - **kwargs): + def __init__( + self, + conv_type: str = "conv2d", + kernels: list = [11, 41], + strides: list = [2, 2], + filters: int = 32, + dropout: float = 0.1, + **kwargs, + ): super(ConvBlock, self).__init__(**kwargs) CNN = layer_util.get_conv(conv_type) - self.conv = CNN(filters=filters, kernel_size=kernels, - strides=strides, padding="same", - dtype=tf.float32, name=f"{self.name}_{conv_type}") + self.conv = CNN( + filters=filters, + kernel_size=kernels, + strides=strides, + padding="same", + dtype=tf.float32, + name=f"{self.name}_{conv_type}", + ) self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn") self.relu = tf.keras.layers.ReLU(name=f"{self.name}_relu") self.do = tf.keras.layers.Dropout(dropout, name=f"{self.name}_dropout") - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.conv(inputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.relu(outputs, training=training) @@ -59,20 +75,23 @@ def get_config(self): class ConvModule(tf.keras.Model): - def __init__(self, - conv_type: str = "conv2d", - kernels: list = [[11, 41], [11, 21], [11, 21]], - strides: list = [[2, 2], [1, 2], [1, 2]], - filters: list = [32, 32, 96], - dropout: float = 0.1, - **kwargs): + def __init__( + self, + conv_type: str = "conv2d", + kernels: list = [[11, 41], [11, 21], [11, 21]], + strides: list = [[2, 2], [1, 2], [1, 2]], + filters: list = [32, 32, 96], + dropout: float = 0.1, + **kwargs, + ): super(ConvModule, self).__init__(**kwargs) assert len(kernels) == len(strides) == len(filters) assert dropout >= 0.0 self.preprocess = None # reshape from [B, T, F, C] to [B, T, F * C] - if conv_type == "conv1d": self.preprocess = Reshape(name=f"{self.name}_preprocess") + if conv_type == "conv1d": + self.preprocess = Reshape(name=f"{self.name}_preprocess") self.blocks = [ ConvBlock( @@ -81,22 +100,32 @@ def __init__(self, strides=strides[i], filters=filters[i], dropout=dropout, - name=f"{self.name}_block_{i}" - ) for i in range(len(filters)) + name=f"{self.name}_block_{i}", + ) + for i in range(len(filters)) ] self.postprocess = None # reshape from [B, T, F, C] to [B, T, F * C] - if conv_type == "conv2d": self.postprocess = Reshape(name=f"{self.name}_postprocess") + if conv_type == "conv2d": + self.postprocess = Reshape(name=f"{self.name}_postprocess") self.reduction_factor = 1 - for s in strides: self.reduction_factor *= s[0] - - def call(self, inputs, training=False, **kwargs): + for s in strides: + self.reduction_factor *= s[0] + + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = inputs - if self.preprocess is not None: outputs = self.preprocess(outputs) + if self.preprocess is not None: + outputs = self.preprocess(outputs) for block in self.blocks: outputs = block(outputs, training=training, **kwargs) - if self.postprocess is not None: outputs = self.postprocess(outputs) + if self.postprocess is not None: + outputs = self.postprocess(outputs) return outputs def get_config(self): @@ -109,27 +138,42 @@ def get_config(self): class RnnBlock(tf.keras.layers.Layer): - def __init__(self, - rnn_type: str = "lstm", - units: int = 1024, - bidirectional: bool = True, - rowconv: int = 0, - dropout: float = 0.1, - **kwargs): + def __init__( + self, + rnn_type: str = "lstm", + units: int = 1024, + bidirectional: bool = True, + rowconv: int = 0, + dropout: float = 0.1, + **kwargs, + ): super(RnnBlock, self).__init__(**kwargs) RNN = layer_util.get_rnn(rnn_type) - self.rnn = RNN(units, dropout=dropout, return_sequences=True, - use_bias=True, name=f"{self.name}_{rnn_type}") + self.rnn = RNN( + units, + dropout=dropout, + return_sequences=True, + use_bias=True, + name=f"{self.name}_{rnn_type}", + ) if bidirectional: self.rnn = tf.keras.layers.Bidirectional(self.rnn, name=f"{self.name}_b{rnn_type}") self.bn = SequenceBatchNorm(time_major=False, name=f"{self.name}_bn") self.rowconv = None if not bidirectional and rowconv > 0: - self.rowconv = RowConv1D(filters=units, future_context=rowconv, - name=f"{self.name}_rowconv") - - def call(self, inputs, training=False, **kwargs): + self.rowconv = RowConv1D( + filters=units, + future_context=rowconv, + name=f"{self.name}_rowconv", + ) + + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.rnn(inputs, training=training) outputs = self.bn(outputs, training=training) if self.rowconv is not None: @@ -146,14 +190,16 @@ def get_config(self): class RnnModule(tf.keras.Model): - def __init__(self, - nlayers: int = 5, - rnn_type: str = "lstm", - units: int = 1024, - bidirectional: bool = True, - rowconv: int = 0, - dropout: float = 0.1, - **kwargs): + def __init__( + self, + nlayers: int = 5, + rnn_type: str = "lstm", + units: int = 1024, + bidirectional: bool = True, + rowconv: int = 0, + dropout: float = 0.1, + **kwargs, + ): super(RnnModule, self).__init__(**kwargs) self.blocks = [ @@ -163,11 +209,17 @@ def __init__(self, bidirectional=bidirectional, rowconv=rowconv, dropout=dropout, - name=f"{self.name}_block_{i}" - ) for i in range(nlayers) + name=f"{self.name}_block_{i}", + ) + for i in range(nlayers) ] - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = inputs for block in self.blocks: outputs = block(outputs, training=training, **kwargs) @@ -181,10 +233,12 @@ def get_config(self): class FcBlock(tf.keras.layers.Layer): - def __init__(self, - units: int = 1024, - dropout: float = 0.1, - **kwargs): + def __init__( + self, + units: int = 1024, + dropout: float = 0.1, + **kwargs, + ): super(FcBlock, self).__init__(**kwargs) self.fc = tf.keras.layers.Dense(units, name=f"{self.name}_fc") @@ -192,7 +246,12 @@ def __init__(self, self.relu = tf.keras.layers.ReLU(name=f"{self.name}_relu") self.do = tf.keras.layers.Dropout(dropout, name=f"{self.name}_dropout") - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.fc(inputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.relu(outputs, training=training) @@ -209,22 +268,23 @@ def get_config(self): class FcModule(tf.keras.Model): - def __init__(self, - nlayers: int = 0, - units: int = 1024, - dropout: float = 0.1, - **kwargs): + def __init__( + self, + nlayers: int = 0, + units: int = 1024, + dropout: float = 0.1, + **kwargs, + ): super(FcModule, self).__init__(**kwargs) - self.blocks = [ - FcBlock( - units=units, - dropout=dropout, - name=f"{self.name}_block_{i}" - ) for i in range(nlayers) - ] + self.blocks = [FcBlock(units=units, dropout=dropout, name=f"{self.name}_block_{i}") for i in range(nlayers)] - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = inputs for block in self.blocks: outputs = block(outputs, training=training, **kwargs) @@ -238,23 +298,25 @@ def get_config(self): class DeepSpeech2Encoder(tf.keras.Model): - def __init__(self, - conv_type: str = "conv2d", - conv_kernels: list = [[11, 41], [11, 21], [11, 21]], - conv_strides: list = [[2, 2], [1, 2], [1, 2]], - conv_filters: list = [32, 32, 96], - conv_dropout: float = 0.1, - rnn_nlayers: int = 5, - rnn_type: str = "lstm", - rnn_units: int = 1024, - rnn_bidirectional: bool = True, - rnn_rowconv: int = 0, - rnn_dropout: float = 0.1, - fc_nlayers: int = 0, - fc_units: int = 1024, - fc_dropout: float = 0.1, - name="deepspeech2_encoder", - **kwargs): + def __init__( + self, + conv_type: str = "conv2d", + conv_kernels: list = [[11, 41], [11, 21], [11, 21]], + conv_strides: list = [[2, 2], [1, 2], [1, 2]], + conv_filters: list = [32, 32, 96], + conv_dropout: float = 0.1, + rnn_nlayers: int = 5, + rnn_type: str = "lstm", + rnn_units: int = 1024, + rnn_bidirectional: bool = True, + rnn_rowconv: int = 0, + rnn_dropout: float = 0.1, + fc_nlayers: int = 0, + fc_units: int = 1024, + fc_dropout: float = 0.1, + name="deepspeech2_encoder", + **kwargs, + ): super().__init__(**kwargs) self.conv_module = ConvModule( @@ -263,7 +325,7 @@ def __init__(self, strides=conv_strides, filters=conv_filters, dropout=conv_dropout, - name=f"{self.name}_conv_module" + name=f"{self.name}_conv_module", ) self.rnn_module = RnnModule( @@ -273,23 +335,32 @@ def __init__(self, bidirectional=rnn_bidirectional, rowconv=rnn_rowconv, dropout=rnn_dropout, - name=f"{self.name}_rnn_module" + name=f"{self.name}_rnn_module", ) self.fc_module = FcModule( nlayers=fc_nlayers, units=fc_units, dropout=fc_dropout, - name=f"{self.name}_fc_module" + name=f"{self.name}_fc_module", ) - def summary(self, line_length=100, **kwargs): + def summary( + self, + line_length=100, + **kwargs, + ): self.conv_module.summary(line_length=line_length, **kwargs) self.rnn_module.summary(line_length=line_length, **kwargs) self.fc_module.summary(line_length=line_length, **kwargs) super().summary(line_length=line_length, **kwargs) - def call(self, inputs, training, **kwargs): + def call( + self, + inputs, + training, + **kwargs, + ): outputs = self.conv_module(inputs, training=training, **kwargs) outputs = self.rnn_module(outputs, training=training, **kwargs) outputs = self.fc_module(outputs, training=training, **kwargs) @@ -304,24 +375,26 @@ def get_config(self): class DeepSpeech2(CtcModel): - def __init__(self, - vocabulary_size: int, - conv_type: str = "conv2d", - conv_kernels: list = [[11, 41], [11, 21], [11, 21]], - conv_strides: list = [[2, 2], [1, 2], [1, 2]], - conv_filters: list = [32, 32, 96], - conv_dropout: float = 0.1, - rnn_nlayers: int = 5, - rnn_type: str = "lstm", - rnn_units: int = 1024, - rnn_bidirectional: bool = True, - rnn_rowconv: int = 0, - rnn_dropout: float = 0.1, - fc_nlayers: int = 0, - fc_units: int = 1024, - fc_dropout: float = 0.1, - name: str = "deepspeech2", - **kwargs): + def __init__( + self, + vocabulary_size: int, + conv_type: str = "conv2d", + conv_kernels: list = [[11, 41], [11, 21], [11, 21]], + conv_strides: list = [[2, 2], [1, 2], [1, 2]], + conv_filters: list = [32, 32, 96], + conv_dropout: float = 0.1, + rnn_nlayers: int = 5, + rnn_type: str = "lstm", + rnn_units: int = 1024, + rnn_bidirectional: bool = True, + rnn_rowconv: int = 0, + rnn_dropout: float = 0.1, + fc_nlayers: int = 0, + fc_units: int = 1024, + fc_dropout: float = 0.1, + name: str = "deepspeech2", + **kwargs, + ): super().__init__( encoder=DeepSpeech2Encoder( conv_type=conv_type, @@ -338,10 +411,10 @@ def __init__(self, fc_nlayers=fc_nlayers, fc_units=fc_units, fc_dropout=fc_dropout, - name=f"{name}_encoder" + name=f"{name}_encoder", ), vocabulary_size=vocabulary_size, name=name, - **kwargs + **kwargs, ) self.time_reduction_factor = self.encoder.conv_module.reduction_factor diff --git a/tensorflow_asr/models/ctc/jasper.py b/tensorflow_asr/models/ctc/jasper.py index 23b47ed063..b363f4d0cc 100644 --- a/tensorflow_asr/models/ctc/jasper.py +++ b/tensorflow_asr/models/ctc/jasper.py @@ -19,33 +19,44 @@ class Reshape(tf.keras.layers.Layer): - def call(self, inputs): return math_util.merge_two_last_dims(inputs) + def call(self, inputs): + return math_util.merge_two_last_dims(inputs) class JasperSubBlock(tf.keras.layers.Layer): - def __init__(self, - channels: int = 256, - kernels: int = 11, - strides: int = 1, - dropout: float = 0.1, - dilation: int = 1, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs): + def __init__( + self, + channels: int = 256, + kernels: int = 11, + strides: int = 1, + dropout: float = 0.1, + dilation: int = 1, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(JasperSubBlock, self).__init__(**kwargs) self.conv1d = tf.keras.layers.Conv1D( - filters=channels, kernel_size=kernels, - strides=strides, dilation_rate=dilation, padding="same", + filters=channels, + kernel_size=kernels, + strides=strides, + dilation_rate=dilation, + padding="same", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_conv1d" + name=f"{self.name}_conv1d", ) self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn") self.relu = tf.keras.layers.ReLU(name=f"{self.name}_relu") self.do = tf.keras.layers.Dropout(dropout, name=f"{self.name}_dropout") self.reduction_factor = strides - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = inputs outputs = self.conv1d(outputs, training=training) outputs = self.bn(outputs, training=training) @@ -63,22 +74,31 @@ def get_config(self): class JasperResidual(tf.keras.layers.Layer): - def __init__(self, - channels: int = 256, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs): + def __init__( + self, + channels: int = 256, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(JasperResidual, self).__init__(**kwargs) self.pointwise_conv1d = tf.keras.layers.Conv1D( - filters=channels, kernel_size=1, - strides=1, padding="same", + filters=channels, + kernel_size=1, + strides=1, + padding="same", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_pointwise_conv1d" + name=f"{self.name}_pointwise_conv1d", ) self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn") - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.pointwise_conv1d(inputs, training=training) outputs = self.bn(outputs, training=training) return outputs @@ -91,21 +111,27 @@ def get_config(self): class JasperSubBlockResidual(JasperSubBlock): - def __init__(self, - channels: int = 256, - kernels: int = 11, - strides: int = 1, - dropout: float = 0.1, - dilation: int = 1, - nresiduals: int = 1, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs): + def __init__( + self, + channels: int = 256, + kernels: int = 11, + strides: int = 1, + dropout: float = 0.1, + dilation: int = 1, + nresiduals: int = 1, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(JasperSubBlockResidual, self).__init__( - channels=channels, kernels=kernels, - strides=strides, dropout=dropout, - dilation=dilation, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, **kwargs + channels=channels, + kernels=kernels, + strides=strides, + dropout=dropout, + dilation=dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + **kwargs, ) self.residuals = [ @@ -113,13 +139,19 @@ def __init__(self, channels=channels, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_residual_{i}" - ) for i in range(nresiduals) + name=f"{self.name}_residual_{i}", + ) + for i in range(nresiduals) ] self.add = tf.keras.layers.Add(name=f"{self.name}_add") - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs, residuals = inputs outputs = self.conv1d(outputs, training=training) outputs = self.bn(outputs, training=training) @@ -138,16 +170,18 @@ def get_config(self): class JasperBlock(tf.keras.Model): - def __init__(self, - nsubblocks: int = 3, - channels: int = 256, - kernels: int = 11, - dropout: float = 0.1, - dense: bool = False, - nresiduals: int = 1, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs): + def __init__( + self, + nsubblocks: int = 3, + channels: int = 256, + kernels: int = 11, + dropout: float = 0.1, + dense: bool = False, + nresiduals: int = 1, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(JasperBlock, self).__init__(**kwargs) self.dense = dense @@ -159,8 +193,9 @@ def __init__(self, dropout=dropout, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_subordinate_{i}" - ) for i in range(nsubblocks - 1) + name=f"{self.name}_subordinate_{i}", + ) + for i in range(nsubblocks - 1) ] self.subblock_residual = JasperSubBlockResidual( @@ -170,12 +205,17 @@ def __init__(self, nresiduals=nresiduals, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_subordinate_{nsubblocks - 1}" + name=f"{self.name}_subordinate_{nsubblocks - 1}", ) self.reduction_factor = 1 - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): inputs, residuals = inputs outputs = inputs for subblock in self.subblocks: @@ -196,31 +236,33 @@ def get_config(self): class JasperEncoder(tf.keras.Model): - def __init__(self, - dense: bool = False, - first_additional_block_channels: int = 256, - first_additional_block_kernels: int = 11, - first_additional_block_strides: int = 2, - first_additional_block_dilation: int = 1, - first_additional_block_dropout: int = 0.2, - nsubblocks: int = 5, - block_channels: list = [256, 384, 512, 640, 768], - block_kernels: list = [11, 13, 17, 21, 25], - block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3], - second_additional_block_channels: int = 896, - second_additional_block_kernels: int = 1, - second_additional_block_strides: int = 1, - second_additional_block_dilation: int = 2, - second_additional_block_dropout: int = 0.4, - third_additional_block_channels: int = 1024, - third_additional_block_kernels: int = 1, - third_additional_block_strides: int = 1, - third_additional_block_dilation: int = 1, - third_additional_block_dropout: int = 0.4, - kernel_regularizer=None, - bias_regularizer=None, - name: str = "jasper_encoder", - **kwargs): + def __init__( + self, + dense: bool = False, + first_additional_block_channels: int = 256, + first_additional_block_kernels: int = 11, + first_additional_block_strides: int = 2, + first_additional_block_dilation: int = 1, + first_additional_block_dropout: int = 0.2, + nsubblocks: int = 5, + block_channels: list = [256, 384, 512, 640, 768], + block_kernels: list = [11, 13, 17, 21, 25], + block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3], + second_additional_block_channels: int = 896, + second_additional_block_kernels: int = 1, + second_additional_block_strides: int = 1, + second_additional_block_dilation: int = 2, + second_additional_block_dropout: int = 0.4, + third_additional_block_channels: int = 1024, + third_additional_block_kernels: int = 1, + third_additional_block_strides: int = 1, + third_additional_block_dilation: int = 1, + third_additional_block_dropout: int = 0.4, + kernel_regularizer=None, + bias_regularizer=None, + name: str = "jasper_encoder", + **kwargs, + ): super().__init__(name=name, **kwargs) assert len(block_channels) == len(block_kernels) == len(block_dropout) @@ -235,7 +277,7 @@ def __init__(self, dilation=first_additional_block_dilation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_first_block" + name=f"{self.name}_first_block", ) self.blocks = [ @@ -248,8 +290,9 @@ def __init__(self, nresiduals=(i + 1) if dense else 1, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_block_{i}" - ) for i in range(len(block_channels)) + name=f"{self.name}_block_{i}", + ) + for i in range(len(block_channels)) ] self.second_additional_block = JasperSubBlock( @@ -260,7 +303,7 @@ def __init__(self, dilation=second_additional_block_dilation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_second_block" + name=f"{self.name}_second_block", ) self.third_additional_block = JasperSubBlock( @@ -271,10 +314,15 @@ def __init__(self, dilation=third_additional_block_dilation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_third_block" + name=f"{self.name}_third_block", ) - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.reshape(inputs) outputs = self.first_additional_block(outputs, training=training, **kwargs) @@ -286,7 +334,11 @@ def call(self, inputs, training=False, **kwargs): outputs = self.third_additional_block(outputs, training=training, **kwargs) return outputs - def summary(self, line_length=100, **kwargs): + def summary( + self, + line_length=100, + **kwargs, + ): super().summary(line_length=line_length, **kwargs) def get_config(self): @@ -301,32 +353,34 @@ def get_config(self): class Jasper(CtcModel): - def __init__(self, - vocabulary_size: int, - dense: bool = False, - first_additional_block_channels: int = 256, - first_additional_block_kernels: int = 11, - first_additional_block_strides: int = 2, - first_additional_block_dilation: int = 1, - first_additional_block_dropout: int = 0.2, - nsubblocks: int = 5, - block_channels: list = [256, 384, 512, 640, 768], - block_kernels: list = [11, 13, 17, 21, 25], - block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3], - second_additional_block_channels: int = 896, - second_additional_block_kernels: int = 1, - second_additional_block_strides: int = 1, - second_additional_block_dilation: int = 2, - second_additional_block_dropout: int = 0.4, - third_additional_block_channels: int = 1024, - third_additional_block_kernels: int = 1, - third_additional_block_strides: int = 1, - third_additional_block_dilation: int = 1, - third_additional_block_dropout: int = 0.4, - kernel_regularizer=None, - bias_regularizer=None, - name="jasper", - **kwargs): + def __init__( + self, + vocabulary_size: int, + dense: bool = False, + first_additional_block_channels: int = 256, + first_additional_block_kernels: int = 11, + first_additional_block_strides: int = 2, + first_additional_block_dilation: int = 1, + first_additional_block_dropout: int = 0.2, + nsubblocks: int = 5, + block_channels: list = [256, 384, 512, 640, 768], + block_kernels: list = [11, 13, 17, 21, 25], + block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3], + second_additional_block_channels: int = 896, + second_additional_block_kernels: int = 1, + second_additional_block_strides: int = 1, + second_additional_block_dilation: int = 2, + second_additional_block_dropout: int = 0.4, + third_additional_block_channels: int = 1024, + third_additional_block_kernels: int = 1, + third_additional_block_strides: int = 1, + third_additional_block_dilation: int = 1, + third_additional_block_dropout: int = 0.4, + kernel_regularizer=None, + bias_regularizer=None, + name="jasper", + **kwargs, + ): super().__init__( encoder=JasperEncoder( dense=dense, @@ -353,15 +407,17 @@ def __init__(self, bias_regularizer=None, ), decoder=tf.keras.layers.Conv1D( - filters=vocabulary_size, kernel_size=1, - strides=1, padding="same", + filters=vocabulary_size, + kernel_size=1, + strides=1, + padding="same", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{name}_logits" + name=f"{name}_logits", ), vocabulary_size=vocabulary_size, name=name, - **kwargs + **kwargs, ) self.time_reduction_factor = self.encoder.first_additional_block.reduction_factor self.time_reduction_factor *= self.encoder.second_additional_block.reduction_factor diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index de7b767fdd..8acb53bc8c 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -14,47 +14,56 @@ import tensorflow as tf +from ...utils import shape_util from ..activations.glu import GLU -from ..layers.subsampling import VggSubsampling, Conv2dSubsampling -from ..layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat from ..layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention -from ...utils import shape_util +from ..layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat +from ..layers.subsampling import Conv2dSubsampling, VggSubsampling L2 = tf.keras.regularizers.l2(1e-6) class FFModule(tf.keras.layers.Layer): - def __init__(self, - input_dim, - dropout=0.0, - fc_factor=0.5, - kernel_regularizer=L2, - bias_regularizer=L2, - name="ff_module", - **kwargs): + def __init__( + self, + input_dim, + dropout=0.0, + fc_factor=0.5, + kernel_regularizer=L2, + bias_regularizer=L2, + name="ff_module", + **kwargs, + ): super(FFModule, self).__init__(name=name, **kwargs) self.fc_factor = fc_factor self.ln = tf.keras.layers.LayerNormalization( name=f"{name}_ln", gamma_regularizer=kernel_regularizer, - beta_regularizer=bias_regularizer + beta_regularizer=bias_regularizer, ) self.ffn1 = tf.keras.layers.Dense( - 4 * input_dim, name=f"{name}_dense_1", + 4 * input_dim, + name=f"{name}_dense_1", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation") self.do1 = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout_1") self.ffn2 = tf.keras.layers.Dense( - input_dim, name=f"{name}_dense_2", + input_dim, + name=f"{name}_dense_2", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.do2 = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout_2") self.res_add = tf.keras.layers.Add(name=f"{name}_add") - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.ln(inputs, training=training) outputs = self.ffn1(outputs, training=training) outputs = self.swish(outputs) @@ -78,34 +87,38 @@ def get_config(self): class MHSAModule(tf.keras.layers.Layer): - def __init__(self, - head_size, - num_heads, - dropout=0.0, - mha_type="relmha", - kernel_regularizer=L2, - bias_regularizer=L2, - name="mhsa_module", - **kwargs): + def __init__( + self, + head_size, + num_heads, + dropout=0.0, + mha_type="relmha", + kernel_regularizer=L2, + bias_regularizer=L2, + name="mhsa_module", + **kwargs, + ): super(MHSAModule, self).__init__(name=name, **kwargs) self.ln = tf.keras.layers.LayerNormalization( name=f"{name}_ln", gamma_regularizer=kernel_regularizer, - beta_regularizer=bias_regularizer + beta_regularizer=bias_regularizer, ) if mha_type == "relmha": self.mha = RelPositionMultiHeadAttention( name=f"{name}_mhsa", - head_size=head_size, num_heads=num_heads, + head_size=head_size, + num_heads=num_heads, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) elif mha_type == "mha": self.mha = MultiHeadAttention( name=f"{name}_mhsa", - head_size=head_size, num_heads=num_heads, + head_size=head_size, + num_heads=num_heads, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) else: raise ValueError("mha_type must be either 'mha' or 'relmha'") @@ -113,7 +126,13 @@ def __init__(self, self.res_add = tf.keras.layers.Add(name=f"{name}_add") self.mha_type = mha_type - def call(self, inputs, training=False, mask=None, **kwargs): + def call( + self, + inputs, + training=False, + mask=None, + **kwargs, + ): inputs, pos = inputs # pos is positional encoding outputs = self.ln(inputs, training=training) if self.mha_type == "relmha": @@ -136,47 +155,65 @@ def get_config(self): class ConvModule(tf.keras.layers.Layer): - def __init__(self, - input_dim, - kernel_size=32, - dropout=0.0, - depth_multiplier=1, - kernel_regularizer=L2, - bias_regularizer=L2, - name="conv_module", - **kwargs): + def __init__( + self, + input_dim, + kernel_size=32, + dropout=0.0, + depth_multiplier=1, + kernel_regularizer=L2, + bias_regularizer=L2, + name="conv_module", + **kwargs, + ): super(ConvModule, self).__init__(name=name, **kwargs) self.ln = tf.keras.layers.LayerNormalization() self.pw_conv_1 = tf.keras.layers.Conv2D( - filters=2 * input_dim, kernel_size=1, strides=1, - padding="valid", name=f"{name}_pw_conv_1", + filters=2 * input_dim, + kernel_size=1, + strides=1, + padding="valid", + name=f"{name}_pw_conv_1", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.glu = GLU(name=f"{name}_glu") self.dw_conv = tf.keras.layers.DepthwiseConv2D( - kernel_size=(kernel_size, 1), strides=1, - padding="same", name=f"{name}_dw_conv", + kernel_size=(kernel_size, 1), + strides=1, + padding="same", + name=f"{name}_dw_conv", depth_multiplier=depth_multiplier, depthwise_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.bn = tf.keras.layers.BatchNormalization( name=f"{name}_bn", gamma_regularizer=kernel_regularizer, - beta_regularizer=bias_regularizer + beta_regularizer=bias_regularizer, + ) + self.swish = tf.keras.layers.Activation( + tf.nn.swish, + name=f"{name}_swish_activation", ) - self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation") self.pw_conv_2 = tf.keras.layers.Conv2D( - filters=input_dim, kernel_size=1, strides=1, - padding="valid", name=f"{name}_pw_conv_2", + filters=input_dim, + kernel_size=1, + strides=1, + padding="valid", + name=f"{name}_pw_conv_2", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.do = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout") self.res_add = tf.keras.layers.Add(name=f"{name}_add") - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.ln(inputs, training=training) B, T, E = shape_util.shape_list(outputs) outputs = tf.reshape(outputs, [B, T, 1, E]) @@ -206,53 +243,69 @@ def get_config(self): class ConformerBlock(tf.keras.layers.Layer): - def __init__(self, - input_dim, - dropout=0.0, - fc_factor=0.5, - head_size=36, - num_heads=4, - mha_type="relmha", - kernel_size=32, - depth_multiplier=1, - kernel_regularizer=L2, - bias_regularizer=L2, - name="conformer_block", - **kwargs): + def __init__( + self, + input_dim, + dropout=0.0, + fc_factor=0.5, + head_size=36, + num_heads=4, + mha_type="relmha", + kernel_size=32, + depth_multiplier=1, + kernel_regularizer=L2, + bias_regularizer=L2, + name="conformer_block", + **kwargs, + ): super(ConformerBlock, self).__init__(name=name, **kwargs) self.ffm1 = FFModule( - input_dim=input_dim, dropout=dropout, - fc_factor=fc_factor, name=f"{name}_ff_module_1", + input_dim=input_dim, + dropout=dropout, + fc_factor=fc_factor, + name=f"{name}_ff_module_1", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.mhsam = MHSAModule( mha_type=mha_type, - head_size=head_size, num_heads=num_heads, - dropout=dropout, name=f"{name}_mhsa_module", + head_size=head_size, + num_heads=num_heads, + dropout=dropout, + name=f"{name}_mhsa_module", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.convm = ConvModule( - input_dim=input_dim, kernel_size=kernel_size, - dropout=dropout, name=f"{name}_conv_module", + input_dim=input_dim, + kernel_size=kernel_size, + dropout=dropout, + name=f"{name}_conv_module", depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.ffm2 = FFModule( - input_dim=input_dim, dropout=dropout, - fc_factor=fc_factor, name=f"{name}_ff_module_2", + input_dim=input_dim, + dropout=dropout, + fc_factor=fc_factor, + name=f"{name}_ff_module_2", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.ln = tf.keras.layers.LayerNormalization( name=f"{name}_ln", gamma_regularizer=kernel_regularizer, - beta_regularizer=kernel_regularizer + beta_regularizer=kernel_regularizer, ) - def call(self, inputs, training=False, mask=None, **kwargs): + def call( + self, + inputs, + training=False, + mask=None, + **kwargs, + ): inputs, pos = inputs # pos is positional encoding outputs = self.ffm1(inputs, training=training, **kwargs) outputs = self.mhsam([outputs, pos], training=training, mask=mask, **kwargs) @@ -272,22 +325,24 @@ def get_config(self): class ConformerEncoder(tf.keras.Model): - def __init__(self, - subsampling, - positional_encoding="sinusoid", - dmodel=144, - num_blocks=16, - mha_type="relmha", - head_size=36, - num_heads=4, - kernel_size=32, - depth_multiplier=1, - fc_factor=0.5, - dropout=0.0, - kernel_regularizer=L2, - bias_regularizer=L2, - name="conformer_encoder", - **kwargs): + def __init__( + self, + subsampling, + positional_encoding="sinusoid", + dmodel=144, + num_blocks=16, + mha_type="relmha", + head_size=36, + num_heads=4, + kernel_size=32, + depth_multiplier=1, + fc_factor=0.5, + dropout=0.0, + kernel_regularizer=L2, + bias_regularizer=L2, + name="conformer_encoder", + **kwargs, + ): super(ConformerEncoder, self).__init__(name=name, **kwargs) subsampling_name = subsampling.pop("type", "conv2d") @@ -299,9 +354,10 @@ def __init__(self, raise ValueError("subsampling must be either 'conv2d' or 'vgg'") self.conv_subsampling = subsampling_class( - **subsampling, name=f"{name}_subsampling", + **subsampling, + name=f"{name}_subsampling", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) if positional_encoding == "sinusoid": @@ -315,13 +371,16 @@ def __init__(self, elif positional_encoding == "subsampling": self.pe = tf.keras.layers.Activation("linear", name=f"{name}_pe") else: - raise ValueError("positional_encoding must be either 'sinusoid', \ - 'sinusoid_concat', 'sinusoid_v2', 'sinusoid_concat_v2' or 'subsampling'") + raise ValueError( + "positional_encoding must be either 'sinusoid', \ + 'sinusoid_concat', 'sinusoid_v2', 'sinusoid_concat_v2' or 'subsampling'" + ) self.linear = tf.keras.layers.Dense( - dmodel, name=f"{name}_linear", + dmodel, + name=f"{name}_linear", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.do = tf.keras.layers.Dropout(dropout, name=f"{name}_dropout") @@ -338,11 +397,17 @@ def __init__(self, depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{name}_block_{i}" + name=f"{name}_block_{i}", ) self.conformer_blocks.append(conformer_block) - def call(self, inputs, training=False, mask=None, **kwargs): + def call( + self, + inputs, + training=False, + mask=None, + **kwargs, + ): # input with shape [B, T, V1, V2] outputs = self.conv_subsampling(inputs, training=training) outputs = self.linear(outputs, training=training) diff --git a/tensorflow_asr/models/encoders/contextnet.py b/tensorflow_asr/models/encoders/contextnet.py index 5fd9924972..d36ef9907e 100644 --- a/tensorflow_asr/models/encoders/contextnet.py +++ b/tensorflow_asr/models/encoders/contextnet.py @@ -14,44 +14,65 @@ """ Ref: https://github.com/iankur/ContextNet """ from typing import List + import tensorflow as tf + from ...utils import math_util L2 = tf.keras.regularizers.l2(1e-6) -def get_activation(activation: str = "silu"): +def get_activation( + activation: str = "silu", +): activation = activation.lower() - if activation in ["silu", "swish"]: return tf.nn.swish - elif activation == "relu": return tf.nn.relu - elif activation == "linear": return tf.keras.activations.linear - else: raise ValueError("activation must be either 'silu', 'swish', 'relu' or 'linear'") + if activation in ["silu", "swish"]: + return tf.nn.swish + elif activation == "relu": + return tf.nn.relu + elif activation == "linear": + return tf.keras.activations.linear + else: + raise ValueError("activation must be either 'silu', 'swish', 'relu' or 'linear'") class Reshape(tf.keras.layers.Layer): - def call(self, inputs): return math_util.merge_two_last_dims(inputs) + def call(self, inputs): + return math_util.merge_two_last_dims(inputs) class ConvModule(tf.keras.layers.Layer): - def __init__(self, - kernel_size: int = 3, - strides: int = 1, - filters: int = 256, - activation: str = "silu", - kernel_regularizer = None, - bias_regularizer = None, - **kwargs): + def __init__( + self, + kernel_size: int = 3, + strides: int = 1, + filters: int = 256, + activation: str = "silu", + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(ConvModule, self).__init__(**kwargs) self.strides = strides self.conv = tf.keras.layers.SeparableConv1D( - filters=filters, kernel_size=kernel_size, strides=strides, padding="same", - depthwise_regularizer=kernel_regularizer, pointwise_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, name=f"{self.name}_conv" + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding="same", + depthwise_regularizer=kernel_regularizer, + pointwise_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_conv", ) self.bn = tf.keras.layers.BatchNormalization(name=f"{self.name}_bn") self.activation = get_activation(activation) - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.conv(inputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.activation(outputs) @@ -59,26 +80,36 @@ def call(self, inputs, training=False, **kwargs): class SEModule(tf.keras.layers.Layer): - def __init__(self, - kernel_size: int = 3, - strides: int = 1, - filters: int = 256, - activation: str = "silu", - kernel_regularizer = None, - bias_regularizer = None, - **kwargs): + def __init__( + self, + kernel_size: int = 3, + strides: int = 1, + filters: int = 256, + activation: str = "silu", + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(SEModule, self).__init__(**kwargs) self.conv = ConvModule( - kernel_size=kernel_size, strides=strides, - filters=filters, activation=activation, - kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_conv_module" + kernel_size=kernel_size, + strides=strides, + filters=filters, + activation=activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_conv_module", ) self.activation = get_activation(activation) self.fc1 = tf.keras.layers.Dense(filters // 8, name=f"{self.name}_fc1") self.fc2 = tf.keras.layers.Dense(filters, name=f"{self.name}_fc2") - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): features, input_length = inputs outputs = self.conv(features, training=training) @@ -95,17 +126,19 @@ def call(self, inputs, training=False, **kwargs): class ConvBlock(tf.keras.layers.Layer): - def __init__(self, - nlayers: int = 3, - kernel_size: int = 3, - filters: int = 256, - strides: int = 1, - residual: bool = True, - activation: str = 'silu', - alpha: float = 1.0, - kernel_regularizer = None, - bias_regularizer = None, - **kwargs): + def __init__( + self, + nlayers: int = 3, + kernel_size: int = 3, + filters: int = 256, + strides: int = 1, + residual: bool = True, + activation: str = "silu", + alpha: float = 1.0, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(ConvBlock, self).__init__(**kwargs) self.dmodel = filters @@ -116,38 +149,56 @@ def __init__(self, for i in range(nlayers - 1): self.convs.append( ConvModule( - kernel_size=kernel_size, strides=1, - filters=filters, activation=activation, - kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_conv_module_{i}" + kernel_size=kernel_size, + strides=1, + filters=filters, + activation=activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_conv_module_{i}", ) ) self.last_conv = ConvModule( - kernel_size=kernel_size, strides=strides, - filters=filters, activation=activation, - kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_conv_module_{nlayers - 1}" + kernel_size=kernel_size, + strides=strides, + filters=filters, + activation=activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_conv_module_{nlayers - 1}", ) self.se = SEModule( - kernel_size=kernel_size, strides=1, filters=filters, activation=activation, - kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_se" + kernel_size=kernel_size, + strides=1, + filters=filters, + activation=activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_se", ) self.residual = None if residual: self.residual = ConvModule( - kernel_size=kernel_size, strides=strides, - filters=filters, activation="linear", - kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_residual" + kernel_size=kernel_size, + strides=strides, + filters=filters, + activation="linear", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_residual", ) self.activation = get_activation(activation) - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): features, input_length = inputs outputs = features for conv in self.convs: @@ -163,12 +214,14 @@ def call(self, inputs, training=False, **kwargs): class ContextNetEncoder(tf.keras.Model): - def __init__(self, - blocks: List[dict] = [], - alpha: float = 1.0, - kernel_regularizer = None, - bias_regularizer = None, - **kwargs): + def __init__( + self, + blocks: List[dict] = [], + alpha: float = 1.0, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super(ContextNetEncoder, self).__init__(**kwargs) self.reshape = Reshape(name=f"{self.name}_reshape") @@ -177,13 +230,20 @@ def __init__(self, for i, config in enumerate(blocks): self.blocks.append( ConvBlock( - **config, alpha=alpha, - kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_block_{i}" + **config, + alpha=alpha, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"{self.name}_block_{i}", ) ) - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs, input_length = inputs outputs = self.reshape(outputs) for block in self.blocks: diff --git a/tensorflow_asr/models/layers/bnlstmcell.py b/tensorflow_asr/models/layers/bnlstmcell.py index 0edb225f3b..9346bf6def 100755 --- a/tensorflow_asr/models/layers/bnlstmcell.py +++ b/tensorflow_asr/models/layers/bnlstmcell.py @@ -16,46 +16,69 @@ from tensorflow.python.ops import array_ops -def ds2_rnn_batch_norm(x_i, x_f, x_c, x_o, beta=None, gamma=None): +def ds2_rnn_batch_norm( + x_i, + x_f, + x_c, + x_o, + beta=None, + gamma=None, +): # x is input * weight with shape [batch_size, units * 4] # Merge into single array of features # https://www.tensorflow.org/api_docs/python/tf/nn/moments x = tf.concat([x_i, x_f, x_c, x_o], axis=1) mean, variance = tf.nn.moments(x, axes=[0, 1], keepdims=False) - x = tf.nn.batch_normalization(x=x, mean=mean, variance=variance, - offset=beta, scale=gamma, - variance_epsilon=K.epsilon()) - x_i, x_f, x_c, x_o = array_ops.split(x, num_or_size_splits=4, - axis=1) + x = tf.nn.batch_normalization( + x=x, + mean=mean, + variance=variance, + offset=beta, + scale=gamma, + variance_epsilon=K.epsilon(), + ) + x_i, x_f, x_c, x_o = array_ops.split(x, num_or_size_splits=4, axis=1) return x_i, x_f, x_c, x_o # Frame-wise Batch Norm RNN class BNLSTMCell(tf.keras.layers.LSTMCell): - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + **kwargs, + ): super().__init__(*args, **kwargs) - self.beta = self.add_weight(shape=(self.units * 4,), - name='lstm_bn_beta', initializer='zeros', - regularizer=None, constraint=None, trainable=True) - self.gamma = self.add_weight(shape=(self.units * 4,), - name='lstm_bn_gamma', initializer='ones', - regularizer=None, constraint=None, trainable=True) + self.beta = self.add_weight( + shape=(self.units * 4,), + name="lstm_bn_beta", + initializer="zeros", + regularizer=None, + constraint=None, + trainable=True, + ) + self.gamma = self.add_weight( + shape=(self.units * 4,), + name="lstm_bn_gamma", + initializer="ones", + regularizer=None, + constraint=None, + trainable=True, + ) - def _compute_carry_and_output(self, x, h_tm1, c_tm1): + def _compute_carry_and_output( + self, + x, + h_tm1, + c_tm1, + ): """Computes carry and output using split kernels.""" x_i, x_f, x_c, x_o = x - x_i, x_f, x_c, x_o = ds2_rnn_batch_norm(x_i, x_f, x_c, x_o, - beta=self.beta, - gamma=self.gamma) + x_i, x_f, x_c, x_o = ds2_rnn_batch_norm(x_i, x_f, x_c, x_o, beta=self.beta, gamma=self.gamma) h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 - i = self.recurrent_activation( - x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) - f = self.recurrent_activation(x_f + K.dot( - h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])) - c = f * c_tm1 + i * self.activation(x_c + K.dot( - h_tm1_c, - self.recurrent_kernel[:, self.units * 2:self.units * 3])) - o = self.recurrent_activation( - x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) + i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])) + f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2])) + c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3])) + o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])) return c, o diff --git a/tensorflow_asr/models/layers/embedding.py b/tensorflow_asr/models/layers/embedding.py index 2cf7bc48cd..ab711f91a9 100644 --- a/tensorflow_asr/models/layers/embedding.py +++ b/tensorflow_asr/models/layers/embedding.py @@ -16,13 +16,15 @@ class Embedding(tf.keras.layers.Layer): - def __init__(self, - vocab_size, - embed_dim, - contraint=None, - regularizer=None, - initializer=None, - **kwargs): + def __init__( + self, + vocab_size, + embed_dim, + contraint=None, + regularizer=None, + initializer=None, + **kwargs, + ): super(Embedding, self).__init__(**kwargs) self.vocab_size = vocab_size self.embed_dim = embed_dim @@ -30,27 +32,37 @@ def __init__(self, self.regularizer = tf.keras.regularizers.get(regularizer) self.initializer = tf.keras.initializers.get(initializer) - def build(self, input_shape): + def build( + self, + input_shape, + ): self.embeddings = self.add_weight( - name="embeddings", dtype=tf.float32, + name="embeddings", + dtype=tf.float32, shape=[self.vocab_size, self.embed_dim], initializer=self.initializer, - trainable=True, regularizer=self.regularizer, - constraint=self.contraint + trainable=True, + regularizer=self.regularizer, + constraint=self.contraint, ) self.built = True - def call(self, inputs): + def call( + self, + inputs, + ): outputs = tf.cast(tf.expand_dims(inputs, axis=-1), dtype=tf.int32) return tf.gather_nd(self.embeddings, outputs) def get_config(self): conf = super(Embedding, self).get_config() - conf.update({ - "vocab_size": self.vocab_size, - "embed_dim": self.embed_dim, - "contraint": self.contraint, - "regularizer": self.regularizer, - "initializer": self.initializer - }) + conf.update( + { + "vocab_size": self.vocab_size, + "embed_dim": self.embed_dim, + "contraint": self.contraint, + "regularizer": self.regularizer, + "initializer": self.initializer, + } + ) return conf diff --git a/tensorflow_asr/models/layers/multihead_attention.py b/tensorflow_asr/models/layers/multihead_attention.py index 3b3143b85d..4a78e204a9 100644 --- a/tensorflow_asr/models/layers/multihead_attention.py +++ b/tensorflow_asr/models/layers/multihead_attention.py @@ -13,24 +13,27 @@ # limitations under the License. import typing + import tensorflow as tf class MultiHeadAttention(tf.keras.layers.Layer): - def __init__(self, - num_heads, - head_size, - output_size: int = None, - dropout: float = 0.0, - use_projection_bias: bool = True, - return_attn_coef: bool = False, - kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform", - kernel_regularizer: typing.Union[str, typing.Callable] = None, - kernel_constraint: typing.Union[str, typing.Callable] = None, - bias_initializer: typing.Union[str, typing.Callable] = "zeros", - bias_regularizer: typing.Union[str, typing.Callable] = None, - bias_constraint: typing.Union[str, typing.Callable] = None, - **kwargs): + def __init__( + self, + num_heads, + head_size, + output_size: int = None, + dropout: float = 0.0, + use_projection_bias: bool = True, + return_attn_coef: bool = False, + kernel_initializer: typing.Union[str, typing.Callable] = "glorot_uniform", + kernel_regularizer: typing.Union[str, typing.Callable] = None, + kernel_constraint: typing.Union[str, typing.Callable] = None, + bias_initializer: typing.Union[str, typing.Callable] = "zeros", + bias_regularizer: typing.Union[str, typing.Callable] = None, + bias_constraint: typing.Union[str, typing.Callable] = None, + **kwargs, + ): super(MultiHeadAttention, self).__init__(**kwargs) if output_size is not None and output_size < 1: @@ -52,15 +55,14 @@ def __init__(self, self.dropout = tf.keras.layers.Dropout(dropout, name="dropout") self._droput_rate = dropout - def build(self, input_shape): + def build( + self, + input_shape, + ): num_query_features = input_shape[0][-1] num_key_features = input_shape[1][-1] - num_value_features = ( - input_shape[2][-1] if len(input_shape) > 2 else num_key_features - ) - output_size = ( - self.output_size if self.output_size is not None else num_value_features - ) + num_value_features = input_shape[2][-1] if len(input_shape) > 2 else num_key_features + output_size = self.output_size if self.output_size is not None else num_value_features self.query_kernel = self.add_weight( name="query_kernel", shape=[self.num_heads, num_query_features, self.head_size], @@ -100,12 +102,17 @@ def build(self, input_shape): else: self.projection_bias = None - def call_qkv(self, query, key, value, training=False): + def call_qkv( + self, + query, + key, + value, + training=False, + ): # verify shapes if key.shape[-2] != value.shape[-2]: raise ValueError( - "the number of elements in 'key' must be equal to " - "the same as the number of elements in 'value'" + "the number of elements in 'key' must be equal to " "the same as the number of elements in 'value'" ) # Linear transformations query = tf.einsum("...NI,HIO->...NHO", query, self.query_kernel) @@ -114,20 +121,23 @@ def call_qkv(self, query, key, value, training=False): return query, key, value - def call_attention(self, query, key, value, logits, training=False, mask=None): + def call_attention( + self, + query, + key, + value, + logits, + training=False, + mask=None, + ): # mask = attention mask with shape [B, Tquery, Tkey] with 1 is for positions we want to attend, 0 for masked if mask is not None: if len(mask.shape) < 2: raise ValueError("'mask' must have at least 2 dimensions") if query.shape[-3] != mask.shape[-2]: - raise ValueError( - "mask's second to last dimension must be equal to " - "the number of elements in 'query'" - ) + raise ValueError("mask's second to last dimension must be equal to " "the number of elements in 'query'") if key.shape[-3] != mask.shape[-1]: - raise ValueError( - "mask's last dimension must be equal to the number of elements in 'key'" - ) + raise ValueError("mask's last dimension must be equal to the number of elements in 'key'") # apply mask if mask is not None: mask = tf.cast(mask, tf.float32) @@ -155,7 +165,13 @@ def call_attention(self, query, key, value, logits, training=False, mask=None): return output, attn_coef - def call(self, inputs, training=False, mask=None, **kwargs): + def call( + self, + inputs, + training=False, + mask=None, + **kwargs, + ): query, key, value = inputs query, key, value = self.call_qkv(query, key, value, training=training) @@ -168,21 +184,19 @@ def call(self, inputs, training=False, mask=None, **kwargs): # Calculate dot product attention logits = tf.einsum("...NHO,...MHO->...HNM", query, key) - output, attn_coef = self.call_attention(query, key, value, logits, - training=training, mask=mask) + output, attn_coef = self.call_attention(query, key, value, logits, training=training, mask=mask) if self.return_attn_coef: return output, attn_coef else: return output - def compute_output_shape(self, input_shape): - num_value_features = ( - input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1] - ) - output_size = ( - self.output_size if self.output_size is not None else num_value_features - ) + def compute_output_shape( + self, + input_shape, + ): + num_value_features = input_shape[2][-1] if len(input_shape) > 2 else input_shape[1][-1] + output_size = self.output_size if self.output_size is not None else num_value_features output_shape = input_shape[0][:-1] + (output_size,) @@ -221,28 +235,31 @@ def get_config(self): class RelPositionMultiHeadAttention(MultiHeadAttention): - def build(self, input_shape): + def build( + self, + input_shape, + ): num_pos_features = input_shape[-1][-1] self.pos_kernel = self.add_weight( name="pos_kernel", shape=[self.num_heads, num_pos_features, self.head_size], initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint + constraint=self.kernel_constraint, ) self.pos_bias_u = self.add_weight( name="pos_bias_u", shape=[self.num_heads, self.head_size], regularizer=self.kernel_regularizer, initializer=self.kernel_initializer, - constraint=self.kernel_constraint + constraint=self.kernel_constraint, ) self.pos_bias_v = self.add_weight( name="pos_bias_v", shape=[self.num_heads, self.head_size], regularizer=self.kernel_regularizer, initializer=self.kernel_initializer, - constraint=self.kernel_constraint + constraint=self.kernel_constraint, ) super(RelPositionMultiHeadAttention, self).build(input_shape[:-1]) @@ -254,7 +271,13 @@ def relative_shift(x): x = tf.reshape(x[:, :, 1:, :], x_shape) return x - def call(self, inputs, training=False, mask=None, **kwargs): + def call( + self, + inputs, + training=False, + mask=None, + **kwargs, + ): query, key, value, pos = inputs query, key, value = self.call_qkv(query, key, value, training=training) @@ -268,13 +291,12 @@ def call(self, inputs, training=False, mask=None, **kwargs): logits_with_v = tf.einsum("...NHO,...MHO->...HNM", query_with_v, pos) logits_with_v = self.relative_shift(logits_with_v) - logits = logits_with_u + logits_with_v[:, :, :, :tf.shape(logits_with_u)[3]] + logits = logits_with_u + logits_with_v[:, :, :, : tf.shape(logits_with_u)[3]] depth = tf.constant(self.head_size, dtype=tf.float32) logits /= tf.sqrt(depth) - output, attn_coef = self.call_attention(query, key, value, logits, - training=training, mask=mask) + output, attn_coef = self.call_attention(query, key, value, logits, training=training, mask=mask) if self.return_attn_coef: return output, attn_coef diff --git a/tensorflow_asr/models/layers/point_wise_ffn.py b/tensorflow_asr/models/layers/point_wise_ffn.py index a9e4820f4a..222f4ae8fe 100755 --- a/tensorflow_asr/models/layers/point_wise_ffn.py +++ b/tensorflow_asr/models/layers/point_wise_ffn.py @@ -16,13 +16,15 @@ class PointWiseFFN(tf.keras.layers.Layer): - def __init__(self, - size, - output_size, - activation="relu", - dropout=0.1, - name="point_wise_ffn", - **kwargs): + def __init__( + self, + size, + output_size, + activation="relu", + dropout=0.1, + name="point_wise_ffn", + **kwargs, + ): super(PointWiseFFN, self).__init__(name=name, **kwargs) self.ffn1 = tf.keras.layers.Dense(units=size, activation=activation) self.do1 = tf.keras.layers.Dropout(dropout) diff --git a/tensorflow_asr/models/layers/positional_encoding.py b/tensorflow_asr/models/layers/positional_encoding.py index bf108aa263..d6a6650f23 100755 --- a/tensorflow_asr/models/layers/positional_encoding.py +++ b/tensorflow_asr/models/layers/positional_encoding.py @@ -13,21 +13,34 @@ # limitations under the License. import tensorflow as tf + from ...utils.shape_util import shape_list class PositionalEncoding(tf.keras.layers.Layer): - def __init__(self, alpha: int = 1, beta: int = 0, name="positional_encoding", **kwargs): + def __init__( + self, + alpha: int = 1, + beta: int = 0, + name="positional_encoding", + **kwargs, + ): super().__init__(trainable=False, name=name, **kwargs) self.alpha = alpha self.beta = beta - def build(self, input_shape): + def build( + self, + input_shape, + ): dmodel = input_shape[-1] assert dmodel % 2 == 0, f"Input last dim must be even: {dmodel}" @staticmethod - def encode(max_len, dmodel): + def encode( + max_len, + dmodel, + ): pos = tf.expand_dims(tf.range(max_len - 1, -1, -1.0, dtype=tf.float32), axis=1) index = tf.expand_dims(tf.range(0, dmodel, dtype=tf.float32), axis=0) @@ -35,17 +48,19 @@ def encode(max_len, dmodel): # Sin cos will be [max_len, size // 2] # we add 0 between numbers by using padding and reshape - sin = tf.pad(tf.expand_dims(tf.sin(pe[:, 0::2]), -1), - [[0, 0], [0, 0], [0, 1]], mode="CONSTANT", constant_values=0) + sin = tf.pad(tf.expand_dims(tf.sin(pe[:, 0::2]), -1), [[0, 0], [0, 0], [0, 1]], mode="CONSTANT", constant_values=0) sin = tf.reshape(sin, [max_len, dmodel]) - cos = tf.pad(tf.expand_dims(tf.cos(pe[:, 1::2]), -1), - [[0, 0], [0, 0], [1, 0]], mode="CONSTANT", constant_values=0) + cos = tf.pad(tf.expand_dims(tf.cos(pe[:, 1::2]), -1), [[0, 0], [0, 0], [1, 0]], mode="CONSTANT", constant_values=0) cos = tf.reshape(cos, [max_len, dmodel]) # Then add sin and cos, which results in [time, size] pe = tf.add(sin, cos) return tf.expand_dims(pe, axis=0) # [1, time, size] - def call(self, inputs, **kwargs): + def call( + self, + inputs, + **kwargs, + ): # inputs shape [B, T, V] _, max_len, dmodel = shape_list(inputs) pe = self.encode(max_len * self.alpha + self.beta, dmodel) @@ -57,12 +72,18 @@ def get_config(self): class PositionalEncodingConcat(PositionalEncoding): - def build(self, input_shape): + def build( + self, + input_shape, + ): dmodel = input_shape[-1] assert dmodel % 2 == 0, f"Input last dim must be even: {dmodel}" @staticmethod - def encode(max_len, dmodel): + def encode( + max_len, + dmodel, + ): pos = tf.range(max_len - 1, -1, -1.0, dtype=tf.float32) index = tf.range(0, dmodel, 2.0, dtype=tf.float32) @@ -73,7 +94,11 @@ def encode(max_len, dmodel): return tf.expand_dims(pos, axis=0) - def call(self, inputs, **kwargs): + def call( + self, + inputs, + **kwargs, + ): # inputs shape [B, T, V] _, max_len, dmodel = shape_list(inputs) pe = self.encode(max_len * self.alpha + self.beta, dmodel) diff --git a/tensorflow_asr/models/layers/row_conv_1d.py b/tensorflow_asr/models/layers/row_conv_1d.py index d62b43f6c1..ef827259a4 100755 --- a/tensorflow_asr/models/layers/row_conv_1d.py +++ b/tensorflow_asr/models/layers/row_conv_1d.py @@ -12,69 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf -from tensorflow.python.ops import nn_ops from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.ops import nn_ops class RowConv1D(tf.keras.layers.Conv1D): - def __init__(self, filters, future_context, **kwargs): + def __init__( + self, + filters, + future_context, + **kwargs, + ): assert future_context >= 0, "Future context must be positive" - super().__init__(filters=filters, - kernel_size=(future_context * 2 + 1), **kwargs) + super().__init__(filters=filters, kernel_size=(future_context * 2 + 1), **kwargs) self.future_context = future_context - def build(self, input_shape): + def build( + self, + input_shape, + ): input_shape = tf.TensorShape(input_shape) input_channel = self._get_input_channel(input_shape) kernel_shape = self.kernel_size + (input_channel, self.filters) self.kernel = self.add_weight( - name='kernel', + name="kernel", shape=kernel_shape, initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True, - dtype=self.dtype) + dtype=self.dtype, + ) # Add mask to remove weights on half of the kernel to the left # (only keep future # context) - left_kernel_dims = ( - self.future_context, input_channel, self.filters) + left_kernel_dims = (self.future_context, input_channel, self.filters) left_kernel = tf.fill(dims=left_kernel_dims, value=0) - right_kernel_dims = ( - self.future_context + 1, input_channel, self.filters) + right_kernel_dims = (self.future_context + 1, input_channel, self.filters) right_kernel = tf.fill(dims=right_kernel_dims, value=1) - mask_kernel = tf.cast( - tf.concat([left_kernel, right_kernel], axis=0), - dtype=self.dtype) + mask_kernel = tf.cast(tf.concat([left_kernel, right_kernel], axis=0), dtype=self.dtype) self.kernel = tf.multiply(self.kernel, mask_kernel) if self.use_bias: self.bias = self.add_weight( - name='bias', + name="bias", shape=(self.filters,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True, - dtype=self.dtype) + dtype=self.dtype, + ) else: self.bias = None channel_axis = self._get_channel_axis() - self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2, - axes={channel_axis: input_channel}) + self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2, axes={channel_axis: input_channel}) self.make_conv_op_input_shape = input_shape self.make_input_channel = input_channel self._padding_op = self._get_padding_op() - self._conv_op_data_format = conv_utils.convert_data_format( - self.data_format, self.rank + 2) + self._conv_op_data_format = conv_utils.convert_data_format(self.data_format, self.rank + 2) self._convolution_op = nn_ops.Convolution( input_shape, filter_shape=self.kernel.shape, dilation_rate=self.dilation_rate, strides=self.strides, padding=self._padding_op, - data_format=self._conv_op_data_format) + data_format=self._conv_op_data_format, + ) self.built = True diff --git a/tensorflow_asr/models/layers/sequence_wise_bn.py b/tensorflow_asr/models/layers/sequence_wise_bn.py index 0c2742445b..3a6a98e7e9 100644 --- a/tensorflow_asr/models/layers/sequence_wise_bn.py +++ b/tensorflow_asr/models/layers/sequence_wise_bn.py @@ -20,15 +20,32 @@ def __init__(self, name, time_major=False, **kwargs): super(SequenceBatchNorm, self).__init__(name=name, **kwargs) self.time_major = time_major - def build(self, input_shape): - self.beta = self.add_weight(shape=[input_shape[-1]], - name='beta', initializer='zeros', - regularizer=None, constraint=None, trainable=True) - self.gamma = self.add_weight(shape=[input_shape[-1]], - name='gamma', initializer='ones', - regularizer=None, constraint=None, trainable=True) + def build( + self, + input_shape, + ): + self.beta = self.add_weight( + shape=[input_shape[-1]], + name="beta", + initializer="zeros", + regularizer=None, + constraint=None, + trainable=True, + ) + self.gamma = self.add_weight( + shape=[input_shape[-1]], + name="gamma", + initializer="ones", + regularizer=None, + constraint=None, + trainable=True, + ) - def call(self, inputs, **kwargs): + def call( + self, + inputs, + **kwargs, + ): mean, variance = tf.nn.moments(inputs, axes=[0, 1], keepdims=False) if self.time_major: total_padded_frames = tf.cast(tf.shape(inputs)[0], tf.keras.backend.dtype(mean)) @@ -37,22 +54,22 @@ def call(self, inputs, **kwargs): total_padded_frames = tf.cast(tf.shape(inputs)[1], tf.keras.backend.dtype(mean)) batch_size = tf.cast(tf.shape(inputs)[0], tf.keras.backend.dtype(mean)) total_unpadded_frames_batch = tf.math.count_nonzero( - inputs, axis=[0, 1], keepdims=False, - dtype=tf.keras.backend.dtype(mean) + inputs, axis=[0, 1], keepdims=False, dtype=tf.keras.backend.dtype(mean) ) mean = (mean * total_padded_frames * batch_size) / total_unpadded_frames_batch variance = (variance * total_padded_frames * batch_size) / total_unpadded_frames_batch return tf.nn.batch_normalization( - inputs, mean=mean, variance=variance, - offset=self.beta, scale=self.gamma, - variance_epsilon=tf.keras.backend.epsilon() + inputs, + mean=mean, + variance=variance, + offset=self.beta, + scale=self.gamma, + variance_epsilon=tf.keras.backend.epsilon(), ) def get_config(self): config = super(SequenceBatchNorm, self).get_config() - config.update({ - "time_major": self.time_major - }) + config.update({"time_major": self.time_major}) return config def from_config(self, config): diff --git a/tensorflow_asr/models/layers/subsampling.py b/tensorflow_asr/models/layers/subsampling.py index 3e69f4dcdf..a68e05dc82 100644 --- a/tensorflow_asr/models/layers/subsampling.py +++ b/tensorflow_asr/models/layers/subsampling.py @@ -14,19 +14,31 @@ import tensorflow as tf -from ...utils import shape_util, math_util +from ...utils import math_util, shape_util class TimeReduction(tf.keras.layers.Layer): - def __init__(self, factor: int, name: str = "TimeReduction", **kwargs): + def __init__( + self, + factor: int, + name: str = "TimeReduction", + **kwargs, + ): super(TimeReduction, self).__init__(name=name, **kwargs) self.time_reduction_factor = factor - def padding(self, time): + def padding( + self, + time, + ): new_time = tf.math.ceil(time / self.time_reduction_factor) * self.time_reduction_factor return tf.cast(new_time, dtype=tf.int32) - time - def call(self, inputs, **kwargs): + def call( + self, + inputs, + **kwargs, + ): shape = shape_util.shape_list(inputs) outputs = tf.pad(inputs, [[0, 0], [0, self.padding(shape[1])], [0, 0]]) outputs = tf.reshape(outputs, [shape[0], -1, shape[-1] * self.time_reduction_factor]) @@ -39,50 +51,63 @@ def get_config(self): class VggSubsampling(tf.keras.layers.Layer): - def __init__(self, - filters: tuple or list = (32, 64), - kernel_size: int or list or tuple = 3, - strides: int or list or tuple = 2, - kernel_regularizer=None, - bias_regularizer=None, - name="VggSubsampling", - **kwargs): + def __init__( + self, + filters: tuple or list = (32, 64), + kernel_size: int or list or tuple = 3, + strides: int or list or tuple = 2, + kernel_regularizer=None, + bias_regularizer=None, + name="VggSubsampling", + **kwargs, + ): super(VggSubsampling, self).__init__(name=name, **kwargs) self.conv1 = tf.keras.layers.Conv2D( - filters=filters[0], kernel_size=kernel_size, strides=1, - padding="same", name=f"{name}_conv_1", + filters=filters[0], + kernel_size=kernel_size, + strides=1, + padding="same", + name=f"{name}_conv_1", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.conv2 = tf.keras.layers.Conv2D( - filters=filters[0], kernel_size=kernel_size, strides=1, - padding="same", name=f"{name}_conv_2", + filters=filters[0], + kernel_size=kernel_size, + strides=1, + padding="same", + name=f"{name}_conv_2", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer - ) - self.maxpool1 = tf.keras.layers.MaxPool2D( - pool_size=strides, - padding="same", name=f"{name}_maxpool_1" + bias_regularizer=bias_regularizer, ) + self.maxpool1 = tf.keras.layers.MaxPool2D(pool_size=strides, padding="same", name=f"{name}_maxpool_1") self.conv3 = tf.keras.layers.Conv2D( - filters=filters[1], kernel_size=kernel_size, strides=1, - padding="same", name=f"{name}_conv_3", + filters=filters[1], + kernel_size=kernel_size, + strides=1, + padding="same", + name=f"{name}_conv_3", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.conv4 = tf.keras.layers.Conv2D( - filters=filters[1], kernel_size=kernel_size, strides=1, - padding="same", name=f"{name}_conv_4", + filters=filters[1], + kernel_size=kernel_size, + strides=1, + padding="same", + name=f"{name}_conv_4", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer - ) - self.maxpool2 = tf.keras.layers.MaxPool2D( - pool_size=strides, - padding="same", name=f"{name}_maxpool_2" + bias_regularizer=bias_regularizer, ) + self.maxpool2 = tf.keras.layers.MaxPool2D(pool_size=strides, padding="same", name=f"{name}_maxpool_2") self.time_reduction_factor = self.maxpool1.pool_size[0] + self.maxpool2.pool_size[0] - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.conv1(inputs, training=training) outputs = tf.nn.relu(outputs) outputs = self.conv2(outputs, training=training) @@ -97,7 +122,9 @@ def call(self, inputs, training=False, **kwargs): return math_util.merge_two_last_dims(outputs) - def get_config(self): + def get_config( + self, + ): conf = super(VggSubsampling, self).get_config() conf.update(self.conv1.get_config()) conf.update(self.conv2.get_config()) @@ -109,30 +136,43 @@ def get_config(self): class Conv2dSubsampling(tf.keras.layers.Layer): - def __init__(self, - filters: int, - strides: list or tuple or int = 2, - kernel_size: int or list or tuple = 3, - kernel_regularizer=None, - bias_regularizer=None, - name="Conv2dSubsampling", - **kwargs): + def __init__( + self, + filters: int, + strides: list or tuple or int = 2, + kernel_size: int or list or tuple = 3, + kernel_regularizer=None, + bias_regularizer=None, + name="Conv2dSubsampling", + **kwargs, + ): super(Conv2dSubsampling, self).__init__(name=name, **kwargs) self.conv1 = tf.keras.layers.Conv2D( - filters=filters, kernel_size=kernel_size, - strides=strides, padding="same", name=f"{name}_1", + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding="same", + name=f"{name}_1", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.conv2 = tf.keras.layers.Conv2D( - filters=filters, kernel_size=kernel_size, - strides=strides, padding="same", name=f"{name}_2", + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding="same", + name=f"{name}_2", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) self.time_reduction_factor = self.conv1.strides[0] + self.conv2.strides[0] - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.conv1(inputs, training=training) outputs = tf.nn.relu(outputs) outputs = self.conv2(outputs, training=training) diff --git a/tensorflow_asr/models/transducer/conformer.py b/tensorflow_asr/models/transducer/conformer.py index b5d151e266..5a068c16a1 100644 --- a/tensorflow_asr/models/transducer/conformer.py +++ b/tensorflow_asr/models/transducer/conformer.py @@ -13,44 +13,46 @@ # limitations under the License. -from ..encoders.conformer import ConformerEncoder, L2 +from ..encoders.conformer import L2, ConformerEncoder from .transducer import Transducer class Conformer(Transducer): - def __init__(self, - vocabulary_size: int, - encoder_subsampling: dict, - encoder_positional_encoding: str = "sinusoid", - encoder_dmodel: int = 144, - encoder_num_blocks: int = 16, - encoder_head_size: int = 36, - encoder_num_heads: int = 4, - encoder_mha_type: str = "relmha", - encoder_kernel_size: int = 32, - encoder_depth_multiplier: int = 1, - encoder_fc_factor: float = 0.5, - encoder_dropout: float = 0, - encoder_trainable: bool = True, - prediction_embed_dim: int = 512, - prediction_embed_dropout: int = 0, - prediction_num_rnns: int = 1, - prediction_rnn_units: int = 320, - prediction_rnn_type: str = "lstm", - prediction_rnn_implementation: int = 2, - prediction_layer_norm: bool = True, - prediction_projection_units: int = 0, - prediction_trainable: bool = True, - joint_dim: int = 1024, - joint_activation: str = "tanh", - prejoint_linear: bool = True, - postjoint_linear: bool = False, - joint_mode: str = "add", - joint_trainable: bool = True, - kernel_regularizer=L2, - bias_regularizer=L2, - name: str = "conformer", - **kwargs): + def __init__( + self, + vocabulary_size: int, + encoder_subsampling: dict, + encoder_positional_encoding: str = "sinusoid", + encoder_dmodel: int = 144, + encoder_num_blocks: int = 16, + encoder_head_size: int = 36, + encoder_num_heads: int = 4, + encoder_mha_type: str = "relmha", + encoder_kernel_size: int = 32, + encoder_depth_multiplier: int = 1, + encoder_fc_factor: float = 0.5, + encoder_dropout: float = 0, + encoder_trainable: bool = True, + prediction_embed_dim: int = 512, + prediction_embed_dropout: int = 0, + prediction_num_rnns: int = 1, + prediction_rnn_units: int = 320, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_layer_norm: bool = True, + prediction_projection_units: int = 0, + prediction_trainable: bool = True, + joint_dim: int = 1024, + joint_activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=L2, + bias_regularizer=L2, + name: str = "conformer", + **kwargs, + ): super(Conformer, self).__init__( encoder=ConformerEncoder( subsampling=encoder_subsampling, @@ -67,7 +69,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, - name=f"{name}_encoder" + name=f"{name}_encoder", ), vocabulary_size=vocabulary_size, embed_dim=prediction_embed_dim, @@ -88,7 +90,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=name, - **kwargs + **kwargs, ) self.dmodel = encoder_dmodel self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor diff --git a/tensorflow_asr/models/transducer/contextnet.py b/tensorflow_asr/models/transducer/contextnet.py index 134a06a8f2..15e7e601d4 100644 --- a/tensorflow_asr/models/transducer/contextnet.py +++ b/tensorflow_asr/models/transducer/contextnet.py @@ -13,38 +13,41 @@ # limitations under the License. from typing import Dict, List + import tensorflow as tf -from ..encoders.contextnet import ContextNetEncoder, L2 +from ...utils import data_util, math_util +from ..encoders.contextnet import L2, ContextNetEncoder from .transducer import Transducer -from ...utils import math_util, data_util class ContextNet(Transducer): - def __init__(self, - vocabulary_size: int, - encoder_blocks: List[dict], - encoder_alpha: float = 0.5, - encoder_trainable: bool = True, - prediction_embed_dim: int = 512, - prediction_embed_dropout: int = 0, - prediction_num_rnns: int = 1, - prediction_rnn_units: int = 320, - prediction_rnn_type: str = "lstm", - prediction_rnn_implementation: int = 2, - prediction_layer_norm: bool = True, - prediction_projection_units: int = 0, - prediction_trainable: bool = True, - joint_dim: int = 1024, - joint_activation: str = "tanh", - prejoint_linear: bool = True, - postjoint_linear: bool = False, - joint_mode: str = "add", - joint_trainable: bool = True, - kernel_regularizer=L2, - bias_regularizer=L2, - name: str = "contextnet", - **kwargs): + def __init__( + self, + vocabulary_size: int, + encoder_blocks: List[dict], + encoder_alpha: float = 0.5, + encoder_trainable: bool = True, + prediction_embed_dim: int = 512, + prediction_embed_dropout: int = 0, + prediction_num_rnns: int = 1, + prediction_rnn_units: int = 320, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_layer_norm: bool = True, + prediction_projection_units: int = 0, + prediction_trainable: bool = True, + joint_dim: int = 1024, + joint_activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=L2, + bias_regularizer=L2, + name: str = "contextnet", + **kwargs, + ): super(ContextNet, self).__init__( encoder=ContextNetEncoder( blocks=encoder_blocks, @@ -52,7 +55,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, - name=f"{name}_encoder" + name=f"{name}_encoder", ), vocabulary_size=vocabulary_size, embed_dim=prediction_embed_dim, @@ -73,22 +76,31 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=name, - **kwargs + **kwargs, ) self.dmodel = self.encoder.blocks[-1].dmodel self.time_reduction_factor = 1 - for block in self.encoder.blocks: self.time_reduction_factor *= block.time_reduction_factor - - def call(self, inputs, training=False, **kwargs): + for block in self.encoder.blocks: + self.time_reduction_factor *= block.time_reduction_factor + + def call( + self, + inputs, + training=False, + **kwargs, + ): enc = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=training, **kwargs) pred = self.predict_net([inputs["predictions"], inputs["predictions_length"]], training=training, **kwargs) logits = self.joint_net([enc, pred], training=training, **kwargs) return data_util.create_logits( - logits=logits, - logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) + logits=logits, logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) ) - def encoder_inference(self, features: tf.Tensor, input_length: tf.Tensor): + def encoder_inference( + self, + features: tf.Tensor, + input_length: tf.Tensor, + ): with tf.name_scope(f"{self.name}_encoder"): input_length = tf.expand_dims(tf.shape(features)[0], axis=0) outputs = tf.expand_dims(features, axis=0) @@ -98,7 +110,10 @@ def encoder_inference(self, features: tf.Tensor, input_length: tf.Tensor): # -------------------------------- GREEDY ------------------------------------- @tf.function - def recognize(self, inputs: Dict[str, tf.Tensor]): + def recognize( + self, + inputs: Dict[str, tf.Tensor], + ): """ RNN Transducer Greedy decoding Args: @@ -111,7 +126,12 @@ def recognize(self, inputs: Dict[str, tf.Tensor]): encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length) - def recognize_tflite(self, signal, predicted, prediction_states): + def recognize_tflite( + self, + signal, + predicted, + prediction_states, + ): """ Function to convert to tflite using greedy decoding (default streaming mode) Args: @@ -131,7 +151,12 @@ def recognize_tflite(self, signal, predicted, prediction_states): transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) return transcript, hypothesis.index, hypothesis.states - def recognize_tflite_with_timestamp(self, signal, predicted, states): + def recognize_tflite_with_timestamp( + self, + signal, + predicted, + states, + ): features = self.speech_featurizer.tf_extract(signal) encoded = self.encoder_inference(features, tf.shape(features)[0]) hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states) @@ -157,7 +182,11 @@ def recognize_tflite_with_timestamp(self, signal, predicted, states): # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function - def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): + def recognize_beam( + self, + inputs: Dict[str, tf.Tensor], + lm: bool = False, + ): """ RNN Transducer Beam Search Args: diff --git a/tensorflow_asr/models/transducer/rnn_transducer.py b/tensorflow_asr/models/transducer/rnn_transducer.py index f1456bc175..72c23cac55 100644 --- a/tensorflow_asr/models/transducer/rnn_transducer.py +++ b/tensorflow_asr/models/transducer/rnn_transducer.py @@ -14,27 +14,31 @@ """ http://arxiv.org/abs/1811.06621 """ from typing import Dict + import tensorflow as tf +from ...utils import layer_util, math_util, shape_util from ..layers.subsampling import TimeReduction from .transducer import Transducer -from ...utils import layer_util, math_util, shape_util class Reshape(tf.keras.layers.Layer): - def call(self, inputs): return math_util.merge_two_last_dims(inputs) + def call(self, inputs): + return math_util.merge_two_last_dims(inputs) class RnnTransducerBlock(tf.keras.Model): - def __init__(self, - reduction_factor: int = 0, - dmodel: int = 640, - rnn_type: str = "lstm", - rnn_units: int = 2048, - layer_norm: bool = True, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs): + def __init__( + self, + reduction_factor: int = 0, + dmodel: int = 640, + rnn_type: str = "lstm", + rnn_units: int = 2048, + layer_norm: bool = True, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super().__init__(**kwargs) if reduction_factor > 0: @@ -44,10 +48,12 @@ def __init__(self, RNN = layer_util.get_rnn(rnn_type) self.rnn = RNN( - units=rnn_units, return_sequences=True, - name=f"{self.name}_{rnn_type}", return_state=True, + units=rnn_units, + return_sequences=True, + name=f"{self.name}_{rnn_type}", + return_state=True, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) if layer_norm: @@ -56,12 +62,18 @@ def __init__(self, self.ln = None self.projection = tf.keras.layers.Dense( - dmodel, name=f"{self.name}_projection", + dmodel, + name=f"{self.name}_projection", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = inputs if self.reduction is not None: outputs = self.reduction(outputs) @@ -72,7 +84,11 @@ def call(self, inputs, training=False, **kwargs): outputs = self.projection(outputs, training=training) return outputs - def recognize(self, inputs, states): + def recognize( + self, + inputs, + states, + ): outputs = inputs if self.reduction is not None: outputs = self.reduction(outputs) @@ -96,16 +112,18 @@ def get_config(self): class RnnTransducerEncoder(tf.keras.Model): - def __init__(self, - reductions: dict = {0: 3, 1: 2}, - dmodel: int = 640, - nlayers: int = 8, - rnn_type: str = "lstm", - rnn_units: int = 2048, - layer_norm: bool = True, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs): + def __init__( + self, + reductions: dict = {0: 3, 1: 2}, + dmodel: int = 640, + nlayers: int = 8, + rnn_type: str = "lstm", + rnn_units: int = 2048, + layer_norm: bool = True, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): super().__init__(**kwargs) self.reshape = Reshape(name=f"{self.name}_reshape") @@ -119,16 +137,21 @@ def __init__(self, layer_norm=layer_norm, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - name=f"{self.name}_block_{i}" - ) for i in range(nlayers) + name=f"{self.name}_block_{i}", + ) + for i in range(nlayers) ] self.time_reduction_factor = 1 for i in range(nlayers): reduction_factor = reductions.get(i, 0) - if reduction_factor > 0: self.time_reduction_factor *= reduction_factor + if reduction_factor > 0: + self.time_reduction_factor *= reduction_factor - def get_initial_state(self, batch_size=1): + def get_initial_state( + self, + batch_size=1, + ): """Get zeros states Returns: @@ -136,22 +159,25 @@ def get_initial_state(self, batch_size=1): """ states = [] for block in self.blocks: - states.append( - tf.stack( - block.rnn.get_initial_state( - tf.zeros([batch_size, 1, 1], dtype=tf.float32) - ), axis=0 - ) - ) + states.append(tf.stack(block.rnn.get_initial_state(tf.zeros([batch_size, 1, 1], dtype=tf.float32)), axis=0)) return tf.stack(states, axis=0) - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): outputs = self.reshape(inputs) for block in self.blocks: outputs = block(outputs, training=training, **kwargs) return outputs - def recognize(self, inputs, states): + def recognize( + self, + inputs, + states, + ): """Recognize function for encoder network Args: @@ -173,38 +199,41 @@ def get_config(self): conf = self.reshape.get_config() if self.fnorm is not None: conf.update(self.fnorm.get_config()) - for block in self.blocks: conf.update(block.get_config()) + for block in self.blocks: + conf.update(block.get_config()) return conf class RnnTransducer(Transducer): - def __init__(self, - vocabulary_size: int, - encoder_reductions: dict = {0: 3, 1: 2}, - encoder_dmodel: int = 640, - encoder_nlayers: int = 8, - encoder_rnn_type: str = "lstm", - encoder_rnn_units: int = 2048, - encoder_layer_norm: bool = True, - encoder_trainable: bool = True, - prediction_embed_dim: int = 320, - prediction_embed_dropout: float = 0, - prediction_num_rnns: int = 2, - prediction_rnn_units: int = 2048, - prediction_rnn_type: str = "lstm", - prediction_layer_norm: bool = True, - prediction_projection_units: int = 640, - prediction_trainable: bool = True, - joint_dim: int = 640, - joint_activation: str = "tanh", - prejoint_linear: bool = True, - postjoint_linear: bool = False, - joint_mode: str = "add", - joint_trainable: bool = True, - kernel_regularizer = None, - bias_regularizer = None, - name = "RnnTransducer", - **kwargs): + def __init__( + self, + vocabulary_size: int, + encoder_reductions: dict = {0: 3, 1: 2}, + encoder_dmodel: int = 640, + encoder_nlayers: int = 8, + encoder_rnn_type: str = "lstm", + encoder_rnn_units: int = 2048, + encoder_layer_norm: bool = True, + encoder_trainable: bool = True, + prediction_embed_dim: int = 320, + prediction_embed_dropout: float = 0, + prediction_num_rnns: int = 2, + prediction_rnn_units: int = 2048, + prediction_rnn_type: str = "lstm", + prediction_layer_norm: bool = True, + prediction_projection_units: int = 640, + prediction_trainable: bool = True, + joint_dim: int = 640, + joint_activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=None, + bias_regularizer=None, + name="RnnTransducer", + **kwargs, + ): super().__init__( encoder=RnnTransducerEncoder( reductions=encoder_reductions, @@ -216,7 +245,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, - name=f"{name}_encoder" + name=f"{name}_encoder", ), vocabulary_size=vocabulary_size, embed_dim=prediction_embed_dim, @@ -236,12 +265,16 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=name, - **kwargs + **kwargs, ) self.time_reduction_factor = self.encoder.time_reduction_factor self.dmodel = encoder_dmodel - def encoder_inference(self, features: tf.Tensor, states: tf.Tensor): + def encoder_inference( + self, + features: tf.Tensor, + states: tf.Tensor, + ): """Infer function for encoder (or encoders) Args: @@ -260,7 +293,10 @@ def encoder_inference(self, features: tf.Tensor, states: tf.Tensor): # -------------------------------- GREEDY ------------------------------------- @tf.function - def recognize(self, inputs: Dict[str, tf.Tensor]): + def recognize( + self, + inputs: Dict[str, tf.Tensor], + ): """ RNN Transducer Greedy decoding Args: @@ -274,7 +310,13 @@ def recognize(self, inputs: Dict[str, tf.Tensor]): encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length) - def recognize_tflite(self, signal, predicted, encoder_states, prediction_states): + def recognize_tflite( + self, + signal, + predicted, + encoder_states, + prediction_states, + ): """ Function to convert to tflite using greedy decoding (default streaming mode) Args: @@ -295,7 +337,13 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states) transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) return transcript, hypothesis.index, new_encoder_states, hypothesis.states - def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states): + def recognize_tflite_with_timestamp( + self, + signal, + predicted, + encoder_states, + prediction_states, + ): features = self.speech_featurizer.tf_extract(signal) encoded, new_encoder_states = self.encoder_inference(features, encoder_states) hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) @@ -321,7 +369,11 @@ def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, pre # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function - def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): + def recognize_beam( + self, + inputs: Dict[str, tf.Tensor], + lm: bool = False, + ): """ RNN Transducer Beam Search Args: @@ -338,7 +390,10 @@ def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): # -------------------------------- TFLITE ------------------------------------- - def make_tflite_function(self, timestamp: bool = True): + def make_tflite_function( + self, + timestamp: bool = True, + ): tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite return tf.function( tflite_func, @@ -346,6 +401,6 @@ def make_tflite_function(self, timestamp: bool = True): tf.TensorSpec([None], dtype=tf.float32), tf.TensorSpec([], dtype=tf.int32), tf.TensorSpec(self.encoder.get_initial_state().get_shape(), dtype=tf.float32), - tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32) - ] + tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32), + ], ) diff --git a/tensorflow_asr/models/transducer/transducer.py b/tensorflow_asr/models/transducer/transducer.py index a55163f22d..fa672bc511 100644 --- a/tensorflow_asr/models/transducer/transducer.py +++ b/tensorflow_asr/models/transducer/transducer.py @@ -15,14 +15,15 @@ import collections from typing import Dict + import tensorflow as tf -from ..base_model import BaseModel -from ...utils import math_util, layer_util, shape_util, data_util from ...featurizers.speech_featurizers import SpeechFeaturizer from ...featurizers.text_featurizers import TextFeaturizer -from ..layers.embedding import Embedding from ...losses.rnnt_loss import RnntLoss +from ...utils import data_util, layer_util, math_util, shape_util +from ..base_model import BaseModel +from ..layers.embedding import Embedding Hypothesis = collections.namedtuple("Hypothesis", ("index", "prediction", "states")) @@ -30,34 +31,37 @@ class TransducerPrediction(tf.keras.Model): - def __init__(self, - vocabulary_size: int, - embed_dim: int, - embed_dropout: float = 0, - num_rnns: int = 1, - rnn_units: int = 512, - rnn_type: str = "lstm", - rnn_implementation: int = 2, - layer_norm: bool = True, - projection_units: int = 0, - kernel_regularizer=None, - bias_regularizer=None, - name="transducer_prediction", - **kwargs): + def __init__( + self, + vocabulary_size: int, + embed_dim: int, + embed_dropout: float = 0, + num_rnns: int = 1, + rnn_units: int = 512, + rnn_type: str = "lstm", + rnn_implementation: int = 2, + layer_norm: bool = True, + projection_units: int = 0, + kernel_regularizer=None, + bias_regularizer=None, + name="transducer_prediction", + **kwargs, + ): super().__init__(name=name, **kwargs) - self.embed = Embedding(vocabulary_size, embed_dim, - regularizer=kernel_regularizer, name=f"{name}_embedding") + self.embed = Embedding(vocabulary_size, embed_dim, regularizer=kernel_regularizer, name=f"{name}_embedding") self.do = tf.keras.layers.Dropout(embed_dropout, name=f"{name}_dropout") # Initialize rnn layers RNN = layer_util.get_rnn(rnn_type) self.rnns = [] for i in range(num_rnns): rnn = RNN( - units=rnn_units, return_sequences=True, - name=f"{name}_{rnn_type}_{i}", return_state=True, + units=rnn_units, + return_sequences=True, + name=f"{name}_{rnn_type}_{i}", + return_state=True, implementation=rnn_implementation, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) if layer_norm: ln = tf.keras.layers.LayerNormalization(name=f"{name}_ln_{i}") @@ -68,7 +72,7 @@ def __init__(self, projection_units, name=f"{name}_projection_{i}", kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, ) else: projection = None @@ -82,13 +86,7 @@ def get_initial_state(self): """ states = [] for rnn in self.rnns: - states.append( - tf.stack( - rnn["rnn"].get_initial_state( - tf.zeros([1, 1, 1], dtype=tf.float32) - ), axis=0 - ) - ) + states.append(tf.stack(rnn["rnn"].get_initial_state(tf.zeros([1, 1, 1], dtype=tf.float32)), axis=0)) return tf.stack(states, axis=0) def call(self, inputs, training=False, **kwargs): @@ -144,10 +142,7 @@ def get_config(self): class TransducerJointReshape(tf.keras.layers.Layer): - def __init__(self, - axis: int = 1, - name="transducer_joint_reshape", - **kwargs): + def __init__(self, axis: int = 1, name="transducer_joint_reshape", **kwargs): super().__init__(name=name, trainable=False, **kwargs) self.axis = axis @@ -162,17 +157,19 @@ def get_config(self): class TransducerJoint(tf.keras.Model): - def __init__(self, - vocabulary_size: int, - joint_dim: int = 1024, - activation: str = "tanh", - prejoint_linear: bool = True, - postjoint_linear: bool = False, - joint_mode: str = "add", - kernel_regularizer=None, - bias_regularizer=None, - name="tranducer_joint", - **kwargs): + def __init__( + self, + vocabulary_size: int, + joint_dim: int = 1024, + activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + kernel_regularizer=None, + bias_regularizer=None, + name="tranducer_joint", + **kwargs, + ): super().__init__(name=name, **kwargs) activation = activation.lower() @@ -190,13 +187,10 @@ def __init__(self, if self.prejoint_linear: self.ffn_enc = tf.keras.layers.Dense( - joint_dim, name=f"{name}_enc", - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + joint_dim, name=f"{name}_enc", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ) self.ffn_pred = tf.keras.layers.Dense( - joint_dim, use_bias=False, name=f"{name}_pred", - kernel_regularizer=kernel_regularizer + joint_dim, use_bias=False, name=f"{name}_pred", kernel_regularizer=kernel_regularizer ) self.enc_reshape = TransducerJointReshape(axis=2, name=f"{name}_enc_reshape") @@ -211,15 +205,11 @@ def __init__(self, if self.postjoint_linear: self.ffn = tf.keras.layers.Dense( - joint_dim, name=f"{name}_ffn", - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + joint_dim, name=f"{name}_ffn", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ) self.ffn_out = tf.keras.layers.Dense( - vocabulary_size, name=f"{name}_vocab", - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + vocabulary_size, name=f"{name}_vocab", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ) def call(self, inputs, training=False, **kwargs): @@ -249,30 +239,32 @@ def get_config(self): class Transducer(BaseModel): - """ Transducer Model Warper """ - - def __init__(self, - encoder: tf.keras.Model, - vocabulary_size: int, - embed_dim: int = 512, - embed_dropout: float = 0, - num_rnns: int = 1, - rnn_units: int = 320, - rnn_type: str = "lstm", - rnn_implementation: int = 2, - layer_norm: bool = True, - projection_units: int = 0, - prediction_trainable: bool = True, - joint_dim: int = 1024, - joint_activation: str = "tanh", - prejoint_linear: bool = True, - postjoint_linear: bool = False, - joint_mode: str = "add", - joint_trainable: bool = True, - kernel_regularizer=None, - bias_regularizer=None, - name="transducer", - **kwargs): + """Transducer Model Warper""" + + def __init__( + self, + encoder: tf.keras.Model, + vocabulary_size: int, + embed_dim: int = 512, + embed_dropout: float = 0, + num_rnns: int = 1, + rnn_units: int = 320, + rnn_type: str = "lstm", + rnn_implementation: int = 2, + layer_norm: bool = True, + projection_units: int = 0, + prediction_trainable: bool = True, + joint_dim: int = 1024, + joint_activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=None, + bias_regularizer=None, + name="transducer", + **kwargs, + ): super().__init__(name=name, **kwargs) self.encoder = encoder self.predict_net = TransducerPrediction( @@ -288,7 +280,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=prediction_trainable, - name=f"{name}_prediction" + name=f"{name}_prediction", ) self.joint_net = TransducerJoint( vocabulary_size=vocabulary_size, @@ -300,11 +292,16 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=joint_trainable, - name=f"{name}_joint" + name=f"{name}_joint", ) self.time_reduction_factor = 1 - def make(self, input_shape, prediction_shape=[None], batch_size=None): + def make( + self, + input_shape, + prediction_shape=[None], + batch_size=None, + ): inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) predictions = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) @@ -314,20 +311,27 @@ def make(self, input_shape, prediction_shape=[None], batch_size=None): inputs=inputs, inputs_length=inputs_length, predictions=predictions, - predictions_length=predictions_length + predictions_length=predictions_length, ), - training=False + training=False, ) - def summary(self, line_length=None, **kwargs): - if self.encoder is not None: self.encoder.summary(line_length=line_length, **kwargs) + def summary( + self, + line_length=None, + **kwargs, + ): + if self.encoder is not None: + self.encoder.summary(line_length=line_length, **kwargs) self.predict_net.summary(line_length=line_length, **kwargs) self.joint_net.summary(line_length=line_length, **kwargs) super(Transducer, self).summary(line_length=line_length, **kwargs) - def add_featurizers(self, - speech_featurizer: SpeechFeaturizer, - text_featurizer: TextFeaturizer): + def add_featurizers( + self, + speech_featurizer: SpeechFeaturizer, + text_featurizer: TextFeaturizer, + ): """ Function to add featurizer to model to convert to end2end tflite Args: @@ -338,27 +342,37 @@ def add_featurizers(self, self.speech_featurizer = speech_featurizer self.text_featurizer = text_featurizer - def compile(self, - optimizer, - global_batch_size, - blank=0, - run_eagerly=None, - **kwargs): + def compile( + self, + optimizer, + global_batch_size, + blank=0, + run_eagerly=None, + **kwargs, + ): loss = RnntLoss(blank=blank, global_batch_size=global_batch_size) super().compile(loss=loss, optimizer=optimizer, run_eagerly=run_eagerly, **kwargs) - def call(self, inputs, training=False, **kwargs): + def call( + self, + inputs, + training=False, + **kwargs, + ): enc = self.encoder(inputs["inputs"], training=training, **kwargs) pred = self.predict_net([inputs["predictions"], inputs["predictions_length"]], training=training, **kwargs) logits = self.joint_net([enc, pred], training=training, **kwargs) return data_util.create_logits( logits=logits, - logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) + logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor), ) # -------------------------------- INFERENCES ------------------------------------- - def encoder_inference(self, features: tf.Tensor): + def encoder_inference( + self, + features: tf.Tensor, + ): """Infer function for encoder (or encoders) Args: @@ -372,7 +386,12 @@ def encoder_inference(self, features: tf.Tensor): outputs = self.encoder(outputs, training=False) return tf.squeeze(outputs, axis=0) - def decoder_inference(self, encoded: tf.Tensor, predicted: tf.Tensor, states: tf.Tensor): + def decoder_inference( + self, + encoded: tf.Tensor, + predicted: tf.Tensor, + states: tf.Tensor, + ): """Infer function for decoder Args: @@ -400,7 +419,10 @@ def get_config(self): # -------------------------------- GREEDY ------------------------------------- @tf.function - def recognize(self, inputs: Dict[str, tf.Tensor]): + def recognize( + self, + inputs: Dict[str, tf.Tensor], + ): """ RNN Transducer Greedy decoding Args: @@ -414,7 +436,12 @@ def recognize(self, inputs: Dict[str, tf.Tensor]): encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length) - def recognize_tflite(self, signal, predicted, states): + def recognize_tflite( + self, + signal, + predicted, + states, + ): """ Function to convert to tflite using greedy decoding (default streaming mode) Args: @@ -433,7 +460,12 @@ def recognize_tflite(self, signal, predicted, states): transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) return transcript, hypothesis.index, hypothesis.states - def recognize_tflite_with_timestamp(self, signal, predicted, states): + def recognize_tflite_with_timestamp( + self, + signal, + predicted, + states, + ): features = self.speech_featurizer.tf_extract(signal) encoded = self.encoder_inference(features) hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states) @@ -456,21 +488,27 @@ def recognize_tflite_with_timestamp(self, signal, predicted, states): return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states - def _perform_greedy_batch(self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - parallel_iterations: int = 10, - swap_memory: bool = False): + def _perform_greedy_batch( + self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = False, + ): with tf.name_scope(f"{self.name}_perform_greedy_batch"): total_batch = tf.shape(encoded)[0] batch = tf.constant(0, dtype=tf.int32) decoded = tf.TensorArray( - dtype=tf.int32, size=total_batch, dynamic_size=False, - clear_after_read=False, element_shape=tf.TensorShape([None]) + dtype=tf.int32, + size=total_batch, + dynamic_size=False, + clear_after_read=False, + element_shape=tf.TensorShape([None]), ) - def condition(batch, _): return tf.less(batch, total_batch) + def condition(batch, _): + return tf.less(batch, total_batch) def body(batch, decoded): hypothesis = self._perform_greedy( @@ -479,26 +517,31 @@ def body(batch, decoded): predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32), states=self.predict_net.get_initial_state(), parallel_iterations=parallel_iterations, - swap_memory=swap_memory + swap_memory=swap_memory, ) decoded = decoded.write(batch, hypothesis.prediction) return batch + 1, decoded batch, decoded = tf.while_loop( - condition, body, loop_vars=[batch, decoded], - parallel_iterations=parallel_iterations, swap_memory=True, + condition, + body, + loop_vars=[batch, decoded], + parallel_iterations=parallel_iterations, + swap_memory=True, ) decoded = math_util.pad_prediction_tfarray(decoded, blank=self.text_featurizer.blank) return self.text_featurizer.iextract(decoded.stack()) - def _perform_greedy(self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - predicted: tf.Tensor, - states: tf.Tensor, - parallel_iterations: int = 10, - swap_memory: bool = False): + def _perform_greedy( + self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + predicted: tf.Tensor, + states: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = False, + ): with tf.name_scope(f"{self.name}_greedy"): time = tf.constant(0, dtype=tf.int32) total = encoded_length @@ -506,20 +549,24 @@ def _perform_greedy(self, hypothesis = Hypothesis( index=predicted, prediction=tf.TensorArray( - dtype=tf.int32, size=total, dynamic_size=False, - clear_after_read=False, element_shape=tf.TensorShape([]) + dtype=tf.int32, + size=total, + dynamic_size=False, + clear_after_read=False, + element_shape=tf.TensorShape([]), ), - states=states + states=states, ) - def condition(_time, _hypothesis): return tf.less(_time, total) + def condition(_time, _hypothesis): + return tf.less(_time, total) def body(_time, _hypothesis): ytu, _states = self.decoder_inference( # avoid using [index] in tflite encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])), predicted=_hypothesis.index, - states=_hypothesis.states + states=_hypothesis.states, ) _predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] @@ -538,22 +585,29 @@ def body(_time, _hypothesis): return _time + 1, _hypothesis time, hypothesis = tf.while_loop( - condition, body, + condition, + body, loop_vars=[time, hypothesis], parallel_iterations=parallel_iterations, - swap_memory=swap_memory + swap_memory=swap_memory, ) - return Hypothesis(index=hypothesis.index, prediction=hypothesis.prediction.stack(), states=hypothesis.states) + return Hypothesis( + index=hypothesis.index, + prediction=hypothesis.prediction.stack(), + states=hypothesis.states, + ) - def _perform_greedy_v2(self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - predicted: tf.Tensor, - states: tf.Tensor, - parallel_iterations: int = 10, - swap_memory: bool = False): - """ Ref: https://arxiv.org/pdf/1801.00841.pdf """ + def _perform_greedy_v2( + self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + predicted: tf.Tensor, + states: tf.Tensor, + parallel_iterations: int = 10, + swap_memory: bool = False, + ): + """Ref: https://arxiv.org/pdf/1801.00841.pdf""" with tf.name_scope(f"{self.name}_greedy_v2"): time = tf.constant(0, dtype=tf.int32) total = encoded_length @@ -561,20 +615,24 @@ def _perform_greedy_v2(self, hypothesis = Hypothesis( index=predicted, prediction=tf.TensorArray( - dtype=tf.int32, size=0, dynamic_size=True, - clear_after_read=False, element_shape=tf.TensorShape([]) + dtype=tf.int32, + size=0, + dynamic_size=True, + clear_after_read=False, + element_shape=tf.TensorShape([]), ), - states=states + states=states, ) - def condition(_time, _hypothesis): return tf.less(_time, total) + def condition(_time, _hypothesis): + return tf.less(_time, total) def body(_time, _hypothesis): ytu, _states = self.decoder_inference( # avoid using [index] in tflite encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])), predicted=_hypothesis.index, - states=_hypothesis.states + states=_hypothesis.states, ) _predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] @@ -589,18 +647,27 @@ def body(_time, _hypothesis): return _time, _hypothesis time, hypothesis = tf.while_loop( - condition, body, + condition, + body, loop_vars=[time, hypothesis], parallel_iterations=parallel_iterations, - swap_memory=swap_memory + swap_memory=swap_memory, ) - return Hypothesis(index=hypothesis.index, prediction=hypothesis.prediction.stack(), states=hypothesis.states) + return Hypothesis( + index=hypothesis.index, + prediction=hypothesis.prediction.stack(), + states=hypothesis.states, + ) # -------------------------------- BEAM SEARCH ------------------------------------- @tf.function - def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): + def recognize_beam( + self, + inputs: Dict[str, tf.Tensor], + lm: bool = False, + ): """ RNN Transducer Beam Search Args: @@ -612,52 +679,70 @@ def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): """ encoded = self.encoder(inputs["inputs"], training=False) encoded_length = math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) - return self._perform_beam_search_batch(encoded=encoded, encoded_length=encoded_length, lm=lm) - - def _perform_beam_search_batch(self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - lm: bool = False, - parallel_iterations: int = 10, - swap_memory: bool = True): + return self._perform_beam_search_batch( + encoded=encoded, + encoded_length=encoded_length, + lm=lm, + ) + + def _perform_beam_search_batch( + self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = True, + ): with tf.name_scope(f"{self.name}_perform_beam_search_batch"): total_batch = tf.shape(encoded)[0] batch = tf.constant(0, dtype=tf.int32) decoded = tf.TensorArray( - dtype=tf.int32, size=total_batch, dynamic_size=False, - clear_after_read=False, element_shape=None + dtype=tf.int32, + size=total_batch, + dynamic_size=False, + clear_after_read=False, + element_shape=None, ) - def condition(batch, _): return tf.less(batch, total_batch) + def condition(batch, _): + return tf.less(batch, total_batch) def body(batch, decoded): hypothesis = self._perform_beam_search( - encoded[batch], encoded_length[batch], lm, - parallel_iterations=parallel_iterations, swap_memory=swap_memory + encoded[batch], + encoded_length[batch], + lm, + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, ) decoded = decoded.write(batch, hypothesis.prediction) return batch + 1, decoded batch, decoded = tf.while_loop( - condition, body, loop_vars=[batch, decoded], - parallel_iterations=parallel_iterations, swap_memory=True, + condition, + body, + loop_vars=[batch, decoded], + parallel_iterations=parallel_iterations, + swap_memory=True, ) decoded = math_util.pad_prediction_tfarray(decoded, blank=self.text_featurizer.blank) return self.text_featurizer.iextract(decoded.stack()) - def _perform_beam_search(self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - lm: bool = False, - parallel_iterations: int = 10, - swap_memory: bool = True): + def _perform_beam_search( + self, + encoded: tf.Tensor, + encoded_length: tf.Tensor, + lm: bool = False, + parallel_iterations: int = 10, + swap_memory: bool = True, + ): with tf.name_scope(f"{self.name}_beam_search"): beam_width = tf.cond( tf.less(self.text_featurizer.decoder_config.beam_width, self.text_featurizer.num_classes), true_fn=lambda: self.text_featurizer.decoder_config.beam_width, - false_fn=lambda: self.text_featurizer.num_classes - 1 + false_fn=lambda: self.text_featurizer.num_classes - 1, ) total = encoded_length @@ -668,28 +753,28 @@ def initialize_beam(dynamic=False): size=beam_width if not dynamic else 0, dynamic_size=dynamic, element_shape=tf.TensorShape([]), - clear_after_read=False + clear_after_read=False, ), indices=tf.TensorArray( dtype=tf.int32, size=beam_width if not dynamic else 0, dynamic_size=dynamic, element_shape=tf.TensorShape([]), - clear_after_read=False + clear_after_read=False, ), prediction=tf.TensorArray( dtype=tf.int32, size=beam_width if not dynamic else 0, dynamic_size=dynamic, element_shape=None, - clear_after_read=False + clear_after_read=False, ), states=tf.TensorArray( dtype=tf.float32, size=beam_width if not dynamic else 0, dynamic_size=dynamic, element_shape=tf.TensorShape(shape_util.shape_list(self.predict_net.get_initial_state())), - clear_after_read=False + clear_after_read=False, ), ) @@ -698,10 +783,11 @@ def initialize_beam(dynamic=False): score=B.score.write(0, 0.0), indices=B.indices.write(0, self.text_featurizer.blank), prediction=B.prediction.write(0, tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank), - states=B.states.write(0, self.predict_net.get_initial_state()) + states=B.states.write(0, self.predict_net.get_initial_state()), ) - def condition(time, total, B): return tf.less(time, total) + def condition(time, total, B): + return tf.less(time, total) def body(time, total, B): A = initialize_beam(dynamic=True) @@ -709,10 +795,7 @@ def body(time, total, B): score=A.score.unstack(B.score.stack()), indices=A.indices.unstack(B.indices.stack()), prediction=A.prediction.unstack( - math_util.pad_prediction_tfarray( - B.prediction, - blank=self.text_featurizer.blank - ).stack() + math_util.pad_prediction_tfarray(B.prediction, blank=self.text_featurizer.blank).stack() ), states=A.states.unstack(B.states.stack()), ) @@ -721,7 +804,8 @@ def body(time, total, B): encoded_t = tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)) - def beam_condition(beam, beam_width, A, A_i, B): return tf.less(beam, beam_width) + def beam_condition(beam, beam_width, A, A_i, B): + return tf.less(beam, beam_width) def beam_body(beam, beam_width, A, A_i, B): # get y_hat @@ -730,7 +814,7 @@ def beam_body(beam, beam_width, A, A_i, B): y_hat_index = tf.gather_nd(A.indices.stack(), y_hat_score_index) y_hat_prediction = tf.gather_nd( math_util.pad_prediction_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), - y_hat_score_index + y_hat_score_index, ) y_hat_states = tf.gather_nd(A.states.stack(), y_hat_score_index) @@ -744,7 +828,7 @@ def beam_body(beam, beam_width, A, A_i, B): prediction=A.prediction.unstack( tf.gather_nd( math_util.pad_prediction_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), - remain_indices + remain_indices, ) ), states=A.states.unstack(tf.gather_nd(A.states.stack(), remain_indices)), @@ -753,7 +837,8 @@ def beam_body(beam, beam_width, A, A_i, B): ytu, new_states = self.decoder_inference(encoded=encoded_t, predicted=y_hat_index, states=y_hat_states) - def predict_condition(pred, A, A_i, B): return tf.less(pred, self.text_featurizer.num_classes) + def predict_condition(pred, A, A_i, B): + return tf.less(pred, self.text_featurizer.num_classes) def predict_body(pred, A, A_i, B): new_score = y_hat_score + tf.gather_nd(ytu, tf.expand_dims(pred, axis=-1)) @@ -776,7 +861,7 @@ def false_fn(): updated_prediction = tf.tensor_scatter_nd_update( y_hat_prediction, indices=tf.reshape(scatter_index, [1, 1]), - updates=tf.expand_dims(pred, axis=-1) + updates=tf.expand_dims(pred, axis=-1), ) return ( B.score, @@ -787,12 +872,12 @@ def false_fn(): A.indices.write(A_i, pred), A.prediction.write(A_i, updated_prediction), A.states.write(A_i, new_states), - A_i + 1 + A_i + 1, ) - b_score, b_indices, b_prediction, b_states, \ - a_score, a_indices, a_prediction, a_states, A_i = tf.cond( - tf.equal(pred, self.text_featurizer.blank), true_fn=true_fn, false_fn=false_fn) + b_score, b_indices, b_prediction, b_states, a_score, a_indices, a_prediction, a_states, A_i = tf.cond( + tf.equal(pred, self.text_featurizer.blank), true_fn=true_fn, false_fn=false_fn + ) B = BeamHypothesis(score=b_score, indices=b_indices, prediction=b_prediction, states=b_states) A = BeamHypothesis(score=a_score, indices=a_indices, prediction=a_prediction, states=a_states) @@ -800,25 +885,31 @@ def false_fn(): return pred + 1, A, A_i, B _, A, A_i, B = tf.while_loop( - predict_condition, predict_body, + predict_condition, + predict_body, loop_vars=[0, A, A_i, B], - parallel_iterations=parallel_iterations, swap_memory=swap_memory + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, ) return beam + 1, beam_width, A, A_i, B _, _, A, A_i, B = tf.while_loop( - beam_condition, beam_body, + beam_condition, + beam_body, loop_vars=[0, beam_width, A, A_i, B], - parallel_iterations=parallel_iterations, swap_memory=swap_memory + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, ) return time + 1, total, B _, _, B = tf.while_loop( - condition, body, + condition, + body, loop_vars=[0, total, B], - parallel_iterations=parallel_iterations, swap_memory=swap_memory + parallel_iterations=parallel_iterations, + swap_memory=swap_memory, ) scores = B.score.stack() @@ -837,13 +928,16 @@ def false_fn(): # -------------------------------- TFLITE ------------------------------------- - def make_tflite_function(self, timestamp: bool = False): + def make_tflite_function( + self, + timestamp: bool = False, + ): tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite return tf.function( tflite_func, input_signature=[ tf.TensorSpec([None], dtype=tf.float32), tf.TensorSpec([], dtype=tf.int32), - tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32) - ] + tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32), + ], ) diff --git a/tensorflow_asr/optimizers/schedules.py b/tensorflow_asr/optimizers/schedules.py index ec8d151774..ba13a0d93f 100755 --- a/tensorflow_asr/optimizers/schedules.py +++ b/tensorflow_asr/optimizers/schedules.py @@ -13,8 +13,6 @@ # limitations under the License. import tensorflow as tf -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops from tensorflow.keras.optimizers.schedules import ExponentialDecay @@ -41,7 +39,7 @@ def get_config(self): return { "d_model": self.d_model, "warmup_steps": self.warmup_steps, - "max_lr": self.max_lr + "max_lr": self.max_lr, } @@ -64,7 +62,7 @@ def get_config(self): return { "lamb": self.lamb, "d_model": self.d_model, - "warmup_steps": self.warmup_steps + "warmup_steps": self.warmup_steps, } @@ -74,20 +72,18 @@ def __init__(self, min_lr=0.0, **kwargs): self.min_lr = min_lr def __call__(self, step): - with ops.name_scope_v2(self.name or "ExponentialDecay") as name: - initial_learning_rate = ops.convert_to_tensor( - self.initial_learning_rate, name="initial_learning_rate") + with tf.name_scope(self.name or "ExponentialDecay") as name: + initial_learning_rate = tf.convert_to_tensor(self.initial_learning_rate, name="initial_learning_rate") dtype = initial_learning_rate.dtype - decay_steps = math_ops.cast(self.decay_steps, dtype) - decay_rate = math_ops.cast(self.decay_rate, dtype) + decay_steps = tf.cast(self.decay_steps, dtype) + decay_rate = tf.cast(self.decay_rate, dtype) - global_step_recomp = math_ops.cast(step, dtype) + global_step_recomp = tf.cast(step, dtype) p = global_step_recomp / decay_steps if self.staircase: - p = math_ops.floor(p) - new_lr = math_ops.multiply( - initial_learning_rate, math_ops.pow(decay_rate, p), name=name) - return math_ops.maximum(self.min_lr, new_lr) + p = tf.math.floor(p) + new_lr = tf.multiply(initial_learning_rate, tf.pow(decay_rate, p), name=name) + return tf.maximum(self.min_lr, new_lr) class CyclicTransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): @@ -110,8 +106,7 @@ class CyclicTransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedu https://arxiv.org/abs/1506.01186) """ - def __init__(self, d_model, warmup_steps=4000, max_lr=None, - step_size=None): + def __init__(self, d_model, warmup_steps=4000, max_lr=None, step_size=None): """Applies triangular cyclic to the square root decay learning rate. Args: d_model: Model dimension @@ -134,9 +129,8 @@ def __call__(self, step): lr = tf.math.minimum(self.max_lr, lr) cycle = tf.math.floor(1 + step / (2 * self.step_size)) x = tf.math.abs(step / self.step_size - 2 * cycle + 1) - lr = lr * (0.5 + tf.math.maximum(0., x)) - lr = tf.math.minimum(self.max_lr, - tf.math.minimum(lr, warmup)) + lr = lr * (0.5 + tf.math.maximum(0.0, x)) + lr = tf.math.minimum(self.max_lr, tf.math.minimum(lr, warmup)) return lr def get_config(self): @@ -144,5 +138,5 @@ def get_config(self): "d_model": self.d_model, "warmup_steps": self.warmup_steps, "max_lr": self.max_lr, - "step_size": self.step_size + "step_size": self.step_size, } diff --git a/tensorflow_asr/utils/app_util.py b/tensorflow_asr/utils/app_util.py index 6dda52707c..089b3ea79c 100644 --- a/tensorflow_asr/utils/app_util.py +++ b/tensorflow_asr/utils/app_util.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tqdm import tqdm import tensorflow as tf +from tqdm import tqdm -from .metric_util import wer, cer from ..metrics.error_rates import ErrorRate from .file_util import read_file +from .metric_util import cer, wer logger = tf.get_logger() -def evaluate_results(filepath: str): +def evaluate_results( + filepath: str, +): logger.info(f"Evaluating result from {filepath} ...") metrics = { "greedy_wer": ErrorRate(wer, name="greedy_wer", dtype=tf.float32), "greedy_cer": ErrorRate(cer, name="greedy_cer", dtype=tf.float32), "beamsearch_wer": ErrorRate(wer, name="beamsearch_wer", dtype=tf.float32), - "beamsearch_cer": ErrorRate(cer, name="beamsearch_cer", dtype=tf.float32) + "beamsearch_cer": ErrorRate(cer, name="beamsearch_cer", dtype=tf.float32), } with read_file(filepath) as path: with open(path, "r", encoding="utf-8") as openfile: diff --git a/tensorflow_asr/utils/data_util.py b/tensorflow_asr/utils/data_util.py index 2bcdca8d4e..a146e36c19 100644 --- a/tensorflow_asr/utils/data_util.py +++ b/tensorflow_asr/utils/data_util.py @@ -17,10 +17,12 @@ import tensorflow as tf -def create_inputs(inputs: tf.Tensor, - inputs_length: tf.Tensor, - predictions: tf.Tensor = None, - predictions_length: tf.Tensor = None) -> dict: +def create_inputs( + inputs: tf.Tensor, + inputs_length: tf.Tensor, + predictions: tf.Tensor = None, + predictions_length: tf.Tensor = None, +) -> dict: data = { "inputs": inputs, "inputs_length": inputs_length, @@ -32,14 +34,17 @@ def create_inputs(inputs: tf.Tensor, return data -def create_logits(logits: tf.Tensor, logits_length: tf.Tensor) -> dict: - return { - "logits": logits, - "logits_length": logits_length - } +def create_logits( + logits: tf.Tensor, + logits_length: tf.Tensor, +) -> dict: + return {"logits": logits, "logits_length": logits_length} -def create_labels(labels: tf.Tensor, labels_length: tf.Tensor) -> dict: +def create_labels( + labels: tf.Tensor, + labels_length: tf.Tensor, +) -> dict: return { "labels": labels, "labels_length": labels_length, diff --git a/tensorflow_asr/utils/env_util.py b/tensorflow_asr/utils/env_util.py index 854bcccd9f..221bd9383e 100644 --- a/tensorflow_asr/utils/env_util.py +++ b/tensorflow_asr/utils/env_util.py @@ -13,15 +13,16 @@ # limitations under the License. import logging -from typing import Union, List import warnings +from typing import List, Union + import tensorflow as tf logger = tf.get_logger() def setup_environment(): - """ Setting tensorflow running environment """ + """Setting tensorflow running environment""" warnings.simplefilter("ignore") logger.setLevel(logging.INFO) return logger @@ -50,7 +51,9 @@ def setup_tpu(tpu_address=None): if tpu_address is None: resolver = tf.distribute.cluster_resolver.TPUClusterResolver() else: - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="grpc://" + tpu_address) + resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + tpu="grpc://" + tpu_address + ) tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) logger.info(f"All TPUs: {tf.config.list_logical_devices('TPU')}") diff --git a/tensorflow_asr/utils/feature_util.py b/tensorflow_asr/utils/feature_util.py index 0d8a294ce1..46dc4ff1ea 100644 --- a/tensorflow_asr/utils/feature_util.py +++ b/tensorflow_asr/utils/feature_util.py @@ -15,13 +15,19 @@ import tensorflow as tf -def float_feature(list_of_floats): +def float_feature( + list_of_floats, +): return tf.train.Feature(float_list=tf.train.FloatList(value=list_of_floats)) -def int64_feature(list_of_ints): +def int64_feature( + list_of_ints, +): return tf.train.Feature(int64_list=tf.train.Int64List(value=list_of_ints)) -def bytestring_feature(list_of_bytestrings): +def bytestring_feature( + list_of_bytestrings, +): return tf.train.Feature(bytes_list=tf.train.BytesList(value=list_of_bytestrings)) diff --git a/tensorflow_asr/utils/file_util.py b/tensorflow_asr/utils/file_util.py index c46363d1ac..5e2dcbf788 100644 --- a/tensorflow_asr/utils/file_util.py +++ b/tensorflow_asr/utils/file_util.py @@ -12,38 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os import re -import yaml import tempfile -import contextlib -from typing import Union, List +from typing import List, Union + import tensorflow as tf +import yaml -def load_yaml(path): +def load_yaml( + path: str, +) -> dict: # Fix yaml numbers https://stackoverflow.com/a/30462009/11037553 loader = yaml.SafeLoader loader.add_implicit_resolver( - u'tag:yaml.org,2002:float', - re.compile(u'''^(?: + u"tag:yaml.org,2002:float", + re.compile( + u"""^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.')) + |\\.(?:nan|NaN|NAN))$""", + re.X, + ), + list(u"-+0123456789."), + ) with open(path, "r", encoding="utf-8") as file: return yaml.load(file, Loader=loader) -def is_hdf5_filepath(filepath: str) -> bool: - return (filepath.endswith('.h5') or filepath.endswith('.keras') or filepath.endswith('.hdf5')) +def is_hdf5_filepath( + filepath: str, +) -> bool: + return ( + filepath.endswith(".h5") + or filepath.endswith(".keras") + or filepath.endswith(".hdf5") + ) -def is_cloud_path(path: str) -> bool: - """ Check if the path is on cloud (which requires tf.io.gfile) +def is_cloud_path( + path: str, +) -> bool: + """Check if the path is on cloud (which requires tf.io.gfile) Args: path (str): Path to directory or file @@ -54,8 +69,11 @@ def is_cloud_path(path: str) -> bool: return bool(re.match(r"^[a-z]+://", path)) -def preprocess_paths(paths: Union[List[str], str], isdir: bool = False) -> Union[List[str], str]: - """ Expand the path to the root "/" and makedirs +def preprocess_paths( + paths: Union[List[str], str], + isdir: bool = False, +) -> Union[List[str], str]: + """Expand the path to the root "/" and makedirs Args: paths (Union[List, str]): A path or list of paths @@ -64,21 +82,32 @@ def preprocess_paths(paths: Union[List[str], str], isdir: bool = False) -> Union Union[List, str]: A processed path or list of paths, return None if it's not path """ if isinstance(paths, list): - paths = [path if is_cloud_path(path) else os.path.abspath(os.path.expanduser(path)) for path in paths] + paths = [ + path if is_cloud_path(path) else os.path.abspath(os.path.expanduser(path)) + for path in paths + ] for path in paths: dirpath = path if isdir else os.path.dirname(path) - if not tf.io.gfile.exists(dirpath): tf.io.gfile.makedirs(dirpath) + if not tf.io.gfile.exists(dirpath): + tf.io.gfile.makedirs(dirpath) return paths if isinstance(paths, str): - paths = paths if is_cloud_path(paths) else os.path.abspath(os.path.expanduser(paths)) + paths = ( + paths + if is_cloud_path(paths) + else os.path.abspath(os.path.expanduser(paths)) + ) dirpath = paths if isdir else os.path.dirname(paths) - if not tf.io.gfile.exists(dirpath): tf.io.gfile.makedirs(dirpath) + if not tf.io.gfile.exists(dirpath): + tf.io.gfile.makedirs(dirpath) return paths return None @contextlib.contextmanager -def save_file(filepath: str): +def save_file( + filepath: str, +): if is_cloud_path(filepath) and is_hdf5_filepath(filepath): _, ext = os.path.splitext(filepath) with tempfile.NamedTemporaryFile(suffix=ext) as tmp: @@ -89,7 +118,9 @@ def save_file(filepath: str): @contextlib.contextmanager -def read_file(filepath: str): +def read_file( + filepath: str, +): if is_cloud_path(filepath) and is_hdf5_filepath(filepath): _, ext = os.path.splitext(filepath) with tempfile.NamedTemporaryFile(suffix=ext) as tmp: diff --git a/tensorflow_asr/utils/layer_util.py b/tensorflow_asr/utils/layer_util.py index 6e2647f581..bf61e7b250 100644 --- a/tensorflow_asr/utils/layer_util.py +++ b/tensorflow_asr/utils/layer_util.py @@ -1,4 +1,3 @@ - # Copyright 2020 Huy Le Nguyen (@usimarit) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,14 +15,21 @@ import tensorflow as tf -def get_rnn(rnn_type: str): +def get_rnn( + rnn_type: str, +): assert rnn_type in ["lstm", "gru", "rnn"] - if rnn_type == "lstm": return tf.keras.layers.LSTM - if rnn_type == "gru": return tf.keras.layers.GRU + if rnn_type == "lstm": + return tf.keras.layers.LSTM + if rnn_type == "gru": + return tf.keras.layers.GRU return tf.keras.layers.SimpleRNN -def get_conv(conv_type): +def get_conv( + conv_type: str, +): assert conv_type in ["conv1d", "conv2d"] - if conv_type == "conv1d": return tf.keras.layers.Conv1D + if conv_type == "conv1d": + return tf.keras.layers.Conv1D return tf.keras.layers.Conv2D diff --git a/tensorflow_asr/utils/math_util.py b/tensorflow_asr/utils/math_util.py index fef1ace805..3432a55c82 100644 --- a/tensorflow_asr/utils/math_util.py +++ b/tensorflow_asr/utils/math_util.py @@ -13,6 +13,7 @@ # limitations under the License. import math + import numpy as np import tensorflow as tf @@ -25,27 +26,56 @@ def log10(x): return numerator / denominator -def get_num_batches(nsamples, batch_size, drop_remainders=True): - if nsamples is None or batch_size is None: return None - if drop_remainders: return math.floor(float(nsamples) / float(batch_size)) +def get_num_batches( + nsamples, + batch_size, + drop_remainders=True, +): + if nsamples is None or batch_size is None: + return None + if drop_remainders: + return math.floor(float(nsamples) / float(batch_size)) return math.ceil(float(nsamples) / float(batch_size)) -def nan_to_zero(input_tensor): - return tf.where(tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor) +def nan_to_zero( + input_tensor: tf.Tensor, +): + return tf.where( + tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor + ) -def bytes_to_string(array: np.ndarray, encoding: str = "utf-8"): - if array is None: return None +def bytes_to_string( + array: np.ndarray, + encoding: str = "utf-8", +): + if array is None: + return None return [transcript.decode(encoding) for transcript in array] -def get_reduced_length(length, reduction_factor): - return tf.cast(tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), dtype=tf.int32) +def get_reduced_length( + length, + reduction_factor, +): + return tf.cast( + tf.math.ceil(tf.divide(length, tf.cast(reduction_factor, dtype=length.dtype))), + dtype=tf.int32, + ) -def count_non_blank(tensor: tf.Tensor, blank: int or tf.Tensor = 0, axis=None): - return tf.reduce_sum(tf.where(tf.not_equal(tensor, blank), x=tf.ones_like(tensor), y=tf.zeros_like(tensor)), axis=axis) +def count_non_blank( + tensor: tf.Tensor, + blank: int or tf.Tensor = 0, + axis=None, +): + return tf.reduce_sum( + tf.where( + tf.not_equal(tensor, blank), x=tf.ones_like(tensor), y=tf.zeros_like(tensor) + ), + axis=axis, + ) def merge_two_last_dims(x): @@ -53,13 +83,17 @@ def merge_two_last_dims(x): return tf.reshape(x, shape=[b, -1, f * c]) -def merge_repeated(yseqs, blank=0): +def merge_repeated( + yseqs, + blank=0, +): result = tf.reshape(yseqs[0], [1]) U = shape_util.shape_list(yseqs)[0] i = tf.constant(1, dtype=tf.int32) - def _cond(i, result, yseqs, U): return tf.less(i, U) + def _cond(i, result, yseqs, U): + return tf.less(i, U) def _body(i, result, yseqs, U): if yseqs[i] != result[-1]: @@ -74,20 +108,25 @@ def _body(i, result, yseqs, U): tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([None]), - tf.TensorShape([]) - ) + tf.TensorShape([]), + ), ) - return tf.pad(result, [[U - shape_util.shape_list(result)[0], 0]], constant_values=blank) + return tf.pad( + result, [[U - shape_util.shape_list(result)[0], 0]], constant_values=blank + ) -def find_max_length_prediction_tfarray(tfarray: tf.TensorArray) -> tf.Tensor: +def find_max_length_prediction_tfarray( + tfarray: tf.TensorArray, +) -> tf.Tensor: with tf.name_scope("find_max_length_prediction_tfarray"): index = tf.constant(0, dtype=tf.int32) total = tfarray.size() max_length = tf.constant(0, dtype=tf.int32) - def condition(index, _): return tf.less(index, total) + def condition(index, _): + return tf.less(index, total) def body(index, max_length): prediction = tfarray.read(index) @@ -95,26 +134,36 @@ def body(index, max_length): max_length = tf.where(tf.greater(length, max_length), length, max_length) return index + 1, max_length - index, max_length = tf.while_loop(condition, body, loop_vars=[index, max_length], swap_memory=False) + index, max_length = tf.while_loop( + condition, body, loop_vars=[index, max_length], swap_memory=False + ) return max_length -def pad_prediction_tfarray(tfarray: tf.TensorArray, blank: int or tf.Tensor) -> tf.TensorArray: +def pad_prediction_tfarray( + tfarray: tf.TensorArray, + blank: int or tf.Tensor, +) -> tf.TensorArray: with tf.name_scope("pad_prediction_tfarray"): index = tf.constant(0, dtype=tf.int32) total = tfarray.size() max_length = find_max_length_prediction_tfarray(tfarray) + 1 - def condition(index, _): return tf.less(index, total) + def condition(index, _): + return tf.less(index, total) def body(index, tfarray): prediction = tfarray.read(index) prediction = tf.pad( - prediction, paddings=[[0, max_length - tf.shape(prediction)[0]]], - mode="CONSTANT", constant_values=blank + prediction, + paddings=[[0, max_length - tf.shape(prediction)[0]]], + mode="CONSTANT", + constant_values=blank, ) tfarray = tfarray.write(index, prediction) return index + 1, tfarray - index, tfarray = tf.while_loop(condition, body, loop_vars=[index, tfarray], swap_memory=False) + index, tfarray = tf.while_loop( + condition, body, loop_vars=[index, tfarray], swap_memory=False + ) return tfarray diff --git a/tensorflow_asr/utils/metric_util.py b/tensorflow_asr/utils/metric_util.py index c26dcc451f..9b6f851f7e 100644 --- a/tensorflow_asr/utils/metric_util.py +++ b/tensorflow_asr/utils/metric_util.py @@ -13,13 +13,17 @@ # limitations under the License. from typing import Tuple -from nltk.metrics import distance + import tensorflow as tf +from nltk.metrics import distance from . import math_util -def execute_wer(decode, target): +def execute_wer( + decode, + target, +): decode = math_util.bytes_to_string(decode) target = math_util.bytes_to_string(target) dis = 0.0 @@ -31,12 +35,17 @@ def execute_wer(decode, target): new_decode = [chr(word2char[w]) for w in dec.split()] new_target = [chr(word2char[w]) for w in tar.split()] - dis += distance.edit_distance(''.join(new_decode), ''.join(new_target)) + dis += distance.edit_distance("".join(new_decode), "".join(new_target)) length += len(tar.split()) - return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32) + return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor( + length, tf.float32 + ) -def wer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: +def wer( + decode: tf.Tensor, + target: tf.Tensor, +) -> Tuple[tf.Tensor, tf.Tensor]: """Word Error Rate Args: @@ -46,7 +55,9 @@ def wer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: Returns: tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text """ - return tf.numpy_function(execute_wer, inp=[decode, target], Tout=[tf.float32, tf.float32]) + return tf.numpy_function( + execute_wer, inp=[decode, target], Tout=[tf.float32, tf.float32] + ) def execute_cer(decode, target): @@ -57,10 +68,15 @@ def execute_cer(decode, target): for dec, tar in zip(decode, target): dis += distance.edit_distance(dec, tar) length += len(tar) - return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32) + return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor( + length, tf.float32 + ) -def cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: +def cer( + decode: tf.Tensor, + target: tf.Tensor, +) -> Tuple[tf.Tensor, tf.Tensor]: """Character Error Rate Args: @@ -70,10 +86,15 @@ def cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: Returns: tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text """ - return tf.numpy_function(execute_cer, inp=[decode, target], Tout=[tf.float32, tf.float32]) + return tf.numpy_function( + execute_cer, inp=[decode, target], Tout=[tf.float32, tf.float32] + ) -def tf_cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: +def tf_cer( + decode: tf.Tensor, + target: tf.Tensor, +) -> Tuple[tf.Tensor, tf.Tensor]: """Tensorflwo Charactor Error rate Args: @@ -85,6 +106,8 @@ def tf_cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: """ decode = tf.strings.bytes_split(decode) # [B, N] target = tf.strings.bytes_split(target) # [B, M] - distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B] + distances = tf.edit_distance( + decode.to_sparse(), target.to_sparse(), normalize=False + ) # [B] lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B] return tf.reduce_sum(distances), tf.reduce_sum(lengths)