Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix determinize issues #42

Merged
merged 2 commits into from
May 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions k2/csrc/determinize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ void TraceBack(std::unordered_set<MaxTracebackState *> *cur_states,
// `deriv_out` is just a list of arc indexes in the input FSA
// that this output arc depends on (it's their sum).
(*deriv_out)[i] = state->arc_id;
state = state->prev_state.get();
}
double prev_forward_prob = state->forward_prob;
*weight_out = cur_forward_prob - prev_forward_prob;
Expand Down
55 changes: 24 additions & 31 deletions k2/csrc/determinize.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <vector>

#include "k2/csrc/fsa.h"
#include "k2/csrc/fsa_algo.h"
#include "k2/csrc/util.h"
#include "k2/csrc/weights.h"

Expand Down Expand Up @@ -371,20 +372,20 @@ class DetState;

template <class TracebackState>
struct DetStateCompare {
bool operator()(const std::unique_ptr<DetState<TracebackState>> &a,
const std::unique_ptr<DetState<TracebackState>> &b) {
bool operator()(const std::shared_ptr<DetState<TracebackState>> &a,
const std::shared_ptr<DetState<TracebackState>> &b) {
return a->forward_backward_prob < b->forward_backward_prob;
}
};
// Priority queue template arguments:
// item queued = unique_ptr<DetState> (using pointer equality as comparison)
// container type = vector<unique_ptr<DetState> >
// item queued = shared_ptr<DetState> (using pointer equality as comparison)
// container type = vector<shared_ptr<DetState> >
// less-than operator = DetStateCompare (which compares the
// forward_backward_prob).
template <class TracebackState>
using DetStatePriorityQueue =
std::priority_queue<std::unique_ptr<DetState<TracebackState>>,
std::vector<std::unique_ptr<DetState<TracebackState>>>,
std::priority_queue<std::shared_ptr<DetState<TracebackState>>,
std::vector<std::shared_ptr<DetState<TracebackState>>>,
Comment on lines +387 to +388
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was the problem with unique_ptr here? I made it unique_ptr as it's more lightweight than shared_ptr (and this is only ever owned in one place, I believed.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

priority_queue.top() returns a const reference so we cannnot call std::move() on its returned value to construct any unqiue_ptr or shared_ptr.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see; shared_ptr is fine.

DetStateCompare<TracebackState>>;

template <class TracebackState>
Expand Down Expand Up @@ -448,7 +449,7 @@ class DetState {
const std::shared_ptr<TracebackState> &src,
int32_t incoming_arc_index, int32_t arc_weight) {
auto ret = elements.insert({state_id, nullptr});
if (!ret.second) { // No such state existed in `elements`
if (ret.second) { // No such state existed in `elements`
ret.first->second = std::make_shared<TracebackState>(
state_id, src, incoming_arc_index, arc_weight);
} else { // A state with this staste_id existed in `elements`.
Expand Down Expand Up @@ -527,20 +528,6 @@ class DetState {
DetStateMap<TracebackState> *state_map,
DetStatePriorityQueue<TracebackState> *queue);

// Computes the forward-backward weight of this DetState. This is
// related to the best cost of any path through the output FSA
// that included this determinized state. I say "related to"
// because while it should be exact in the Max case, in the
// LogSum case the relationship is a bit more complicated;
// maybe just best to say that this is a weight that we use
// for pruning.
// @param [in] backward_state_weight Array, indexed by
// state in input WFSA, of the weight from this state
// to the end. (Of the best path or the sum of paths,
// depending how it was computed; this will of
// course affect the pruning).
void ComputeFbWeight(const float *backward_state_weights);

/*
Normalizes this DetState by reducing seq_len to the extent possible
and outputting the weight and derivative info corresponding to this
Expand Down Expand Up @@ -597,12 +584,13 @@ int32_t DetState<TracebackState>::ProcessArcs(
iter->second = new DetState<TracebackState>(seq_len + 1);
}
DetState<TracebackState> *det_state = iter->second;
det_state->AcceptIncomingArc(state_id, state_ptr, a, weight);
det_state->AcceptIncomingArc(arc.dest_state, state_ptr, a, weight);
}
}
CHECK(!label_to_state.empty() ||
elements[0]->state_id == fsa.FinalState()); // I'm assuming the input
// FSA is connected.
elements.begin()->second->state_id ==
fsa.FinalState()); // I'm assuming the input
// FSA is connected.

// The following loop normalizes successor det-states, outputs the arcs
// that lead to them, and adds them to the queue if necessary.
Expand Down Expand Up @@ -743,8 +731,7 @@ class DetStateMap {
std::pair<uint64_t, uint64_t> *vec) {
assert(d.normalized);

uint64_t a = d.state_id + 17489 * d.seq_len,
b = d.state_id * 103979 + d.seq_len;
uint64_t a = 17489 * d.seq_len, b = d.seq_len;

// We choose an arbitrary DetStateElement (the first one in the list) to
// read the symbol sequence from; the symbol sequence will be the same no
Expand All @@ -758,6 +745,9 @@ class DetStateMap {
b = symbol + 102983 * b;
elem = elem->prev_state;
}
// This is `base_state`: the state from which we
// start (and accept the specified symbol sequence).
a = elem->state_id + 14051 * a;
vec->first = a;
vec->second = b;
}
Expand All @@ -766,8 +756,7 @@ class DetStateMap {
const Fsa &fsa, std::pair<uint64_t, uint64_t> *vec) {
assert(d.normalized);

uint64_t a = d.state_id + 17489 * d.seq_len,
b = d.state_id * 103979 + d.seq_len;
uint64_t a = 17489 * d.seq_len, b = d.seq_len;

// We choose an arbitrary DetStateElement (the first one in the list) to
// read the symbol sequence from; the symbol sequence will be the same no
Expand All @@ -781,6 +770,9 @@ class DetStateMap {
b = symbol + 102983 * b;
elem = elem->prev_elements[0].prev_state;
}
// This is `base_state`: the state from which we
// start (and accept the specified symbol sequence).
a = elem->state_id + 14051 * a;
vec->first = a;
vec->second = b;
}
Expand All @@ -799,14 +791,13 @@ float DeterminizePrunedTpl(
std::vector<std::vector<typename TracebackState::DerivType>>
*arc_derivs_out) {
CHECK_GT(beam, 0);
CHECK(IsDeterministic(wfsa_in.fsa));
CHECK(!IsEmpty(wfsa_in.fsa));

DetStatePriorityQueue<TracebackState> queue;
DetStateMap<TracebackState> map;
using DS = DetState<TracebackState>;

std::shared_ptr<DS> start_state = std::make_shared<DS>();
std::shared_ptr<DS> start_state(new DS());

std::vector<Arc> arcs_out;
arc_weights_out->clear();
Expand All @@ -822,13 +813,15 @@ float DeterminizePrunedTpl(

double total_prob = wfsa_in.BackwardStateWeights()[0],
prune_cutoff = total_prob - beam;
queue.push(std::move(start_state));
while (num_steps < max_step && !queue.empty()) {
std::shared_ptr<DS> state(queue.top().get());
std::shared_ptr<DS> state(queue.top());
queue.pop();
num_steps +=
state->ProcessArcs(wfsa_in, prune_cutoff, &arcs_out, arc_weights_out,
arc_derivs_out, &map, &queue);
}
CreateFsa(arcs_out, fsa_out);
if (!queue.empty()) { // We stopped early due to max_step
return total_prob - queue.top()->forward_backward_prob;
} else {
Expand Down
47 changes: 47 additions & 0 deletions k2/csrc/fsa_algo_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,53 @@ TEST(FsaAlgo, TopSort) {
}
}

class DeterminizeTest : public ::testing::Test {
protected:
DeterminizeTest() {
std::vector<Arc> arcs = {{0, 4, 1}, {0, 1, 1}, {1, 2, 2}, {1, 3, 3},
{2, 7, 1}, {3, 7, 1}, {4, 6, 1}, {4, 6, 1},
{4, 5, 1}, {4, 8, -1}, {5, 8, -1}, {6, 8, -1},
{7, 8, -1}};
fsa_ = new Fsa(std::move(arcs), 8);
num_states_ = fsa_->NumStates();

auto num_arcs = fsa_->arcs.size();
arc_weights_ = new float[num_arcs];
std::vector<float> weights = {1, 1, 2, 3, 4, 5, 2, 3, 3, 2, 4, 3, 5};
std::copy_n(weights.begin(), num_arcs, arc_weights_);

max_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kMaxWeight);
log_wfsa_ = new WfsaWithFbWeights(*fsa_, arc_weights_, kLogSumWeight);
}

~DeterminizeTest() {
delete fsa_;
delete[] arc_weights_;
delete max_wfsa_;
delete log_wfsa_;
}

WfsaWithFbWeights *max_wfsa_;
WfsaWithFbWeights *log_wfsa_;
Fsa *fsa_;
int32_t num_states_;
float *arc_weights_;
};

TEST_F(DeterminizeTest, DeterminizePrunedMax) {
Fsa b;
std::vector<float> b_arc_weights;
std::vector<std::vector<int32_t>> arc_derivs;
DeterminizePrunedMax(*max_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just leave EXPECT statements empty for now, will add after fix issues of weights. please ignore this for now.

}

TEST_F(DeterminizeTest, DeterminizePrunedLogSum) {
Fsa b;
std::vector<float> b_arc_weights;
std::vector<std::vector<std::pair<int32_t, float>>> arc_derivs;
DeterminizePrunedLogSum(*log_wfsa_, 10, 100, &b, &b_arc_weights, &arc_derivs);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like, to have this code automatically:
(1) check that the result is deterministic,
(2) check that the result is equivalent to the original (in the appropriate semiring)

and if possible check that the arc_derivs make sense somehow, although that is more complex and can be left for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can merge this PR when you fix the other issue, though; you can work on this later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add all EXPECT statements after fix issues of weights (As I said before, the weights_out now are not correct)


TEST(FsaAlgo, CreateFsa) {
{
// clang-format off
Expand Down