Skip to content

Commit

Permalink
Construct FSA from un-sorted list of arcs (fix #31). (#32)
Browse files Browse the repository at this point in the history
* Construct FSA from un-sorted list of arcs (fix #31).

* fix style check.

* replace std::unqiue_ptr with passed pointer argument.
  • Loading branch information
csukuangfj authored May 7, 2020
1 parent 52d0b36 commit c47681e
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 41 deletions.
101 changes: 92 additions & 9 deletions k2/csrc/fsa_algo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

namespace {

static constexpr int8_t kNotVisited = 0; // a node that has not been visited
static constexpr int8_t kVisiting = 1; // a node that is under visiting
static constexpr int8_t kVisited = 2; // a node that has been visited
constexpr int8_t kNotVisited = 0; // a node that has not been visited
constexpr int8_t kVisiting = 1; // a node that is under visiting
constexpr int8_t kVisited = 2; // a node that has been visited
// depth first search state
struct DfsState {
int32_t state; // state number of the visiting node
Expand Down Expand Up @@ -364,8 +364,8 @@ bool Intersect(const Fsa &a, const Fsa &b, Fsa *c,
auto b_arc_range =
std::equal_range(b_arc_iter_begin, b_arc_iter_end, curr_a_arc,
[](const Arc &left, const Arc &right) {
return left.label < right.label;
});
return left.label < right.label;
});
for (ArcIterator it_b = b_arc_range.first; it_b != b_arc_range.second;
++it_b) {
Arc curr_b_arc = *it_b;
Expand Down Expand Up @@ -424,12 +424,13 @@ void ArcSort(const Fsa &a, Fsa *b,
std::transform(arc_begin_iter + begin, arc_begin_iter + end,
index_begin_iter + begin,
std::back_inserter(arc_range_to_be_sorted),
[](const Arc & arc, int32_t index)
->ArcWithIndex { return std::make_pair(arc, index); });
[](const Arc &arc, int32_t index) -> ArcWithIndex {
return std::make_pair(arc, index);
});
std::sort(arc_range_to_be_sorted.begin(), arc_range_to_be_sorted.end(),
[](const ArcWithIndex &left, const ArcWithIndex &right) {
return left.first < right.first; // sort on arc
});
return left.first < right.first; // sort on arc
});
// copy index mappings back to `indexes`
std::transform(arc_range_to_be_sorted.begin(), arc_range_to_be_sorted.end(),
index_begin_iter + begin,
Expand Down Expand Up @@ -540,4 +541,86 @@ bool TopSort(const Fsa &a, Fsa *b,
return true;
}

void CreateFsa(const std::vector<Arc> &arcs, Fsa *fsa) {
CHECK_NOTNULL(fsa);
fsa->arc_indexes.clear();
fsa->arcs.clear();

if (arcs.empty()) return;

std::vector<std::vector<Arc>> vec;
for (const auto &arc : arcs) {
auto src_state = arc.src_state;
auto dest_state = arc.dest_state;
auto new_size = std::max(src_state, dest_state);
if (new_size >= vec.size()) vec.resize(new_size + 1);
vec[src_state].push_back(arc);
}

std::stack<DfsState> stack;
std::vector<char> state_status(vec.size(), kNotVisited);
std::vector<int32_t> order;

auto num_states = static_cast<int32_t>(vec.size());
for (auto i = 0; i != num_states; ++i) {
if (state_status[i] == kVisited) continue;
stack.push({i, 0, static_cast<int32_t>(vec[i].size())});
state_status[i] = kVisiting;
while (!stack.empty()) {
auto &current_state = stack.top();
auto state = current_state.state;

if (current_state.arc_begin == current_state.arc_end) {
state_status[state] = kVisited;
order.push_back(state);
stack.pop();
continue;
}

const auto &arc = vec[state][current_state.arc_begin];
auto next_state = arc.dest_state;
auto status = state_status[next_state];
switch (status) {
case kNotVisited:
state_status[next_state] = kVisiting;
stack.push(
{next_state, 0, static_cast<int32_t>(vec[next_state].size())});
++current_state.arc_begin;
break;
case kVisiting:
LOG(FATAL) << "there is a cycle: " << state << " -> " << next_state;
break;
case kVisited:
++current_state.arc_begin;
break;
default:
LOG(FATAL) << "Unreachable code is executed!";
break;
}
}
}

CHECK_EQ(num_states, static_cast<int32_t>(order.size()));

std::reverse(order.begin(), order.end());

fsa->arc_indexes.resize(num_states + 1);
fsa->arcs.reserve(arcs.size());

std::vector<int32_t> old_to_new(num_states);
for (auto i = 0; i != num_states; ++i) old_to_new[order[i]] = i;

for (auto i = 0; i != num_states; ++i) {
auto old_state = order[i];
fsa->arc_indexes[i] = static_cast<int32_t>(fsa->arcs.size());
for (auto arc : vec[old_state]) {
arc.src_state = i;
arc.dest_state = old_to_new[arc.dest_state];
fsa->arcs.push_back(arc);
}
}

fsa->arc_indexes.back() = static_cast<int32_t>(fsa->arcs.size());
}

} // namespace k2
13 changes: 13 additions & 0 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,19 @@ bool TopSort(const Fsa& a, Fsa* b, std::vector<int32_t>* state_map = nullptr);
void Determinize(const Fsa &a, Fsa *b,
std::vector<std::vector<int32_t>> *state_map);

/* Create an acyclic FSA from a list of arcs.
Arcs do not need to be pre-sorted by src_state.
If there is a cycle, it aborts.
The start state MUST be 0. The final state will be automatically determined
by topological sort.
@param [in] arcs A list of arcs.
@param [out] fsa Output fsa.
*/
void CreateFsa(const std::vector<Arc> &arcs, Fsa *fsa);

} // namespace k2

#endif // K2_CSRC_FSA_ALGO_H_
85 changes: 54 additions & 31 deletions k2/csrc/fsa_algo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "k2/csrc/fsa_renderer.h"
#include "k2/csrc/fsa_util.h"

namespace k2 {
Expand Down Expand Up @@ -173,7 +174,8 @@ TEST(FsaAlgo, Connect) {
EXPECT_THAT(b.arc_indexes, ::testing::ElementsAre(0, 2, 4, 5, 5));

std::vector<Arc> target_arcs = {
{0, 2, 1}, {0, 1, 2}, {1, 3, 3}, {1, 2, 1}, {2, 3, 6}, };
{0, 2, 1}, {0, 1, 2}, {1, 3, 3}, {1, 2, 1}, {2, 3, 6},
};
for (auto i = 0; i != target_arcs.size(); ++i)
EXPECT_EQ(b.arcs[i], target_arcs[i]);

Expand Down Expand Up @@ -300,18 +302,17 @@ TEST(FsaAlgo, Intersect) {
}

{
std::vector<Arc> arcs_a = {{0, 1, 1},
{1, 2, 0},
{1, 3, 1},
{1, 4, 2},
{2, 2, 1},
{2, 3, 1},
{2, 3, 2},
{3, 3, 0},
{3, 4, 1}};
std::vector<Arc> arcs_a = {{0, 1, 1}, {1, 2, 0}, {1, 3, 1},
{1, 4, 2}, {2, 2, 1}, {2, 3, 1},
{2, 3, 2}, {3, 3, 0}, {3, 4, 1}};
Fsa a(std::move(arcs_a), 4);

std::vector<Arc> arcs_b = {{0, 1, 1}, {1, 3, 1}, {1, 2, 2}, {2, 3, 1}, };
std::vector<Arc> arcs_b = {
{0, 1, 1},
{1, 3, 1},
{1, 2, 2},
{2, 3, 1},
};
Fsa b(std::move(arcs_b), 3);

Fsa c;
Expand All @@ -320,16 +321,10 @@ TEST(FsaAlgo, Intersect) {
bool status = Intersect(a, b, &c, &arc_map_a, &arc_map_b);
EXPECT_TRUE(status);

std::vector<Arc> arcs_c = {{0, 1, 1},
{1, 2, 0},
{1, 3, 1},
{1, 4, 2},
{2, 5, 1},
{2, 6, 1},
{2, 6, 2},
{3, 3, 0},
{6, 6, 0},
{6, 7, 1}, };
std::vector<Arc> arcs_c = {
{0, 1, 1}, {1, 2, 0}, {1, 3, 1}, {1, 4, 2}, {2, 5, 1},
{2, 6, 1}, {2, 6, 2}, {3, 3, 0}, {6, 6, 0}, {6, 7, 1},
};
std::vector<int32_t> arc_indexes_c = {0, 1, 4, 6, 8, 8, 8, 10, 10};

ASSERT_EQ(c.arc_indexes.size(), arc_indexes_c.size());
Expand Down Expand Up @@ -386,7 +381,8 @@ TEST(FsaAlgo, ArcSort) {

{
std::vector<Arc> arcs = {
{0, 1, 2}, {0, 4, 0}, {0, 2, 0}, {1, 2, 1}, {1, 3, 0}, {2, 1, 0}, };
{0, 1, 2}, {0, 4, 0}, {0, 2, 0}, {1, 2, 1}, {1, 3, 0}, {2, 1, 0},
};
Fsa fsa(std::move(arcs), 4);
Fsa arc_sorted;
std::vector<int32_t> arc_map;
Expand All @@ -395,7 +391,8 @@ TEST(FsaAlgo, ArcSort) {
::testing::ElementsAre(0, 3, 5, 6, 6, 6));
ASSERT_EQ(arc_sorted.arcs.size(), fsa.arcs.size());
std::vector<Arc> target_arcs = {
{0, 2, 0}, {0, 4, 0}, {0, 1, 2}, {1, 3, 0}, {1, 2, 1}, {2, 1, 0}, };
{0, 2, 0}, {0, 4, 0}, {0, 1, 2}, {1, 3, 0}, {1, 2, 1}, {2, 1, 0},
};
for (std::size_t i = 0; i != target_arcs.size(); ++i)
EXPECT_EQ(arc_sorted.arcs[i], target_arcs[i]);

Expand Down Expand Up @@ -494,19 +491,45 @@ TEST(FsaAlgo, TopSort) {

ASSERT_EQ(arc_indexes.size(), 8u);
EXPECT_THAT(arc_indexes, ::testing::ElementsAre(0, 2, 3, 4, 5, 7, 8, 8));
std::vector<Arc> expected_arcs = {{0, 1, 40},
{0, 3, 20},
{1, 2, 50},
{2, 3, 8},
{3, 4, 30},
{4, 6, 60},
{4, 5, 10},
{5, 6, 2}, };
std::vector<Arc> expected_arcs = {
{0, 1, 40}, {0, 3, 20}, {1, 2, 50}, {2, 3, 8},
{3, 4, 30}, {4, 6, 60}, {4, 5, 10}, {5, 6, 2},
};

for (auto i = 0; i != 8; ++i) {
EXPECT_EQ(arcs[i], expected_arcs[i]);
}
}
}

TEST(FsaAlgo, CreateFsa) {
{
// clang-format off
std::vector<Arc> arcs = {
{0, 3, 3},
{0, 2, 2},
{2, 3, 3},
{2, 4, 4},
{3, 1, 1},
{1, 4, 4},
{1, 8, 8},
{4, 8, 8},
{8, 6, 6},
{8, 7, 7},
{6, 7, 7},
{7, 5, 5},
};
// clang-format on
Fsa a;
CreateFsa(arcs, &a);

auto num_states = a.NumStates();

Fsa b;
Swap(&a, &b);
EXPECT_EQ(a.NumStates(), 0);
EXPECT_EQ(b.NumStates(), num_states);
}
}

} // namespace k2
7 changes: 7 additions & 0 deletions k2/csrc/fsa_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ void GetEnteringArcs(const Fsa &fsa, std::vector<int32_t> *arc_index,
}
}

void Swap(Fsa *a, Fsa *b) {
CHECK_NOTNULL(a);
CHECK_NOTNULL(b);
std::swap(a->arc_indexes, b->arc_indexes);
std::swap(a->arcs, b->arcs);
}

std::unique_ptr<Fsa> StringToFsa(const std::string &s) {
static constexpr const char *kDelim = " \t";

Expand Down
4 changes: 3 additions & 1 deletion k2/csrc/fsa_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ void ConvertIndexes1(const std::vector<int32_t> &arc_map, int64_t *indexes_out);
total number of int32's in arc_map, will contain
arc-indexes in the output FSA
*/
void GetArcIndexes2(const std::vector<std::vector<int32_t> > &arc_map,
void GetArcIndexes2(const std::vector<std::vector<int32_t>> &arc_map,
std::vector<int64_t> *indexes1,
std::vector<int64_t> *indexes2);

void Swap(Fsa *a, Fsa *b);

/** Build a FSA from a string.
The input string is a transition table with the following
Expand Down

0 comments on commit c47681e

Please sign in to comment.