Skip to content

Commit

Permalink
enable hot-word boosting (#3297)
Browse files Browse the repository at this point in the history
* enable hot-word boosting

* more consistent ordering of CLI arguments

* progress on review

* use map instead of set for hot-words, move string logic to client.cc

* typo bug

* pointer things?

* use map for hotwords, better string splitting

* add the boost, not multiply

* cleaning up

* cleaning whitespace

* remove <set> inclusion

* change typo set-->map

* rename boost_coefficient to boost

X-DeepSpeech: NOBUILD

* add hot_words to python bindings

* missing hot_words

* include map in swigwrapper.i

* add Map template to swigwrapper.i

* emacs intermediate file

* map things

* map-->unordered_map

* typu

* typu

* use dict() not None

* error out if hot_words without scorer

* two new functions: remove hot-word and clear all hot-words

* starting to work on better error messages

X-DeepSpeech: NOBUILD

* better error handling + .Net ERR codes

* allow for negative boosts:)

* adding TC test for hot-words

* add hot-words to python client, make TC test hot-words everywhere

* only run TC tests for C++ and Python

* fully expose API in python bindings

* expose API in Java (thanks spectie!)

* expose API in dotnet (thanks spectie!)

* expose API in javascript (thanks spectie!)

* java lol

* typo in javascript

* commenting

* java error codes from swig

* java docs from SWIG

* java and dotnet issues

* add hotword test to android tests

* dotnet fixes from carlos

* add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command

* make sure lm is on android for hotword test

* path to android model + nit

* path

* path
  • Loading branch information
JRMeyer authored Sep 24, 2020
1 parent d466fb0 commit 1eb155e
Show file tree
Hide file tree
Showing 25 changed files with 400 additions and 11 deletions.
10 changes: 9 additions & 1 deletion native_client/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ int json_candidate_transcripts = 3;

int stream_size = 0;

char* hot_words = NULL;

void PrintHelp(const char* bin)
{
std::cout <<
Expand All @@ -56,6 +58,7 @@ void PrintHelp(const char* bin)
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
"\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n"
"\t--help\t\t\t\tShow help\n"
"\t--version\t\t\tPrint version and exits\n";
char* version = DS_Version();
Expand All @@ -66,7 +69,7 @@ void PrintHelp(const char* bin)

bool ProcessArgs(int argc, char** argv)
{
const char* const short_opts = "m:l:a:b:c:d:tejs:vh";
const char* const short_opts = "m:l:a:b:c:d:tejs:w:vh";
const option long_opts[] = {
{"model", required_argument, nullptr, 'm'},
{"scorer", required_argument, nullptr, 'l'},
Expand All @@ -79,6 +82,7 @@ bool ProcessArgs(int argc, char** argv)
{"json", no_argument, nullptr, 'j'},
{"candidate_transcripts", required_argument, nullptr, 150},
{"stream", required_argument, nullptr, 's'},
{"hot_words", required_argument, nullptr, 'w'},
{"version", no_argument, nullptr, 'v'},
{"help", no_argument, nullptr, 'h'},
{nullptr, no_argument, nullptr, 0}
Expand Down Expand Up @@ -144,6 +148,10 @@ bool ProcessArgs(int argc, char** argv)
has_versions = true;
break;

case 'w':
hot_words = optarg;
break;

case 'h': // -h or --help
case '?': // Unrecognized option
default:
Expand Down
33 changes: 33 additions & 0 deletions native_client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,22 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
}
}

std::vector<std::string>
SplitStringOnDelim(std::string in_string, std::string delim)
{
std::vector<std::string> out_vector;
char * tmp_str = new char[in_string.size() + 1];
std::copy(in_string.begin(), in_string.end(), tmp_str);
tmp_str[in_string.size()] = '\0';
const char* token = strtok(tmp_str, delim.c_str());
while( token != NULL ) {
out_vector.push_back(token);
token = strtok(NULL, delim.c_str());
}
delete[] tmp_str;
return out_vector;
}

int
main(int argc, char **argv)
{
Expand Down Expand Up @@ -432,6 +448,23 @@ main(int argc, char **argv)
}
// sphinx-doc: c_ref_model_stop

if (hot_words) {
std::vector<std::string> hot_words_ = SplitStringOnDelim(hot_words, ",");
for ( std::string hot_word_ : hot_words_ ) {
std::vector<std::string> pair_ = SplitStringOnDelim(hot_word_, ":");
const char* word = (pair_[0]).c_str();
// the strtof function will return 0 in case of non numeric characters
// so, check the boost string before we turn it into a float
bool boost_is_valid = (pair_[1].find_first_not_of("-.0123456789") == std::string::npos);
float boost = strtof((pair_[1]).c_str(),0);
status = DS_AddHotWord(ctx, word, boost);
if (status != 0 || !boost_is_valid) {
fprintf(stderr, "Could not enable hot-word.\n");
return 1;
}
}
}

#ifndef NO_SOX
// Initialise SOX
assert(sox_init() == SOX_SUCCESS);
Expand Down
10 changes: 8 additions & 2 deletions native_client/ctcdecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_prob=1.0,
cutoff_top_n=40,
scorer=None,
hot_words=dict(),
num_results=1):
"""Wrapper for the CTC Beam Search Decoder.
Expand All @@ -116,6 +117,8 @@ def ctc_beam_search_decoder(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
Expand All @@ -124,7 +127,7 @@ def ctc_beam_search_decoder(probs_seq,
"""
beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
scorer, num_results)
scorer, hot_words, num_results)
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
return beam_results

