Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement RmEpsilonPruneMax #40

Merged
merged 3 commits into from
May 26, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions k2/csrc/determinize.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ int32_t DetState<TracebackState>::ProcessArcs(
derivs_per_arc->push_back(std::move(deriv_info));
if (is_new_state)
queue->push(std::unique_ptr<DetState<TracebackState>>(det_state));
else
delete det_state;
} else {
delete det_state;
}
Expand Down
111 changes: 111 additions & 0 deletions k2/csrc/fsa_algo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
#include "k2/csrc/fsa_algo.h"

#include <algorithm>
#include <functional>
#include <limits>
#include <numeric>
#include <queue>
#include <stack>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "glog/logging.h"
Expand Down Expand Up @@ -275,6 +277,115 @@ bool Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map /*=nullptr*/) {
return is_acyclic;
}

void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
std::vector<std::vector<int32_t>> *arc_derivs) {
CHECK_EQ(a.weight_type, kMaxWeight);
CHECK_GT(beam, 0);
CHECK_NOTNULL(b);
CHECK_NOTNULL(arc_derivs);
b->arc_indexes.clear();
b->arcs.clear();
arc_derivs->clear();

qindazhu marked this conversation as resolved.
Show resolved Hide resolved
const auto &fsa = a.fsa;
if (IsEmpty(fsa)) return;
int32_t num_states_a = fsa.NumStates();
int32_t final_state = fsa.FinalState();
const auto &arcs_a = fsa.arcs;
const float *arc_weights_a = a.arc_weights;

// identify all states that should be kept
std::vector<char> non_eps_in(num_states_a, 0);
non_eps_in[0] = 1;
for (const auto &arc : arcs_a) {
// We suppose the input fsa `a` is top-sorted, but only check this in DEBUG
// time.
DCHECK_GE(arc.dest_state, arc.src_state);
if (arc.label != kEpsilon) non_eps_in[arc.dest_state] = 1;
}

// remap state id
std::vector<int32_t> state_map_a2b(num_states_a, -1);
int32_t num_states_b = 0;
for (int32_t i = 0; i != num_states_a; ++i) {
if (non_eps_in[i] == 1) state_map_a2b[i] = num_states_b++;
}
b->arc_indexes.reserve(num_states_b + 1);
int32_t arc_num_b = 0;

const double *forward_state_weights = a.ForwardStateWeights();
const double *backward_state_weights = a.BackwardStateWeights();
const double best_weight = forward_state_weights[final_state] - beam;
for (int32_t i = 0; i != num_states_a; ++i) {
if (non_eps_in[i] != 1) continue;
b->arc_indexes.push_back(arc_num_b);
int32_t curr_state_b = state_map_a2b[i];
// as the input FSA is top-sorted, we use a heap here so we can process
// states when they already have the best cost they are going to get
std::priority_queue<int32_t, std::vector<int32_t>, std::greater<int32_t>> q;
// stores states that have been queued
std::unordered_set<int32_t> qstates;
// state -> local_forward_state_weights of this state
std::unordered_map<int32_t, double> local_forward_weights;
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
// state -> (src_state, arc_index) entering this state which contributes to
// `local_forward_weights` of this state.
std::unordered_map<int32_t, std::pair<int32_t, int32_t>>
local_backward_arcs;
local_forward_weights.emplace(i, forward_state_weights[i]);
// `-1` means we have traced back to current state `i`
local_backward_arcs.emplace(i, std::make_pair(i, -1));
q.push(i);
qstates.insert(i);
while (!q.empty()) {
int32_t state = q.top();
q.pop();
int32_t arc_end = fsa.arc_indexes[state + 1];
for (int32_t arc_index = fsa.arc_indexes[state]; arc_index != arc_end;
++arc_index) {
int32_t next_state = arcs_a[arc_index].dest_state;
int32_t label = arcs_a[arc_index].label;
double next_weight =
local_forward_weights[state] + arc_weights_a[arc_index];
if (next_weight + backward_state_weights[next_state] >= best_weight) {
if (label == kEpsilon) {
auto result =
local_forward_weights.emplace(next_state, next_weight);
if (result.second) {
local_backward_arcs[next_state] =
std::make_pair(state, arc_index);
} else {
if (next_weight > result.first->second) {
result.first->second = next_weight;
local_backward_arcs[next_state] =
std::make_pair(state, arc_index);
}
}
if (qstates.find(next_state) == qstates.end()) {
q.push(next_state);
qstates.insert(next_state);
}
} else {
b->arcs.emplace_back(curr_state_b, state_map_a2b[next_state],
label);
std::vector<int32_t> curr_arc_deriv;
std::pair<int32_t, int32_t> curr_backward_arc{state, arc_index};
auto *backward_arc = &curr_backward_arc;
while (backward_arc->second != -1) {
curr_arc_deriv.push_back(backward_arc->second);
backward_arc = &(local_backward_arcs[backward_arc->first]);
}
std::reverse(curr_arc_deriv.begin(), curr_arc_deriv.end());
arc_derivs->emplace_back(std::move(curr_arc_deriv));
++arc_num_b;
}
}
}
}
}
// duplicate of final state
b->arc_indexes.push_back(b->arc_indexes.back());
}

