Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[speechx] fix nnet input and output name #1740

Merged
merged 3 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ exclude =
.git,
# python cache
__pycache__,
# third party
utils/compute-wer.py,
third_party/,
# Provide a comma-separate list of glob patterns to include for checks.
filename =
Expand Down
6 changes: 4 additions & 2 deletions paddlespeech/cli/asr/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

__all__ = ['ASRExecutor']


@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
Expand Down Expand Up @@ -148,7 +149,7 @@ def _init_from_path(self,
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)

#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
Expand Down Expand Up @@ -278,7 +279,8 @@ def infer(self, model_type: str):
self._outputs["result"] = result_transcripts[0]

elif "conformer" in model_type or "transformer" in model_type:
logger.info(f"we will use the transformer like model : {model_type}")
logger.info(
f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode(
audio,
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/models/u2/u2.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def recognize(
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break

# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/modules/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
# init once
if self._ext_scorer is not None:
return

if language_model_path != '':
logger.info("begin to initialize the external scorer "
"for decoding")
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml

```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```
```
2 changes: 1 addition & 1 deletion paddlespeech/server/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml

```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```
```
2 changes: 2 additions & 0 deletions paddlespeech/server/engine/asr/online/ctc_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict

import paddle

from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add

Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/t2s/exps/synthesize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def evaluate(args):
# acoustic model
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]

am_inference = get_am_inference(
am=args.am,
am_config=am_config,
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/vector/cluster/diarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import argparse
import copy
import warnings
from distutils.util import strtobool

import numpy as np
import scipy
import sklearn
from distutils.util import strtobool
from scipy import linalg
from scipy import sparse
from scipy.sparse.csgraph import connected_components
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");

Expand Down Expand Up @@ -76,6 +78,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_path;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
Expand Down
5 changes: 2 additions & 3 deletions speechx/examples/ds2_ol/decoder/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ if [ ! -f $lm ]; then
popd
fi


feat_wspecifier=$exp_dir/feats.ark
cmvn=$exp_dir/cmvn.ark

Expand All @@ -57,7 +56,7 @@ export GLOG_logtostderr=1
# dump json cmvn to kaldi
cmvn-json2kaldi \
--json_file $ckpt_dir/data/mean_std.json \
--cmvn_write_path $exp_dir/cmvn.ark \
--cmvn_write_path $cmvn \
--binary=false
echo "convert json cmvn to kaldi ark."

Expand All @@ -66,7 +65,7 @@ echo "convert json cmvn to kaldi ark."
linear-spectrogram-wo-db-norm-ol \
--wav_rspecifier=scp:$data/wav.scp \
--feature_wspecifier=ark,t:$feat_wspecifier \
--cmvn_file=$exp_dir/cmvn.ark
--cmvn_file=$cmvn
echo "compute linear spectrogram feature."

# run ctc beam search decoder as streaming
Expand Down
9 changes: 6 additions & 3 deletions speechx/examples/ds2_ol/decoder/wfst-decoder-ol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ DEFINE_int32(receptive_field_length,
DEFINE_int32(downsampling_rate,
4,
"two CNN(kernel=5) module downsampling rate.");
DEFINE_string(
model_input_names,
"audio_chunk,audio_chunk_lens,chunk_state_h_box,chunk_state_c_box",
"model input names");
DEFINE_string(model_output_names,
"save_infer_model/scale_0.tmp_1,save_infer_model/"
"scale_1.tmp_1,save_infer_model/scale_2.tmp_1,save_infer_model/"
"scale_3.tmp_1",
"softmax_0.tmp_0,tmp_5,concat_0.tmp_0,concat_1.tmp_0",
"model output names");
DEFINE_string(model_cache_names, "5-1-1024,5-1-1024", "model cache names");

Expand Down Expand Up @@ -79,6 +81,7 @@ int main(int argc, char* argv[]) {
model_opts.model_path = model_graph;
model_opts.params_path = model_params;
model_opts.cache_shape = FLAGS_model_cache_names;
model_opts.input_names = FLAGS_model_input_names;
model_opts.output_names = FLAGS_model_output_names;
std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
Expand Down
2 changes: 1 addition & 1 deletion speechx/examples/ds2_ol/feat/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ target_link_libraries(${bin_name} frontend kaldi-util kaldi-feat-common gflags g
set(bin_name cmvn-json2kaldi)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog ${DEPS})
80 changes: 42 additions & 38 deletions speechx/examples/ds2_ol/feat/cmvn-json2kaldi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,68 +14,72 @@

// Note: Do not print/log ondemand object.

#include "base/common.h"
#include "base/flags.h"
#include "base/log.h"
#include "kaldi/matrix/kaldi-matrix.h"
#include "kaldi/util/kaldi-io.h"
#include "utils/file_utils.h"
#include "utils/simdjson.h"
// #include "boost/json.hpp"
#include <boost/json/src.hpp>

DEFINE_string(json_file, "", "cmvn json file");
DEFINE_string(cmvn_write_path, "./cmvn.ark", "write cmvn");
DEFINE_bool(binary, true, "write cmvn in binary (true) or text(false)");

using namespace simdjson;
using namespace boost::json; // from <boost/json.hpp>

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);

LOG(INFO) << "cmvn josn path: " << FLAGS_json_file;

try {
padded_string json = padded_string::load(FLAGS_json_file);

ondemand::parser parser;
ondemand::document doc = parser.iterate(json);
ondemand::value val = doc;
auto ifs = std::ifstream(FLAGS_json_file);
std::string json_str = ppspeech::ReadFile2String(FLAGS_json_file);
auto value = boost::json::parse(json_str);
if (!value.is_object()) {
LOG(ERROR) << "Input json file format error.";
}

ondemand::array mean_stat = val["mean_stat"];
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (double x : mean_stat) {
mean_stat_vec.push_back(x);
for (auto obj : value.as_object()) {
if (obj.key() == "mean_stat") {
LOG(INFO) << "mean_stat:" << obj.value();
}
// LOG(INFO) << mean_stat; this line will casue
// simdjson::simdjson_error("Objects and arrays can only be iterated
// when
// they are first encountered")

ondemand::array var_stat = val["var_stat"];
std::vector<kaldi::BaseFloat> var_stat_vec;
for (double x : var_stat) {
var_stat_vec.push_back(x);
if (obj.key() == "var_stat") {
LOG(INFO) << "var_stat: " << obj.value();
}

kaldi::int32 frame_num = uint64_t(val["frame_num"]);
LOG(INFO) << "nframe: " << frame_num;

size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
if (obj.key() == "frame_num") {
LOG(INFO) << "frame_num: " << obj.value();
}
cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats;
}

boost::json::array mean_stat = value.at("mean_stat").as_array();
std::vector<kaldi::BaseFloat> mean_stat_vec;
for (auto it = mean_stat.begin(); it != mean_stat.end(); it++) {
mean_stat_vec.push_back(it->as_double());
}

kaldi::WriteKaldiObject(
cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
} catch (simdjson::simdjson_error& err) {
LOG(ERROR) << err.what();
boost::json::array var_stat = value.at("var_stat").as_array();
std::vector<kaldi::BaseFloat> var_stat_vec;
for (auto it = var_stat.begin(); it != var_stat.end(); it++) {
var_stat_vec.push_back(it->as_double());
}

kaldi::int32 frame_num = uint64_t(value.at("frame_num").as_int64());
LOG(INFO) << "nframe: " << frame_num;

size_t mean_size = mean_stat_vec.size();
kaldi::Matrix<double> cmvn_stats(2, mean_size + 1);
for (size_t idx = 0; idx < mean_size; ++idx) {
cmvn_stats(0, idx) = mean_stat_vec[idx];
cmvn_stats(1, idx) = var_stat_vec[idx];
}
cmvn_stats(0, mean_size) = frame_num;
LOG(INFO) << cmvn_stats;

kaldi::WriteKaldiObject(cmvn_stats, FLAGS_cmvn_write_path, FLAGS_binary);
LOG(INFO) << "cmvn stats have write into: " << FLAGS_cmvn_write_path;
LOG(INFO) << "Binary: " << FLAGS_binary;
return 0;
}
16 changes: 6 additions & 10 deletions speechx/examples/ngram/zh/local/text_to_lexicon.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
from collections import Counter


def main(args):
counter = Counter()
with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout:
Expand All @@ -12,29 +13,24 @@ def main(args):
words = text.split()
else:
words = line.split()

counter.update(words)

for word in counter:
val = " ".join(list(word))
fout.write(f"{word}\t{val}\n")
fout.flush()


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='text(line:utt1 中国 人) to lexicon(line:中国 中 国).')
parser.add_argument(
'--has_key',
default=True,
help='text path, with utt or not')
'--has_key', default=True, help='text path, with utt or not')
parser.add_argument(
'--text',
required=True,
help='text path. line: utt1 中国 人 or 中国 人')
'--text', required=True, help='text path. line: utt1 中国 人 or 中国 人')
parser.add_argument(
'--lexicon',
required=True,
help='lexicon path. line:中国 中 国')
'--lexicon', required=True, help='lexicon path. line:中国 中 国')
args = parser.parse_args()
print(args)

Expand Down
Loading