From cfd8b79b7867b5bf39b1a3c1a071fdd2759d3b0e Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Sat, 16 May 2020 19:00:44 +0800 Subject: [PATCH 1/2] fix determinize issues --- k2/csrc/determinize.cc | 1 + k2/csrc/determinize.h | 55 ++++++++++++++++++---------------------- k2/csrc/fsa_algo_test.cc | 47 ++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 31 deletions(-) diff --git a/k2/csrc/determinize.cc b/k2/csrc/determinize.cc index c3bed63aa..54b166bb8 100644 --- a/k2/csrc/determinize.cc +++ b/k2/csrc/determinize.cc @@ -115,6 +115,7 @@ void TraceBack(std::unordered_set *cur_states, // `deriv_out` is just a list of arc indexes in the input FSA // that this output arc depends on (it's their sum). (*deriv_out)[i] = state->arc_id; + state = state->prev_state.get(); } double prev_forward_prob = state->forward_prob; *weight_out = cur_forward_prob - prev_forward_prob; diff --git a/k2/csrc/determinize.h b/k2/csrc/determinize.h index 8235dd10e..f0ac82e05 100644 --- a/k2/csrc/determinize.h +++ b/k2/csrc/determinize.h @@ -20,6 +20,7 @@ #include #include "k2/csrc/fsa.h" +#include "k2/csrc/fsa_algo.h" #include "k2/csrc/util.h" #include "k2/csrc/weights.h" @@ -371,20 +372,20 @@ class DetState; template struct DetStateCompare { - bool operator()(const std::unique_ptr> &a, - const std::unique_ptr> &b) { + bool operator()(const std::shared_ptr> &a, + const std::shared_ptr> &b) { return a->forward_backward_prob < b->forward_backward_prob; } }; // Priority queue template arguments: -// item queued = unique_ptr (using pointer equality as comparison) -// container type = vector > +// item queued = shared_ptr (using pointer equality as comparison) +// container type = vector > // less-than operator = DetStateCompare (which compares the // forward_backward_prob). template using DetStatePriorityQueue = - std::priority_queue>, - std::vector>>, + std::priority_queue>, + std::vector>>, DetStateCompare>; template @@ -448,7 +449,7 @@ class DetState { const std::shared_ptr &src, int32_t incoming_arc_index, int32_t arc_weight) { auto ret = elements.insert({state_id, nullptr}); - if (!ret.second) { // No such state existed in `elements` + if (ret.second) { // No such state existed in `elements` ret.first->second = std::make_shared( state_id, src, incoming_arc_index, arc_weight); } else { // A state with this staste_id existed in `elements`. @@ -527,20 +528,6 @@ class DetState { DetStateMap *state_map, DetStatePriorityQueue *queue); - // Computes the forward-backward weight of this DetState. This is - // related to the best cost of any path through the output FSA - // that included this determinized state. I say "related to" - // because while it should be exact in the Max case, in the - // LogSum case the relationship is a bit more complicated; - // maybe just best to say that this is a weight that we use - // for pruning. - // @param [in] backward_state_weight Array, indexed by - // state in input WFSA, of the weight from this state - // to the end. (Of the best path or the sum of paths, - // depending how it was computed; this will of - // course affect the pruning). - void ComputeFbWeight(const float *backward_state_weights); - /* Normalizes this DetState by reducing seq_len to the extent possible and outputting the weight and derivative info corresponding to this @@ -597,12 +584,13 @@ int32_t DetState::ProcessArcs( iter->second = new DetState(seq_len + 1); } DetState *det_state = iter->second; - det_state->AcceptIncomingArc(state_id, state_ptr, a, weight); + det_state->AcceptIncomingArc(arc.dest_state, state_ptr, a, weight); } } CHECK(!label_to_state.empty() || - elements[0]->state_id == fsa.FinalState()); // I'm assuming the input - // FSA is connected. + elements.begin()->second->state_id == + fsa.FinalState()); // I'm assuming the input + // FSA is connected. // The following loop normalizes successor det-states, outputs the arcs // that lead to them, and adds them to the queue if necessary. @@ -743,8 +731,10 @@ class DetStateMap { std::pair *vec) { assert(d.normalized); - uint64_t a = d.state_id + 17489 * d.seq_len, - b = d.state_id * 103979 + d.seq_len; + int32_t input_state_id = d.elements.begin()->first; + + uint64_t a = input_state_id + 17489 * d.seq_len, + b = input_state_id * 103979 + d.seq_len; // We choose an arbitrary DetStateElement (the first one in the list) to // read the symbol sequence from; the symbol sequence will be the same no @@ -766,8 +756,10 @@ class DetStateMap { const Fsa &fsa, std::pair *vec) { assert(d.normalized); - uint64_t a = d.state_id + 17489 * d.seq_len, - b = d.state_id * 103979 + d.seq_len; + int32_t input_state_id = d.elements.begin()->first; + + uint64_t a = input_state_id + 17489 * d.seq_len, + b = input_state_id * 103979 + d.seq_len; // We choose an arbitrary DetStateElement (the first one in the list) to // read the symbol sequence from; the symbol sequence will be the same no @@ -799,14 +791,13 @@ float DeterminizePrunedTpl( std::vector> *arc_derivs_out) { CHECK_GT(beam, 0); - CHECK(IsDeterministic(wfsa_in.fsa)); CHECK(!IsEmpty(wfsa_in.fsa)); DetStatePriorityQueue queue; DetStateMap map; using DS = DetState; - std::shared_ptr start_state = std::make_shared(); + std::shared_ptr start_state(new DS()); std::vector arcs_out; arc_weights_out->clear(); @@ -822,13 +813,15 @@ float DeterminizePrunedTpl( double total_prob = wfsa_in.BackwardStateWeights()[0], prune_cutoff = total_prob - beam; + queue.push(std::move(start_state)); while (num_steps < max_step && !queue.empty()) { - std::shared_ptr state(queue.top().get()); + std::shared_ptr state(queue.top()); queue.pop(); num_steps += state->ProcessArcs(wfsa_in, prune_cutoff, &arcs_out, arc_weights_out, arc_derivs_out, &map, &queue); } + CreateFsa(arcs_out, fsa_out); if (!queue.empty()) { // We stopped early due to max_step return total_prob - queue.top()->forward_backward_prob; } else { diff --git a/k2/csrc/fsa_algo_test.cc b/k2/csrc/fsa_algo_test.cc index 9ee84715d..e956eb43b 100644 --- a/k2/csrc/fsa_algo_test.cc +++ b/k2/csrc/fsa_algo_test.cc @@ -502,6 +502,53 @@ TEST(FsaAlgo, TopSort) { } } +class DeterminizeTest : public ::testing::Test { + protected: + DeterminizeTest() { + std::vector arcs = {{0, 4, 1}, {0, 1, 1}, {1, 2, 2}, {1, 3, 3}, + {2, 7, 1}, {3, 7, 1}, {4, 6, 1}, {4, 6, 1}, + {4, 5, 1}, {4, 8, -1}, {5, 8, -1}, {6, 8, -1}, + {7, 8, -1}}; + fsa_ = new Fsa(std::move(arcs), 8); + num_states_ = fsa_->NumStates(); + + auto num_arcs = fsa_->arcs.size(); + arc_weights_ = new float[num_arcs]; + std::vector weights = {1, 1, 2, 3, 4, 5, 2, 3, 3, 2, 4, 3, 5}; + std::copy_n(weights.begin(), num_arcs, arc_weights_); + + max_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kMaxWeight); + log_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kLogSumWeight); + } + + ~DeterminizeTest() { + delete fsa_; + delete[] arc_weights_; + delete max_wfsa_; + delete log_wfsa_; + } + + WfsaWithFbWeights *max_wfsa_; + WfsaWithFbWeights *log_wfsa_; + Fsa *fsa_; + int32_t num_states_; + float *arc_weights_; +}; + +TEST_F(DeterminizeTest, DeterminizePrunedMax) { + Fsa b; + std::vector b_arc_weights; + std::vector> arc_derivs; + DeterminizePrunedMax(*max_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs); +} + +TEST_F(DeterminizeTest, DeterminizePrunedLogSum) { + Fsa b; + std::vector b_arc_weights; + std::vector>> arc_derivs; + DeterminizePrunedLogSum(*log_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs); +} + TEST(FsaAlgo, CreateFsa) { { // clang-format off From d3b0ce82cafb03d880d60b02980b7d306bf95bcd Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Sat, 16 May 2020 20:49:31 +0800 Subject: [PATCH 2/2] fix hash issue --- k2/csrc/determinize.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/k2/csrc/determinize.h b/k2/csrc/determinize.h index f0ac82e05..a4f6553b8 100644 --- a/k2/csrc/determinize.h +++ b/k2/csrc/determinize.h @@ -731,10 +731,7 @@ class DetStateMap { std::pair *vec) { assert(d.normalized); - int32_t input_state_id = d.elements.begin()->first; - - uint64_t a = input_state_id + 17489 * d.seq_len, - b = input_state_id * 103979 + d.seq_len; + uint64_t a = 17489 * d.seq_len, b = d.seq_len; // We choose an arbitrary DetStateElement (the first one in the list) to // read the symbol sequence from; the symbol sequence will be the same no @@ -748,6 +745,9 @@ class DetStateMap { b = symbol + 102983 * b; elem = elem->prev_state; } + // This is `base_state`: the state from which we + // start (and accept the specified symbol sequence). + a = elem->state_id + 14051 * a; vec->first = a; vec->second = b; } @@ -756,10 +756,7 @@ class DetStateMap { const Fsa &fsa, std::pair *vec) { assert(d.normalized); - int32_t input_state_id = d.elements.begin()->first; - - uint64_t a = input_state_id + 17489 * d.seq_len, - b = input_state_id * 103979 + d.seq_len; + uint64_t a = 17489 * d.seq_len, b = d.seq_len; // We choose an arbitrary DetStateElement (the first one in the list) to // read the symbol sequence from; the symbol sequence will be the same no @@ -773,6 +770,9 @@ class DetStateMap { b = symbol + 102983 * b; elem = elem->prev_elements[0].prev_state; } + // This is `base_state`: the state from which we + // start (and accept the specified symbol sequence). + a = elem->state_id + 14051 * a; vec->first = a; vec->second = b; }