bool Intersect(const Fsa &a, const Fsa &b, Fsa *c,
std::vector<int32_t> *arc_map_a /*= nullptr*/,
std::vector<int32_t> *arc_map_b /*= nullptr*/) {
Expand Down
70 changes: 6 additions & 64 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,72 +90,14 @@ bool Connect(const Fsa &a, Fsa *b, std::vector<int32_t> *arc_map = nullptr);
keep paths that are within `beam` of the best path.
Just make this very large if you don't want pruning.
@param [out] b The output FSA; will be epsilon-free, and the states
will be in the same order that they were in in `a`.
@param [out] arc_map If non-NULL: for each arc in `b`, a list of
the arc-indexes in `a`, in order, that contributed
to that arc (e.g. its cost would be a sum of their costs).

Notes on algorithm (please rework all this when it's complete, i.e. just
make sure the code is clear and remove this).

The states in the output FSA will correspond to the subset of states in the
input FSA which are within `beam` of the best path and which have at least
one non-epsilon arc entering them, plus the start state. (Note: this
automatically includes the final state, assuming `a` has at least one
successful path; if it does not, the output will be empty).

If we ever need the associated state map from calling code, we'll add an
extra output argument to this function.

The basic algorithm is to (1) identify the kept states, (2) from each kept
input-state ki, we'll iterate over all states that are reachable via zero
or more epsilons from this state and process the non-epsilon outgoing arcs
from those states, which will become the arcs in the output. We'll also
store a back-pointer array that will allow us to figure out the best path
back to ki, in order to produce the output `arc_map`. Assume we have
arrays

local_forward_weights (float) and local_backpointers (int) indexed by
state-id, and that the local_forward_weights are initialized with
-infinity's each time we process a new ki. (we have to figure out how to do
this efficiently).


Processing input-state ki:
local_forward_state_weights[ki] = forward_state_weights[ki] // from
WfsaWithFbWeights.
// Caution:
we should probably use
// double
here; these kinds of algorithms
// are
extremely sensitive to roundoff for
// very
long FSAs. local_backpointers[ki] = -1 // will terminate a sequence..
queue.push_back(ki)
while (!queue.empty()) {
ji = queue.front() // we have to be a bit careful about order here,
to make sure
// we always process states when they already
have the
// best cost they are going to get. If
// FSA was top-sorted at the start, which we
assume, we could perhaps
// process them in numerical order, e.g. using a
heap. queue.pop_front() for each arc leaving state ji: next_weight =
local_forward_state_weights[ji] + arc_weights[this_arc_index] if next_weight
+ backward_state_weights[arc_dest_state] < best_path_weight - beam: if arc
label is epsilon: if next_weight < local_forward_state_weight[next_state]:
local_forward_state_weight[next_state] = next_weight
local_backpointers[next_state] = ji
else:
add an arc to the output FSA, and create the appropriate
arc_map entry by following backpointers (hopefully you can
figure out the details). Note: the output FSA's weights can
be computed later on, by calling code, using the info in arc_map.
will be in the same order that they were in `a`.
@param [out] arc_derivs Indexed by arc in `b`, this is the sequence of
qindazhu marked this conversation as resolved.
Show resolved Hide resolved
arcs in `a` that this arc in `b` corresponds to; the
weight of the arc in b will equal the sum of those input
arcs' weights
*/
void RmEpsilonsPrunedMax(const WfsaWithFbWeights &a, float beam, Fsa *b,
std::vector<std::vector<int32_t>> *arc_map);
std::vector<std::vector<int32_t>> *arc_derivs);
csukuangfj marked this conversation as resolved.
Show resolved Hide resolved

/*
Version of RmEpsilonsPrunedMax that doesn't support pruning; see its
Expand Down
48 changes: 48 additions & 0 deletions k2/csrc/fsa_algo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,54 @@ TEST(FsaAlgo, Connect) {
}
}

class RmEpsilonTest : public ::testing::Test {
protected:
RmEpsilonTest() {
std::vector<Arc> arcs = {
{0, 4, 1}, {0, 1, 1}, {1, 2, 0}, {1, 3, 0}, {1, 4, 0},
{2, 7, 0}, {3, 7, 0}, {4, 6, 1}, {4, 6, 0}, {4, 8, 1},
{4, 9, -1}, {5, 9, -1}, {6, 9, -1}, {7, 9, -1}, {8, 9, -1},
};
fsa_ = new Fsa(std::move(arcs), 9);
num_states_ = fsa_->NumStates();

auto num_arcs = fsa_->arcs.size();
arc_weights_ = new float[num_arcs];
std::vector<float> weights = {1, 1, 2, 3, 2, 4, 5, 2, 3, 3, 2, 4, 3, 5, 6};
std::copy_n(weights.begin(), num_arcs, arc_weights_);

max_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kMaxWeight);
log_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kLogSumWeight);
}

~RmEpsilonTest() {
delete fsa_;
delete[] arc_weights_;
delete max_wfsa_;
delete log_wfsa_;
}

WfsaWithFbWeights *max_wfsa_;
WfsaWithFbWeights *log_wfsa_;
Fsa *fsa_;
int32_t num_states_;
float *arc_weights_;
};

TEST_F(RmEpsilonTest, RmEpsilonsPrunedMax) {
Fsa b;
std::vector<std::vector<int32_t>> arc_derivs_b;
RmEpsilonsPrunedMax(*max_wfsa_, 8, &b, &arc_derivs_b);

EXPECT_TRUE(IsEpsilonFree(b));
ASSERT_EQ(b.arcs.size(), 11);
ASSERT_EQ(b.arc_indexes.size(), 7);
ASSERT_EQ(arc_derivs_b.size(), 11);

// TODO(haowen): check the equivalence after implementing RandEquivalent for
// WFSA
}

TEST(FsaAlgo, Intersect) {
// empty fsa
{
Expand Down