From b45999088ccbda19b57327c05810a8a015ce9a89 Mon Sep 17 00:00:00 2001 From: Robert Schubert Date: Fri, 8 Mar 2019 12:30:16 +0100 Subject: [PATCH] LSTM char_whitelist/blacklist (6ac2ff0): multi-code chars - move decision from ComputeTopN to ContinueContext, where it belongs: block context continuations which emit final codes translating to disabled unichar_ids. (The normal logic for fallback from top2 > top2 > rest will apply.) - pass UNICHARSET refs appropriately --- src/lstm/recodebeam.cpp | 30 +++++++++++++----------------- src/lstm/recodebeam.h | 8 ++++---- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/lstm/recodebeam.cpp b/src/lstm/recodebeam.cpp index 2635481de5..9c9f332568 100644 --- a/src/lstm/recodebeam.cpp +++ b/src/lstm/recodebeam.cpp @@ -87,7 +87,7 @@ void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio, if (lstm_choice_mode) timesteps.clear(); for (int t = 0; t < width; ++t) { - ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0], charset); + ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]); DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert, charset); if (lstm_choice_mode) { @@ -102,7 +102,7 @@ void RecodeBeamSearch::Decode(const GENERIC_2D_ARRAY& output, beam_size_ = 0; int width = output.dim1(); for (int t = 0; t < width; ++t) { - ComputeTopN(output[t], output.dim2(), kBeamWidths[0], charset); + ComputeTopN(output[t], output.dim2(), kBeamWidths[0]); DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset); } } @@ -456,19 +456,12 @@ WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space, // Fills top_n_flags_ with bools that are true iff the corresponding output // is one of the top_n. void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs, - int top_n, const UNICHARSET* charset) { + int top_n) { top_n_flags_.init_to_size(num_outputs, TN_ALSO_RAN); top_code_ = -1; second_code_ = -1; top_heap_.clear(); for (int i = 0; i < num_outputs; ++i) { - // Decode label via recoder_. - RecodedCharID code; - code.Set(0, i); - int label = recoder_.DecodeUnichar(code); - if (label != INVALID_UNICHAR_ID && // not part of a bigger code. - !charset->get_enabled(label)) // disabled - continue; if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) { TopPair entry(outputs[i], i); top_heap_.Push(&entry); @@ -505,10 +498,10 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t, if (t == 0) { // The first step can only use singles and initials. ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2, - dict_ratio, cert_offset, worst_dict_cert, step); + charset, dict_ratio, cert_offset, worst_dict_cert, step); if (dict_ != nullptr) { - ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, - TN_TOP2, dict_ratio, cert_offset, worst_dict_cert, step); + ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, TN_TOP2, + charset, dict_ratio, cert_offset, worst_dict_cert, step); } } else { RecodeBeam* prev = beam_[t - 1]; @@ -540,9 +533,8 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t, // best first, but it comes before a lot of the worst, so it is slightly // more efficient than going forwards. for (int i = prev->beams_[index].size() - 1; i >= 0; --i) { - ContinueContext(&prev->beams_[index].get(i).data, index, outputs, - top_n, dict_ratio, cert_offset, worst_dict_cert, - step); + ContinueContext(&prev->beams_[index].get(i).data, index, outputs, top_n, + charset, dict_ratio, cert_offset, worst_dict_cert, step); } } for (int index = 0; index < kNumBeams; ++index) { @@ -569,7 +561,9 @@ void RecodeBeamSearch::DecodeStep(const float* outputs, int t, // choices for which top_n_flags[index] == top_n_flag. void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index, const float* outputs, - TopNState top_n_flag, double dict_ratio, + TopNState top_n_flag, + const UNICHARSET* charset, + double dict_ratio, double cert_offset, double worst_dict_cert, RecodeBeam* step) { @@ -632,6 +626,8 @@ void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index, int unichar_id = recoder_.DecodeUnichar(full_code); // Map the null char to INVALID. if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID; + if (unichar_id != INVALID_UNICHAR_ID && !charset->get_enabled(unichar_id)) + continue; // disabled by whitelist/blacklist ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio, use_dawgs, NC_ANYTHING, prev, step); if (top_n_flag == TN_TOP2 && code != null_char_) { diff --git a/src/lstm/recodebeam.h b/src/lstm/recodebeam.h index 5db77b4b7c..7d8cec96d2 100644 --- a/src/lstm/recodebeam.h +++ b/src/lstm/recodebeam.h @@ -293,7 +293,7 @@ class RecodeBeamSearch { // Fills top_n_flags_ with bools that are true iff the corresponding output // is one of the top_n. - void ComputeTopN(const float* outputs, int num_outputs, int top_n, const UNICHARSET* unicharset); + void ComputeTopN(const float* outputs, int num_outputs, int top_n); // Adds the computation for the current time-step to the beam. Call at each // time-step in sequence from left to right. outputs is the activation vector @@ -310,9 +310,9 @@ class RecodeBeamSearch { // using the given network outputs to provide scores to the choices. Uses only // those choices for which top_n_flags[code] == top_n_flag. void ContinueContext(const RecodeNode* prev, int index, const float* outputs, - TopNState top_n_flag, double dict_ratio, - double cert_offset, double worst_dict_cert, - RecodeBeam* step); + TopNState top_n_flag, const UNICHARSET* unicharset, + double dict_ratio, double cert_offset, + double worst_dict_cert, RecodeBeam* step); // Continues for a new unichar, using dawg or non-dawg as per flag. void ContinueUnichar(int code, int unichar_id, float cert, float worst_dict_cert, float dict_ratio, bool use_dawgs,