Skip to content

Commit

Permalink
Implement GetArcIndex1/2 and return arc_map in RandFsa (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
qindazhu authored Jul 7, 2020
1 parent c5d7406 commit 53e3317
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 107 deletions.
1 change: 0 additions & 1 deletion k2/csrc/connect_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "k2/csrc/fsa.h"
#include "k2/csrc/fsa_algo.h"
#include "k2/csrc/fsa_util.h"
#include "k2/csrc/properties.h"

Expand Down
22 changes: 0 additions & 22 deletions k2/csrc/fsa_algo.h

This file was deleted.

52 changes: 28 additions & 24 deletions k2/csrc/fsa_equivalent.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,17 @@ static bool Intersect(const k2::Fsa &a, const k2::Fsa &b, k2::FsaCreator *c,
*/
static bool RandomPath(const k2::Fsa &fsa_in, bool no_eps_arc,
k2::FsaCreator *path,
std::vector<int32_t> *state_map = nullptr) {
std::vector<int32_t> *arc_map = nullptr) {
CHECK_NOTNULL(path);
k2::RandPath rand_path(fsa_in, no_eps_arc);
k2::Array2Size<int32_t> fsa_size;
rand_path.GetSizes(&fsa_size);

path->Init(fsa_size);
auto &path_fsa = path->GetFsa();
if (state_map != nullptr) state_map->resize(fsa_size.size2);
if (arc_map != nullptr) arc_map->resize(fsa_size.size2);
bool status = rand_path.GetOutput(
&path_fsa, state_map == nullptr ? nullptr : state_map->data());
&path_fsa, arc_map == nullptr ? nullptr : arc_map->data());
return status;
}

Expand Down Expand Up @@ -268,10 +268,10 @@ bool IsRandEquivalent(const Fsa &a, const float *a_weights, const Fsa &b,
::Intersect(valid_b, valid_path, &b_compose_path_storage, &arc_map_b_path);
std::vector<float> a_compose_weights(arc_map_a_path.size());
std::vector<float> b_compose_weights(arc_map_b_path.size());
GetArcWeights(valid_a_weights.data(), arc_map_a_path,
a_compose_weights.data());
GetArcWeights(valid_b_weights.data(), arc_map_b_path,
b_compose_weights.data());
GetArcWeights(valid_a_weights.data(), arc_map_a_path.data(),
arc_map_a_path.size(), a_compose_weights.data());
GetArcWeights(valid_b_weights.data(), arc_map_b_path.data(),
arc_map_b_path.size(), b_compose_weights.data());
// TODO(haowen): we may need to implement a version of `ShortestDistance`
// for non-top-sorted FSAs, but we prefer to decide this later as there's no
// such scenarios (input FSAs are not top-sorted) currently. If we finally
Expand Down Expand Up @@ -369,10 +369,10 @@ bool IsRandEquivalentAfterRmEpsPrunedLogSum(
::Intersect(valid_b, valid_path, &b_compose_path_storage, &arc_map_b_path);
std::vector<float> a_compose_weights(arc_map_a_path.size());
std::vector<float> b_compose_weights(arc_map_b_path.size());
GetArcWeights(valid_a_weights.data(), arc_map_a_path,
a_compose_weights.data());
GetArcWeights(valid_b_weights.data(), arc_map_b_path,
b_compose_weights.data());
GetArcWeights(valid_a_weights.data(), arc_map_a_path.data(),
arc_map_a_path.size(), a_compose_weights.data());
GetArcWeights(valid_b_weights.data(), arc_map_b_path.data(),
arc_map_b_path.size(), b_compose_weights.data());
// TODO(haowen): we may need to implement a version of `ShortestDistance`
// for non-top-sorted FSAs, but we prefer to decide this later as there's no
// such scenarios (input FSAs are not top-sorted) currently.
Expand Down Expand Up @@ -404,15 +404,16 @@ void RandPath::GetSizes(Array2Size<int32_t> *fsa_size) {

arc_indexes_.clear();
arcs_.clear();
state_map_.clear();
arc_map_.clear();

status_ = !IsEmpty(fsa_in_) && IsConnected(fsa_in_);
if (!status_) return;

int32_t num_states = fsa_in_.NumStates();
std::vector<int32_t> state_map_in_to_out(num_states, -1);
// `visited_arcs[i]` stores `arcs` leaving from state `i` in the output `path`
std::vector<std::unordered_set<Arc, ArcHash>> visited_arcs;
// `visited_arcs[i]` maps `arcs` leaving from state `i` in the output `path`
// to arc-index in the input FSA.
std::vector<std::unordered_map<Arc, int32_t, ArcHash>> visited_arcs;

std::random_device rd;
std::mt19937 generator(rd());
Expand All @@ -425,19 +426,19 @@ void RandPath::GetSizes(Array2Size<int32_t> *fsa_size) {
while (true) {
if (state_map_in_to_out[state] == -1) {
state_map_in_to_out[state] = num_visited_state;
state_map_.push_back(state);
visited_arcs.emplace_back(std::unordered_set<Arc, ArcHash>());
visited_arcs.emplace_back(std::unordered_map<Arc, int32_t, ArcHash>());
++num_visited_state;
}
if (state == final_state) break;
const Arc *curr_arc = nullptr;
int32_t arc_index_in = -1;
int32_t tries = 0;
do {
int32_t begin = fsa_in_.indexes[state];
int32_t end = fsa_in_.indexes[state + 1];
// since `fsa_in_` is valid, so every state contains at least one arc.
int32_t arc_index = begin + (distribution(generator) % (end - begin));
curr_arc = &fsa_in_.data[arc_index];
arc_index_in = begin + (distribution(generator) % (end - begin));
curr_arc = &fsa_in_.data[arc_index_in];
++tries;
} while (no_epsilon_arc_ && curr_arc->label == kEpsilon &&
tries < eps_arc_tries_);
Expand All @@ -448,22 +449,26 @@ void RandPath::GetSizes(Array2Size<int32_t> *fsa_size) {
}
int32_t state_id_out = state_map_in_to_out[state];
if (visited_arcs[state_id_out]
.insert({state, curr_arc->dest_state, curr_arc->label})
.insert({{state, curr_arc->dest_state, curr_arc->label},
arc_index_in - fsa_in_.indexes[0]})
.second)
++num_visited_arcs;
state = curr_arc->dest_state;
}

arc_indexes_.resize(num_visited_state);
arcs_.resize(num_visited_arcs);
arc_map_.resize(num_visited_arcs);
int32_t n = 0;
for (int32_t i = 0; i != num_visited_state; ++i) {
arc_indexes_[i] = n;
for (const auto &arc : visited_arcs[i]) {
for (const auto &arc_with_index : visited_arcs[i]) {
const auto &arc = arc_with_index.first;
auto &output_arc = arcs_[n];
output_arc.src_state = i;
output_arc.dest_state = state_map_in_to_out[arc.dest_state];
output_arc.label = arc.label;
arc_map_[n] = arc_with_index.second;
++n;
}
}
Expand All @@ -473,7 +478,7 @@ void RandPath::GetSizes(Array2Size<int32_t> *fsa_size) {
fsa_size->size2 = num_visited_arcs;
}

bool RandPath::GetOutput(Fsa *fsa_out, int32_t *state_map /*= nullptr*/) {
bool RandPath::GetOutput(Fsa *fsa_out, int32_t *arc_map /*= nullptr*/) {
CHECK_NOTNULL(fsa_out);
if (!status_) return false;

Expand All @@ -484,9 +489,8 @@ bool RandPath::GetOutput(Fsa *fsa_out, int32_t *state_map /*= nullptr*/) {
CHECK_EQ(arcs_.size(), fsa_out->size2);
std::copy(arcs_.begin(), arcs_.end(), fsa_out->data);

// output state map
if (state_map != nullptr)
std::copy(state_map_.begin(), state_map_.end(), state_map);
// output arc map
if (arc_map != nullptr) std::copy(arc_map_.begin(), arc_map_.end(), arc_map);

return true;
}
Expand Down
16 changes: 8 additions & 8 deletions k2/csrc/fsa_equivalent.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,20 @@ class RandPath {

/*
Finish the operation and output the path to `path` and
state mapping information to `state_map` (if provided).
arc mapping information to `arc_map` (if provided).
@param [out] path Output path.
Must be initialized; search for 'initialized
definition' in class Array2 in array.h for meaning.
@param [out] state_map If non-NULL, Maps from state indexes in the output
path to state indexes in the input fsa.
If non-NULL, at entry it must be allocated with
size num-states of `fsa_out`,
e.g. `fsa_out->size1`.
@param [out] arc_map If non-NULL, will output a map from the arc-index
in `fsa_out` to the corresponding arc-index in
`fsa_in`.
If non-NULL, at entry it must be allocated with
size num-arcs of `fsa_out`, e.g. `fsa_out->size2`.
@return true if it succeeds; will be false if it fails,
`fsa_out` will be empty when it fails.
*/
bool GetOutput(Fsa *fsa_out, int32_t *state_map = nullptr);
bool GetOutput(Fsa *fsa_out, int32_t *arc_map = nullptr);

private:
const Fsa &fsa_in_;
Expand All @@ -164,7 +164,7 @@ class RandPath {
bool status_;
std::vector<int32_t> arc_indexes_; // arc_index of fsa_out
std::vector<Arc> arcs_; // arcs of fsa_out
std::vector<int32_t> state_map_;
std::vector<int32_t> arc_map_;
};

} // namespace k2
Expand Down
36 changes: 18 additions & 18 deletions k2/csrc/fsa_equivalent_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ TEST(FsaEquivalent, RandomPathFail) {

FsaCreator fsa_creator_out(fsa_size);
auto &path = fsa_creator_out.GetFsa();
std::vector<int32_t> state_map(fsa_size.size1);
bool status = rand_path.GetOutput(&path, state_map.data());
std::vector<int32_t> arc_map(fsa_size.size2);
bool status = rand_path.GetOutput(&path, arc_map.data());
EXPECT_FALSE(status);
EXPECT_TRUE(state_map.empty());
EXPECT_TRUE(arc_map.empty());
}

{
Expand All @@ -215,10 +215,10 @@ TEST(FsaEquivalent, RandomPathFail) {

FsaCreator fsa_creator_out(fsa_size);
auto &path = fsa_creator_out.GetFsa();
std::vector<int32_t> state_map(fsa_size.size1);
bool status = rand_path.GetOutput(&path, state_map.data());
std::vector<int32_t> arc_map(fsa_size.size2);
bool status = rand_path.GetOutput(&path, arc_map.data());
EXPECT_FALSE(status);
EXPECT_TRUE(state_map.empty());
EXPECT_TRUE(arc_map.empty());
}
}

Expand Down Expand Up @@ -257,10 +257,10 @@ TEST(FsaEquivalent, RandomPathSuccess) {

FsaCreator fsa_creator_out(fsa_size);
auto &path = fsa_creator_out.GetFsa();
std::vector<int32_t> state_map(fsa_size.size1);
bool status = rand_path.GetOutput(&path, state_map.data());
std::vector<int32_t> arc_map(fsa_size.size2);
bool status = rand_path.GetOutput(&path, arc_map.data());
EXPECT_TRUE(status);
EXPECT_GT(state_map.size(), 0);
EXPECT_GT(arc_map.size(), 0);
}
}
}
Expand All @@ -280,8 +280,8 @@ TEST(FsaEquivalent, RandomPathSuccess) {

FsaCreator fsa_creator_out(fsa_size);
auto &path = fsa_creator_out.GetFsa();
std::vector<int32_t> state_map(fsa_size.size1);
bool status = rand_path.GetOutput(&path, state_map.data());
std::vector<int32_t> arc_map(fsa_size.size2);
bool status = rand_path.GetOutput(&path, arc_map.data());

EXPECT_TRUE(status);
std::vector<int32_t> arc_indexes(path.indexes,
Expand All @@ -292,7 +292,7 @@ TEST(FsaEquivalent, RandomPathSuccess) {
ASSERT_EQ(fsa.size2, path.size2);
for (std::size_t i = 0; i != path.size2; ++i)
EXPECT_EQ(arcs[i], src_arcs[i]);
EXPECT_THAT(state_map, ::testing::ElementsAre(0, 1, 2, 3));
EXPECT_THAT(arc_map, ::testing::ElementsAre(0, 1, 2));
}
}

Expand All @@ -313,10 +313,10 @@ TEST(FsaEquivalent, RandomPathWithoutEpsilonArc) {

FsaCreator fsa_creator_out(fsa_size);
auto &path = fsa_creator_out.GetFsa();
std::vector<int32_t> state_map(fsa_size.size1);
bool status = rand_path.GetOutput(&path, state_map.data());
std::vector<int32_t> arc_map(fsa_size.size2);
bool status = rand_path.GetOutput(&path, arc_map.data());
EXPECT_TRUE(status);
EXPECT_GT(state_map.size(), 0);
EXPECT_GT(arc_map.size(), 0);
for (const auto &arc : path) {
EXPECT_NE(arc.label, kEpsilon);
}
Expand All @@ -341,10 +341,10 @@ TEST(FsaEquivalent, RandomPathWithoutEpsilonArc) {

FsaCreator fsa_creator_out(fsa_size);
auto &path = fsa_creator_out.GetFsa();
std::vector<int32_t> state_map(fsa_size.size1);
bool status = rand_path.GetOutput(&path, state_map.data());
std::vector<int32_t> arc_map(fsa_size.size2);
bool status = rand_path.GetOutput(&path, arc_map.data());
EXPECT_FALSE(status);
EXPECT_TRUE(state_map.empty());
EXPECT_TRUE(arc_map.empty());
}
}
} // namespace k2
40 changes: 32 additions & 8 deletions k2/csrc/fsa_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,26 @@ void GetEnteringArcs(const Fsa &fsa, Array2<int32_t *, int32_t> *arc_indexes) {
}

void GetArcWeights(const float *arc_weights_in,
const std::vector<std::vector<int32_t>> &arc_map,
const Array2<int32_t *, int32_t> &arc_map,
float *arc_weights_out) {
CHECK_NOTNULL(arc_weights_in);
CHECK_NOTNULL(arc_weights_out);
for (const auto &arcs : arc_map) {
for (int32_t i = 0; i != arc_map.size1; ++i) {
float sum_weights = 0.0f;
for (auto arc : arcs) sum_weights += arc_weights_in[arc];
for (int32_t j = arc_map.indexes[i]; j != arc_map.indexes[i + 1]; ++j) {
int32_t arc_index_in = arc_map.data[j];
sum_weights += arc_weights_in[arc_index_in];
}
*arc_weights_out++ = sum_weights;
}
}

void GetArcWeights(const float *arc_weights_in,
const std::vector<int32_t> &arc_map,
float *arc_weights_out) {
void GetArcWeights(const float *arc_weights_in, const int32_t *arc_map,
int32_t num_arcs, float *arc_weights_out) {
CHECK_NOTNULL(arc_weights_in);
CHECK_NOTNULL(arc_weights_out);
for (const auto &arc : arc_map) {
*arc_weights_out++ = arc_weights_in[arc];
for (int32_t i = 0; i != num_arcs; ++i) {
*arc_weights_out++ = arc_weights_in[arc_map[i]];
}
}

Expand Down Expand Up @@ -220,6 +222,28 @@ void ReorderArcs(const std::vector<Arc> &arcs, Fsa *fsa,
if (arc_map != nullptr) arc_map->swap(arc_map_out);
}

void ConvertIndexes1(const int32_t *arc_map, int32_t num_arcs,
int64_t *indexes_out) {
CHECK_NOTNULL(arc_map);
CHECK_GE(num_arcs, 0);
CHECK_NOTNULL(indexes_out);
std::copy(arc_map, arc_map + num_arcs, indexes_out);
}

void GetArcIndexes2(const Array2<int32_t *, int32_t> &arc_map,
int64_t *indexes1, int64_t *indexes2) {
CHECK_NOTNULL(indexes1);
CHECK_NOTNULL(indexes2);
std::copy(arc_map.data + arc_map.indexes[0],
arc_map.data + arc_map.indexes[arc_map.size1], indexes1);
int32_t num_arcs = 0;
for (int32_t i = 0; i != arc_map.size1; ++i) {
int32_t curr_arc_mappings = arc_map.indexes[i + 1] - arc_map.indexes[i];
std::fill_n(indexes2 + num_arcs, curr_arc_mappings, i);
num_arcs += curr_arc_mappings;
}
}

void StringToFsa::GetSizes(Array2Size<int32_t> *fsa_size) {
CHECK_NOTNULL(fsa_size);
fsa_size->size1 = fsa_size->size2 = 0;
Expand Down
Loading

0 comments on commit 53e3317

Please sign in to comment.