Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[speechx]add wfst decoder #2886

Merged
merged 1 commit into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion speechx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set(FETCHCONTENT_BASE_DIR ${fc_patch})

# compiler option
# Keep the same with openfst, -fPIC or -fpic
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ldl")
SET(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O0 -Wall -g -ggdb")
SET(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} --std=c++14 -pthread -fPIC -O3 -Wall")

Expand Down
2 changes: 2 additions & 0 deletions speechx/speechx/asr/decoder/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
set(srcs)
list(APPEND srcs
ctc_prefix_beam_search_decoder.cc
ctc_tlg_decoder.cc
)

add_library(decoder STATIC ${srcs})
Expand All @@ -9,6 +10,7 @@ target_link_libraries(decoder PUBLIC utils fst frontend nnet kaldi-decoder)
# test
set(TEST_BINS
ctc_prefix_beam_search_decoder_main
ctc_tlg_decoder_main
)

foreach(bin_name IN LISTS TEST_BINS)
Expand Down
3 changes: 1 addition & 2 deletions speechx/speechx/asr/decoder/ctc_prefix_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class CTCPrefixBeamSearch : public DecoderBase {

void FinalizeSearch();

const std::shared_ptr<fst::SymbolTable> VocabTable() const {
const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit table 应该是模型的输出,不一定是 word

return unit_table_;
}

Expand All @@ -57,7 +57,6 @@ class CTCPrefixBeamSearch : public DecoderBase {
}
const std::vector<std::vector<int>>& Times() const { return times_; }


protected:
std::string GetBestPath() override;
std::vector<std::pair<double, std::string>> GetNBestPath() override;
Expand Down
44 changes: 41 additions & 3 deletions speechx/speechx/asr/decoder/ctc_tlg_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#include "decoder/ctc_tlg_decoder.h"
namespace ppspeech {

TLGDecoder::TLGDecoder(TLGDecoderOptions opts) {
TLGDecoder::TLGDecoder(TLGDecoderOptions opts) : opts_(opts) {
fst_.reset(fst::Fst<fst::StdArc>::Read(opts.fst_path));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

opts -> opts_

CHECK(fst_ != nullptr);

Expand Down Expand Up @@ -68,14 +68,52 @@ std::string TLGDecoder::GetPartialResult() {
return words;
}

void TLGDecoder::FinalizeSearch() {
decoder_->FinalizeDecoding();
kaldi::CompactLattice clat;
decoder_->GetLattice(&clat, true);
kaldi::Lattice lat, nbest_lat;
fst::ConvertLattice(clat, &lat);
fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
std::vector<kaldi::Lattice> nbest_lats;
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);

hypotheses_.clear();
hypotheses_.reserve(nbest_lats.size());
likelihood_.clear();
likelihood_.reserve(nbest_lats.size());
times_.clear();
times_.reserve(nbest_lats.size());
for (auto lat : nbest_lats) {
kaldi::LatticeWeight weight;
std::vector<int> hypothese;
std::vector<int> time;
std::vector<int> alignment;
std::vector<int> words_id;
fst::GetLinearSymbolSequence(lat, &alignment, &words_id, &weight);
int idx = 0;
for (; idx < alignment.size() - 1; ++idx) {
if (alignment[idx] == 0) continue;
if (alignment[idx] != alignment[idx + 1]) {
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
}
}
hypothese.push_back(alignment[idx] - 1);
time.push_back(idx); // fake time, todo later
hypotheses_.push_back(hypothese);
times_.push_back(time);
olabels.push_back(words_id);
likelihood_.push_back(-(weight.Value2() + weight.Value1()));
}
}

std::string TLGDecoder::GetFinalBestPath() {
if (num_frame_decoded_ == 0) {
// Assertion failed: (this->NumFramesDecoded() > 0 && "You cannot call
// BestPathEnd if no frames were decoded.")
return std::string("");
}

decoder_->FinalizeDecoding();
kaldi::Lattice lat;
kaldi::LatticeWeight weight;
std::vector<int> alignment;
Expand Down
37 changes: 32 additions & 5 deletions speechx/speechx/asr/decoder/ctc_tlg_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "util/parse-options.h"


DECLARE_string(graph_path);
DECLARE_string(word_symbol_table);
DECLARE_string(graph_path);
DECLARE_int32(max_active);
DECLARE_double(beam);
DECLARE_double(lattice_beam);
Expand All @@ -33,6 +32,9 @@ struct TLGDecoderOptions {
// todo remove later, add into decode resource
std::string word_symbol_table;
std::string fst_path;
int nbest;

TLGDecoderOptions() : word_symbol_table(""), fst_path(""), nbest(10) {}

static TLGDecoderOptions InitFromFlags() {
TLGDecoderOptions decoder_opts;
Expand All @@ -44,6 +46,7 @@ struct TLGDecoderOptions {
decoder_opts.opts.max_active = FLAGS_max_active;
decoder_opts.opts.beam = FLAGS_beam;
decoder_opts.opts.lattice_beam = FLAGS_lattice_beam;
// decoder_opts.nbest = FLAGS_lattice_nbest;
SmileGoat marked this conversation as resolved.
Show resolved Hide resolved
LOG(INFO) << "LatticeFasterDecoder max active: "
<< decoder_opts.opts.max_active;
LOG(INFO) << "LatticeFasterDecoder beam: " << decoder_opts.opts.beam;
Expand All @@ -59,20 +62,38 @@ class TLGDecoder : public DecoderBase {
explicit TLGDecoder(TLGDecoderOptions opts);
~TLGDecoder() = default;

void InitDecoder();
void Reset();
void InitDecoder() override;
void Reset() override;

void AdvanceDecode(
const std::shared_ptr<kaldi::DecodableInterface>& decodable);
const std::shared_ptr<kaldi::DecodableInterface>& decodable) override;

void Decode();

std::string GetFinalBestPath() override;
std::string GetPartialResult() override;

const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const override {
return word_symbol_table_;
}

int DecodeLikelihoods(const std::vector<std::vector<BaseFloat>>& probs,
const std::vector<std::string>& nbest_words);

void FinalizeSearch() override;
const std::vector<std::vector<int>>& Inputs() const override {
return hypotheses_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return olabels;
} // outputs_; }
const std::vector<float>& Likelihood() const override {
return likelihood_;
}
const std::vector<std::vector<int>>& Times() const override {
return times_;
}

protected:
std::string GetBestPath() override {
CHECK(false);
Expand All @@ -90,9 +111,15 @@ class TLGDecoder : public DecoderBase {
private:
void AdvanceDecoding(kaldi::DecodableInterface* decodable);

std::vector<std::vector<int>> hypotheses_;
std::vector<std::vector<int>> olabels;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;

std::shared_ptr<kaldi::LatticeFasterOnlineDecoder> decoder_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_;
std::shared_ptr<fst::SymbolTable> word_symbol_table_;
TLGDecoderOptions opts_;
};


Expand Down
77 changes: 15 additions & 62 deletions speechx/speechx/asr/decoder/ctc_tlg_decoder_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@

// todo refactor, repalce with gtest

#include "base/common.h"
#include "decoder/ctc_tlg_decoder.h"
#include "base/common.h"
#include "decoder/param.h"
#include "frontend/audio/data_cache.h"
#include "frontend/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/ds2_nnet.h"
#include "nnet/nnet_producer.h"


DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");


Expand All @@ -39,8 +39,8 @@ int main(int argc, char* argv[]) {
google::InstallFailureSignalHandler();
FLAGS_logtostderr = 1;

kaldi::SequentialBaseFloatMatrixReader feature_reader(
FLAGS_feature_rspecifier);
kaldi::SequentialBaseFloatMatrixReader nnet_prob_reader(
FLAGS_nnet_prob_rspecifier);
kaldi::TokenWriter result_writer(FLAGS_result_wspecifier);

int32 num_done = 0, num_err = 0;
Expand All @@ -53,66 +53,19 @@ int main(int argc, char* argv[]) {

ppspeech::ModelOptions model_opts = ppspeech::ModelOptions::InitFromFlags();

std::shared_ptr<ppspeech::PaddleNnet> nnet(
new ppspeech::PaddleNnet(model_opts));
std::shared_ptr<ppspeech::DataCache> raw_data(new ppspeech::DataCache());
std::shared_ptr<ppspeech::NnetProducer> nnet_producer =
std::make_shared<ppspeech::NnetProducer>(nullptr);
std::shared_ptr<ppspeech::Decodable> decodable(
new ppspeech::Decodable(nnet, raw_data, FLAGS_acoustic_scale));

int32 chunk_size = FLAGS_receptive_field_length +
(FLAGS_nnet_decoder_chunk - 1) * FLAGS_subsampling_rate;
int32 chunk_stride = FLAGS_subsampling_rate * FLAGS_nnet_decoder_chunk;
int32 receptive_field_length = FLAGS_receptive_field_length;
LOG(INFO) << "chunk size (frame): " << chunk_size;
LOG(INFO) << "chunk stride (frame): " << chunk_stride;
LOG(INFO) << "receptive field (frame): " << receptive_field_length;
new ppspeech::Decodable(nnet_producer, FLAGS_acoustic_scale));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make_shared


decoder.InitDecoder();
kaldi::Timer timer;
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
kaldi::Matrix<BaseFloat> feature = feature_reader.Value();
raw_data->SetDim(feature.NumCols());
LOG(INFO) << "process utt: " << utt;
LOG(INFO) << "rows: " << feature.NumRows();
LOG(INFO) << "cols: " << feature.NumCols();

int32 row_idx = 0;
int32 padding_len = 0;
int32 ori_feature_len = feature.NumRows();
if ((feature.NumRows() - chunk_size) % chunk_stride != 0) {
padding_len =
chunk_stride - (feature.NumRows() - chunk_size) % chunk_stride;
feature.Resize(feature.NumRows() + padding_len,
feature.NumCols(),
kaldi::kCopyData);
}
int32 num_chunks = (feature.NumRows() - chunk_size) / chunk_stride + 1;
for (int chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) {
kaldi::Vector<kaldi::BaseFloat> feature_chunk(chunk_size *
feature.NumCols());
int32 feature_chunk_size = 0;
if (ori_feature_len > chunk_idx * chunk_stride) {
feature_chunk_size = std::min(
ori_feature_len - chunk_idx * chunk_stride, chunk_size);
}
if (feature_chunk_size < receptive_field_length) break;

int32 start = chunk_idx * chunk_stride;
for (int row_id = 0; row_id < chunk_size; ++row_id) {
kaldi::SubVector<kaldi::BaseFloat> tmp(feature, start);
kaldi::SubVector<kaldi::BaseFloat> f_chunk_tmp(
feature_chunk.Data() + row_id * feature.NumCols(),
feature.NumCols());
f_chunk_tmp.CopyFromVec(tmp);
++start;
}
raw_data->Accept(feature_chunk);
if (chunk_idx == num_chunks - 1) {
raw_data->SetFinished();
}
decoder.AdvanceDecode(decodable);
}

for (; !nnet_prob_reader.Done(); nnet_prob_reader.Next()) {
string utt = nnet_prob_reader.Key();
kaldi::Matrix<BaseFloat> prob = nnet_prob_reader.Value();
decodable->Acceptlikelihood(prob);
decoder.AdvanceDecode(decodable);
std::string result;
result = decoder.GetFinalBestPath();
decodable->Reset();
Expand Down
10 changes: 9 additions & 1 deletion speechx/speechx/asr/decoder/decoder_itf.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,6 +15,7 @@
#pragma once

#include "base/common.h"
#include "fst/symbol-table.h"
#include "kaldi/decoder/decodable-itf.h"

namespace ppspeech {
Expand All @@ -41,6 +41,14 @@ class DecoderInterface {

virtual std::string GetPartialResult() = 0;

virtual const std::shared_ptr<fst::SymbolTable> WordSymbolTable() const = 0;
virtual void FinalizeSearch() = 0;

virtual const std::vector<std::vector<int>>& Inputs() const = 0;
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
virtual const std::vector<float>& Likelihood() const = 0;
virtual const std::vector<std::vector<int>>& Times() const = 0;

protected:
// virtual void AdvanceDecoding(kaldi::DecodableInterface* decodable) = 0;

Expand Down
4 changes: 2 additions & 2 deletions speechx/speechx/asr/decoder/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ DEFINE_string(model_cache_shapes, "5-1-1024,5-1-1024", "model cache shapes");
// decoder
DEFINE_double(acoustic_scale, 1.0, "acoustic scale");

DEFINE_string(graph_path, "TLG", "decoder graph");
DEFINE_string(word_symbol_table, "words.txt", "word symbol table");
DEFINE_string(graph_path, "", "decoder graph");
DEFINE_string(word_symbol_table, "", "word symbol table");
DEFINE_int32(max_active, 7500, "max active");
DEFINE_double(beam, 15.0, "decoder beam");
DEFINE_double(lattice_beam, 7.5, "decoder beam");
Expand Down
2 changes: 0 additions & 2 deletions speechx/speechx/asr/nnet/decodable.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class Decodable : public kaldi::DecodableInterface {
explicit Decodable(const std::shared_ptr<NnetProducer>& nnet_producer,
kaldi::BaseFloat acoustic_scale = 1.0);

// void Init(DecodableOpts config);

// nnet logprob output, used by wfst
virtual kaldi::BaseFloat LogLikelihood(int32 frame, int32 index);

Expand Down
16 changes: 8 additions & 8 deletions speechx/speechx/asr/nnet/nnet_producer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,25 @@ using kaldi::BaseFloat;
NnetProducer::NnetProducer(std::shared_ptr<NnetBase> nnet,
std::shared_ptr<FrontendInterface> frontend)
: nnet_(nnet), frontend_(frontend) {
abort_ = false;
Reset();
thread_ = std::thread(RunNnetEvaluation, this);
}
abort_ = false;
Reset();
if (nnet_ != nullptr) thread_ = std::thread(RunNnetEvaluation, this);
SmileGoat marked this conversation as resolved.
Show resolved Hide resolved
}

void NnetProducer::Accept(const std::vector<kaldi::BaseFloat>& inputs) {
frontend_->Accept(inputs);
condition_variable_.notify_one();
}

void NnetProducer::UnLock() {
void NnetProducer::WaitProduce() {
std::unique_lock<std::mutex> lock(read_mutex_);
while (frontend_->IsFinished() == false && cache_.empty()) {
condition_read_ready_.wait(lock);
condition_read_ready_.wait(lock);
}
return;
}

void NnetProducer::RunNnetEvaluation(NnetProducer *me) {
void NnetProducer::RunNnetEvaluation(NnetProducer* me) {
me->RunNnetEvaluationInteral();
}

Expand All @@ -55,7 +55,7 @@ void NnetProducer::RunNnetEvaluationInteral() {
result = Compute();
} while (result);
if (frontend_->IsFinished() == true) {
if (cache_.empty()) finished_ = true;
if (cache_.empty()) finished_ = true;
}
}
LOG(INFO) << "NnetEvaluationInteral exit";
Expand Down
Loading