Skip to content

Commit

Permalink
Merge pull request #86 from TensorSpeech/dev/timestamp
Browse files Browse the repository at this point in the history
Support naive token level timestamp
  • Loading branch information
nglehuy authored Dec 27, 2020
2 parents 3eac6c0 + 1640892 commit 3004f0e
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 56 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as

## What's New?

- (12/27/2020) Supported _naive_ token level timestamp, see [demo](./examples/demonstration/conformer.py) with flag `--timestamp`
- (12/17/2020) Supported ContextNet [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191)
- (12/12/2020) Add support for using masking
- (11/14/2020) Supported Gradient Accumulation for Training in Larger Batch Size
Expand Down Expand Up @@ -219,4 +220,3 @@ For pretrained models, go to [drive](https://drive.google.com/drive/folders/1BD0
Huy Le Nguyen
Email: [email protected]
17 changes: 13 additions & 4 deletions examples/demonstration/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

parser.add_argument("--beam_width", type=int, default=0, help="Beam width")

parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp")

parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")

parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
Expand Down Expand Up @@ -66,9 +68,16 @@
features = speech_featurizer.tf_extract(signal)
input_length = get_reduced_length(tf.shape(features)[0], conformer.time_reduction_factor)

if (args.beam_width):
if args.beam_width:
transcript = conformer.recognize_beam(features[None, ...], input_length[None, ...])
print("Transcript:", transcript[0].numpy().decode("UTF-8"))
elif args.timestamp:
transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
print("Transcript:", transcript)
print("Start time:", stime)
print("End time:", etime)
else:
transcript = conformer.recognize(features[None, ...], input_length[None, ...])

tf.print("Transcript:", transcript[0])
transcript, _, _ = conformer.recognize_tflite(
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
print("Transcript:", tf.strings.unicode_encode(transcript, "UTF-8").numpy().decode("UTF-8"))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.5.5",
version="0.6.0",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_asr/featurizers/speech_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def shape(self) -> list:
def stft(self, signal):
return tf.square(
tf.abs(tf.signal.stft(signal, frame_length=self.frame_length,
frame_step=self.frame_step, fft_length=self.nfft)))
frame_step=self.frame_step, fft_length=self.nfft, pad_end=True)))

def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
if amin <= 0:
Expand Down
38 changes: 20 additions & 18 deletions tensorflow_asr/featurizers/text_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ def __init_vocabulary(self):
self.tokens.insert(self.blank, "") # add blank token to tokens
self.num_classes = len(self.tokens)
self.tokens = tf.convert_to_tensor(self.tokens, dtype=tf.string)
self.upoints = tf.squeeze(
tf.strings.unicode_decode(
self.tokens, "UTF-8").to_tensor(shape=[None, 1])
)
self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8").to_tensor(shape=[None, 1])

def extract(self, text: str) -> tf.Tensor:
"""
Expand Down Expand Up @@ -170,7 +167,7 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
with tf.name_scope("indices2upoints"):
indices = self.normalize_indices(indices)
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
return upoints
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))


class SubwordFeaturizer(TextFeaturizer):
Expand Down Expand Up @@ -265,18 +262,25 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor:
Returns:
transcripts: tf.Tensor of dtype tf.string with dim [B]
"""
indices = self.normalize_indices(indices)
with tf.device("/CPU:0"): # string data is not supported on GPU
def decode(x):
if x[0] == self.blank: x = x[1:]
return self.subwords.decode(x)

text = tf.map_fn(
lambda x: tf.numpy_function(decode, inp=[x], Tout=tf.string),
indices,
fn_output_signature=tf.TensorSpec([], dtype=tf.string)
total = tf.shape(indices)[0]
batch = tf.constant(0, dtype=tf.int32)
transcripts = tf.TensorArray(
dtype=tf.string, size=total, dynamic_size=False, infer_shape=False,
clear_after_read=False, element_shape=tf.TensorShape([])
)
return text

def cond(batch, total, transcripts): return tf.less(batch, total)

def body(batch, total, transcripts):
upoints = self.indices2upoints(indices[batch])
_transcript = tf.strings.unicode_encode(upoints, "UTF-8")
transcripts = transcripts.write(batch, _transcript)
return batch + 1, total, transcripts

_, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts])

return transcripts.stack()

@tf.function(
input_signature=[
Expand All @@ -295,6 +299,4 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
with tf.name_scope("indices2upoints"):
indices = self.normalize_indices(indices)
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
# upoints now has shape [None, max_subword_length]
shape = tf.shape(upoints)
return tf.reshape(upoints, [shape[0] * shape[1]]) # flatten
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))
77 changes: 47 additions & 30 deletions tensorflow_asr/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,9 @@
from ..featurizers.text_featurizers import TextFeaturizer
from .layers.embedding import Embedding

Hypothesis = collections.namedtuple(
"Hypothesis",
("index", "prediction", "states")
)
Hypothesis = collections.namedtuple("Hypothesis", ("index", "prediction", "states"))

BeamHypothesis = collections.namedtuple(
"BeamHypothesis",
("score", "indices", "prediction", "states")
)
BeamHypothesis = collections.namedtuple("BeamHypothesis", ("score", "indices", "prediction", "states"))


class TransducerPrediction(tf.keras.Model):
Expand Down Expand Up @@ -233,6 +227,7 @@ def __init__(self,
bias_regularizer=bias_regularizer,
name=f"{name}_joint"
)
self.time_reduction_factor = 1

