Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement randomPath #16

Merged
merged 1 commit into from
Apr 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ add_library(properties properties.cc)
target_include_directories(properties PUBLIC ${CMAKE_SOURCE_DIR})
target_compile_features(properties PUBLIC cxx_std_11)

add_library(fsa_equivalent fsa_equivalent.cc)
target_include_directories(fsa_equivalent PUBLIC ${CMAKE_SOURCE_DIR})
target_compile_features(fsa_equivalent PUBLIC cxx_std_11)
target_link_libraries(fsa_equivalent PUBLIC properties)

add_library(fsa_util fsa_util.cc)
target_include_directories(fsa_util PUBLIC ${CMAKE_SOURCE_DIR})
target_compile_features(fsa_util PUBLIC cxx_std_11)
Expand All @@ -24,6 +29,20 @@ add_test(NAME Test.properties_test
$<TARGET_FILE:properties_test>
)

add_executable(fsa_equivalent_test fsa_equivalent_test.cc)

target_link_libraries(fsa_equivalent_test
PRIVATE
fsa_equivalent
gtest
gtest_main
)

add_test(NAME Test.fsa_equivalent_test
COMMAND
$<TARGET_FILE:fsa_equivalent_test>
)

add_executable(fsa_util_test fsa_util_test.cc)

