Skip to content

Commit

Permalink
Merge pull request #115 from TensorSpeech/fix/tflite
Browse files Browse the repository at this point in the history
Fix TFLite Conversion and Interpretation
  • Loading branch information
nglehuy authored Jan 15, 2021
2 parents 304330a + 6113267 commit 86e8c43
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 112 deletions.
5 changes: 2 additions & 3 deletions examples/conformer/tflite_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@
conformer.summary(line_length=150)
conformer.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = conformer.make_tflite_function(greedy=True).get_concrete_function()
concrete_func = conformer.make_tflite_function().get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
Expand Down
2 changes: 1 addition & 1 deletion examples/conformer/tflite_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
conformer.summary(line_length=150)
conformer.add_featurizers(speech_featurizer, text_featurizer)

concrete_func = conformer.make_tflite_function(greedy=True).get_concrete_function()
concrete_func = conformer.make_tflite_function().get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.6.3",
version="0.6.4",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
93 changes: 87 additions & 6 deletions tensorflow_asr/models/contextnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
""" Ref: https://github.com/iankur/ContextNet """

from typing import List, Optional
from typing import List
import tensorflow as tf
from .transducer import Transducer
from ..utils.utils import merge_two_last_dims, get_reduced_length
Expand Down Expand Up @@ -245,13 +245,94 @@ def call(self, inputs, training=False, **kwargs):
outputs = self.joint_net([enc, pred], training=training, **kwargs)
return outputs

def encoder_inference(self,
features: tf.Tensor,
input_length: Optional[tf.Tensor] = None,
with_batch: bool = False):
def encoder_inference(self, features: tf.Tensor, input_length: tf.Tensor):
with tf.name_scope(f"{self.name}_encoder"):
if with_batch: return self.encoder([features, input_length], training=False)
input_length = tf.expand_dims(tf.shape(features)[0], axis=0)
outputs = tf.expand_dims(features, axis=0)
outputs = self.encoder([outputs, input_length], training=False)
return tf.squeeze(outputs, axis=0)

# -------------------------------- GREEDY -------------------------------------

@tf.function
def recognize(self,
features: tf.Tensor,
input_length: tf.Tensor,
parallel_iterations: int = 10,
swap_memory: bool = True):
"""
RNN Transducer Greedy decoding
Args:
features (tf.Tensor): a batch of padded extracted features
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded = self.encoder([features, input_length], training=False)
return self._perform_greedy_batch(encoded, input_length,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

def recognize_tflite(self, signal, predicted, prediction_states):
"""
Function to convert to tflite using greedy decoding (default streaming mode)
Args:
signal: tf.Tensor with shape [None] indicating a single audio signal
predicted: last predicted character with shape []
prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P]
Return:
transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32
predicted: last predicted character with shape []
encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P]
prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P]
"""
features = self.speech_featurizer.tf_extract(signal)
encoded = self.encoder_inference(features, tf.shape(features)[0])
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states)
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
return transcript, hypothesis.index, hypothesis.states

def recognize_tflite_with_timestamp(self, signal, predicted, states):
features = self.speech_featurizer.tf_extract(signal)
encoded = self.encoder_inference(features, tf.shape(features)[0])
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
indices = self.text_featurizer.normalize_indices(hypothesis.prediction)
upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length]

num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32)
total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step

stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32)
etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32)

non_blank = tf.where(tf.not_equal(upoints, 0))
non_blank_transcript = tf.gather_nd(upoints, non_blank)
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)

return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states

# -------------------------------- BEAM SEARCH -------------------------------------

@tf.function
def recognize_beam(self,
features: tf.Tensor,
input_length: tf.Tensor,
lm: bool = False,
parallel_iterations: int = 10,
swap_memory: bool = True):
"""
RNN Transducer Beam Search
Args:
features (tf.Tensor): a batch of padded extracted features
lm (bool, optional): whether to use language model. Defaults to False.
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded = self.encoder([features, input_length], training=False)
return self._perform_beam_search_batch(encoded, input_length, lm,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)
37 changes: 5 additions & 32 deletions tensorflow_asr/models/streaming_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
""" http://arxiv.org/abs/1811.06621 """

from typing import Optional
import tensorflow as tf

