Skip to content

Commit

Permalink
Merge pull request #83 from TensorSpeech/fix/tflite
Browse files Browse the repository at this point in the history
Fix TFLite Conversion for Transducer Greedy
  • Loading branch information
nglehuy authored Dec 19, 2020
2 parents 6d70eab + 6df0393 commit 1bad037
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 41 deletions.
4 changes: 2 additions & 2 deletions examples/conformer/tflite_subword_conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@

concrete_func = conformer.make_tflite_function(greedy=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

if not os.path.exists(os.path.dirname(args.output)):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.5.1",
version="0.5.2",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
54 changes: 16 additions & 38 deletions tensorflow_asr/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,73 +417,51 @@ def __perform_greedy(self,
with tf.name_scope(f"{self.name}_greedy"):
time = tf.constant(0, dtype=tf.int32)
total = encoded_length
# Initialize prediction with a blank
# Prediction can not be longer than the encoded of audio plus blank
prediction = tf.TensorArray(
dtype=tf.int32,
size=(total + 1),
dynamic_size=False,
element_shape=tf.TensorShape([]),
clear_after_read=False
)

hypothesis = Hypothesis(
index=tf.constant(0, dtype=tf.int32),
prediction=prediction.write(0, predicted),
prediction=tf.ones([total + 1], dtype=tf.int32) * self.text_featurizer.blank,
states=states
)

def condition(time, total, encoded, hypothesis): return tf.less(time, total)

def body(time, total, encoded, hypothesis):
predicted = tf.gather_nd(hypothesis.prediction, tf.expand_dims(hypothesis.index, axis=-1))

ytu, new_states = self.decoder_inference(
# avoid using [index] in tflite
encoded=tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)),
predicted=hypothesis.prediction.read(hypothesis.index),
predicted=predicted,
states=hypothesis.states
)
char = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []

index, char, new_states = tf.cond(
tf.equal(char, self.text_featurizer.blank),
true_fn=lambda: (
hypothesis.index,
hypothesis.prediction.read(hypothesis.index),
hypothesis.states
),
false_fn=lambda: (
hypothesis.index + 1,
char,
new_states
)
new_predicted = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax []

index, new_predicted, new_states = tf.cond(
tf.equal(new_predicted, self.text_featurizer.blank),
true_fn=lambda: (hypothesis.index, predicted, hypothesis.states),
false_fn=lambda: (hypothesis.index + 1, new_predicted, new_states)
)

hypothesis = Hypothesis(
index=index,
prediction=hypothesis.prediction.write(index, char),
prediction=tf.tensor_scatter_nd_update(
hypothesis.prediction,
indices=tf.reshape(index, [1, 1]),
updates=tf.expand_dims(new_predicted, axis=-1)
),
states=new_states
)

return time + 1, total, encoded, hypothesis

time, total, encoded, hypothesis = tf.while_loop(
condition,
body,
condition, body,
loop_vars=(time, total, encoded, hypothesis),
parallel_iterations=parallel_iterations,
swap_memory=swap_memory
)

# Gather predicted sequence
hypothesis = Hypothesis(
index=hypothesis.index,
prediction=tf.gather_nd(
params=hypothesis.prediction.stack(),
indices=tf.expand_dims(tf.range(hypothesis.index + 1), axis=-1)
),
states=hypothesis.states
)

return hypothesis

# -------------------------------- BEAM SEARCH -------------------------------------
Expand Down

0 comments on commit 1bad037

Please sign in to comment.