Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable hot-word boosting #3297

Merged
merged 48 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
26155ac
enable hot-word boosting
JRMeyer Aug 31, 2020
3745a51
more consistent ordering of CLI arguments
JRMeyer Sep 1, 2020
9a3be21
progress on review
JRMeyer Sep 2, 2020
2b44f74
use map instead of set for hot-words, move string logic to client.cc
JRMeyer Sep 2, 2020
42f1f3a
typo bug
JRMeyer Sep 3, 2020
103ee93
pointer things?
JRMeyer Sep 3, 2020
81157f5
use map for hotwords, better string splitting
JRMeyer Sep 10, 2020
ba80943
add the boost, not multiply
JRMeyer Sep 10, 2020
ff74bd5
cleaning up
JRMeyer Sep 10, 2020
b997bab
cleaning whitespace
JRMeyer Sep 10, 2020
6fbad16
remove <set> inclusion
JRMeyer Sep 10, 2020
96cd43d
change typo set-->map
JRMeyer Sep 10, 2020
d3a5378
rename boost_coefficient to boost
JRMeyer Sep 10, 2020
cdf44aa
add hot_words to python bindings
JRMeyer Sep 10, 2020
d8a779d
missing hot_words
JRMeyer Sep 10, 2020
b047db2
include map in swigwrapper.i
JRMeyer Sep 10, 2020
9e8ff99
add Map template to swigwrapper.i
JRMeyer Sep 10, 2020
0fc3521
emacs intermediate file
JRMeyer Sep 14, 2020
c64c68b
map things
JRMeyer Sep 14, 2020
c805ab7
map-->unordered_map
JRMeyer Sep 14, 2020
8c611bc
typu
JRMeyer Sep 14, 2020
97b0416
typu
JRMeyer Sep 14, 2020
82a582d
use dict() not None
JRMeyer Sep 14, 2020
6155e43
error out if hot_words without scorer
JRMeyer Sep 14, 2020
6df4297
two new functions: remove hot-word and clear all hot-words
JRMeyer Sep 15, 2020
22af3c6
starting to work on better error messages
JRMeyer Sep 18, 2020
c90b054
better error handling + .Net ERR codes
JRMeyer Sep 21, 2020
b69a99c
allow for negative boosts:)
JRMeyer Sep 21, 2020
753b62f
adding TC test for hot-words
JRMeyer Sep 21, 2020
a7a6dcc
add hot-words to python client, make TC test hot-words everywhere
JRMeyer Sep 21, 2020
9c437c3
only run TC tests for C++ and Python
JRMeyer Sep 21, 2020
63ccc4a
fully expose API in python bindings
JRMeyer Sep 22, 2020
44bf59d
expose API in Java (thanks spectie!)
JRMeyer Sep 22, 2020
4866341
expose API in dotnet (thanks spectie!)
JRMeyer Sep 22, 2020
3d0b67d
expose API in javascript (thanks spectie!)
JRMeyer Sep 22, 2020
6636a65
java lol
JRMeyer Sep 22, 2020
ecb3d27
typo in javascript
JRMeyer Sep 22, 2020
f742e05
commenting
JRMeyer Sep 22, 2020
37a27f6
java error codes from swig
JRMeyer Sep 22, 2020
90f1611
java docs from SWIG
JRMeyer Sep 22, 2020
102445c
java and dotnet issues
JRMeyer Sep 22, 2020
fe7275e
add hotword test to android tests
JRMeyer Sep 22, 2020
5432f56
dotnet fixes from carlos
JRMeyer Sep 22, 2020
4945380
add DS_BINARY_PREFIX to tc-asserts.sh for hotwords command
JRMeyer Sep 23, 2020
ae1d39f
make sure lm is on android for hotword test
JRMeyer Sep 24, 2020
90c6a87
path to android model + nit
JRMeyer Sep 24, 2020
019a514
path
JRMeyer Sep 24, 2020
a382783
path
JRMeyer Sep 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion native_client/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ int json_candidate_transcripts = 3;

int stream_size = 0;