def _build(self, input_shape):
inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32)
Expand Down Expand Up @@ -369,6 +364,29 @@ def recognize_tflite(self, signal, predicted, states):
hypothesis.states
)

def recognize_tflite_with_timestamp(self, signal, predicted, states):
features = self.speech_featurizer.tf_extract(signal)
encoded = self.encoder_inference(features)
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
indices = self.text_featurizer.normalize_indices(hypothesis.prediction)
upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length]

num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32)
total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step

stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

non_blank = tf.where(tf.not_equal(upoints, 0))
non_blank_transcript = tf.gather_nd(upoints, non_blank)
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)

return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.prediction, hypothesis.states

def _perform_greedy_batch(self,
encoded: tf.Tensor,
encoded_length: tf.Tensor,
Expand Down Expand Up @@ -400,7 +418,7 @@ def body(batch, total, encoded, encoded_length, decoded):

batch, total, _, _, decoded = tf.while_loop(
condition, body,
loop_vars=(batch, total, encoded, encoded_length, decoded),
loop_vars=[batch, total, encoded, encoded_length, decoded],
parallel_iterations=parallel_iterations,
swap_memory=True,
)
Expand All @@ -419,45 +437,43 @@ def _perform_greedy(self,
total = encoded_length

hypothesis = Hypothesis(
index=tf.constant(0, dtype=tf.int32),
prediction=tf.ones([total + 1], dtype=tf.int32) * self.text_featurizer.blank,
index=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
prediction=tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank,
states=states
)

def condition(time, total, encoded, hypothesis): return tf.less(time, total)

def body(time, total, encoded, hypothesis):
predicted = tf.gather_nd(hypothesis.prediction, tf.expand_dims(hypothesis.index, axis=-1))

ytu, new_states = self.decoder_inference(
ytu, states = self.decoder_inference(
# avoid using [index] in tflite
encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)),
predicted=predicted,
predicted=hypothesis.index,
states=hypothesis.states
)
new_predicted = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []

index, new_predicted, new_states = tf.cond(
tf.equal(new_predicted, self.text_featurizer.blank),
true_fn=lambda: (hypothesis.index, predicted, hypothesis.states),
false_fn=lambda: (hypothesis.index + 1, new_predicted, new_states)
index, predict, states = tf.cond(
tf.equal(predict, self.text_featurizer.blank),
true_fn=lambda: (hypothesis.index, predict, hypothesis.states),
false_fn=lambda: (predict, predict, states) # update if the new prediction is a non-blank
)

hypothesis = Hypothesis(
index=index,
prediction=tf.tensor_scatter_nd_update(
hypothesis.prediction,
indices=tf.reshape(index, [1, 1]),
updates=tf.expand_dims(new_predicted, axis=-1)
indices=tf.reshape(time, [1, 1]),
updates=tf.expand_dims(predict, axis=-1)
),
states=new_states
states=states
)

return time + 1, total, encoded, hypothesis

time, total, encoded, hypothesis = tf.while_loop(
condition, body,
loop_vars=(time, total, encoded, hypothesis),
loop_vars=[time, total, encoded, hypothesis],
parallel_iterations=parallel_iterations,
swap_memory=swap_memory
)
Expand Down Expand Up @@ -512,7 +528,7 @@ def body(batch, total, encoded, encoded_length, decoded):

batch, total, _, _, decoded = tf.while_loop(
condition, body,
loop_vars=(batch, total, encoded, encoded_length, decoded),
loop_vars=[batch, total, encoded, encoded_length, decoded],
parallel_iterations=parallel_iterations,
swap_memory=True,
)
Expand Down Expand Up @@ -626,23 +642,23 @@ def predict_body(pred, A, A_i, B):

_, A, A_i, B = tf.while_loop(
predict_condition, predict_body,
loop_vars=(0, A, A_i, B),
loop_vars=[0, A, A_i, B],
parallel_iterations=parallel_iterations, swap_memory=swap_memory
)

return beam + 1, beam_width, A, A_i, B

_, _, A, A_i, B = tf.while_loop(
beam_condition, beam_body,
loop_vars=(0, beam_width, A, A_i, B),
loop_vars=[0, beam_width, A, A_i, B],
parallel_iterations=parallel_iterations, swap_memory=swap_memory
)

return time + 1, total, B

_, _, B = tf.while_loop(
condition, body,
loop_vars=(0, total, B),
loop_vars=[0, total, B],
parallel_iterations=parallel_iterations, swap_memory=swap_memory
)

Expand All @@ -665,9 +681,10 @@ def predict_body(pred, A, A_i, B):

# -------------------------------- TFLITE -------------------------------------

def make_tflite_function(self, greedy: bool = True):
def make_tflite_function(self, timestamp: bool = False):
tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite
return tf.function(
self.recognize_tflite,
tflite_func,
input_signature=[
tf.TensorSpec([None], dtype=tf.float32),
tf.TensorSpec([], dtype=tf.int32),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_asr/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _body(i, result, yseqs, U):
_, result, _, _ = tf.while_loop(
_cond,
_body,
loop_vars=(i, result, yseqs, U),
loop_vars=[i, result, yseqs, U],
shape_invariants=(
tf.TensorShape([]),
tf.TensorShape([None]),
Expand Down

0 comments on commit 3004f0e

Please sign in to comment.