Expand All @@ -137,6 +140,7 @@ def ctc_beam_search_decoder_batch(probs_seq,
cutoff_prob=1.0,
cutoff_top_n=40,
scorer=None,
hot_words=dict(),
num_results=1):
"""Wrapper for the batched CTC beam search decoder.
Expand All @@ -161,13 +165,15 @@ def ctc_beam_search_decoder_batch(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence.
:rtype: list
"""
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results)
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results)
batch_beam_results = [
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results
Expand Down
28 changes: 24 additions & 4 deletions native_client/ctcdecode/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <utility>

#include "decoder_utils.h"
Expand All @@ -18,7 +18,8 @@ DecoderState::init(const Alphabet& alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer)
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words)
{
// assign special ids
abs_time_step_ = 0;
Expand All @@ -29,6 +30,7 @@ DecoderState::init(const Alphabet& alphabet,
cutoff_prob_ = cutoff_prob;
cutoff_top_n_ = cutoff_top_n;
ext_scorer_ = ext_scorer;
hot_words_ = hot_words;
start_expanding_ = false;

// init prefixes' root
Expand Down Expand Up @@ -160,8 +162,23 @@ DecoderState::next(const double *probs,
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer_->make_ngram(prefix_to_score);

float hot_boost = 0.0;
if (!hot_words_.empty()) {
std::unordered_map<std::string, float>::iterator iter;
// increase prob of prefix for every word
// that matches a word in the hot-words list
for (std::string word : ngram) {
iter = hot_words_.find(word);
if ( iter != hot_words_.end() ) {
// increase the log_cond_prob(prefix|LM)
hot_boost += iter->second;
}
}
}

bool bos = ngram.size() < ext_scorer_->get_max_order();
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
score = ( ext_scorer_->get_log_cond_prob(ngram, bos) + hot_boost ) * ext_scorer_->alpha;
log_p += score;
log_p += ext_scorer_->beta;
}
Expand Down Expand Up @@ -256,11 +273,12 @@ std::vector<Output> ctc_beam_search_decoder(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results)
{
VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model.");
DecoderState state;
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer, hot_words);
state.next(probs, time_dim, class_dim);
return state.decode(num_results);
}
Expand All @@ -279,6 +297,7 @@ ctc_beam_search_decoder_batch(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results)
{
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
Expand All @@ -298,6 +317,7 @@ ctc_beam_search_decoder_batch(
cutoff_prob,
cutoff_top_n,
ext_scorer,
hot_words,
num_results));
}

Expand Down
10 changes: 9 additions & 1 deletion native_client/ctcdecode/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class DecoderState {
std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_;
TimestepTreeNode timestep_tree_root_{nullptr, 0};
std::unordered_map<std::string, float> hot_words_;

public:
DecoderState() = default;
Expand All @@ -48,7 +49,8 @@ class DecoderState {
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer);
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words);

/* Send data to the decoder
*
Expand Down Expand Up @@ -88,6 +90,8 @@ class DecoderState {
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A vector where each element is a pair of score and decoding result,
Expand All @@ -103,6 +107,7 @@ std::vector<Output> ctc_beam_search_decoder(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1);

/* CTC Beam Search Decoder for batch data
Expand All @@ -117,6 +122,8 @@ std::vector<Output> ctc_beam_search_decoder(
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A 2-D vector where each element is a vector of beam search decoding
Expand All @@ -136,6 +143,7 @@ ctc_beam_search_decoder_batch(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1);

#endif // CTC_BEAM_SEARCH_DECODER_H_
2 changes: 2 additions & 0 deletions native_client/ctcdecode/swigwrapper.i
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
%include <std_string.i>
%include <std_vector.i>
%include <std_shared_ptr.i>
%include <std_unordered_map.i>
%include "numpy.i"

%init %{
Expand All @@ -22,6 +23,7 @@ namespace std {
%template(UnsignedIntVector) vector<unsigned int>;
%template(OutputVector) vector<Output>;
%template(OutputVectorVector) vector<vector<Output>>;
%template(Map) unordered_map<string, float>;
}

%shared_ptr(Scorer);
Expand Down
50 changes: 49 additions & 1 deletion native_client/deepspeech.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,53 @@ DS_EnableExternalScorer(ModelState* aCtx,
return DS_ERR_OK;
}

int
DS_AddHotWord(ModelState* aCtx,
const char* word,
float boost)
{
if (aCtx->scorer_) {
const int size_before = aCtx->hot_words_.size();
aCtx->hot_words_.insert( std::pair<std::string,float> (word, boost) );
const int size_after = aCtx->hot_words_.size();
if (size_before == size_after) {
return DS_ERR_FAIL_INSERT_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}

int
DS_EraseHotWord(ModelState* aCtx,
const char* word)
{
if (aCtx->scorer_) {
const int size_before = aCtx->hot_words_.size();
int err = aCtx->hot_words_.erase(word);
const int size_after = aCtx->hot_words_.size();
if (size_before == size_after) {
return DS_ERR_FAIL_ERASE_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}

int
DS_ClearHotWords(ModelState* aCtx)
{
if (aCtx->scorer_) {
aCtx->hot_words_.clear();
const int size_after = aCtx->hot_words_.size();
if (size_after != 0) {
return DS_ERR_FAIL_CLEAR_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}

int
DS_DisableExternalScorer(ModelState* aCtx)
{
Expand Down Expand Up @@ -390,7 +437,8 @@ DS_CreateStream(ModelState* aCtx,
aCtx->beam_width_,
cutoff_prob,
cutoff_top_n,
aCtx->scorer_);
aCtx->scorer_,
aCtx->hot_words_);

*retval = ctx.release();
return DS_ERR_OK;
Expand Down
Loading

0 comments on commit 1eb155e

Please sign in to comment.