Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support naive token level timestamp #86

Merged
merged 1 commit into from
Dec 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ TensorFlowASR implements some automatic speech recognition architectures such as

## What's New?

- (12/27/2020) Supported _naive_ token level timestamp, see [demo](./examples/demonstration/conformer.py) with flag `--timestamp`
- (12/17/2020) Supported ContextNet [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191)
- (12/12/2020) Add support for using masking
- (11/14/2020) Supported Gradient Accumulation for Training in Larger Batch Size
Expand Down Expand Up @@ -219,4 +220,3 @@ For pretrained models, go to [drive](https://drive.google.com/drive/folders/1BD0
Huy Le Nguyen

Email: [email protected]

17 changes: 13 additions & 4 deletions examples/demonstration/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

parser.add_argument("--beam_width", type=int, default=0, help="Beam width")

parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp")

parser.add_argument("--device", type=int, default=0, help="Device's id to run test on")

parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu")
Expand Down Expand Up @@ -66,9 +68,16 @@
features = speech_featurizer.tf_extract(signal)
input_length = get_reduced_length(tf.shape(features)[0], conformer.time_reduction_factor)

if (args.beam_width):
if args.beam_width:
transcript = conformer.recognize_beam(features[None, ...], input_length[None, ...])
print("Transcript:", transcript[0].numpy().decode("UTF-8"))
elif args.timestamp:
transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp(
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
print("Transcript:", transcript)
print("Start time:", stime)
print("End time:", etime)
else:
transcript = conformer.recognize(features[None, ...], input_length[None, ...])

tf.print("Transcript:", transcript[0])
transcript, _, _ = conformer.recognize_tflite(
signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state())
print("Transcript:", tf.strings.unicode_encode(transcript, "UTF-8").numpy().decode("UTF-8"))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.5.5",
version="0.6.0",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_asr/featurizers/speech_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def shape(self) -> list:
def stft(self, signal):
return tf.square(
tf.abs(tf.signal.stft(signal, frame_length=self.frame_length,
frame_step=self.frame_step, fft_length=self.nfft)))
frame_step=self.frame_step, fft_length=self.nfft, pad_end=True)))

def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
if amin <= 0:
Expand Down
38 changes: 20 additions & 18 deletions tensorflow_asr/featurizers/text_featurizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ def __init_vocabulary(self):
self.tokens.insert(self.blank, "") # add blank token to tokens
self.num_classes = len(self.tokens)
self.tokens = tf.convert_to_tensor(self.tokens, dtype=tf.string)
self.upoints = tf.squeeze(
tf.strings.unicode_decode(
self.tokens, "UTF-8").to_tensor(shape=[None, 1])
)
self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8").to_tensor(shape=[None, 1])

def extract(self, text: str) -> tf.Tensor:
"""
Expand Down Expand Up @@ -170,7 +167,7 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
with tf.name_scope("indices2upoints"):
indices = self.normalize_indices(indices)
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
return upoints
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))


class SubwordFeaturizer(TextFeaturizer):
Expand Down Expand Up @@ -265,18 +262,25 @@ def iextract(self, indices: tf.Tensor) -> tf.Tensor:
Returns:
transcripts: tf.Tensor of dtype tf.string with dim [B]
"""
indices = self.normalize_indices(indices)
with tf.device("/CPU:0"): # string data is not supported on GPU
def decode(x):
if x[0] == self.blank: x = x[1:]
return self.subwords.decode(x)

text = tf.map_fn(
lambda x: tf.numpy_function(decode, inp=[x], Tout=tf.string),
indices,
fn_output_signature=tf.TensorSpec([], dtype=tf.string)
total = tf.shape(indices)[0]
batch = tf.constant(0, dtype=tf.int32)
transcripts = tf.TensorArray(
dtype=tf.string, size=total, dynamic_size=False, infer_shape=False,
clear_after_read=False, element_shape=tf.TensorShape([])
)
return text

def cond(batch, total, transcripts): return tf.less(batch, total)

def body(batch, total, transcripts):
upoints = self.indices2upoints(indices[batch])
_transcript = tf.strings.unicode_encode(upoints, "UTF-8")
transcripts = transcripts.write(batch, _transcript)
return batch + 1, total, transcripts

_, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts])

return transcripts.stack()

@tf.function(
input_signature=[
Expand All @@ -295,6 +299,4 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
with tf.name_scope("indices2upoints"):
indices = self.normalize_indices(indices)
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))
# upoints now has shape [None, max_subword_length]
shape = tf.shape(upoints)
return tf.reshape(upoints, [shape[0] * shape[1]]) # flatten
return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0)))
77 changes: 47 additions & 30 deletions tensorflow_asr/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,9 @@
from ..featurizers.text_featurizers import TextFeaturizer
from .layers.embedding import Embedding

Hypothesis = collections.namedtuple(
"Hypothesis",
("index", "prediction", "states")
)
Hypothesis = collections.namedtuple("Hypothesis", ("index", "prediction", "states"))

BeamHypothesis = collections.namedtuple(
"BeamHypothesis",
("score", "indices", "prediction", "states")
)
BeamHypothesis = collections.namedtuple("BeamHypothesis", ("score", "indices", "prediction", "states"))


class TransducerPrediction(tf.keras.Model):
Expand Down Expand Up @@ -233,6 +227,7 @@ def __init__(self,
bias_regularizer=bias_regularizer,
name=f"{name}_joint"
)
self.time_reduction_factor = 1

