Skip to content

Commit

Permalink
Merge pull request #87 from TensorSpeech/dev/testing
Browse files Browse the repository at this point in the history
Add unittest and Transducer Joint activation
  • Loading branch information
nglehuy authored Dec 27, 2020
2 parents 3004f0e + 4a97b87 commit b771734
Show file tree
Hide file tree
Showing 29 changed files with 746 additions and 599 deletions.
1 change: 1 addition & 0 deletions examples/conformer/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ model_config:
prediction_layer_norm: True
prediction_projection_units: 0
joint_dim: 320
joint_activation: tanh

learning_config:
augmentations:
Expand Down
1 change: 1 addition & 0 deletions examples/contextnet/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ model_config:
prediction_layer_norm: True
prediction_projection_units: 0
joint_dim: 640
joint_activation: tanh

learning_config:
augmentations:
Expand Down
3 changes: 2 additions & 1 deletion examples/deepspeech2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ model_config:
See `python examples/deepspeech2/train_*.py --help`

See `python examples/deepspeech2/test_*.py --help`
See `python examples/deepspeech2/test_*.py --help`

3 changes: 2 additions & 1 deletion examples/jasper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ model_config:
See `python examples/jasper/train_*.py --help`

See `python examples/jasper/test_*.py --help`
See `python examples/jasper/test_*.py --help`

1 change: 1 addition & 0 deletions examples/streaming_transducer/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ model_config:
prediction_projection_units: 320
prediction_layer_norm: True
joint_dim: 320
joint_activation: tanh

learning_config:
augmentations:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