from .layers.subsampling import TimeReduction
Expand Down Expand Up @@ -225,24 +224,18 @@ def __init__(self,
)
self.time_reduction_factor = self.encoder.time_reduction_factor

def encoder_inference(self,
features: tf.Tensor,
states: tf.Tensor,
input_length: Optional[tf.Tensor] = None,
with_batch: bool = False):
def encoder_inference(self, features: tf.Tensor, states: tf.Tensor):
"""Infer function for encoder (or encoders)
Args:
features (tf.Tensor): features with shape [T, F, C]
states (tf.Tensor): previous states of encoders with shape [num_rnns, 1 or 2, 1, P]
with_batch (bool): indicates whether the features included batch dim or not
Returns:
tf.Tensor: output of encoders with shape [T, E]
tf.Tensor: states of encoders with shape [num_rnns, 1 or 2, 1, P]
"""
with tf.name_scope(f"{self.name}_encoder"):
if with_batch: return self.encoder.recognize(features, states)
outputs = tf.expand_dims(features, axis=0)
outputs, new_states = self.encoder.recognize(outputs, states)
return tf.squeeze(outputs, axis=0), new_states
Expand All @@ -263,11 +256,7 @@ def recognize(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded, _ = self.encoder_inference(
features,
self.encoder.get_initial_state(),
input_length=input_length, with_batch=True
)
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
return self._perform_greedy_batch(encoded, input_length,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand All @@ -290,12 +279,7 @@ def recognize_tflite(self, signal, predicted, encoder_states, prediction_states)
encoded, new_encoder_states = self.encoder_inference(features, encoder_states)
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states)
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
return (
transcript,
hypothesis.prediction[-1],
new_encoder_states,
hypothesis.states
)
return transcript, hypothesis.index, new_encoder_states, hypothesis.states

def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states):
features = self.speech_featurizer.tf_extract(signal)
Expand All @@ -318,14 +302,7 @@ def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, pre
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)

return (
non_blank_transcript,
non_blank_stime,
non_blank_etime,
hypothesis.prediction,
new_encoder_states,
hypothesis.states
)
return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, new_encoder_states, hypothesis.states

# -------------------------------- BEAM SEARCH -------------------------------------