def _build(self, input_shape):
inputs = tf.keras.Input(shape=input_shape, dtype=tf.float32)
Expand Down Expand Up @@ -369,6 +364,29 @@ def recognize_tflite(self, signal, predicted, states):
hypothesis.states
)

def recognize_tflite_with_timestamp(self, signal, predicted, states):
features = self.speech_featurizer.tf_extract(signal)
encoded = self.encoder_inference(features)
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
indices = self.text_featurizer.normalize_indices(hypothesis.prediction)
upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length]

num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32)
total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step

stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

non_blank = tf.where(tf.not_equal(upoints, 0))
non_blank_transcript = tf.gather_nd(upoints, non_blank)
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)

return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.prediction, hypothesis.states

def _perform_greedy_batch(self,
encoded: tf.Tensor,
encoded_length: tf.Tensor,
Expand Down Expand Up @@ -400,7 +418,7 @@ def body(batch, total, encoded, encoded_length, decoded):

batch, total, _, _, decoded = tf.while_loop(
condition, body,
loop_vars=(batch, total, encoded, encoded_length, decoded),
loop_vars=[batch, total, encoded, encoded_length, decoded],
parallel_iterations=parallel_iterations,
swap_memory=True,
)
Expand All @@ -419,45 +437,43 @@ def _perform_greedy(self,
total = encoded_length

hypothesis = Hypothesis(
index=tf.constant(0, dtype=tf.int32),
prediction=tf.ones([total + 1], dtype=tf.int32) * self.text_featurizer.blank,
index=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
prediction=tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank,
states=states
)

def condition(time, total, encoded, hypothesis): return tf.less(time, total)

def body(time, total, encoded, hypothesis):
predicted = tf.gather_nd(hypothesis.prediction, tf.expand_dims(hypothesis.index, axis=-1))

ytu, new_states = self.decoder_inference(
ytu, states = self.decoder_inference(
# avoid using [index] in tflite
encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)),
predicted=predicted,
predicted=hypothesis.index,
states=hypothesis.states
)
new_predicted = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []

index, new_predicted, new_states = tf.cond(
tf.equal(new_predicted, self.text_featurizer.blank),
true_fn=lambda: (hypothesis.index, predicted, hypothesis.states),
false_fn=lambda: (hypothesis.index + 1, new_predicted, new_states)
index, predict, states = tf.cond(
tf.equal(predict, self.text_featurizer.blank),
true_fn=lambda: (hypothesis.index, predict, hypothesis.states),
false_fn=lambda: (predict, predict, states) # update if the new prediction is a non-blank
)

hypothesis = Hypothesis(
index=index,
prediction=tf.tensor_scatter_nd_update(
hypothesis.prediction,
indices=tf.reshape(index, [1, 1]),
updates=tf.expand_dims(new_predicted, axis=-1)
indices=tf.reshape(time, [1, 1]),
updates=tf.expand_dims(predict, axis=-1)
),
states=new_states
states=states
)

return time + 1, total, encoded, hypothesis

time, total, encoded, hypothesis = tf.while_loop(
condition, body,
loop_vars=(time, total, encoded, hypothesis),
loop_vars=[time, total, encoded, hypothesis],
parallel_iterations=parallel_iterations,
swap_memory=swap_memory
)
Expand Down Expand Up @@ -512,7 +528,7 @@ def body(batch, total, encoded, encoded_length, decoded):

batch, total, _, _, decoded = tf.while_loop(
condition, body,
loop_vars=(batch, total, encoded, encoded_length, decoded),
loop_vars=[batch, total, encoded, encoded_length, decoded],
parallel_iterations=parallel_iterations,
swap_memory=True,
)
Expand Down Expand Up @@ -626,23 +642,23 @@ def predict_body(pred, A, A_i, B):

_, A, A_i, B = tf.while_loop(
predict_condition, predict_body,
loop_vars=(0, A, A_i, B),
loop_vars=[0, A, A_i, B],
parallel_iterations=parallel_iterations, swap_memory=swap_memory
)

return beam + 1, beam_width, A, A_i, B

_, _, A, A_i, B = tf.while_loop(
beam_condition, beam_body,
loop_vars=(0, beam_width, A, A_i, B),
loop_vars=[0, beam_width, A, A_i, B],
parallel_iterations=parallel_iterations, swap_memory=swap_memory
)

return time + 1, total, B

_, _, B = tf.while_loop(
condition, body,
loop_vars=(0, total, B),
loop_vars=[0, total, B],
parallel_iterations=parallel_iterations, swap_memory=swap_memory
)

Expand All @@ -665,9 +681,10 @@ def predict_body(pred, A, A_i, B):

# -------------------------------- TFLITE -------------------------------------

def make_tflite_function(self, greedy: bool = True):
def make_tflite_function(self, timestamp: bool = False):
tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite
return tf.function(
self.recognize_tflite,
tflite_func,
input_signature=[
tf.TensorSpec([None], dtype=tf.float32),
tf.TensorSpec([], dtype=tf.int32),
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_asr/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _body(i, result, yseqs, U):
_, result, _, _ = tf.while_loop(
_cond,
_body,
loop_vars=(i, result, yseqs, U),
loop_vars=[i, result, yseqs, U],
shape_invariants=(
tf.TensorShape([]),
tf.TensorShape([None]),
Expand Down