target_link_libraries(fsa_util_test
Expand Down
18 changes: 17 additions & 1 deletion k2/csrc/fsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
#define K2_CSRC_FSA_H_

#include <cstdint>
#include <utility>
#include <vector>

#include "k2/csrc/util.h"
qindazhu marked this conversation as resolved.
Show resolved Hide resolved

namespace k2 {

using Label = int32_t;
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -37,6 +38,21 @@ struct Arc {

/* Note: the costs are not stored here but outside the Fst object, in some
kind of array indexed by arc-index. */

bool operator==(const Arc &other) const {
return src_state == other.src_state && dest_state == other.dest_state &&
label == other.label;
}
};

struct ArcHash {
std::size_t operator()(const Arc &arc) const noexcept {
std::size_t result = 0;
hash_combine(&result, arc.src_state);
hash_combine(&result, arc.dest_state);
hash_combine(&result, arc.label);
return result;
}
};

struct ArcLabelCompare {
Expand Down
80 changes: 80 additions & 0 deletions k2/csrc/fsa_equivalent.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// k2/csrc/fsa_equivalent.cc

// Copyright (c) 2020 Haowen Qiu

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/fsa_equivalent.h"

#include <algorithm>
#include <random>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "k2/csrc/fsa.h"
#include "k2/csrc/properties.h"

namespace k2 {

bool RandomPath(const Fsa &a, Fsa *b,
std::vector<int32_t> *state_map /*=nullptr*/) {
if (IsEmpty(a) || b == nullptr) return false;
// we cannot do `connect` on `a` here to get a connected fsa
// as `state_map` will map to states in the connected fsa
// instead of in `a` if we do that.
if (!IsConnected(a)) return false;

int32_t num_states = a.NumStates();
std::vector<int32_t> state_map_b2a;
std::vector<int32_t> state_map_a2b(num_states, -1);
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
// `visited_arcs[i]` stores `arcs` leaving from state `i` in `b`
std::vector<std::unordered_set<Arc, ArcHash>> visited_arcs;
qindazhu marked this conversation as resolved.
Show resolved Hide resolved

std::random_device rd;
std::mt19937 generator(rd());
std::uniform_int_distribution<int32_t> distribution(0);

int32_t num_visited_arcs = 0;
int32_t num_visited_state = 0;
int32_t state = 0;
int32_t final_state = num_states - 1;
while (true) {
if (state_map_a2b[state] == -1) {
state_map_a2b[state] = num_visited_state;
state_map_b2a.push_back(state);
visited_arcs.push_back(std::unordered_set<Arc, ArcHash>());
++num_visited_state;
}
if (state == final_state) break;
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
int32_t begin = a.arc_indexes[state];
int32_t end = a.arc_indexes[state + 1];
// since `a` is valid, so every states contains at least one arc.
int32_t arc_index = begin + (distribution(generator) % (end - begin));
int32_t state_id_in_b = state_map_a2b[state];
const auto &curr_arc = a.arcs[arc_index];
if (visited_arcs[state_id_in_b].insert(curr_arc).second) ++num_visited_arcs;
state = curr_arc.dest_state;
}

// create `b`
b->arc_indexes.resize(num_visited_state);
b->arcs.resize(num_visited_arcs);
int32_t n = 0;
for (int32_t i = 0; i < num_visited_state; ++i) {
b->arc_indexes[i] = n;
for (const auto &arc : visited_arcs[i]) {
auto &b_arc = b->arcs[n];
b_arc.src_state = i;
b_arc.dest_state = state_map_a2b[arc.dest_state];
b_arc.label = arc.label;
++n;
}
}
if (state_map != nullptr) {
state_map->swap(state_map_b2a);
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
}
return true;
}

} // namespace k2
19 changes: 10 additions & 9 deletions k2/csrc/tests.h → k2/csrc/fsa_equivalent.h
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
// k2/csrc/tests.h
// k2/csrc/fsa_equivalent.h

// Copyright (c) 2020 Daniel Povey

// See ../../LICENSE for clarification regarding multiple authors

// TODO(fangjun): rename this file
// since tests.h is not a good name

#include <cstdint>
#include <vector>

#include "k2/csrc/fsa.h"

#ifndef K2_CSRC_TESTS_H_
#define K2_CSRC_TESTS_H_
#ifndef K2_CSRC_FSA_EQUIVALENT_H_
#define K2_CSRC_FSA_EQUIVALENT_H_

namespace k2 {

Expand All @@ -23,9 +20,13 @@ namespace k2 {
*/
bool IsEquivalent(const Fsa &a, const Fsa &b);

/* Gets a random path from an Fsa `a` */
void RandomPath(const Fsa &a, Fsa *b, std::vector<int32_t> *state_map = NULL);
/*
Gets a random path from an Fsa `a`, returns true if we get one path
successfully.
*/
bool RandomPath(const Fsa &a, Fsa *b,
std::vector<int32_t> *state_map = nullptr);

} // namespace k2

#endif // K2_CSRC_TESTS_H_
#endif // K2_CSRC_FSA_EQUIVALENT_H_
86 changes: 86 additions & 0 deletions k2/csrc/fsa_equivalent_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// k2/csrc/fsa_equivalent_test.cc

// Copyright (c) 2020 Haowen Qiu

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/fsa_equivalent.h"

#include <algorithm>
#include <utility>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "k2/csrc/fsa.h"

namespace k2 {

TEST(Properties, RandomPathFail) {
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
{
Fsa fsa;
Fsa path;
bool status = RandomPath(fsa, &path);
EXPECT_FALSE(status);
}
// TODO(haowen): add tests for non-connected fsa
}

TEST(Properties, RandomPathSuccess) {
{
Fsa fsa;
std::vector<Arc> arcs = {
{0, 1, 1}, {0, 2, 2}, {1, 2, 3}, {2, 3, 4},
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
{2, 4, 5}, {3, 4, 7}, {4, 5, 9},
};
std::vector<int32_t> arc_indexes = {0, 2, 3, 5, 6, 7};
fsa.arc_indexes = std::move(arc_indexes);
fsa.arcs = std::move(arcs);
Fsa path;

{
bool status = RandomPath(fsa, &path);
EXPECT_TRUE(status);
}

{
std::vector<int32_t> state_map;
for (auto i = 0; i != 20; ++i) {
bool status = RandomPath(fsa, &path, &state_map);
EXPECT_TRUE(status);
EXPECT_GT(state_map.size(), 0);
}
}
}

// test with linear structure fsa to check the resulted path
{
Fsa fsa;
std::vector<Arc> arcs = {
{0, 1, 1},
{1, 2, 3},
{2, 3, 4},
};
std::vector<int32_t> arc_indexes = {0, 1, 2, 3};
fsa.arc_indexes = std::move(arc_indexes);
fsa.arcs = std::move(arcs);
Fsa path;

std::vector<int32_t> state_map;
bool status = RandomPath(fsa, &path, &state_map);
EXPECT_TRUE(status);
ASSERT_EQ(fsa.arcs.size(), path.arcs.size());
EXPECT_TRUE(fsa.arcs == path.arcs);
ASSERT_EQ(fsa.arc_indexes.size(), path.arc_indexes.size());
EXPECT_TRUE(fsa.arc_indexes == path.arc_indexes);
EXPECT_THAT(state_map, ::testing::ElementsAre(0, 1, 2, 3));
}

// TODO(haowen): add tests for non-connected fsa
std::vector<Arc> arcs = {
{0, 1, 1}, {0, 2, 2}, {1, 2, 3}, {2, 3, 4}, {2, 4, 5},
{3, 1, 6}, {3, 4, 7}, {4, 3, 8}, {4, 5, 9},
};
std::vector<int32_t> arc_indexes = {0, 2, 3, 5, 7, 9};
}
} // namespace k2
24 changes: 24 additions & 0 deletions k2/csrc/util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// k2/csrc/util.h

// Copyright (c) 2020 Haowen Qiu

// See ../../LICENSE for clarification regarding multiple authors

#ifndef K2_CSRC_UTIL_H_
#define K2_CSRC_UTIL_H_

#include <functional>

#include "k2/csrc/fsa.h"

namespace k2 {

// boost::hash_combine
template <class T>
inline void hash_combine(std::size_t *seed, const T &v) {
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2);
}

} // namespace k2
#endif // K2_CSRC_UTIL_H_