From c47681e62e241b6a5e1cef153ee301e91703c29b Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 7 May 2020 22:26:29 +0800 Subject: [PATCH] Construct FSA from un-sorted list of arcs (fix #31). (#32) * Construct FSA from un-sorted list of arcs (fix #31). * fix style check. * replace std::unqiue_ptr with passed pointer argument. --- k2/csrc/fsa_algo.cc | 101 +++++++++++++++++++++++++++++++++++---- k2/csrc/fsa_algo.h | 13 +++++ k2/csrc/fsa_algo_test.cc | 85 ++++++++++++++++++++------------ k2/csrc/fsa_util.cc | 7 +++ k2/csrc/fsa_util.h | 4 +- 5 files changed, 169 insertions(+), 41 deletions(-) diff --git a/k2/csrc/fsa_algo.cc b/k2/csrc/fsa_algo.cc index a20e5938a..2fe49b00c 100644 --- a/k2/csrc/fsa_algo.cc +++ b/k2/csrc/fsa_algo.cc @@ -21,9 +21,9 @@ namespace { -static constexpr int8_t kNotVisited = 0; // a node that has not been visited -static constexpr int8_t kVisiting = 1; // a node that is under visiting -static constexpr int8_t kVisited = 2; // a node that has been visited +constexpr int8_t kNotVisited = 0; // a node that has not been visited +constexpr int8_t kVisiting = 1; // a node that is under visiting +constexpr int8_t kVisited = 2; // a node that has been visited // depth first search state struct DfsState { int32_t state; // state number of the visiting node @@ -364,8 +364,8 @@ bool Intersect(const Fsa &a, const Fsa &b, Fsa *c, auto b_arc_range = std::equal_range(b_arc_iter_begin, b_arc_iter_end, curr_a_arc, [](const Arc &left, const Arc &right) { - return left.label < right.label; - }); + return left.label < right.label; + }); for (ArcIterator it_b = b_arc_range.first; it_b != b_arc_range.second; ++it_b) { Arc curr_b_arc = *it_b; @@ -424,12 +424,13 @@ void ArcSort(const Fsa &a, Fsa *b, std::transform(arc_begin_iter + begin, arc_begin_iter + end, index_begin_iter + begin, std::back_inserter(arc_range_to_be_sorted), - [](const Arc & arc, int32_t index) - ->ArcWithIndex { return std::make_pair(arc, index); }); + [](const Arc &arc, int32_t index) -> ArcWithIndex { + return std::make_pair(arc, index); + }); std::sort(arc_range_to_be_sorted.begin(), arc_range_to_be_sorted.end(), [](const ArcWithIndex &left, const ArcWithIndex &right) { - return left.first < right.first; // sort on arc - }); + return left.first < right.first; // sort on arc + }); // copy index mappings back to `indexes` std::transform(arc_range_to_be_sorted.begin(), arc_range_to_be_sorted.end(), index_begin_iter + begin, @@ -540,4 +541,86 @@ bool TopSort(const Fsa &a, Fsa *b, return true; } +void CreateFsa(const std::vector &arcs, Fsa *fsa) { + CHECK_NOTNULL(fsa); + fsa->arc_indexes.clear(); + fsa->arcs.clear(); + + if (arcs.empty()) return; + + std::vector> vec; + for (const auto &arc : arcs) { + auto src_state = arc.src_state; + auto dest_state = arc.dest_state; + auto new_size = std::max(src_state, dest_state); + if (new_size >= vec.size()) vec.resize(new_size + 1); + vec[src_state].push_back(arc); + } + + std::stack stack; + std::vector state_status(vec.size(), kNotVisited); + std::vector order; + + auto num_states = static_cast(vec.size()); + for (auto i = 0; i != num_states; ++i) { + if (state_status[i] == kVisited) continue; + stack.push({i, 0, static_cast(vec[i].size())}); + state_status[i] = kVisiting; + while (!stack.empty()) { + auto ¤t_state = stack.top(); + auto state = current_state.state; + + if (current_state.arc_begin == current_state.arc_end) { + state_status[state] = kVisited; + order.push_back(state); + stack.pop(); + continue; + } + + const auto &arc = vec[state][current_state.arc_begin]; + auto next_state = arc.dest_state; + auto status = state_status[next_state]; + switch (status) { + case kNotVisited: + state_status[next_state] = kVisiting; + stack.push( + {next_state, 0, static_cast(vec[next_state].size())}); + ++current_state.arc_begin; + break; + case kVisiting: + LOG(FATAL) << "there is a cycle: " << state << " -> " << next_state; + break; + case kVisited: + ++current_state.arc_begin; + break; + default: + LOG(FATAL) << "Unreachable code is executed!"; + break; + } + } + } + + CHECK_EQ(num_states, static_cast(order.size())); + + std::reverse(order.begin(), order.end()); + + fsa->arc_indexes.resize(num_states + 1); + fsa->arcs.reserve(arcs.size()); + + std::vector old_to_new(num_states); + for (auto i = 0; i != num_states; ++i) old_to_new[order[i]] = i; + + for (auto i = 0; i != num_states; ++i) { + auto old_state = order[i]; + fsa->arc_indexes[i] = static_cast(fsa->arcs.size()); + for (auto arc : vec[old_state]) { + arc.src_state = i; + arc.dest_state = old_to_new[arc.dest_state]; + fsa->arcs.push_back(arc); + } + } + + fsa->arc_indexes.back() = static_cast(fsa->arcs.size()); +} + } // namespace k2 diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 6ed53060b..b9aab345c 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -258,6 +258,19 @@ bool TopSort(const Fsa& a, Fsa* b, std::vector* state_map = nullptr); void Determinize(const Fsa &a, Fsa *b, std::vector> *state_map); +/* Create an acyclic FSA from a list of arcs. + + Arcs do not need to be pre-sorted by src_state. + If there is a cycle, it aborts. + + The start state MUST be 0. The final state will be automatically determined + by topological sort. + + @param [in] arcs A list of arcs. + @param [out] fsa Output fsa. +*/ +void CreateFsa(const std::vector &arcs, Fsa *fsa); + } // namespace k2 #endif // K2_CSRC_FSA_ALGO_H_ diff --git a/k2/csrc/fsa_algo_test.cc b/k2/csrc/fsa_algo_test.cc index aa934dc56..9ee84715d 100644 --- a/k2/csrc/fsa_algo_test.cc +++ b/k2/csrc/fsa_algo_test.cc @@ -13,6 +13,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "k2/csrc/fsa_renderer.h" #include "k2/csrc/fsa_util.h" namespace k2 { @@ -173,7 +174,8 @@ TEST(FsaAlgo, Connect) { EXPECT_THAT(b.arc_indexes, ::testing::ElementsAre(0, 2, 4, 5, 5)); std::vector target_arcs = { - {0, 2, 1}, {0, 1, 2}, {1, 3, 3}, {1, 2, 1}, {2, 3, 6}, }; + {0, 2, 1}, {0, 1, 2}, {1, 3, 3}, {1, 2, 1}, {2, 3, 6}, + }; for (auto i = 0; i != target_arcs.size(); ++i) EXPECT_EQ(b.arcs[i], target_arcs[i]); @@ -300,18 +302,17 @@ TEST(FsaAlgo, Intersect) { } { - std::vector arcs_a = {{0, 1, 1}, - {1, 2, 0}, - {1, 3, 1}, - {1, 4, 2}, - {2, 2, 1}, - {2, 3, 1}, - {2, 3, 2}, - {3, 3, 0}, - {3, 4, 1}}; + std::vector arcs_a = {{0, 1, 1}, {1, 2, 0}, {1, 3, 1}, + {1, 4, 2}, {2, 2, 1}, {2, 3, 1}, + {2, 3, 2}, {3, 3, 0}, {3, 4, 1}}; Fsa a(std::move(arcs_a), 4); - std::vector arcs_b = {{0, 1, 1}, {1, 3, 1}, {1, 2, 2}, {2, 3, 1}, }; + std::vector arcs_b = { + {0, 1, 1}, + {1, 3, 1}, + {1, 2, 2}, + {2, 3, 1}, + }; Fsa b(std::move(arcs_b), 3); Fsa c; @@ -320,16 +321,10 @@ TEST(FsaAlgo, Intersect) { bool status = Intersect(a, b, &c, &arc_map_a, &arc_map_b); EXPECT_TRUE(status); - std::vector arcs_c = {{0, 1, 1}, - {1, 2, 0}, - {1, 3, 1}, - {1, 4, 2}, - {2, 5, 1}, - {2, 6, 1}, - {2, 6, 2}, - {3, 3, 0}, - {6, 6, 0}, - {6, 7, 1}, }; + std::vector arcs_c = { + {0, 1, 1}, {1, 2, 0}, {1, 3, 1}, {1, 4, 2}, {2, 5, 1}, + {2, 6, 1}, {2, 6, 2}, {3, 3, 0}, {6, 6, 0}, {6, 7, 1}, + }; std::vector arc_indexes_c = {0, 1, 4, 6, 8, 8, 8, 10, 10}; ASSERT_EQ(c.arc_indexes.size(), arc_indexes_c.size()); @@ -386,7 +381,8 @@ TEST(FsaAlgo, ArcSort) { { std::vector arcs = { - {0, 1, 2}, {0, 4, 0}, {0, 2, 0}, {1, 2, 1}, {1, 3, 0}, {2, 1, 0}, }; + {0, 1, 2}, {0, 4, 0}, {0, 2, 0}, {1, 2, 1}, {1, 3, 0}, {2, 1, 0}, + }; Fsa fsa(std::move(arcs), 4); Fsa arc_sorted; std::vector arc_map; @@ -395,7 +391,8 @@ TEST(FsaAlgo, ArcSort) { ::testing::ElementsAre(0, 3, 5, 6, 6, 6)); ASSERT_EQ(arc_sorted.arcs.size(), fsa.arcs.size()); std::vector target_arcs = { - {0, 2, 0}, {0, 4, 0}, {0, 1, 2}, {1, 3, 0}, {1, 2, 1}, {2, 1, 0}, }; + {0, 2, 0}, {0, 4, 0}, {0, 1, 2}, {1, 3, 0}, {1, 2, 1}, {2, 1, 0}, + }; for (std::size_t i = 0; i != target_arcs.size(); ++i) EXPECT_EQ(arc_sorted.arcs[i], target_arcs[i]); @@ -494,14 +491,10 @@ TEST(FsaAlgo, TopSort) { ASSERT_EQ(arc_indexes.size(), 8u); EXPECT_THAT(arc_indexes, ::testing::ElementsAre(0, 2, 3, 4, 5, 7, 8, 8)); - std::vector expected_arcs = {{0, 1, 40}, - {0, 3, 20}, - {1, 2, 50}, - {2, 3, 8}, - {3, 4, 30}, - {4, 6, 60}, - {4, 5, 10}, - {5, 6, 2}, }; + std::vector expected_arcs = { + {0, 1, 40}, {0, 3, 20}, {1, 2, 50}, {2, 3, 8}, + {3, 4, 30}, {4, 6, 60}, {4, 5, 10}, {5, 6, 2}, + }; for (auto i = 0; i != 8; ++i) { EXPECT_EQ(arcs[i], expected_arcs[i]); @@ -509,4 +502,34 @@ TEST(FsaAlgo, TopSort) { } } +TEST(FsaAlgo, CreateFsa) { + { + // clang-format off + std::vector arcs = { + {0, 3, 3}, + {0, 2, 2}, + {2, 3, 3}, + {2, 4, 4}, + {3, 1, 1}, + {1, 4, 4}, + {1, 8, 8}, + {4, 8, 8}, + {8, 6, 6}, + {8, 7, 7}, + {6, 7, 7}, + {7, 5, 5}, + }; + // clang-format on + Fsa a; + CreateFsa(arcs, &a); + + auto num_states = a.NumStates(); + + Fsa b; + Swap(&a, &b); + EXPECT_EQ(a.NumStates(), 0); + EXPECT_EQ(b.NumStates(), num_states); + } +} + } // namespace k2 diff --git a/k2/csrc/fsa_util.cc b/k2/csrc/fsa_util.cc index c4cd9fb1d..3b1e7acd3 100644 --- a/k2/csrc/fsa_util.cc +++ b/k2/csrc/fsa_util.cc @@ -129,6 +129,13 @@ void GetEnteringArcs(const Fsa &fsa, std::vector *arc_index, } } +void Swap(Fsa *a, Fsa *b) { + CHECK_NOTNULL(a); + CHECK_NOTNULL(b); + std::swap(a->arc_indexes, b->arc_indexes); + std::swap(a->arcs, b->arcs); +} + std::unique_ptr StringToFsa(const std::string &s) { static constexpr const char *kDelim = " \t"; diff --git a/k2/csrc/fsa_util.h b/k2/csrc/fsa_util.h index fe7eaa356..0a9273a54 100644 --- a/k2/csrc/fsa_util.h +++ b/k2/csrc/fsa_util.h @@ -61,10 +61,12 @@ void ConvertIndexes1(const std::vector &arc_map, int64_t *indexes_out); total number of int32's in arc_map, will contain arc-indexes in the output FSA */ -void GetArcIndexes2(const std::vector > &arc_map, +void GetArcIndexes2(const std::vector> &arc_map, std::vector *indexes1, std::vector *indexes2); +void Swap(Fsa *a, Fsa *b); + /** Build a FSA from a string. The input string is a transition table with the following