From 1e4efad313c3ef5c0eb2865807dfdb206750fad8 Mon Sep 17 00:00:00 2001 From: Haowen Qiu Date: Mon, 27 Apr 2020 19:19:15 +0800 Subject: [PATCH] implement randomPath --- k2/csrc/CMakeLists.txt | 19 ++++++ k2/csrc/fsa.h | 18 +++++- k2/csrc/fsa_equivalent.cc | 80 +++++++++++++++++++++++++ k2/csrc/{tests.h => fsa_equivalent.h} | 19 +++--- k2/csrc/fsa_equivalent_test.cc | 86 +++++++++++++++++++++++++++ k2/csrc/util.h | 24 ++++++++ 6 files changed, 236 insertions(+), 10 deletions(-) create mode 100644 k2/csrc/fsa_equivalent.cc rename k2/csrc/{tests.h => fsa_equivalent.h} (53%) create mode 100644 k2/csrc/fsa_equivalent_test.cc create mode 100644 k2/csrc/util.h diff --git a/k2/csrc/CMakeLists.txt b/k2/csrc/CMakeLists.txt index 273c4773d..5138c0637 100644 --- a/k2/csrc/CMakeLists.txt +++ b/k2/csrc/CMakeLists.txt @@ -2,6 +2,11 @@ add_library(properties properties.cc) target_include_directories(properties PUBLIC ${CMAKE_SOURCE_DIR}) target_compile_features(properties PUBLIC cxx_std_11) +add_library(fsa_equivalent fsa_equivalent.cc) +target_include_directories(fsa_equivalent PUBLIC ${CMAKE_SOURCE_DIR}) +target_compile_features(fsa_equivalent PUBLIC cxx_std_11) +target_link_libraries(fsa_equivalent PUBLIC properties) + add_library(fsa_util fsa_util.cc) target_include_directories(fsa_util PUBLIC ${CMAKE_SOURCE_DIR}) target_compile_features(fsa_util PUBLIC cxx_std_11) @@ -24,6 +29,20 @@ add_test(NAME Test.properties_test $ ) +add_executable(fsa_equivalent_test fsa_equivalent_test.cc) + +target_link_libraries(fsa_equivalent_test + PRIVATE + fsa_equivalent + gtest + gtest_main +) + +add_test(NAME Test.fsa_equivalent_test + COMMAND + $ +) + add_executable(fsa_util_test fsa_util_test.cc) target_link_libraries(fsa_util_test diff --git a/k2/csrc/fsa.h b/k2/csrc/fsa.h index 98b2e08d3..624552197 100644 --- a/k2/csrc/fsa.h +++ b/k2/csrc/fsa.h @@ -8,9 +8,10 @@ #define K2_CSRC_FSA_H_ #include -#include #include +#include "k2/csrc/util.h" + namespace k2 { using Label = int32_t; @@ -37,6 +38,21 @@ struct Arc { /* Note: the costs are not stored here but outside the Fst object, in some kind of array indexed by arc-index. */ + + bool operator==(const Arc &other) const { + return src_state == other.src_state && dest_state == other.dest_state && + label == other.label; + } +}; + +struct ArcHash { + std::size_t operator()(const Arc &arc) const noexcept { + std::size_t result = 0; + hash_combine(&result, arc.src_state); + hash_combine(&result, arc.dest_state); + hash_combine(&result, arc.label); + return result; + } }; struct ArcLabelCompare { diff --git a/k2/csrc/fsa_equivalent.cc b/k2/csrc/fsa_equivalent.cc new file mode 100644 index 000000000..7956b6986 --- /dev/null +++ b/k2/csrc/fsa_equivalent.cc @@ -0,0 +1,80 @@ +// k2/csrc/fsa_equivalent.cc + +// Copyright (c) 2020 Haowen Qiu + +// See ../../LICENSE for clarification regarding multiple authors + +#include "k2/csrc/fsa_equivalent.h" + +#include +#include +#include +#include +#include + +#include "k2/csrc/fsa.h" +#include "k2/csrc/properties.h" + +namespace k2 { + +bool RandomPath(const Fsa &a, Fsa *b, + std::vector *state_map /*=nullptr*/) { + if (IsEmpty(a) || b == nullptr) return false; + // we cannot do `connect` on `a` here to get a connected fsa + // as `state_map` will map to states in the connected fsa + // instead of in `a` if we do that. + if (!IsConnected(a)) return false; + + int32_t num_states = a.NumStates(); + std::vector state_map_b2a; + std::vector state_map_a2b(num_states, -1); + // `visited_arcs[i]` stores `arcs` leaving from state `i` in `b` + std::vector> visited_arcs; + + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(0); + + int32_t num_visited_arcs = 0; + int32_t num_visited_state = 0; + int32_t state = 0; + int32_t final_state = num_states - 1; + while (true) { + if (state_map_a2b[state] == -1) { + state_map_a2b[state] = num_visited_state; + state_map_b2a.push_back(state); + visited_arcs.push_back(std::unordered_set()); + ++num_visited_state; + } + if (state == final_state) break; + int32_t begin = a.arc_indexes[state]; + int32_t end = a.arc_indexes[state + 1]; + // since `a` is valid, so every states contains at least one arc. + int32_t arc_index = begin + (distribution(generator) % (end - begin)); + int32_t state_id_in_b = state_map_a2b[state]; + const auto &curr_arc = a.arcs[arc_index]; + if (visited_arcs[state_id_in_b].insert(curr_arc).second) ++num_visited_arcs; + state = curr_arc.dest_state; + } + + // create `b` + b->arc_indexes.resize(num_visited_state); + b->arcs.resize(num_visited_arcs); + int32_t n = 0; + for (int32_t i = 0; i < num_visited_state; ++i) { + b->arc_indexes[i] = n; + for (const auto &arc : visited_arcs[i]) { + auto &b_arc = b->arcs[n]; + b_arc.src_state = i; + b_arc.dest_state = state_map_a2b[arc.dest_state]; + b_arc.label = arc.label; + ++n; + } + } + if (state_map != nullptr) { + state_map->swap(state_map_b2a); + } + return true; +} + +} // namespace k2 diff --git a/k2/csrc/tests.h b/k2/csrc/fsa_equivalent.h similarity index 53% rename from k2/csrc/tests.h rename to k2/csrc/fsa_equivalent.h index 8128715c6..21f9e82a5 100644 --- a/k2/csrc/tests.h +++ b/k2/csrc/fsa_equivalent.h @@ -1,19 +1,16 @@ -// k2/csrc/tests.h +// k2/csrc/fsa_equivalent.h // Copyright (c) 2020 Daniel Povey // See ../../LICENSE for clarification regarding multiple authors -// TODO(fangjun): rename this file -// since tests.h is not a good name - #include #include #include "k2/csrc/fsa.h" -#ifndef K2_CSRC_TESTS_H_ -#define K2_CSRC_TESTS_H_ +#ifndef K2_CSRC_FSA_EQUIVALENT_H_ +#define K2_CSRC_FSA_EQUIVALENT_H_ namespace k2 { @@ -23,9 +20,13 @@ namespace k2 { */ bool IsEquivalent(const Fsa &a, const Fsa &b); -/* Gets a random path from an Fsa `a` */ -void RandomPath(const Fsa &a, Fsa *b, std::vector *state_map = NULL); +/* + Gets a random path from an Fsa `a`, returns true if we get one path + successfully. +*/ +bool RandomPath(const Fsa &a, Fsa *b, + std::vector *state_map = nullptr); } // namespace k2 -#endif // K2_CSRC_TESTS_H_ +#endif // K2_CSRC_FSA_EQUIVALENT_H_ diff --git a/k2/csrc/fsa_equivalent_test.cc b/k2/csrc/fsa_equivalent_test.cc new file mode 100644 index 000000000..d740a7604 --- /dev/null +++ b/k2/csrc/fsa_equivalent_test.cc @@ -0,0 +1,86 @@ +// k2/csrc/fsa_equivalent_test.cc + +// Copyright (c) 2020 Haowen Qiu + +// See ../../LICENSE for clarification regarding multiple authors + +#include "k2/csrc/fsa_equivalent.h" + +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "k2/csrc/fsa.h" + +namespace k2 { + +TEST(Properties, RandomPathFail) { + { + Fsa fsa; + Fsa path; + bool status = RandomPath(fsa, &path); + EXPECT_FALSE(status); + } + // TODO(haowen): add tests for non-connected fsa +} + +TEST(Properties, RandomPathSuccess) { + { + Fsa fsa; + std::vector arcs = { + {0, 1, 1}, {0, 2, 2}, {1, 2, 3}, {2, 3, 4}, + {2, 4, 5}, {3, 4, 7}, {4, 5, 9}, + }; + std::vector arc_indexes = {0, 2, 3, 5, 6, 7}; + fsa.arc_indexes = std::move(arc_indexes); + fsa.arcs = std::move(arcs); + Fsa path; + + { + bool status = RandomPath(fsa, &path); + EXPECT_TRUE(status); + } + + { + std::vector state_map; + for (auto i = 0; i != 20; ++i) { + bool status = RandomPath(fsa, &path, &state_map); + EXPECT_TRUE(status); + EXPECT_GT(state_map.size(), 0); + } + } + } + + // test with linear structure fsa to check the resulted path + { + Fsa fsa; + std::vector arcs = { + {0, 1, 1}, + {1, 2, 3}, + {2, 3, 4}, + }; + std::vector arc_indexes = {0, 1, 2, 3}; + fsa.arc_indexes = std::move(arc_indexes); + fsa.arcs = std::move(arcs); + Fsa path; + + std::vector state_map; + bool status = RandomPath(fsa, &path, &state_map); + EXPECT_TRUE(status); + ASSERT_EQ(fsa.arcs.size(), path.arcs.size()); + EXPECT_TRUE(fsa.arcs == path.arcs); + ASSERT_EQ(fsa.arc_indexes.size(), path.arc_indexes.size()); + EXPECT_TRUE(fsa.arc_indexes == path.arc_indexes); + EXPECT_THAT(state_map, ::testing::ElementsAre(0, 1, 2, 3)); + } + + // TODO(haowen): add tests for non-connected fsa + std::vector arcs = { + {0, 1, 1}, {0, 2, 2}, {1, 2, 3}, {2, 3, 4}, {2, 4, 5}, + {3, 1, 6}, {3, 4, 7}, {4, 3, 8}, {4, 5, 9}, + }; + std::vector arc_indexes = {0, 2, 3, 5, 7, 9}; +} +} // namespace k2 diff --git a/k2/csrc/util.h b/k2/csrc/util.h new file mode 100644 index 000000000..a8bd57554 --- /dev/null +++ b/k2/csrc/util.h @@ -0,0 +1,24 @@ +// k2/csrc/util.h + +// Copyright (c) 2020 Haowen Qiu + +// See ../../LICENSE for clarification regarding multiple authors + +#ifndef K2_CSRC_UTIL_H_ +#define K2_CSRC_UTIL_H_ + +#include + +#include "k2/csrc/fsa.h" + +namespace k2 { + +// boost::hash_combine +template +inline void hash_combine(std::size_t *seed, const T &v) { + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); +} + +} // namespace k2 +#endif // K2_CSRC_UTIL_H_