char* hot_words = NULL;

void PrintHelp(const char* bin)
{
std::cout <<
Expand All @@ -56,6 +58,7 @@ void PrintHelp(const char* bin)
"\t--json\t\t\t\tExtended output, shows word timings as JSON\n"
"\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n"
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
"\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n"
"\t--help\t\t\t\tShow help\n"
"\t--version\t\t\tPrint version and exits\n";
char* version = DS_Version();
Expand All @@ -66,7 +69,7 @@ void PrintHelp(const char* bin)

bool ProcessArgs(int argc, char** argv)
{
const char* const short_opts = "m:l:a:b:c:d:tejs:vh";
const char* const short_opts = "m:l:a:b:c:d:tejs:w:vh";
const option long_opts[] = {
{"model", required_argument, nullptr, 'm'},
{"scorer", required_argument, nullptr, 'l'},
Expand All @@ -79,6 +82,7 @@ bool ProcessArgs(int argc, char** argv)
{"json", no_argument, nullptr, 'j'},
{"candidate_transcripts", required_argument, nullptr, 150},
{"stream", required_argument, nullptr, 's'},
{"hot_words", required_argument, nullptr, 'w'},
{"version", no_argument, nullptr, 'v'},
{"help", no_argument, nullptr, 'h'},
{nullptr, no_argument, nullptr, 0}
Expand Down Expand Up @@ -144,6 +148,10 @@ bool ProcessArgs(int argc, char** argv)
has_versions = true;
break;

case 'w':
hot_words = optarg;
break;

case 'h': // -h or --help
case '?': // Unrecognized option
default:
Expand Down
33 changes: 33 additions & 0 deletions native_client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,22 @@ ProcessFile(ModelState* context, const char* path, bool show_times)
}
}

std::vector<std::string>
SplitStringOnDelim(std::string in_string, std::string delim)
{
std::vector<std::string> out_vector;
char * tmp_str = new char[in_string.size() + 1];
std::copy(in_string.begin(), in_string.end(), tmp_str);
tmp_str[in_string.size()] = '\0';
const char* token = strtok(tmp_str, delim.c_str());
while( token != NULL ) {
out_vector.push_back(token);
token = strtok(NULL, delim.c_str());
}
delete[] tmp_str;
return out_vector;
}

int
main(int argc, char **argv)
{
Expand Down Expand Up @@ -432,6 +448,23 @@ main(int argc, char **argv)
}
// sphinx-doc: c_ref_model_stop

if (hot_words) {
std::vector<std::string> hot_words_ = SplitStringOnDelim(hot_words, ",");
for ( std::string hot_word_ : hot_words_ ) {
std::vector<std::string> pair_ = SplitStringOnDelim(hot_word_, ":");
const char* word = (pair_[0]).c_str();
// the strtof function will return 0 in case of non numeric characters
// so, check the boost string before we turn it into a float
bool boost_is_valid = (pair_[1].find_first_not_of("-.0123456789") == std::string::npos);
float boost = strtof((pair_[1]).c_str(),0);
status = DS_AddHotWord(ctx, word, boost);
if (status != 0 || !boost_is_valid) {
fprintf(stderr, "Could not enable hot-word.\n");
return 1;
}
}
}

#ifndef NO_SOX
// Initialise SOX
assert(sox_init() == SOX_SUCCESS);
Expand Down
10 changes: 8 additions & 2 deletions native_client/ctcdecode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def ctc_beam_search_decoder(probs_seq,
cutoff_prob=1.0,
cutoff_top_n=40,
scorer=None,
hot_words=dict(),
num_results=1):
"""Wrapper for the CTC Beam Search Decoder.
Expand All @@ -116,6 +117,8 @@ def ctc_beam_search_decoder(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
Expand All @@ -124,7 +127,7 @@ def ctc_beam_search_decoder(probs_seq,
"""
beam_results = swigwrapper.ctc_beam_search_decoder(
probs_seq, alphabet, beam_size, cutoff_prob, cutoff_top_n,
scorer, num_results)
scorer, hot_words, num_results)
beam_results = [(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
return beam_results

Expand All @@ -137,6 +140,7 @@ def ctc_beam_search_decoder_batch(probs_seq,
cutoff_prob=1.0,
cutoff_top_n=40,
scorer=None,
hot_words=dict(),
num_results=1):
"""Wrapper for the batched CTC beam search decoder.
Expand All @@ -161,13 +165,15 @@ def ctc_beam_search_decoder_batch(probs_seq,
:param scorer: External scorer for partially decoded sentence, e.g. word
count or language model.
:type scorer: Scorer
:param hot_words: Map of words (keys) to their assigned boosts (values)
:type hot_words: map{string:float}
:param num_results: Number of beams to return.
:type num_results: int
:return: List of tuples of confidence and sentence as decoding
results, in descending order of the confidence.
:rtype: list
"""
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, num_results)
batch_beam_results = swigwrapper.ctc_beam_search_decoder_batch(probs_seq, seq_lengths, alphabet, beam_size, num_processes, cutoff_prob, cutoff_top_n, scorer, hot_words, num_results)
batch_beam_results = [
[(res.confidence, alphabet.Decode(res.tokens)) for res in beam_results]
for beam_results in batch_beam_results
Expand Down
28 changes: 24 additions & 4 deletions native_client/ctcdecode/ctc_beam_search_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <cmath>
#include <iostream>
#include <limits>
#include <map>
#include <unordered_map>
#include <utility>

#include "decoder_utils.h"
Expand All @@ -18,7 +18,8 @@ DecoderState::init(const Alphabet& alphabet,
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer)
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words)
{
// assign special ids
abs_time_step_ = 0;
Expand All @@ -29,6 +30,7 @@ DecoderState::init(const Alphabet& alphabet,
cutoff_prob_ = cutoff_prob;
cutoff_top_n_ = cutoff_top_n;
ext_scorer_ = ext_scorer;
hot_words_ = hot_words;
start_expanding_ = false;

// init prefixes' root
Expand Down Expand Up @@ -160,8 +162,23 @@ DecoderState::next(const double *probs,
float score = 0.0;
std::vector<std::string> ngram;
ngram = ext_scorer_->make_ngram(prefix_to_score);

float hot_boost = 0.0;
if (!hot_words_.empty()) {
std::unordered_map<std::string, float>::iterator iter;
// increase prob of prefix for every word
// that matches a word in the hot-words list
for (std::string word : ngram) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you measured perf impact with scorers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perf as in WER? or perf as in latency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

latency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not measured the latency effects yet, no.

Are there any TC jobs that do this, or should I profile locally? What do you recommend?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, you'd have to do it locally. Using perf should be quite easy.

iter = hot_words_.find(word);
if ( iter != hot_words_.end() ) {
// increase the log_cond_prob(prefix|LM)
hot_boost += iter->second;
}
}
}

bool bos = ngram.size() < ext_scorer_->get_max_order();
score = ext_scorer_->get_log_cond_prob(ngram, bos) * ext_scorer_->alpha;
score = ( ext_scorer_->get_log_cond_prob(ngram, bos) + hot_boost ) * ext_scorer_->alpha;
log_p += score;
log_p += ext_scorer_->beta;
}
Expand Down Expand Up @@ -256,11 +273,12 @@ std::vector<Output> ctc_beam_search_decoder(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results)
{
VALID_CHECK_EQ(alphabet.GetSize()+1, class_dim, "Number of output classes in acoustic model does not match number of labels in the alphabet file. Alphabet file must be the same one that was used to train the acoustic model.");
DecoderState state;
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer);
state.init(alphabet, beam_size, cutoff_prob, cutoff_top_n, ext_scorer, hot_words);
state.next(probs, time_dim, class_dim);
return state.decode(num_results);
}
Expand All @@ -279,6 +297,7 @@ ctc_beam_search_decoder_batch(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results)
{
VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!");
Expand All @@ -298,6 +317,7 @@ ctc_beam_search_decoder_batch(
cutoff_prob,
cutoff_top_n,
ext_scorer,
hot_words,
num_results));
}

