-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Aux labels plus notes on Python interface #29
Changes from all commits
b9ffae3
a9e1768
9f76458
09915bb
10e6712
735f83d
80339a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
// k2/csrc/determinize.cc | ||
|
||
// Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey [email protected], Haowen Qiu [email protected]) | ||
|
||
// See ../../LICENSE for clarification regarding multiple authors | ||
|
||
#include "k2/csrc/fsa_algo.h" | ||
|
||
#include <utility> | ||
#include <vector> | ||
|
||
namespace k2 { | ||
|
||
|
||
struct DetStateElement { | ||
// Element of the doubly linked list whose start/end are | ||
// members 'head' and 'tail' of DetState. | ||
// We can trace back the `parent` links, which will take | ||
// us backward along a path in the original FSA. | ||
DetStateElement *parent = nullptr; | ||
int32_t arc_index; // Index of most recent arc in path to the dest-state. | ||
// This data-structure represents a path through the FSA, | ||
// with this arc being the most recent arc on that path. | ||
int32_t symbol; // Symbol on the arc numbered `arc_index` of the input FSA | ||
// (copied here for convenience). | ||
|
||
double weight; // Weight from reference state to this state, along | ||
// the path taken by following the 'parent' links | ||
// (the path would have `seq_len` arcs in it). | ||
// Note: by "this state" we mean the destination-state of | ||
// the arc at `arc_index`. | ||
|
||
// `prev` and `next` form the doubly linked list of DetStateElement | ||
DetStateElement *prev = nullptr; | ||
DetStateElement *next = nullptr; | ||
|
||
// This comparator function compares the weights, but is careful in case of | ||
// ties to ensure deterministic behavior. | ||
bool operator < (const DetStateElement &other) const { | ||
if (weight < other.weight) return true; | ||
else if (weight > other.weight) return false; | ||
// TODO. | ||
} | ||
|
||
}; | ||
|
||
|
||
|
||
|
||
|
||
/* | ||
Conceptually a determinized state in weighted FSA determinization would normally | ||
be a weighted subset of states in the input FSA, with the weights normalized | ||
somehow (e.g. subtracting the sum of the weights). | ||
|
||
Two determinized states are equal if the states and weights are the same. To | ||
ensure differentiability, our assumption is that in general no two arcs in the | ||
input FSA have identical weights. We argue that two determinized states can | ||
always be represented as a base-state and a symbol sequence. Imagine that we | ||
follow arcs with that symbol sequence from the base-state, and then in case we | ||
reach the same states in the different ways we always select the best path | ||
from the base-state. That process gives us a set of states and weights. We | ||
argue that this representation is unique. (If not, it won't matter actually; | ||
it will just give us an output that's less minimal than it could be). | ||
|
||
|
||
*/ | ||
struct DetState { | ||
// `base_state` is a state in the input FSA. | ||
int32_t base_state; | ||
// seq_len is the length of symbol sequence that we follow from state `base_state`. | ||
// The sequence of symbols can be found by tracing back one of the DetStateElements | ||
// in the doubly linked list (it doesn't matter which you pick, the result will be the | ||
// same. | ||
int32_t seq_len; | ||
|
||
bool normalized { false }; | ||
|
||
DetState *parent; // Maybe not needed! | ||
|
||
DetStateElement *head; | ||
DetStateElement *tail; | ||
|
||
double forward_backward_weight; | ||
|
||
/* | ||
Normalizes this DetState and sets forward_backward_weight. | ||
|
||
By 'normalize' what we mean is the following: | ||
|
||
- Remove duplicates. | ||
|
||
If the DLL of DetStateElements contains duplicate elements (i.e. | ||
elements whose paths end in the same state) it removes whichever has the | ||
smallest weight. (Remember, a determinized state is, conceptually, a | ||
weighted subset of elements; we are implementing determinization in a | ||
tropical-like semiring where we take the best weight. | ||
|
||
In case of ties on the weights, we carefully re-examine the paths to | ||
make sure that the tie was not due to numerical roundoffi; and if it | ||
was still a tie, we disambiguate using a lexical order on state | ||
sequences. The reason it's important to have deterministic behavior in | ||
case of ties on weights, is that a failure here could lead to | ||
situations where we didn't advance the base state where we could, | ||
leading the number of determinized states to be larger than it could | ||
be. | ||
|
||
- Advance the base state if possible. Each DetState can be represented | ||
as a base state and a sequence of symbols from that base state, but | ||
if some initial subsequence of that symbol sequence takes us to | ||
a unique state then we say the DetState is not normalized. In that | ||
case we need to advance the base state and reduced `seq_len`. | ||
If this happens, then the arc sequence which takes us to the new | ||
base state will be output to `leftover_arcs`. When this is done, | ||
the 'weight' components of the DetStateElement members also need | ||
to be adjusted to remove the weight contribution from those arcs. | ||
|
||
The forward_backward_weight is the weight on the best path through the | ||
output determinized FSA that will include this DetState. It will determine | ||
the order of expansion of DetStates and also whether the states are | ||
expanded at all (if the pruning beam `beam` is finite). | ||
forward_backward_weight is the sum of the forward weight of the base state, | ||
plus (the greatest over the DetStateElements, of its `weight` element, | ||
plus the backward weight in the input FSA of the state that corresponds | ||
to it). | ||
|
||
|
||
worked outobtained from | ||
|
||
*/ | ||
void Normalize(std::vector<int32_t> *leftover_arcs); | ||
}; | ||
|
||
|
||
void DetState::Normalize(std::vector<int32_t> *input_arcs) { | ||
|
||
} | ||
|
||
|
||
class DetStateMap { | ||
public: | ||
/* | ||
Outputs the output state-id corresponding to a specific DetState structure. | ||
This does not store any pointers to the DetState or its contents, so | ||
you can delete the DetState without affecting this object's ability to map | ||
an equivalent DetState to the same state-id. | ||
|
||
@param [in] a The DetState that we're looking up | ||
@param [out] state_id The state-index in the output FSA | ||
corresponding to this DetState (will | ||
be freshly allocated if an equivalent of | ||
this DetState did not already exist. | ||
@return Returns true if this was a NEWLY CREATED state, | ||
false otherwise. | ||
*/ | ||
bool GetOutputState(const DetState &a, int32_t *state_id) { | ||
std::pair<uint64_t, uint64_t> compact; | ||
DetStateToCompact(a, &compact); | ||
auto p = map_.insert({compact, cur_output_state)); | ||
bool inserted = p.second; | ||
if (inserted) { | ||
*state_id = cur_output_state_++; | ||
return true; | ||
} else { | ||
*state_id = p.first->second; | ||
return false; | ||
} | ||
} | ||
|
||
int32_t size() const { return cur_output_state_; } | ||
|
||
private: | ||
|
||
int32_t cur_output_state_ { 0 }; | ||
std::unordered_map<std::pair<uint64_t, uint64_t>, int32_t, DetStateVectorHasher> map_; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, might do that. This code needs to be finished. |
||
|
||
/* Turns DetState into a compact form of 128 bits. Technically there | ||
could be collisions, which would be fatal for the algorithm, but this | ||
is one of those lifetime-of-the-universe type of things (kind of like | ||
the theoretical potential for git hash collision) that we ignore. | ||
|
||
The normalized form | ||
|
||
*/ | ||
void DetStateToCompact(const DetState &d, | ||
std::pair<uint64_t, uint64_t> *vec) { | ||
assert(d.normalized); | ||
|
||
uint64_t a = d.base_state + 17489 * d.seq_len, | ||
b = d.base_state * 103979 + d.seq_len; | ||
|
||
// We choose an arbitrary DetStateElement (the first one in the list) to | ||
// read the symbol sequence from; the symbol sequence will be the same no | ||
// matter which element we choose to trace back. | ||
DetStateElement *elem = d.head; | ||
int32_t seq_len = d.seq_len; | ||
for (int32_t i = 0; i < seq_len; ++i) { | ||
a = elem->symbol + 102299 * a; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an error inside the
It should be |
||
b = elem->symbol + 102983 * b; | ||
elem = elem->parent | ||
} | ||
vec->first = a; | ||
vec->second = b; | ||
} | ||
|
||
struct DetStateHasher { | ||
size_t operator () (const std::pair<uint64_t, uint64_t> &p) const { | ||
return p.first; | ||
} | ||
}; | ||
|
||
|
||
|
||
}; | ||
|
||
|
||
|
||
void DeterminizeMax(const WfsaWithFbWeights &a, | ||
float beam, | ||
Fsa *b, | ||
std::vector<std::vector<int32_t> > *arc_map) { | ||
// TODO: use glog stuff. | ||
assert(IsValid(a) && IsEpsilonFree(a) && IsTopSortedAndAcyclic(a)); | ||
if (a.arc_indexes.empty()) { | ||
b->Clear(); | ||
return; | ||
} | ||
float cutoff = a.backward_state_weights[0] - beam; | ||
// TODO. | ||
|
||
} | ||
|
||
|
||
} // namespace k2 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,10 +17,6 @@ | |
|
||
namespace k2 { | ||
|
||
using Label = int32_t; | ||
using StateId = int32_t; | ||
using Weight = float; | ||
|
||
enum { | ||
kFinalSymbol = -1, // final-costs are represented as arcs with | ||
// kFinalSymbol as their label, to the final | ||
|
@@ -30,9 +26,9 @@ enum { | |
}; | ||
|
||
struct Arc { | ||
StateId src_state; | ||
StateId dest_state; | ||
Label label; // 'label' as in a finite state acceptor. | ||
int32_t src_state; | ||
int32_t dest_state; | ||
int32_t label; // 'label' as in a finite state acceptor. | ||
// For FSTs, the other label will be present in the | ||
// aux_label array. Which of the two represents the input | ||
// vs. the output can be decided by the user; in general, | ||
|
@@ -112,8 +108,8 @@ struct Fsa { | |
arc_indexes.push_back(index); | ||
} | ||
|
||
StateId NumStates() const { | ||
return !arc_indexes.empty() ? (static_cast<StateId>(arc_indexes.size()) - 1) | ||
int32_t NumStates() const { | ||
return !arc_indexes.empty() ? (static_cast<int32_t>(arc_indexes.size()) - 1) | ||
: 0; | ||
} | ||
}; | ||
|
@@ -134,7 +130,7 @@ struct Fsa { | |
weights[t,n]. | ||
*/ | ||
struct DenseFsa { | ||
Weight *weights; // Would typically be a log-prob or unnormalized log-prob | ||
float *weights; // Would typically be a log-prob or unnormalized log-prob | ||
int32_t T; // The number of time steps == rows in the matrix `weights`; | ||
// this FSA has T + 2 states, see explanation above. | ||
int32_t num_symbols; // The number of symbols == columns in the matrix | ||
|
@@ -148,15 +144,43 @@ struct DenseFsa { | |
CAUTION: we may later enforce that stride == num_symbols, in order to | ||
be able to know the layout of a phantom matrix of arcs. (?) | ||
*/ | ||
DenseFsa(Weight *data, int32_t T, int32_t num_symbols, int32_t stride); | ||
DenseFsa(float *data, int32_t T, int32_t num_symbols, int32_t stride); | ||
}; | ||
|
||
struct Fst { | ||
Fsa core; | ||
std::vector<int32_t> aux_label; | ||
}; | ||
|
||
using StatePair = std::pair<StateId, StateId>; | ||
/* | ||
This demonstrates an interface for a deterministic FSA or FST; it's similar | ||
to Kaldi's DeterministicOnDemandFst class. It can be used for things like | ||
language models. Actually we'll template on types like this. There is no | ||
need to actually inherit from this class. */ | ||
class DeterministicGenericFsa { | ||
public: | ||
int32_t Start(); | ||
|
||
|
||
bool LookupArc(int32_t cur_state, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are there no There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because this is an interface for a possibly dynamic object. |
||
int32_t label, | ||
int32_t *arc_index); | ||
|
||
|
||
float GetWeightForArc(int32_t arc_index); | ||
|
||
int32_t Getint32_tForArc(int32_t arc_index); | ||
|
||
int32_t GetPrevStateForArc(int32_t arc_index); | ||
|
||
int32_t GetNextStateForArc(int32_t arc_index); | ||
|
||
// Specific subclasses of this may have additional functions, e.g. | ||
int32_t GetOlabelForArc(int32_t arc_index); | ||
|
||
}; | ||
|
||
|
||
using FsaVec = std::vector<Fsa>; | ||
using FstVec = std::vector<Fst>; | ||
using DenseFsaVec = std::vector<DenseFsa>; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the return value if
weight == other.weight
??There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I didn't realize I had included this code, it is not finished. I have to get back to this.