diff --git a/README.md b/README.md index 2cbc4d5440..3c6b70e902 100755 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as ## What's New? +- (12/27/2020) Supported _naive_ token level timestamp, see [demo](./examples/demonstration/conformer.py) with flag `--timestamp` - (12/17/2020) Supported ContextNet [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191) - (12/12/2020) Add support for using masking - (11/14/2020) Supported Gradient Accumulation for Training in Larger Batch Size @@ -219,4 +220,3 @@ For pretrained models, go to [drive](https://drive.google.com/drive/folders/1BD0 Huy Le Nguyen Email: nlhuy.cs.16@gmail.com - diff --git a/examples/demonstration/conformer.py b/examples/demonstration/conformer.py index 1d318146ab..668a3bc5e5 100644 --- a/examples/demonstration/conformer.py +++ b/examples/demonstration/conformer.py @@ -30,6 +30,8 @@ parser.add_argument("--beam_width", type=int, default=0, help="Beam width") +parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp") + parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") @@ -66,9 +68,16 @@ features = speech_featurizer.tf_extract(signal) input_length = get_reduced_length(tf.shape(features)[0], conformer.time_reduction_factor) -if (args.beam_width): +if args.beam_width: transcript = conformer.recognize_beam(features[None, ...], input_length[None, ...]) + print("Transcript:", transcript[0].numpy().decode("UTF-8")) +elif args.timestamp: + transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp( + signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()) + print("Transcript:", transcript) + print("Start time:", stime) + print("End time:", etime) else: - transcript = conformer.recognize(features[None, ...], input_length[None, ...]) - -tf.print("Transcript:", transcript[0]) + transcript, _, _ = conformer.recognize_tflite( + signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state()) + print("Transcript:", tf.strings.unicode_encode(transcript, "UTF-8").numpy().decode("UTF-8")) diff --git a/setup.py b/setup.py index 840e8fc19e..3bfb61767d 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ setuptools.setup( name="TensorFlowASR", - version="0.5.5", + version="0.6.0", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py index 3594157ccb..f8ffaf83f7 100755 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ b/tensorflow_asr/featurizers/speech_featurizers.py @@ -377,7 +377,7 @@ def shape(self) -> list: def stft(self, signal): return tf.square( tf.abs(tf.signal.stft(signal, frame_length=self.frame_length, - frame_step=self.frame_step, fft_length=self.nfft))) + frame_step=self.frame_step, fft_length=self.nfft, pad_end=True))) def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0): if amin <= 0: diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py index 5372b99e34..04e1c2c5a2 100755 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ b/tensorflow_asr/featurizers/text_featurizers.py @@ -119,10 +119,7 @@ def __init_vocabulary(self): self.tokens.insert(self.blank, "") # add blank token to tokens self.num_classes = len(self.tokens) self.tokens = tf.convert_to_tensor(self.tokens, dtype=tf.string) - self.upoints = tf.squeeze( - tf.strings.unicode_decode( - self.tokens, "UTF-8").to_tensor(shape=[None, 1]) - ) + self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8").to_tensor(shape=[None, 1]) def extract(self, text: str) -> tf.Tensor: """ @@ -170,7 +167,7 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: with tf.name_scope("indices2upoints"): indices = self.normalize_indices(indices) upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1)) - return upoints + return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0))) class SubwordFeaturizer(TextFeaturizer): @@ -265,18 +262,25 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor: Returns: transcripts: tf.Tensor of dtype tf.string with dim [B] """ - indices = self.normalize_indices(indices) with tf.device("/CPU:0"): # string data is not supported on GPU - def decode(x): - if x[0] == self.blank: x = x[1:] - return self.subwords.decode(x) - - text = tf.map_fn( - lambda x: tf.numpy_function(decode, inp=[x], Tout=tf.string), - indices, - fn_output_signature=tf.TensorSpec([], dtype=tf.string) + total = tf.shape(indices)[0] + batch = tf.constant(0, dtype=tf.int32) + transcripts = tf.TensorArray( + dtype=tf.string, size=total, dynamic_size=False, infer_shape=False, + clear_after_read=False, element_shape=tf.TensorShape([]) ) - return text + + def cond(batch, total, transcripts): return tf.less(batch, total) + + def body(batch, total, transcripts): + upoints = self.indices2upoints(indices[batch]) + _transcript = tf.strings.unicode_encode(upoints, "UTF-8") + transcripts = transcripts.write(batch, _transcript) + return batch + 1, total, transcripts + + _, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts]) + + return transcripts.stack() @tf.function( input_signature=[ @@ -295,6 +299,4 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: with tf.name_scope("indices2upoints"): indices = self.normalize_indices(indices) upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1)) - # upoints now has shape [None, max_subword_length] - shape = tf.shape(upoints) - return tf.reshape(upoints, [shape[0] * shape[1]]) # flatten + return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0))) diff --git a/tensorflow_asr/models/transducer.py b/tensorflow_asr/models/transducer.py index b0f9928e8e..ec40de3469 100755 --- a/tensorflow_asr/models/transducer.py +++ b/tensorflow_asr/models/transducer.py @@ -23,15 +23,9 @@ from ..featurizers.text_featurizers import TextFeaturizer from .layers.embedding import Embedding -Hypothesis = collections.namedtuple( - "Hypothesis", - ("index", "prediction", "states") -) +Hypothesis = collections.namedtuple("Hypothesis", ("index", "prediction", "states")) -BeamHypothesis = collections.namedtuple( - "BeamHypothesis", - ("score", "indices", "prediction", "states") -) +BeamHypothesis = collections.namedtuple("BeamHypothesis", ("score", "indices", "prediction", "states")) class TransducerPrediction(tf.keras.Model): @@ -233,6 +227,7 @@ def __init__(self, bias_regularizer=bias_regularizer, name=f"{name}_joint" ) + self.time_reduction_factor = 1 def _build(self, input_shape): inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32) @@ -369,6 +364,29 @@ def recognize_tflite(self, signal, predicted, states): hypothesis.states ) + def recognize_tflite_with_timestamp(self, signal, predicted, states): + features = self.speech_featurizer.tf_extract(signal) + encoded = self.encoder_inference(features) + hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states) + indices = self.text_featurizer.normalize_indices(hypothesis.prediction) + upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] + + num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) + total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step + + stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + non_blank = tf.where(tf.not_equal(upoints, 0)) + non_blank_transcript = tf.gather_nd(upoints, non_blank) + non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + + return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.prediction, hypothesis.states + def _perform_greedy_batch(self, encoded: tf.Tensor, encoded_length: tf.Tensor, @@ -400,7 +418,7 @@ def body(batch, total, encoded, encoded_length, decoded): batch, total, _, _, decoded = tf.while_loop( condition, body, - loop_vars=(batch, total, encoded, encoded_length, decoded), + loop_vars=[batch, total, encoded, encoded_length, decoded], parallel_iterations=parallel_iterations, swap_memory=True, ) @@ -419,45 +437,43 @@ def _perform_greedy(self, total = encoded_length hypothesis = Hypothesis( - index=tf.constant(0, dtype=tf.int32), - prediction=tf.ones([total + 1], dtype=tf.int32) * self.text_featurizer.blank, + index=tf.constant(self.text_featurizer.blank, dtype=tf.int32), + prediction=tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank, states=states ) def condition(time, total, encoded, hypothesis): return tf.less(time, total) def body(time, total, encoded, hypothesis): - predicted = tf.gather_nd(hypothesis.prediction, tf.expand_dims(hypothesis.index, axis=-1)) - - ytu, new_states = self.decoder_inference( + ytu, states = self.decoder_inference( # avoid using [index] in tflite encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)), - predicted=predicted, + predicted=hypothesis.index, states=hypothesis.states ) - new_predicted = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] + predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] - index, new_predicted, new_states = tf.cond( - tf.equal(new_predicted, self.text_featurizer.blank), - true_fn=lambda: (hypothesis.index, predicted, hypothesis.states), - false_fn=lambda: (hypothesis.index + 1, new_predicted, new_states) + index, predict, states = tf.cond( + tf.equal(predict, self.text_featurizer.blank), + true_fn=lambda: (hypothesis.index, predict, hypothesis.states), + false_fn=lambda: (predict, predict, states) # update if the new prediction is a non-blank ) hypothesis = Hypothesis( index=index, prediction=tf.tensor_scatter_nd_update( hypothesis.prediction, - indices=tf.reshape(index, [1, 1]), - updates=tf.expand_dims(new_predicted, axis=-1) + indices=tf.reshape(time, [1, 1]), + updates=tf.expand_dims(predict, axis=-1) ), - states=new_states + states=states ) return time + 1, total, encoded, hypothesis time, total, encoded, hypothesis = tf.while_loop( condition, body, - loop_vars=(time, total, encoded, hypothesis), + loop_vars=[time, total, encoded, hypothesis], parallel_iterations=parallel_iterations, swap_memory=swap_memory ) @@ -512,7 +528,7 @@ def body(batch, total, encoded, encoded_length, decoded): batch, total, _, _, decoded = tf.while_loop( condition, body, - loop_vars=(batch, total, encoded, encoded_length, decoded), + loop_vars=[batch, total, encoded, encoded_length, decoded], parallel_iterations=parallel_iterations, swap_memory=True, ) @@ -626,7 +642,7 @@ def predict_body(pred, A, A_i, B): _, A, A_i, B = tf.while_loop( predict_condition, predict_body, - loop_vars=(0, A, A_i, B), + loop_vars=[0, A, A_i, B], parallel_iterations=parallel_iterations, swap_memory=swap_memory ) @@ -634,7 +650,7 @@ def predict_body(pred, A, A_i, B): _, _, A, A_i, B = tf.while_loop( beam_condition, beam_body, - loop_vars=(0, beam_width, A, A_i, B), + loop_vars=[0, beam_width, A, A_i, B], parallel_iterations=parallel_iterations, swap_memory=swap_memory ) @@ -642,7 +658,7 @@ def predict_body(pred, A, A_i, B): _, _, B = tf.while_loop( condition, body, - loop_vars=(0, total, B), + loop_vars=[0, total, B], parallel_iterations=parallel_iterations, swap_memory=swap_memory ) @@ -665,9 +681,10 @@ def predict_body(pred, A, A_i, B): # -------------------------------- TFLITE ------------------------------------- - def make_tflite_function(self, greedy: bool = True): + def make_tflite_function(self, timestamp: bool = False): + tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite return tf.function( - self.recognize_tflite, + tflite_func, input_signature=[ tf.TensorSpec([None], dtype=tf.float32), tf.TensorSpec([], dtype=tf.int32), diff --git a/tensorflow_asr/utils/utils.py b/tensorflow_asr/utils/utils.py index 869d87549c..e3e35da211 100755 --- a/tensorflow_asr/utils/utils.py +++ b/tensorflow_asr/utils/utils.py @@ -133,7 +133,7 @@ def _body(i, result, yseqs, U): _, result, _, _ = tf.while_loop( _cond, _body, - loop_vars=(i, result, yseqs, U), + loop_vars=[i, result, yseqs, U], shape_invariants=( tf.TensorShape([]), tf.TensorShape([None]),