Expand Down
10 changes: 9 additions & 1 deletion native_client/ctcdecode/ctc_beam_search_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class DecoderState {
std::vector<PathTrie*> prefixes_;
std::unique_ptr<PathTrie> prefix_root_;
TimestepTreeNode timestep_tree_root_{nullptr, 0};
std::unordered_map<std::string, float> hot_words_;

public:
DecoderState() = default;
Expand All @@ -48,7 +49,8 @@ class DecoderState {
size_t beam_size,
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer);
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words);

/* Send data to the decoder
*
Expand Down Expand Up @@ -88,6 +90,8 @@ class DecoderState {
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A vector where each element is a pair of score and decoding result,
Expand All @@ -103,6 +107,7 @@ std::vector<Output> ctc_beam_search_decoder(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1);

/* CTC Beam Search Decoder for batch data
Expand All @@ -117,6 +122,8 @@ std::vector<Output> ctc_beam_search_decoder(
* ext_scorer: External scorer to evaluate a prefix, which consists of
* n-gram language model scoring and word insertion term.
* Default null, decoding the input sample without scorer.
* hot_words: A map of hot-words and their corresponding boosts
* The hot-word is a string and the boost is a float.
* num_results: Number of beams to return.
* Return:
* A 2-D vector where each element is a vector of beam search decoding
Expand All @@ -136,6 +143,7 @@ ctc_beam_search_decoder_batch(
double cutoff_prob,
size_t cutoff_top_n,
std::shared_ptr<Scorer> ext_scorer,
std::unordered_map<std::string, float> hot_words,
size_t num_results=1);

#endif // CTC_BEAM_SEARCH_DECODER_H_
2 changes: 2 additions & 0 deletions native_client/ctcdecode/swigwrapper.i
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
%include <std_string.i>
%include <std_vector.i>
%include <std_shared_ptr.i>
%include <std_unordered_map.i>
%include "numpy.i"

%init %{
Expand All @@ -22,6 +23,7 @@ namespace std {
%template(UnsignedIntVector) vector<unsigned int>;
%template(OutputVector) vector<Output>;
%template(OutputVectorVector) vector<vector<Output>>;
%template(Map) unordered_map<string, float>;
}

%shared_ptr(Scorer);
Expand Down
50 changes: 49 additions & 1 deletion native_client/deepspeech.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,53 @@ DS_EnableExternalScorer(ModelState* aCtx,
return DS_ERR_OK;
}

int
DS_AddHotWord(ModelState* aCtx,
const char* word,
float boost)
{
if (aCtx->scorer_) {
const int size_before = aCtx->hot_words_.size();
aCtx->hot_words_.insert( std::pair<std::string,float> (word, boost) );
const int size_after = aCtx->hot_words_.size();
if (size_before == size_after) {
return DS_ERR_FAIL_INSERT_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}

int
DS_EraseHotWord(ModelState* aCtx,
const char* word)
{
if (aCtx->scorer_) {
const int size_before = aCtx->hot_words_.size();
int err = aCtx->hot_words_.erase(word);
const int size_after = aCtx->hot_words_.size();
if (size_before == size_after) {
return DS_ERR_FAIL_ERASE_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}

int
DS_ClearHotWords(ModelState* aCtx)
{
if (aCtx->scorer_) {
aCtx->hot_words_.clear();
const int size_after = aCtx->hot_words_.size();
if (size_after != 0) {
return DS_ERR_FAIL_CLEAR_HOTWORD;
}
return DS_ERR_OK;
}
return DS_ERR_SCORER_NOT_ENABLED;
}

int
DS_DisableExternalScorer(ModelState* aCtx)
{
Expand Down Expand Up @@ -390,7 +437,8 @@ DS_CreateStream(ModelState* aCtx,
aCtx->beam_width_,
cutoff_prob,
cutoff_top_n,
aCtx->scorer_);
aCtx->scorer_,
aCtx->hot_words_);

*retval = ctx.release();
return DS_ERR_OK;
Expand Down
Loading