Expand All @@ -345,11 +322,7 @@ def recognize_beam(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded, _ = self.encoder_inference(
features,
self.encoder.get_initial_state(),
input_length=input_length, with_batch=True
)
encoded, _ = self.encoder.recognize(features, self.encoder.get_initial_state())
return self._perform_beam_search_batch(encoded, input_length, lm,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand Down
74 changes: 31 additions & 43 deletions tensorflow_asr/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
""" https://arxiv.org/pdf/1811.06621.pdf """

import collections
from typing import Optional
import tensorflow as tf

from . import Model
Expand Down Expand Up @@ -285,22 +284,16 @@ def call(self, inputs, training=False, **kwargs):
outputs = self.joint_net([enc, pred], training=training, **kwargs)
return outputs

def encoder_inference(self,
features: tf.Tensor,
input_length: Optional[tf.Tensor] = None,
with_batch: Optional[bool] = False):
def encoder_inference(self, features: tf.Tensor):
"""Infer function for encoder (or encoders)
Args:
features (tf.Tensor): features with shape [T, F, C]
input_length (tf.Tensor): optional features length with shape []
with_batch (bool): indicates whether the features included batch dim or not
Returns:
tf.Tensor: output of encoders with shape [T, E]
"""
with tf.name_scope(f"{self.name}_encoder"):
if with_batch: return self.encoder(features, training=False)
outputs = tf.expand_dims(features, axis=0)
outputs = self.encoder(outputs, training=False)
return tf.squeeze(outputs, axis=0)
Expand All @@ -321,7 +314,7 @@ def decoder_inference(self, encoded: tf.Tensor, predicted: tf.Tensor, states: tf
predicted = tf.reshape(predicted, [1, 1]) # [] => [1, 1]
y, new_states = self.predict_net.recognize(predicted, states) # [1, 1, P], states
ytu = tf.nn.log_softmax(self.joint_net([encoded, y], training=False)) # [1, 1, V]
ytu = tf.squeeze(ytu, axis=None) # [1, 1, V] => [V]
ytu = tf.reshape(ytu, shape=[-1]) # [1, 1, V] => [V]
return ytu, new_states

def get_config(self):
Expand All @@ -347,7 +340,7 @@ def recognize(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded = self.encoder_inference(features, input_length, with_batch=True)
encoded = self.encoder(features, training=True)
return self._perform_greedy_batch(encoded, input_length,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand All @@ -368,11 +361,7 @@ def recognize_tflite(self, signal, predicted, states):
encoded = self.encoder_inference(features)
hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states)
transcript = self.text_featurizer.indices2upoints(hypothesis.prediction)
return (
transcript,
hypothesis.prediction[-1],
hypothesis.states
)
return transcript, hypothesis.index, hypothesis.states

def recognize_tflite_with_timestamp(self, signal, predicted, states):
features = self.speech_featurizer.tf_extract(signal)
Expand All @@ -395,7 +384,7 @@ def recognize_tflite_with_timestamp(self, signal, predicted, states):
non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)
non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank)

return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.prediction, hypothesis.states
return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states

def _perform_greedy_batch(self,
encoded: tf.Tensor,
Expand Down Expand Up @@ -450,48 +439,47 @@ def _perform_greedy(self,
total = encoded_length

hypothesis = Hypothesis(
index=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
prediction=tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank,
index=predicted,
prediction=tf.TensorArray(
dtype=tf.int32, size=total, dynamic_size=False,
clear_after_read=False, element_shape=tf.TensorShape([])
),
states=states
)

def condition(time, total, encoded, hypothesis): return tf.less(time, total)
def condition(_time, _total, _encoded, _hypothesis): return tf.less(_time, _total)

def body(time, total, encoded, hypothesis):
ytu, states = self.decoder_inference(
def body(_time, _total, _encoded, _hypothesis):
ytu, _states = self.decoder_inference(
# avoid using [index] in tflite
encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)),
predicted=hypothesis.index,
states=hypothesis.states
encoded=tf.gather_nd(_encoded, tf.reshape(_time, shape=[1])),
predicted=_hypothesis.index,
states=_hypothesis.states
)
predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []
_predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []

index, predict, states = tf.cond(
tf.equal(predict, self.text_featurizer.blank),
true_fn=lambda: (hypothesis.index, predict, hypothesis.states),
false_fn=lambda: (predict, predict, states) # update if the new prediction is a non-blank
)
# something is wrong with tflite that drop support for tf.cond
# def equal_blank_fn(): return _hypothesis.index, _hypothesis.states
# def non_equal_blank_fn(): return _predict, _states # update if the new prediction is a non-blank
# _index, _states = tf.cond(tf.equal(_predict, blank), equal_blank_fn, non_equal_blank_fn)

hypothesis = Hypothesis(
index=index,
prediction=tf.tensor_scatter_nd_update(
hypothesis.prediction,
indices=tf.reshape(time, [1, 1]),
updates=tf.expand_dims(predict, axis=-1)
),
states=states
)
_equal = tf.equal(_predict, self.text_featurizer.blank)
_index = tf.where(_equal, _hypothesis.index, _predict)
_states = tf.where(_equal, _hypothesis.states, _states)

_prediction = _hypothesis.prediction.write(_time, _predict)
_hypothesis = Hypothesis(index=_index, prediction=_prediction, states=_states)

return time + 1, total, encoded, hypothesis
return _time + 1, _total, _encoded, _hypothesis

time, total, encoded, hypothesis = tf.while_loop(
_, _, _, hypothesis = tf.while_loop(
condition, body,
loop_vars=[time, total, encoded, hypothesis],
parallel_iterations=parallel_iterations,
swap_memory=swap_memory
)

return hypothesis
return Hypothesis(index=hypothesis.index, prediction=hypothesis.prediction.stack(), states=hypothesis.states)

# -------------------------------- BEAM SEARCH -------------------------------------

Expand All @@ -511,7 +499,7 @@ def recognize_beam(self,
Returns:
tf.Tensor: a batch of decoded transcripts
"""
encoded = self.encoder_inference(features, input_length, with_batch=True)
encoded = self.encoder(features, training=True)
return self._perform_beam_search_batch(encoded, input_length, lm,
parallel_iterations=parallel_iterations, swap_memory=swap_memory)

Expand Down
Loading

0 comments on commit 86e8c43

Please sign in to comment.