setuptools.setup(
name="TensorFlowASR",
version="0.6.0",
version="0.6.1",
author="Huy Le Nguyen",
author_email="[email protected]",
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_asr/models/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def __init__(self,
prediction_layer_norm: bool = True,
prediction_projection_units: int = 0,
joint_dim: int = 1024,
joint_activation: str = "tanh",
kernel_regularizer=L2,
bias_regularizer=L2,
name: str = "conformer_transducer",
Expand Down Expand Up @@ -414,6 +415,7 @@ def __init__(self,
layer_norm=prediction_layer_norm,
projection_units=prediction_projection_units,
joint_dim=joint_dim,
joint_activation=joint_activation,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
name=name, **kwargs
Expand Down
4 changes: 3 additions & 1 deletion tensorflow_asr/models/contextnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ class ContextNet(Transducer):
def __init__(self,
vocabulary_size: int,
encoder_blocks: List[dict],
encoder_alpha: float,
encoder_alpha: float = 0.5,
prediction_embed_dim: int = 512,
prediction_embed_dropout: int = 0,
prediction_num_rnns: int = 1,
Expand All @@ -206,6 +206,7 @@ def __init__(self,
prediction_layer_norm: bool = True,
prediction_projection_units: int = 0,
joint_dim: int = 1024,
joint_activation: str = "tanh",
kernel_regularizer=L2,
bias_regularizer=L2,
name: str = "contextnet",
Expand All @@ -228,6 +229,7 @@ def __init__(self,
layer_norm=prediction_layer_norm,
projection_units=prediction_projection_units,
joint_dim=joint_dim,
joint_activation=joint_activation,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
name=name, **kwargs
Expand Down
5 changes: 3 additions & 2 deletions tensorflow_asr/models/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
class CtcModel(Model):
def __init__(self, **kwargs):
super(CtcModel, self).__init__(**kwargs)
self.time_reduction_factor = 1

def _build(self, input_shape):
features = tf.keras.Input(input_shape, dtype=tf.float32)
Expand Down Expand Up @@ -67,7 +68,7 @@ def recognize_tflite(self, signal):
features = self.speech_featurizer.tf_extract(signal)
features = tf.expand_dims(features, axis=0)
input_length = shape_list(features)[1]
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
input_length = get_reduced_length(input_length, self.time_reduction_factor)
input_length = tf.expand_dims(input_length, axis=0)
logits = self(features, training=False)
probs = tf.nn.softmax(logits)
Expand Down Expand Up @@ -113,7 +114,7 @@ def recognize_beam_tflite(self, signal):
features = self.speech_featurizer.tf_extract(signal)
features = tf.expand_dims(features, axis=0)
input_length = shape_list(features)[1]
input_length = get_reduced_length(input_length, self.base_model.time_reduction_factor)
input_length = get_reduced_length(input_length, self.time_reduction_factor)
input_length = tf.expand_dims(input_length, axis=0)
logits = self(features, training=False)
probs = tf.nn.softmax(logits)
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_asr/models/streaming_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self,
prediction_layer_norm: bool = True,
prediction_projection_units: int = 640,
joint_dim: int = 640,
joint_activation: str = "tanh",
kernel_regularizer = None,
bias_regularizer = None,
name = "StreamingTransducer",
Expand All @@ -217,6 +218,7 @@ def __init__(self,
layer_norm=prediction_layer_norm,
projection_units=prediction_projection_units,
joint_dim=joint_dim,
joint_activation=joint_activation,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
name=name, **kwargs
Expand Down
12 changes: 11 additions & 1 deletion tensorflow_asr/models/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,19 @@ class TransducerJoint(tf.keras.Model):
def __init__(self,
vocabulary_size: int,
joint_dim: int = 1024,
activation: str = "tanh",
kernel_regularizer=None,
bias_regularizer=None,
name="tranducer_joint",
**kwargs):
super(TransducerJoint, self).__init__(name=name, **kwargs)

activation = activation.lower()
if activation == "linear": self.activation = tf.keras.activation.linear
elif activation == "relu": self.activation = tf.nn.relu
elif activation == "tanh": self.activation = tf.nn.tanh
else: raise ValueError("activation must be either 'linear', 'relu' or 'tanh'")

self.ffn_enc = tf.keras.layers.Dense(
joint_dim, name=f"{name}_enc",
kernel_regularizer=kernel_regularizer,
Expand All @@ -174,7 +182,7 @@ def call(self, inputs, training=False, **kwargs):
pred_out = self.ffn_pred(pred_out, training=training) # [B, U, P] => [B, U, V]
enc_out = tf.expand_dims(enc_out, axis=2)
pred_out = tf.expand_dims(pred_out, axis=1)
outputs = tf.nn.tanh(enc_out + pred_out) # => [B, T, U, V]
outputs = self.activation(enc_out + pred_out) # => [B, T, U, V]
outputs = self.ffn_out(outputs, training=training)
return outputs

Expand All @@ -200,6 +208,7 @@ def __init__(self,
layer_norm: bool = True,
projection_units: int = 0,
joint_dim: int = 1024,
joint_activation: str = "tanh",
kernel_regularizer=None,
bias_regularizer=None,
name="transducer",
Expand All @@ -223,6 +232,7 @@ def __init__(self,
self.joint_net = TransducerJoint(
vocabulary_size=vocabulary_size,
joint_dim=joint_dim,
activation=joint_activation,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
name=f"{name}_joint"
Expand Down
95 changes: 95 additions & 0 deletions tests/conformer/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

speech_config:
sample_rate: 16000
frame_ms: 25
stride_ms: 10
num_feature_bins: 80
feature_type: log_mel_spectrogram
preemphasis: 0.97
normalize_signal: True
normalize_feature: True
normalize_per_feature: False

decoder_config:
vocabulary: null
target_vocab_size: 1024
max_subword_length: 4
blank_at_zero: True
beam_width: 5
norm_score: True

model_config:
name: conformer
encoder_subsampling:
type: conv2d
filters: 144
kernel_size: 3
strides: 2
encoder_positional_encoding: sinusoid_concat
encoder_dmodel: 144
encoder_num_blocks: 16
encoder_head_size: 36
encoder_num_heads: 4
encoder_mha_type: relmha
encoder_kernel_size: 32
encoder_fc_factor: 0.5
encoder_dropout: 0.1
prediction_embed_dim: 320
prediction_embed_dropout: 0
prediction_num_rnns: 1
prediction_rnn_units: 320
prediction_rnn_type: lstm
prediction_rnn_implementation: 1
prediction_layer_norm: True
prediction_projection_units: 0
joint_dim: 320
joint_activation: tanh

learning_config:
augmentations:
after:
time_masking:
num_masks: 10
mask_factor: 100
p_upperbound: 0.05
freq_masking:
num_masks: 1
mask_factor: 27

dataset_config:
train_paths:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
eval_paths:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
test_paths:
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
tfrecords_dir: null

optimizer_config:
warmup_steps: 40000
beta1: 0.9
beta2: 0.98
epsilon: 1e-9

running_config:
batch_size: 2
accumulation_steps: 4
num_epochs: 20
outdir: /mnt/Miscellanea/Models/local/conformer
log_interval_steps: 300
eval_interval_steps: 500
save_interval_steps: 1000
57 changes: 57 additions & 0 deletions tests/conformer/test_conformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2020 Huy Le Nguyen (@usimarit)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf

DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml")

from tensorflow_asr.configs.config import Config
from tensorflow_asr.models.conformer import Conformer
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer


def test_conformer():
config = Config(DEFAULT_YAML, learning=False)

text_featurizer = CharFeaturizer(config.decoder_config)

speech_featurizer = TFSpeechFeaturizer(config.speech_config)

model = Conformer(vocabulary_size=text_featurizer.num_classes, **config.model_config)

model._build(speech_featurizer.shape)
model.summary(line_length=150)

model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer)

concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.convert()

print("Converted successfully with no timestamp")

concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.convert()

print("Converted successfully with timestamp")
Loading

0 comments on commit b771734

Please sign in to comment.