Skip to content

Commit

Permalink
Merge pull request #70 from TensorSpeech/fix/transducer
Browse files Browse the repository at this point in the history
Fix typo and format for transducer
  • Loading branch information
nglehuy authored Dec 9, 2020
2 parents 189c7d7 + d750838 commit 96da1e2
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ Session.vim
.idea
.vscode
__pycache__
.pytest*
venv
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[flake8]
ignore = E402,E701,E702,E704,E251
max-line-length = 150
max-line-length = 127

[pep8]
ignore = E402,E701,E702,E704,E251
max-line-length = 150
max-line-length = 127
indent-size = 4
11 changes: 7 additions & 4 deletions tensorflow_asr/models/streaming_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def recognize(self, signals):
"""
def execute(signal: tf.Tensor):
features = self.speech_featurizer.tf_extract(signal)
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_states())
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_state())
hypothesis = self.perform_greedy(
encoded,
predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32),
Expand Down Expand Up @@ -310,10 +310,13 @@ def recognize_beam(self, signals, lm=False):
"""
def execute(signal: tf.Tensor):
features = self.speech_featurizer.tf_extract(signal)
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_states())
encoded, _ = self.encoder_inference(features, self.encoder.get_initial_state())
hypothesis = self.perform_beam_search(encoded, lm)
prediction = tf.map_fn(lambda x: tf.strings.to_number(x, tf.int32),
tf.strings.split(hypothesis.prediction), fn_output_signature=tf.TensorSpec([], dtype=tf.int32))
prediction = tf.map_fn(
lambda x: tf.strings.to_number(x, tf.int32),
tf.strings.split(hypothesis.prediction),
fn_output_signature=tf.TensorSpec([], dtype=tf.int32)
)
transcripts = self.text_featurizer.iextract(tf.expand_dims(prediction, axis=0))
return tf.squeeze(transcripts) # reshape from [1] to []

Expand Down
7 changes: 5 additions & 2 deletions tensorflow_asr/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,11 @@ def execute(signal: tf.Tensor):
features = self.speech_featurizer.tf_extract(signal)
encoded = self.encoder_inference(features)
hypothesis = self.perform_beam_search(encoded, lm)
prediction = tf.map_fn(lambda x: tf.strings.to_number(x, tf.int32),
tf.strings.split(hypothesis.prediction), fn_output_signature=tf.TensorSpec([], dtype=tf.int32))
prediction = tf.map_fn(
lambda x: tf.strings.to_number(x, tf.int32),
tf.strings.split(hypothesis.prediction),
fn_output_signature=tf.TensorSpec([], dtype=tf.int32)
)
transcripts = self.text_featurizer.iextract(tf.expand_dims(prediction, axis=0))
return tf.squeeze(transcripts) # reshape from [1] to []

Expand Down
1 change: 0 additions & 1 deletion tests/plot_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow_asr.optimizers.schedules import SANSchedule, TransformerSchedule
Expand Down

0 comments on commit 96da1e2

Please sign in to comment.