Skip to content

Commit

Permalink
Merge pull request #4885 from danpovey/fix_simple_decoder2
Browse files Browse the repository at this point in the history
Fix #4870, spurious error in ProcessNonemitting; queue can validly be empty.
  • Loading branch information
danpovey authored Nov 10, 2023
2 parents a670447 + cd2b835 commit e187518
Showing 1 changed file with 31 additions and 39 deletions.
70 changes: 31 additions & 39 deletions src/decoder/lattice-simple-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ void LatticeSimpleDecoder::InitDecoding() {

bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) {
InitDecoding();
while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {

while (!decodable->IsLastFrame(NumFramesDecoded() - 1)) {
if (NumFramesDecoded() % config_.prune_interval == 0)
PruneActiveTokens(config_.lattice_beam * config_.prune_scale);
ProcessEmitting(decodable);
Expand All @@ -57,7 +57,7 @@ bool LatticeSimpleDecoder::Decode(DecodableInterface *decodable) {
ProcessNonemitting();
}
FinalizeDecoding();

// Returns true if we have any kind of traceback available (not necessarily
// to the end state; query ReachedFinal() for that).
return !final_costs_.empty();
Expand Down Expand Up @@ -88,9 +88,9 @@ bool LatticeSimpleDecoder::GetRawLattice(Lattice *ofst,
if (decoding_finalized_ && !use_final_probs)
KALDI_ERR << "You cannot call FinalizeDecoding() and then call "
<< "GetRawLattice() with use_final_probs == false";

unordered_map<Token*, BaseFloat> final_costs_local;

const unordered_map<Token*, BaseFloat> &final_costs =
(decoding_finalized_ ? final_costs_ : final_costs_local);

Expand All @@ -100,7 +100,7 @@ bool LatticeSimpleDecoder::GetRawLattice(Lattice *ofst,
ofst->DeleteStates();
int32 num_frames = NumFramesDecoded();
KALDI_ASSERT(num_frames > 0);
const int32 bucket_count = num_toks_/2 + 3;
const int32 bucket_count = num_toks_/2 + 3;
unordered_map<Token*, StateId> tok_map(bucket_count);
// First create all states.
for (int32 f = 0; f <= num_frames; f++) {
Expand Down Expand Up @@ -169,10 +169,10 @@ bool LatticeSimpleDecoder::GetLattice(
fst::ILabelCompare<LatticeArc> ilabel_comp;
ArcSort(&raw_fst, ilabel_comp); // sort on ilabel; makes
// lattice-determinization more efficient.

fst::DeterminizeLatticePrunedOptions lat_opts;
lat_opts.max_mem = config_.det_opts.max_mem;

DeterminizeLatticePruned(raw_fst, config_.lattice_beam, ofst, lat_opts);
raw_fst.DeleteStates(); // Free memory-- raw_fst no longer needed.
Connect(ofst); // Remove unreachable states... there might be
Expand All @@ -196,7 +196,7 @@ inline LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken(
bool emitting, bool *changed) {
KALDI_ASSERT(frame < active_toks_.size());
Token *&toks = active_toks_[frame].toks;

unordered_map<StateId, Token*>::iterator find_iter = cur_toks_.find(state);
if (find_iter == cur_toks_.end()) { // no such token presently.
// Create one.
Expand All @@ -221,7 +221,7 @@ inline LatticeSimpleDecoder::Token *LatticeSimpleDecoder::FindOrAddToken(
return tok;
}
}

// delta is the amount by which the extra_costs must
// change before it sets "extra_costs_changed" to true. If delta is larger,
// we'll tend to go back less far toward the beginning of the file.
Expand All @@ -242,7 +242,7 @@ void LatticeSimpleDecoder::PruneForwardLinks(
warned_ = true;
}
}

bool changed = true;
while (changed) {
changed = false;
Expand Down Expand Up @@ -300,7 +300,7 @@ void LatticeSimpleDecoder::ComputeFinalCosts(
BaseFloat infinity = std::numeric_limits<BaseFloat>::infinity();
BaseFloat best_cost = infinity,
best_cost_with_final = infinity;

for (unordered_map<StateId, Token*>::const_iterator iter = cur_toks_.begin();
iter != cur_toks_.end(); ++iter) {
StateId state = iter->first;
Expand Down Expand Up @@ -336,19 +336,19 @@ void LatticeSimpleDecoder::ComputeFinalCosts(
// on the final frame. If there are final tokens active, it uses the final-probs
// for pruning, otherwise it treats all tokens as final.
void LatticeSimpleDecoder::PruneForwardLinksFinal() {
KALDI_ASSERT(!active_toks_.empty());
KALDI_ASSERT(!active_toks_.empty());
int32 frame_plus_one = active_toks_.size() - 1;

if (active_toks_[frame_plus_one].toks == NULL) // empty list; should not happen.
KALDI_WARN << "No tokens alive at end of file\n";

typedef unordered_map<Token*, BaseFloat>::const_iterator IterType;
typedef unordered_map<Token*, BaseFloat>::const_iterator IterType;
ComputeFinalCosts(&final_costs_, &final_relative_cost_, &final_best_cost_);
decoding_finalized_ = true;
// We're about to delete some of the tokens active on the final frame, so we
// clear cur_toks_ because otherwise it would then contain dangling pointers.
cur_toks_.clear();

// Now go through tokens on this frame, pruning forward links... may have to
// iterate a few times until there is no more change, because the list is not
// in topological order. This is a modified version of the code in
Expand Down Expand Up @@ -429,7 +429,7 @@ BaseFloat LatticeSimpleDecoder::FinalRelativeCost() const {
return final_relative_cost_;
}
}

// Prune away any tokens on this frame that have no forward links. [we don't do
// this in PruneForwardLinks because it would give us a problem with dangling
// pointers].
Expand All @@ -453,22 +453,22 @@ void LatticeSimpleDecoder::PruneTokensForFrame(int32 frame) {
}
}
}

// Go backwards through still-alive tokens, pruning them, starting not from
// the current frame (where we want to keep all tokens) but from the frame before
// that. We go backwards through the frames and stop when we reach a point
// where the delta-costs are not changing (and the delta controls when we consider
// a cost to have "not changed").
void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) {
int32 cur_frame_plus_one = NumFramesDecoded();
int32 cur_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
// The index "f" below represents a "frame plus one", i.e. you'd have to subtract
// one to get the corresponding index for the decodable object.
for (int32 f = cur_frame_plus_one - 1; f >= 0; f--) {
// Reason why we need to prune forward links in this situation:
// (1) we have never pruned them
// (2) we never pruned the forward links on the next frame, which
//
//
if (active_toks_[f].must_prune_forward_links) {
bool extra_costs_changed = false, links_pruned = false;
PruneForwardLinks(f, &extra_costs_changed, &links_pruned, delta);
Expand All @@ -478,7 +478,7 @@ void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) {
active_toks_[f].must_prune_tokens = true;
active_toks_[f].must_prune_forward_links = false;
}
if (f+1 < cur_frame_plus_one &&
if (f+1 < cur_frame_plus_one &&
active_toks_[f+1].must_prune_tokens) {
PruneTokensForFrame(f+1);
active_toks_[f+1].must_prune_tokens = false;
Expand All @@ -493,20 +493,20 @@ void LatticeSimpleDecoder::PruneActiveTokens(BaseFloat delta) {
// (optionally) on the final frame. Takes into account the final-prob of
// tokens. This function used to be called PruneActiveTokensFinal().
void LatticeSimpleDecoder::FinalizeDecoding() {
int32 final_frame_plus_one = NumFramesDecoded();
int32 final_frame_plus_one = NumFramesDecoded();
int32 num_toks_begin = num_toks_;
PruneForwardLinksFinal();
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
for (int32 f = final_frame_plus_one - 1; f >= 0; f--) {
bool b1, b2; // values not used.
BaseFloat dontcare = 0.0;
PruneForwardLinks(f, &b1, &b2, dontcare);
PruneTokensForFrame(f + 1);
}
PruneTokensForFrame(0);
PruneTokensForFrame(0);
KALDI_VLOG(3) << "pruned tokens from " << num_toks_begin
<< " to " << num_toks_;
}

void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) {
int32 frame = active_toks_.size() - 1; // frame is the frame-index
// (zero-based) used to get likelihoods
Expand Down Expand Up @@ -538,9 +538,9 @@ void LatticeSimpleDecoder::ProcessEmitting(DecodableInterface *decodable) {
// AddToken adds the next_tok to cur_toks_ (if not already present).
Token *next_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
true, NULL);

// Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
graph_cost, ac_cost, tok->links);
}
}
Expand All @@ -553,7 +553,7 @@ void LatticeSimpleDecoder::ProcessNonemitting() {
// Note: "frame" is the time-index we just processed, or -1 if
// we are processing the nonemitting transitions before the
// first frame (called from InitDecoding()).

// Processes nonemitting arcs for one frame. Propagates within
// cur_toks_. Note-- this queue structure is is not very optimal as
// it may cause us to process states unnecessarily (e.g. more than once),
Expand All @@ -569,15 +569,9 @@ void LatticeSimpleDecoder::ProcessNonemitting() {
queue.push_back(state);
best_cost = std::min(best_cost, iter->second->tot_cost);
}
if (queue.empty()) {
if (!warned_) {
KALDI_ERR << "Error in ProcessEmitting: no surviving tokens: frame is "
<< frame;
warned_ = true;
}
}

BaseFloat cutoff = best_cost + config_.beam;

while (!queue.empty()) {
StateId state = queue.back();
queue.pop_back();
Expand All @@ -600,10 +594,10 @@ void LatticeSimpleDecoder::ProcessNonemitting() {
bool changed;
Token *new_tok = FindOrAddToken(arc.nextstate, frame + 1, tot_cost,
false, &changed);

tok->links = new ForwardLink(new_tok, 0, arc.olabel,
graph_cost, 0, tok->links);

// "changed" tells us whether the new token has a different
// cost from before, or is new [if so, add into queue].
if (changed && fst_.NumInputEpsilons(arc.nextstate) != 0)
Expand Down Expand Up @@ -662,5 +656,3 @@ void LatticeSimpleDecoder::PruneCurrentTokens(BaseFloat beam, unordered_map<Stat


} // end namespace kaldi.


0 comments on commit e187518

Please sign in to comment.