From e69ff9fd81815f9f3a12602460820e6096417984 Mon Sep 17 00:00:00 2001 From: sarane Date: Mon, 12 Aug 2024 23:11:12 +0530 Subject: [PATCH 01/18] Adding AST support for offline client --- riva/clients/asr/riva_asr_client.cc | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index c9b532f..baeec3c 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -78,6 +78,15 @@ DEFINE_double(stop_threshold, -1., "Threshold value to determine when endpoint d DEFINE_double( stop_threshold_eou, -1., "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_string( + src_lang, "", + "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_string( + dest_lang, "", + "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_string( + task, "transcribe", + "Threshold value for likelihood of blanks before detecting end of utterance"); class RecognizeClient { public: @@ -88,7 +97,7 @@ class RecognizeClient { std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, bool speaker_diarization, int32_t start_history, float start_threshold, int32_t stop_history, - int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou) + int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string src_lang, std::string dest_lang, std::string task) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), @@ -99,7 +108,7 @@ class RecognizeClient { verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), start_history_(start_history), start_threshold_(start_threshold), stop_history_(stop_history), stop_history_eou_(stop_history_eou), - stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou) + stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou), src_lang_(src_lang), dest_lang_(dest_lang), task_(task) { if (!output_filename.empty()) { output_file_.open(output_filename); @@ -216,6 +225,12 @@ class RecognizeClient { config->set_enable_separate_recognition_per_channel(separate_recognition_per_channel_); auto custom_config = config->mutable_custom_configuration(); (*custom_config)["test_key"] = "test_value"; + if(src_lang_ != ""){ + (*custom_config)["src_lang"] = src_lang_; + (*custom_config)["dest_lang"] = dest_lang_; + (*custom_config)["task"] = task_; + } + auto speaker_diarization_config = config->mutable_diarization_config(); speaker_diarization_config->set_enable_speaker_diarization(speaker_diarization_); @@ -418,6 +433,9 @@ class RecognizeClient { int32_t stop_history_eou_; float stop_threshold_; float stop_threshold_eou_; + std::string src_lang_; + std::string dest_lang_; + std::string task_; }; int @@ -452,6 +470,9 @@ main(int argc, char** argv) str_usage << " --stop_history_eou=" << std::endl; str_usage << " --stop_threshold=" << std::endl; str_usage << " --stop_threshold_eou=" << std::endl; + str_usage << " --src_lang=" << std::endl; + str_usage << " --dest_lang=" << std::endl; + str_usage << " --task=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -499,7 +520,7 @@ main(int argc, char** argv) FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold, - FLAGS_stop_threshold_eou); + FLAGS_stop_threshold_eou, FLAGS_src_lang, FLAGS_dest_lang, FLAGS_task); // Preload all wav files, sort by size to reduce tail effects std::vector> all_wav; From 05c45511b10a4cfdb8e6f373dae30d87da9b75ba Mon Sep 17 00:00:00 2001 From: sarane Date: Thu, 22 Aug 2024 22:27:51 +0530 Subject: [PATCH 02/18] Updating cli keywords --- riva/clients/asr/riva_asr_client.cc | 30 ++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index baeec3c..450d183 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -79,14 +79,14 @@ DEFINE_double( stop_threshold_eou, -1., "Threshold value for likelihood of blanks before detecting end of utterance"); DEFINE_string( - src_lang, "", - "Threshold value for likelihood of blanks before detecting end of utterance"); + source_language, "", + "Language of the audio file"); DEFINE_string( - dest_lang, "", - "Threshold value for likelihood of blanks before detecting end of utterance"); + target_language, "", + "Target language for translation"); DEFINE_string( task, "transcribe", - "Threshold value for likelihood of blanks before detecting end of utterance"); + "Task for the model (transcribe/translate)"); class RecognizeClient { public: @@ -97,7 +97,7 @@ class RecognizeClient { std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, bool speaker_diarization, int32_t start_history, float start_threshold, int32_t stop_history, - int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string src_lang, std::string dest_lang, std::string task) + int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string source_language, std::string target_language, std::string task) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), @@ -108,7 +108,7 @@ class RecognizeClient { verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), start_history_(start_history), start_threshold_(start_threshold), stop_history_(stop_history), stop_history_eou_(stop_history_eou), - stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou), src_lang_(src_lang), dest_lang_(dest_lang), task_(task) + stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou), source_language_(source_language), target_language_(target_language), task_(task) { if (!output_filename.empty()) { output_file_.open(output_filename); @@ -225,9 +225,9 @@ class RecognizeClient { config->set_enable_separate_recognition_per_channel(separate_recognition_per_channel_); auto custom_config = config->mutable_custom_configuration(); (*custom_config)["test_key"] = "test_value"; - if(src_lang_ != ""){ - (*custom_config)["src_lang"] = src_lang_; - (*custom_config)["dest_lang"] = dest_lang_; + if(source_language_ != ""){ + (*custom_config)["source_language"] = source_language_; + (*custom_config)["target_language"] = target_language_; (*custom_config)["task"] = task_; } @@ -433,8 +433,8 @@ class RecognizeClient { int32_t stop_history_eou_; float stop_threshold_; float stop_threshold_eou_; - std::string src_lang_; - std::string dest_lang_; + std::string source_language_; + std::string target_language_; std::string task_; }; @@ -470,8 +470,8 @@ main(int argc, char** argv) str_usage << " --stop_history_eou=" << std::endl; str_usage << " --stop_threshold=" << std::endl; str_usage << " --stop_threshold_eou=" << std::endl; - str_usage << " --src_lang=" << std::endl; - str_usage << " --dest_lang=" << std::endl; + str_usage << " --source_language=" << std::endl; + str_usage << " --target_language=" << std::endl; str_usage << " --task=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -520,7 +520,7 @@ main(int argc, char** argv) FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold, - FLAGS_stop_threshold_eou, FLAGS_src_lang, FLAGS_dest_lang, FLAGS_task); + FLAGS_stop_threshold_eou, FLAGS_source_language, FLAGS_target_language, FLAGS_task); // Preload all wav files, sort by size to reduce tail effects std::vector> all_wav; From a6f4eae0cc15e7c5199daf14a3f9fa603fd4bc78 Mon Sep 17 00:00:00 2001 From: shiralal Date: Wed, 28 Aug 2024 15:57:54 +0530 Subject: [PATCH 03/18] Add DNT (do not translate) support in S2S and S2T clients (#91) --- WORKSPACE | 2 +- riva/clients/asr/riva_asr_client.cc | 2 +- riva/clients/asr/riva_asr_client_helper.cc | 18 +-- riva/clients/asr/riva_asr_client_helper.h | 2 +- .../clients/asr/streaming_recognize_client.cc | 2 +- .../nmt/riva_nmt_streaming_s2s_client.cc | 139 ++++++++++-------- .../nmt/riva_nmt_streaming_s2t_client.cc | 6 +- riva/clients/nmt/streaming_s2s_client.cc | 11 +- riva/clients/nmt/streaming_s2s_client.h | 14 +- riva/clients/nmt/streaming_s2t_client.cc | 10 +- riva/clients/nmt/streaming_s2t_client.h | 9 +- riva/clients/tts/riva_tts_client.cc | 6 +- riva/clients/tts/riva_tts_perf_client.cc | 6 +- 13 files changed, 131 insertions(+), 96 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index f73d9e8..d09b312 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -70,7 +70,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/nvidia-riva/common.git", - commit = "a69334c778faf7ecb9e496b084184b17ddc53f3c" + commit = "a7d342c223106ccb84a812f189178afa6e68c123" ) http_archive( diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index c9b532f..0851b68 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -110,7 +110,7 @@ class RecognizeClient { } } - boosted_phrases_ = ReadBoostedPhrases(boosted_phrases_file); + boosted_phrases_ = ReadPhrasesFromFile(boosted_phrases_file); } ~RecognizeClient() diff --git a/riva/clients/asr/riva_asr_client_helper.cc b/riva/clients/asr/riva_asr_client_helper.cc index 946399c..ebb095d 100644 --- a/riva/clients/asr/riva_asr_client_helper.cc +++ b/riva/clients/asr/riva_asr_client_helper.cc @@ -7,23 +7,23 @@ #include "riva_asr_client_helper.h" std::vector -ReadBoostedPhrases(const std::string& boosted_phrases_file) +ReadPhrasesFromFile(const std::string& phrases_file) { - std::vector boosted_phrases; - if (!boosted_phrases_file.empty()) { - std::ifstream infile(boosted_phrases_file); + std::vector phrases; + if (!phrases_file.empty()) { + std::ifstream infile(phrases_file); if (infile.is_open()) { - std::string boosted_phrase; - while (getline(infile, boosted_phrase)) { - boosted_phrases.push_back(boosted_phrase); + std::string phrase; + while (getline(infile, phrase)) { + phrases.push_back(phrase); } } else { - std::string err = "Could not open file " + boosted_phrases_file; + std::string err = "Could not open file " + phrases_file; throw std::runtime_error(err); } } - return boosted_phrases; + return phrases; } bool diff --git a/riva/clients/asr/riva_asr_client_helper.h b/riva/clients/asr/riva_asr_client_helper.h index 1bff487..082dd3c 100644 --- a/riva/clients/asr/riva_asr_client_helper.h +++ b/riva/clients/asr/riva_asr_client_helper.h @@ -19,7 +19,7 @@ namespace nr = nvidia::riva; namespace nr_asr = nvidia::riva::asr; -std::vector ReadBoostedPhrases(const std::string& boosted_phrases_file); +std::vector ReadPhrasesFromFile(const std::string& phrases_file); bool WaitUntilReady( std::shared_ptr channel, std::chrono::system_clock::time_point& deadline); diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index f229230..021dc83 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -81,7 +81,7 @@ StreamingRecognizeClient::StreamingRecognizeClient( output_file_.open(output_filename); } - boosted_phrases_ = ReadBoostedPhrases(boosted_phrases_file); + boosted_phrases_ = ReadPhrasesFromFile(boosted_phrases_file); } StreamingRecognizeClient::~StreamingRecognizeClient() diff --git a/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc b/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc index fce44eb..6a67f6d 100644 --- a/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc +++ b/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -23,7 +24,6 @@ #include #include #include -#include #include "client_call.h" #include "riva/clients/utils/grpc.h" @@ -55,6 +55,9 @@ DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in f DEFINE_int32(chunk_duration_ms, 100, "Chunk duration in milliseconds"); DEFINE_string(source_language_code, "en-US", "Language code for the input speech"); DEFINE_string(target_language_code, "en-US", "Language code for the output speech"); +DEFINE_string( + dnt_phrases_file, "", + "File with a list of words and phrases to do not translate. One line per word or phrase."); DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word."); DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words"); DEFINE_bool( @@ -89,91 +92,109 @@ signal_handler(int signal_num) count++; } -bool is_numeric(const std::string& str) { - if (str.empty()) return false; - size_t pos = str.find_first_not_of("0123456789.-+"); - if (pos != std::string::npos && pos != str.length()) { - return false; - } - try { - std::stod(str); - return true; - } catch (const std::invalid_argument&) { - return false; - } catch (const std::out_of_range&) { - return false; - } +bool +is_numeric(const std::string& str) +{ + if (str.empty()) + return false; + size_t pos = str.find_first_not_of("0123456789.-+"); + if (pos != std::string::npos && pos != str.length()) { + return false; + } + try { + std::stod(str); + return true; + } + catch (const std::invalid_argument&) { + return false; + } + catch (const std::out_of_range&) { + return false; + } } -bool in_range_or_error(std::string numeric_part, double min_value, double max_value, std::string type) { +bool +in_range_or_error(std::string numeric_part, double min_value, double max_value, std::string type) +{ double numeric_value = std::stod(numeric_part); if (numeric_value < min_value || numeric_value > max_value) { - std::cerr << "Value not in range [" << min_value << "," << max_value<< "] for " << type << std::endl; + std::cerr << "Value not in range [" << min_value << "," << max_value << "] for " << type + << std::endl; return false; } return true; } -bool validate_tts_prosody_pitch(std::string &value) { +bool +validate_tts_prosody_pitch(std::string& value) +{ if (value.empty()) { return true; } int len = value.size(); - if (value == "default" || value == "x-low" || value == "low" || value == "medium" || value == "high" || value == "x-high") { + if (value == "default" || value == "x-low" || value == "low" || value == "medium" || + value == "high" || value == "x-high") { return true; } else if ( - (len > 2 && ((value[len-2] == 'H' && value[len-1]=='z') || (value[len-2] == 'h' && value[len-1]=='Z'))) - && is_numeric(value.substr(0, len-2)) - && in_range_or_error(value.substr(0, len-2), -150.0, 150.0, "tts_prosody_pitch")) { + (len > 2 && ((value[len - 2] == 'H' && value[len - 1] == 'z') || + (value[len - 2] == 'h' && value[len - 1] == 'Z'))) && + is_numeric(value.substr(0, len - 2)) && + in_range_or_error(value.substr(0, len - 2), -150.0, 150.0, "tts_prosody_pitch")) { return true; } else if (is_numeric(value) && in_range_or_error(value, -3, 3, "tts_prosody_pitch")) { return true; - } - + } + std::cerr << "Invalid value for tts_prosody_pitch: " << value << std::endl; return false; } -bool validate_tts_prosody_rate(std::string &value) { - if (value.empty()) { - return true; - } +bool +validate_tts_prosody_rate(std::string& value) +{ + if (value.empty()) { + return true; + } - int len = value.size(); - if (value == "default" || value == "x-low" || value == "low" || value == "medium" || - value == "high" || value == "x-high") { - return true; - } else if (len > 1 && value[len - 1] == '%' && is_numeric(value.substr(0, len - 1)) && - in_range_or_error(value.substr(0, len - 1), 25.0, 250.0, "tts_prosody_rate")) { - return true; - } else if (is_numeric(value) && in_range_or_error(value, 25.0, 250.0, "tts_prosody_rate")) { - return true; - } + int len = value.size(); + if (value == "default" || value == "x-low" || value == "low" || value == "medium" || + value == "high" || value == "x-high") { + return true; + } else if ( + len > 1 && value[len - 1] == '%' && is_numeric(value.substr(0, len - 1)) && + in_range_or_error(value.substr(0, len - 1), 25.0, 250.0, "tts_prosody_rate")) { + return true; + } else if (is_numeric(value) && in_range_or_error(value, 25.0, 250.0, "tts_prosody_rate")) { + return true; + } - std::cerr << "Invalid value for tts_prosody_rate: " << value << std::endl; - return false; + std::cerr << "Invalid value for tts_prosody_rate: " << value << std::endl; + return false; } -bool validate_tts_prosody_volume(std::string &value) { - if (value.empty()) { - return true; - } +bool +validate_tts_prosody_volume(std::string& value) +{ + if (value.empty()) { + return true; + } - int len = value.size(); - if (value == "default" || value == "silent" || value == "x-soft" || value == "soft" || - value == "medium" || value == "loud" || value == "x-loud") { - return true; - } else if (len >= 2 && (value[len - 2] == 'd' && value[len - 1] == 'B') && - is_numeric(value.substr(0, len - 2)) && - in_range_or_error(value.substr(0, len - 2), -13.0, 8.0, "tts_prosody_volume")) { - return true; - } else if (is_numeric(value) && in_range_or_error(value, -13.0, 8.0, "tts_prosody_volume")) { - return true; - } + int len = value.size(); + if (value == "default" || value == "silent" || value == "x-soft" || value == "soft" || + value == "medium" || value == "loud" || value == "x-loud") { + return true; + } else if ( + len >= 2 && (value[len - 2] == 'd' && value[len - 1] == 'B') && + is_numeric(value.substr(0, len - 2)) && + in_range_or_error(value.substr(0, len - 2), -13.0, 8.0, "tts_prosody_volume")) { + return true; + } else if (is_numeric(value) && in_range_or_error(value, -13.0, 8.0, "tts_prosody_volume")) { + return true; + } - std::cerr << "Invalid value for tts_prosody_volume: " << value << std::endl; - return false; + std::cerr << "Invalid value for tts_prosody_volume: " << value << std::endl; + return false; } main(int argc, char** argv) @@ -197,6 +218,7 @@ main(int argc, char** argv) << std::endl; str_usage << " --target_language_code=" << std::endl; + str_usage << " --dnt_phrases_file=" << std::endl; str_usage << " --boosted_words_file=" << std::endl; str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; @@ -259,7 +281,8 @@ main(int argc, char** argv) StreamingS2SClient recognize_client( grpc_channel, FLAGS_num_parallel_requests, FLAGS_source_language_code, - FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_automatic_punctuation, + FLAGS_target_language_code, FLAGS_dnt_phrases_file, FLAGS_profanity_filter, + FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_chunk_duration_ms, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_tts_encoding, FLAGS_tts_audio_file, FLAGS_tts_sample_rate, FLAGS_tts_voice_name, diff --git a/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc b/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc index af002db..b717622 100644 --- a/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc +++ b/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc @@ -54,6 +54,8 @@ DEFINE_int32(num_parallel_requests, 1, "Number of parallel requests to keep in f DEFINE_int32(chunk_duration_ms, 100, "Chunk duration in milliseconds"); DEFINE_string(source_language_code, "en-US", "Language code for the input speech"); DEFINE_string(target_language_code, "en-US", "Language code for the output text"); +DEFINE_string( + dnt_phrases_file, "", "File with a list of words to do not translate. One line per word."); DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word."); DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words"); DEFINE_bool( @@ -103,6 +105,7 @@ main(int argc, char** argv) << std::endl; str_usage << " --target_language_code=" << std::endl; + str_usage << " --dnt_phrases_file=" << std::endl; str_usage << " --boosted_words_file=" << std::endl; str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; @@ -147,7 +150,8 @@ main(int argc, char** argv) StreamingS2TClient recognize_client( grpc_channel, FLAGS_num_parallel_requests, FLAGS_source_language_code, - FLAGS_target_language_code, FLAGS_profanity_filter, FLAGS_automatic_punctuation, + FLAGS_target_language_code, FLAGS_dnt_phrases_file, FLAGS_profanity_filter, + FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_chunk_duration_ms, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_nmt_text_file); diff --git a/riva/clients/nmt/streaming_s2s_client.cc b/riva/clients/nmt/streaming_s2s_client.cc index 5648167..5c01b7c 100644 --- a/riva/clients/nmt/streaming_s2s_client.cc +++ b/riva/clients/nmt/streaming_s2s_client.cc @@ -53,9 +53,9 @@ MicrophoneThreadMain( StreamingS2SClient::StreamingS2SClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, + const std::string& dnt_phrases_file, bool profanity_filter, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, const std::string& tts_encoding, const std::string& tts_audio_file, int tts_sample_rate, const std::string& tts_voice_name, std::string& tts_prosody_rate, std::string& tts_prosody_pitch, std::string& tts_prosody_volume) @@ -74,7 +74,8 @@ StreamingS2SClient::StreamingS2SClient( num_streams_finished_.store(0); thread_pool_.reset(new ThreadPool(4 * num_parallel_requests)); - boosted_phrases_ = ReadBoostedPhrases(boosted_phrases_file); + boosted_phrases_ = ReadPhrasesFromFile(boosted_phrases_file); + dnt_phrases_ = ReadPhrasesFromFile(dnt_phrases_file); } StreamingS2SClient::~StreamingS2SClient() {} @@ -115,6 +116,7 @@ StreamingS2SClient::GenerateRequests(std::shared_ptr call) auto translation_config = streaming_s2s_config->mutable_translation_config(); translation_config->set_source_language_code(source_language_code_); translation_config->set_target_language_code(target_language_code_); + *(translation_config->mutable_dnt_phrases()) = {dnt_phrases_.begin(), dnt_phrases_.end()}; // set tts config auto tts_config = streaming_s2s_config->mutable_tts_config(); @@ -372,6 +374,7 @@ StreamingS2SClient::DoStreamingFromMicrophone(const std::string& audio_device, b auto translation_config = s2s_config->mutable_translation_config(); translation_config->set_source_language_code(source_language_code_); translation_config->set_target_language_code(target_language_code_); + // set tts config auto tts_config = s2s_config->mutable_tts_config(); if (tts_encoding_.empty() || tts_encoding_ == "pcm") { diff --git a/riva/clients/nmt/streaming_s2s_client.h b/riva/clients/nmt/streaming_s2s_client.h index 153e740..5e7927f 100644 --- a/riva/clients/nmt/streaming_s2s_client.h +++ b/riva/clients/nmt/streaming_s2s_client.h @@ -50,12 +50,13 @@ class StreamingS2SClient { StreamingS2SClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code_, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, - const std::string& tts_encoding, const std::string& tts_audio_file, int tts_sample_rate, - const std::string& tts_voice_name, std::string& tts_prosody_rate, - std::string& tts_prosody_pitch, std::string& tts_prosody_volume); + const std::string& dnt_words_file, bool profanity_filter, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score, const std::string& tts_encoding, + const std::string& tts_audio_file, int tts_sample_rate, const std::string& tts_voice_name, + std::string& tts_prosody_rate, std::string& tts_prosody_pitch, + std::string& tts_prosody_volume); ~StreamingS2SClient(); @@ -93,6 +94,7 @@ class StreamingS2SClient { std::string tts_voice_name_; std::string source_language_code_; std::string target_language_code_; + std::vector dnt_phrases_; int tts_sample_rate_; bool profanity_filter_; diff --git a/riva/clients/nmt/streaming_s2t_client.cc b/riva/clients/nmt/streaming_s2t_client.cc index b62f132..96ccb58 100644 --- a/riva/clients/nmt/streaming_s2t_client.cc +++ b/riva/clients/nmt/streaming_s2t_client.cc @@ -53,9 +53,9 @@ MicrophoneThreadMain( StreamingS2TClient::StreamingS2TClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, + const std::string& dnt_phrases_file, bool profanity_filter, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, const std::string& nmt_text_file) : stub_(nr_nmt::RivaTranslation::NewStub(channel)), source_language_code_(source_language_code), target_language_code_(target_language_code), profanity_filter_(profanity_filter), @@ -69,7 +69,8 @@ StreamingS2TClient::StreamingS2TClient( num_streams_finished_.store(0); thread_pool_.reset(new ThreadPool(4 * num_parallel_requests)); - boosted_phrases_ = ReadBoostedPhrases(boosted_phrases_file); + boosted_phrases_ = ReadPhrasesFromFile(boosted_phrases_file); + dnt_phrases_ = ReadPhrasesFromFile(dnt_phrases_file); } StreamingS2TClient::~StreamingS2TClient() {} @@ -111,6 +112,7 @@ StreamingS2TClient::GenerateRequests(std::shared_ptr call) auto translation_config = streaming_s2t_config->mutable_translation_config(); translation_config->set_source_language_code(source_language_code_); translation_config->set_target_language_code(target_language_code_); + *(translation_config->mutable_dnt_phrases()) = {dnt_phrases_.begin(), dnt_phrases_.end()}; // set asr config auto streaming_asr_config = streaming_s2t_config->mutable_asr_config(); diff --git a/riva/clients/nmt/streaming_s2t_client.h b/riva/clients/nmt/streaming_s2t_client.h index 931e5d9..7752eac 100644 --- a/riva/clients/nmt/streaming_s2t_client.h +++ b/riva/clients/nmt/streaming_s2t_client.h @@ -49,10 +49,10 @@ class StreamingS2TClient { StreamingS2TClient( std::shared_ptr channel, int32_t num_parallel_requests, const std::string& source_language_code, const std::string& target_language_code, - bool profanity_filter, bool automatic_punctuation, bool separate_recognition_per_channel, - int32_t chunk_duration_ms, bool simulate_realtime, bool verbatim_transcripts, - const std::string& boosted_phrases_file, float boosted_phrases_score, - const std::string& nmt_text_file); + const std::string& dnt_words_file, bool profanity_filter, bool automatic_punctuation, + bool separate_recognition_per_channel, int32_t chunk_duration_ms, bool simulate_realtime, + bool verbatim_transcripts, const std::string& boosted_phrases_file, + float boosted_phrases_score, const std::string& nmt_text_file); ~StreamingS2TClient(); @@ -89,6 +89,7 @@ class StreamingS2TClient { std::string source_language_code_; std::string target_language_code_; + std::vector dnt_phrases_; bool profanity_filter_; int32_t channels_; bool automatic_punctuation_; diff --git a/riva/clients/tts/riva_tts_client.cc b/riva/clients/tts/riva_tts_client.cc index f26088d..83da197 100644 --- a/riva/clients/tts/riva_tts_client.cc +++ b/riva/clients/tts/riva_tts_client.cc @@ -12,8 +12,8 @@ #include #include #include -#include #include +#include #include "riva/clients/utils/grpc.h" #include "riva/proto/riva_tts.grpc.pb.h" @@ -52,7 +52,7 @@ DEFINE_string(custom_dictionary, "", " User dictionary containing graph-to-phone static const std::string LC_enUS = "en-US"; -std::string +std::string ReadUserDictionaryFile(const std::string& dictionary_file) { std::string dictionary_string; @@ -69,7 +69,7 @@ ReadUserDictionaryFile(const std::string& dictionary_file) if (pos != std::string::npos) { std::string key = line.substr(0, pos); - std::string value = std::regex_replace(line.substr(pos+2), std::regex("^ +"), ""); + std::string value = std::regex_replace(line.substr(pos + 2), std::regex("^ +"), ""); // Append the key-value pair to the dictionary string if (!dictionary_string.empty()) { dictionary_string += ","; diff --git a/riva/clients/tts/riva_tts_perf_client.cc b/riva/clients/tts/riva_tts_perf_client.cc index bcd7621..86507f2 100644 --- a/riva/clients/tts/riva_tts_perf_client.cc +++ b/riva/clients/tts/riva_tts_perf_client.cc @@ -14,10 +14,10 @@ #include #include #include +#include #include #include #include -#include #include "riva/clients/utils/grpc.h" #include "riva/proto/riva_tts.grpc.pb.h" @@ -71,7 +71,7 @@ CreateTTS(std::shared_ptr channel) return tts; } -std::string +std::string ReadUserDictionaryFile(const std::string& dictionary_file) { std::string dictionary_string; @@ -88,7 +88,7 @@ ReadUserDictionaryFile(const std::string& dictionary_file) if (pos != std::string::npos) { std::string key = line.substr(0, pos); - std::string value = std::regex_replace(line.substr(pos+2), std::regex("^ +"), ""); + std::string value = std::regex_replace(line.substr(pos + 2), std::regex("^ +"), ""); // Append the key-value pair to the dictionary string if (!dictionary_string.empty()) { dictionary_string += ","; From 9a8b03bc9a41d0e24264a62323dd3fd1897ca392 Mon Sep 17 00:00:00 2001 From: rmittal-github <61574997+rmittal-github@users.noreply.github.com> Date: Wed, 28 Aug 2024 17:52:33 +0530 Subject: [PATCH 04/18] Add list_models option for s2s/s2t clients (#93) --- riva/clients/nmt/riva_nmt_streaming_s2s_client.cc | 15 +++++++++++++++ riva/clients/nmt/riva_nmt_streaming_s2t_client.cc | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc b/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc index 6a67f6d..984ad8e 100644 --- a/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc +++ b/riva/clients/nmt/riva_nmt_streaming_s2s_client.cc @@ -58,6 +58,7 @@ DEFINE_string(target_language_code, "en-US", "Language code for the output speec DEFINE_string( dnt_phrases_file, "", "File with a list of words and phrases to do not translate. One line per word or phrase."); +DEFINE_bool(list_models, false, "List available models on server"); DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word."); DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words"); DEFINE_bool( @@ -219,6 +220,7 @@ main(int argc, char** argv) str_usage << " --target_language_code=" << std::endl; str_usage << " --dnt_phrases_file=" << std::endl; + str_usage << " --list_models" << std::endl; str_usage << " --boosted_words_file=" << std::endl; str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; @@ -267,6 +269,19 @@ main(int argc, char** argv) return 1; } + if (FLAGS_list_models) { + std::unique_ptr nmt_s2s( + nr_nmt::RivaTranslation::NewStub(grpc_channel)); + grpc::ClientContext context; + nr_nmt::AvailableLanguageRequest request; + nr_nmt::AvailableLanguageResponse response; + + request.set_model("s2s_model"); // this is optional, if empty returns all available models/languages + nmt_s2s->ListSupportedLanguagePairs(&context, request, &response); + std::cout << response.DebugString() << std::endl; + return 0; + } + if (!FLAGS_tts_encoding.empty() && FLAGS_tts_encoding != "pcm" && FLAGS_tts_encoding != "opus") { std::cerr << "Unsupported encoding: \'" << FLAGS_tts_encoding << "\'" << std::endl; return -1; diff --git a/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc b/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc index b717622..7173aff 100644 --- a/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc +++ b/riva/clients/nmt/riva_nmt_streaming_s2t_client.cc @@ -56,6 +56,7 @@ DEFINE_string(source_language_code, "en-US", "Language code for the input speech DEFINE_string(target_language_code, "en-US", "Language code for the output text"); DEFINE_string( dnt_phrases_file, "", "File with a list of words to do not translate. One line per word."); +DEFINE_bool(list_models, false, "List available models on server"); DEFINE_string(boosted_words_file, "", "File with a list of words to boost. One line per word."); DEFINE_double(boosted_words_score, 10., "Score by which to boost the boosted words"); DEFINE_bool( @@ -106,6 +107,7 @@ main(int argc, char** argv) str_usage << " --target_language_code=" << std::endl; str_usage << " --dnt_phrases_file=" << std::endl; + str_usage << " --list_models" << std::endl; str_usage << " --boosted_words_file=" << std::endl; str_usage << " --boosted_words_score=" << std::endl; str_usage << " --ssl_cert=" << std::endl; @@ -148,6 +150,19 @@ main(int argc, char** argv) return 1; } + if (FLAGS_list_models) { + std::unique_ptr nmt_s2t( + nr_nmt::RivaTranslation::NewStub(grpc_channel)); + grpc::ClientContext context; + nr_nmt::AvailableLanguageRequest request; + nr_nmt::AvailableLanguageResponse response; + + request.set_model("s2t_model"); // this is optional, if empty returns all available models/languages + nmt_s2t->ListSupportedLanguagePairs(&context, request, &response); + std::cout << response.DebugString() << std::endl; + return 0; + } + StreamingS2TClient recognize_client( grpc_channel, FLAGS_num_parallel_requests, FLAGS_source_language_code, FLAGS_target_language_code, FLAGS_dnt_phrases_file, FLAGS_profanity_filter, From 6019c70037e29dcce8c7f6d7d5fa9d68ae1a4a02 Mon Sep 17 00:00:00 2001 From: sarane Date: Fri, 30 Aug 2024 14:43:34 +0530 Subject: [PATCH 05/18] exposing custom_configuration to cli --- riva/clients/asr/riva_asr_client.cc | 33 +++++++------------ riva/clients/asr/riva_asr_client_helper.cc | 18 ++++++++++ riva/clients/asr/riva_asr_client_helper.h | 3 ++ riva/clients/asr/riva_streaming_asr_client.cc | 6 +++- .../clients/asr/streaming_recognize_client.cc | 8 +++-- riva/clients/asr/streaming_recognize_client.h | 3 +- 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index 450d183..1605928 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -79,14 +79,8 @@ DEFINE_double( stop_threshold_eou, -1., "Threshold value for likelihood of blanks before detecting end of utterance"); DEFINE_string( - source_language, "", - "Language of the audio file"); -DEFINE_string( - target_language, "", - "Target language for translation"); -DEFINE_string( - task, "transcribe", - "Task for the model (transcribe/translate)"); + custom_configuration, "", + "Add custom configurations to be sent to the custom backends. "); class RecognizeClient { public: @@ -97,7 +91,7 @@ class RecognizeClient { std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, bool speaker_diarization, int32_t start_history, float start_threshold, int32_t stop_history, - int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string source_language, std::string target_language, std::string task) + int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string custom_configuration) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), @@ -108,7 +102,7 @@ class RecognizeClient { verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), start_history_(start_history), start_threshold_(start_threshold), stop_history_(stop_history), stop_history_eou_(stop_history_eou), - stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou), source_language_(source_language), target_language_(target_language), task_(task) + stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou), custom_configuration_(custom_configuration) { if (!output_filename.empty()) { output_file_.open(output_filename); @@ -224,12 +218,11 @@ class RecognizeClient { config->set_verbatim_transcripts(verbatim_transcripts_); config->set_enable_separate_recognition_per_channel(separate_recognition_per_channel_); auto custom_config = config->mutable_custom_configuration(); - (*custom_config)["test_key"] = "test_value"; - if(source_language_ != ""){ - (*custom_config)["source_language"] = source_language_; - (*custom_config)["target_language"] = target_language_; - (*custom_config)["task"] = task_; + std::unordered_map custom_configuration_map = ReadCustomConfiguration(custom_configuration_); + for (auto& it: custom_configuration_map) { + (*custom_config)[it.first] = it.second; } + (*custom_config)["test_key"] = "test_value"; auto speaker_diarization_config = config->mutable_diarization_config(); @@ -433,9 +426,7 @@ class RecognizeClient { int32_t stop_history_eou_; float stop_threshold_; float stop_threshold_eou_; - std::string source_language_; - std::string target_language_; - std::string task_; + std::string custom_configuration_; }; int @@ -470,9 +461,7 @@ main(int argc, char** argv) str_usage << " --stop_history_eou=" << std::endl; str_usage << " --stop_threshold=" << std::endl; str_usage << " --stop_threshold_eou=" << std::endl; - str_usage << " --source_language=" << std::endl; - str_usage << " --target_language=" << std::endl; - str_usage << " --task=" << std::endl; + str_usage << " --custom_configuration=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -520,7 +509,7 @@ main(int argc, char** argv) FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold, - FLAGS_stop_threshold_eou, FLAGS_source_language, FLAGS_target_language, FLAGS_task); + FLAGS_stop_threshold_eou, FLAGS_custom_configuration); // Preload all wav files, sort by size to reduce tail effects std::vector> all_wav; diff --git a/riva/clients/asr/riva_asr_client_helper.cc b/riva/clients/asr/riva_asr_client_helper.cc index 946399c..de13f1d 100644 --- a/riva/clients/asr/riva_asr_client_helper.cc +++ b/riva/clients/asr/riva_asr_client_helper.cc @@ -189,4 +189,22 @@ PrintResult( std::cout << "Audio processed: " << output_result.audio_processed << " sec." << std::endl; std::cout << "-----------------------------------------------------------" << std::endl; std::cout << std::endl; +} + +std::unordered_map ReadCustomConfiguration(std::string& custom_configuration){ + custom_configuration = absl::StrReplaceAll(custom_configuration, {{" ", ""}}); + std::unordered_map custom_configuration_map; + // Split the input string by commas to get key-value pairs + std::vector pairs = absl::StrSplit(custom_configuration, ','); + for (const auto& pair : pairs) { + // Split each pair by colon to separate the key and value + std::vector key_value = absl::StrSplit(pair, absl::ByString(":")); + if (key_value.size() == 2) { + if (custom_configuration_map.find(std::string(key_value[0])) == custom_configuration_map.end()) { + // If the key does not exist, insert the new key-value pair + custom_configuration_map[std::string(key_value[0])] = std::string(key_value[1]); + } + } + } + return custom_configuration_map; } \ No newline at end of file diff --git a/riva/clients/asr/riva_asr_client_helper.h b/riva/clients/asr/riva_asr_client_helper.h index 1bff487..ea65b79 100644 --- a/riva/clients/asr/riva_asr_client_helper.h +++ b/riva/clients/asr/riva_asr_client_helper.h @@ -14,6 +14,7 @@ #include #include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "riva/proto/riva_asr.grpc.pb.h" namespace nr = nvidia::riva; @@ -51,3 +52,5 @@ void AppendResult( void PrintResult( Results& output_result, const std::string& filename, bool word_time_offsets, bool speaker_diarization); + +std::unordered_map ReadCustomConfiguration(std::string& custom_configuration); \ No newline at end of file diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 5e27c38..66a4d00 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -89,6 +89,9 @@ DEFINE_int32( DEFINE_double( stop_threshold_eou, -1., "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_string( + custom_configuration, "", + "Add custom configurations to be sent to the custom backends. "); void signal_handler(int signal_num) @@ -137,6 +140,7 @@ main(int argc, char** argv) str_usage << " --stop_history_eou=" << std::endl; str_usage << " --stop_threshold=" << std::endl; str_usage << " --stop_threshold_eou=" << std::endl; + str_usage << " --custom_configuration=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); @@ -185,7 +189,7 @@ main(int argc, char** argv) FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, - FLAGS_stop_threshold, FLAGS_stop_threshold_eou); + FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration); if (FLAGS_audio_file.size()) { return recognize_client.DoStreamingFromFile( diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index f229230..79b42df 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -59,7 +59,7 @@ StreamingRecognizeClient::StreamingRecognizeClient( std::string output_filename, std::string model_name, bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou, - float stop_threshold, float stop_threshold_eou) + float stop_threshold, float stop_threshold_eou, std::string custom_configuration) : print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), @@ -71,7 +71,7 @@ StreamingRecognizeClient::StreamingRecognizeClient( verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), start_history_(start_history), start_threshold_(start_threshold), stop_history_(stop_history), stop_history_eou_(stop_history_eou), stop_threshold_(stop_threshold), - stop_threshold_eou_(stop_threshold_eou) + stop_threshold_eou_(stop_threshold_eou), custom_configuration_(custom_configuration) { num_active_streams_.store(0); num_streams_finished_.store(0); @@ -165,6 +165,10 @@ StreamingRecognizeClient::GenerateRequests(std::shared_ptr call) config->set_enable_automatic_punctuation(automatic_punctuation_); config->set_enable_separate_recognition_per_channel(separate_recognition_per_channel_); auto custom_config = config->mutable_custom_configuration(); + std::unordered_map custom_configuration_map = ReadCustomConfiguration(custom_configuration_); + for (auto& it: custom_configuration_map) { + (*custom_config)[it.first] = it.second; + } (*custom_config)["test_key"] = "test_value"; config->set_verbatim_transcripts(verbatim_transcripts_); if (model_name_ != "") { diff --git a/riva/clients/asr/streaming_recognize_client.h b/riva/clients/asr/streaming_recognize_client.h index 98c4add..56d66bc 100644 --- a/riva/clients/asr/streaming_recognize_client.h +++ b/riva/clients/asr/streaming_recognize_client.h @@ -49,7 +49,7 @@ class StreamingRecognizeClient { bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou, float stop_threshold, - float stop_threshold_eou); + float stop_threshold_eou, std::string custom_configuration); ~StreamingRecognizeClient(); @@ -125,4 +125,5 @@ class StreamingRecognizeClient { int32_t stop_history_eou_; float stop_threshold_; float stop_threshold_eou_; + std::string custom_configuration_; }; \ No newline at end of file From 92013447e8fd2320462df55a261ea98b4144ccc9 Mon Sep 17 00:00:00 2001 From: sarane Date: Fri, 30 Aug 2024 14:49:37 +0530 Subject: [PATCH 06/18] Code formatting --- riva/clients/asr/riva_asr_client.cc | 15 +++++++++------ riva/clients/asr/riva_asr_client_helper.cc | 7 +++++-- riva/clients/asr/riva_asr_client_helper.h | 3 ++- riva/clients/asr/streaming_recognize_client.cc | 7 ++++--- 4 files changed, 20 insertions(+), 12 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index 1605928..8ab50f9 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -91,7 +91,8 @@ class RecognizeClient { std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, bool speaker_diarization, int32_t start_history, float start_threshold, int32_t stop_history, - int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, std::string custom_configuration) + int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, + std::string custom_configuration) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), @@ -102,7 +103,8 @@ class RecognizeClient { verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), start_history_(start_history), start_threshold_(start_threshold), stop_history_(stop_history), stop_history_eou_(stop_history_eou), - stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou), custom_configuration_(custom_configuration) + stop_threshold_(stop_threshold), stop_threshold_eou_(stop_threshold_eou), + custom_configuration_(custom_configuration) { if (!output_filename.empty()) { output_file_.open(output_filename); @@ -218,12 +220,13 @@ class RecognizeClient { config->set_verbatim_transcripts(verbatim_transcripts_); config->set_enable_separate_recognition_per_channel(separate_recognition_per_channel_); auto custom_config = config->mutable_custom_configuration(); - std::unordered_map custom_configuration_map = ReadCustomConfiguration(custom_configuration_); - for (auto& it: custom_configuration_map) { - (*custom_config)[it.first] = it.second; + std::unordered_map custom_configuration_map = + ReadCustomConfiguration(custom_configuration_); + for (auto& it : custom_configuration_map) { + (*custom_config)[it.first] = it.second; } (*custom_config)["test_key"] = "test_value"; - + auto speaker_diarization_config = config->mutable_diarization_config(); speaker_diarization_config->set_enable_speaker_diarization(speaker_diarization_); diff --git a/riva/clients/asr/riva_asr_client_helper.cc b/riva/clients/asr/riva_asr_client_helper.cc index de13f1d..283647d 100644 --- a/riva/clients/asr/riva_asr_client_helper.cc +++ b/riva/clients/asr/riva_asr_client_helper.cc @@ -191,7 +191,9 @@ PrintResult( std::cout << std::endl; } -std::unordered_map ReadCustomConfiguration(std::string& custom_configuration){ +std::unordered_map +ReadCustomConfiguration(std::string& custom_configuration) +{ custom_configuration = absl::StrReplaceAll(custom_configuration, {{" ", ""}}); std::unordered_map custom_configuration_map; // Split the input string by commas to get key-value pairs @@ -200,7 +202,8 @@ std::unordered_map ReadCustomConfiguration(std::string // Split each pair by colon to separate the key and value std::vector key_value = absl::StrSplit(pair, absl::ByString(":")); if (key_value.size() == 2) { - if (custom_configuration_map.find(std::string(key_value[0])) == custom_configuration_map.end()) { + if (custom_configuration_map.find(std::string(key_value[0])) == + custom_configuration_map.end()) { // If the key does not exist, insert the new key-value pair custom_configuration_map[std::string(key_value[0])] = std::string(key_value[1]); } diff --git a/riva/clients/asr/riva_asr_client_helper.h b/riva/clients/asr/riva_asr_client_helper.h index ea65b79..576b2b3 100644 --- a/riva/clients/asr/riva_asr_client_helper.h +++ b/riva/clients/asr/riva_asr_client_helper.h @@ -53,4 +53,5 @@ void PrintResult( Results& output_result, const std::string& filename, bool word_time_offsets, bool speaker_diarization); -std::unordered_map ReadCustomConfiguration(std::string& custom_configuration); \ No newline at end of file +std::unordered_map ReadCustomConfiguration( + std::string& custom_configuration); diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index 79b42df..73f39f7 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -165,9 +165,10 @@ StreamingRecognizeClient::GenerateRequests(std::shared_ptr call) config->set_enable_automatic_punctuation(automatic_punctuation_); config->set_enable_separate_recognition_per_channel(separate_recognition_per_channel_); auto custom_config = config->mutable_custom_configuration(); - std::unordered_map custom_configuration_map = ReadCustomConfiguration(custom_configuration_); - for (auto& it: custom_configuration_map) { - (*custom_config)[it.first] = it.second; + std::unordered_map custom_configuration_map = + ReadCustomConfiguration(custom_configuration_); + for (auto& it : custom_configuration_map) { + (*custom_config)[it.first] = it.second; } (*custom_config)["test_key"] = "test_value"; config->set_verbatim_transcripts(verbatim_transcripts_); From 014180a139c85ff7d31cec7d44d0f1ff513ee0fd Mon Sep 17 00:00:00 2001 From: sarane Date: Mon, 2 Sep 2024 12:27:05 +0530 Subject: [PATCH 07/18] Throwing expection for invalid custom configurations --- riva/clients/asr/riva_asr_client.cc | 4 +--- riva/clients/asr/riva_asr_client_helper.cc | 21 +++++++++++++------ riva/clients/asr/riva_streaming_asr_client.cc | 2 +- .../clients/asr/streaming_recognize_client.cc | 1 - 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index 8ab50f9..ed49623 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -80,7 +80,7 @@ DEFINE_double( "Threshold value for likelihood of blanks before detecting end of utterance"); DEFINE_string( custom_configuration, "", - "Add custom configurations to be sent to the custom backends. "); + "Add custom configurations to be sent to the custom backends. "); class RecognizeClient { public: @@ -225,8 +225,6 @@ class RecognizeClient { for (auto& it : custom_configuration_map) { (*custom_config)[it.first] = it.second; } - (*custom_config)["test_key"] = "test_value"; - auto speaker_diarization_config = config->mutable_diarization_config(); speaker_diarization_config->set_enable_speaker_diarization(speaker_diarization_); diff --git a/riva/clients/asr/riva_asr_client_helper.cc b/riva/clients/asr/riva_asr_client_helper.cc index 283647d..6fc4e54 100644 --- a/riva/clients/asr/riva_asr_client_helper.cc +++ b/riva/clients/asr/riva_asr_client_helper.cc @@ -197,15 +197,24 @@ ReadCustomConfiguration(std::string& custom_configuration) custom_configuration = absl::StrReplaceAll(custom_configuration, {{" ", ""}}); std::unordered_map custom_configuration_map; // Split the input string by commas to get key-value pairs + std::vector pairs = absl::StrSplit(custom_configuration, ','); for (const auto& pair : pairs) { // Split each pair by colon to separate the key and value - std::vector key_value = absl::StrSplit(pair, absl::ByString(":")); - if (key_value.size() == 2) { - if (custom_configuration_map.find(std::string(key_value[0])) == - custom_configuration_map.end()) { - // If the key does not exist, insert the new key-value pair - custom_configuration_map[std::string(key_value[0])] = std::string(key_value[1]); + if (pair != "") { + std::vector key_value = absl::StrSplit(pair, absl::ByString(":")); + if (key_value.size() == 2) { + if (custom_configuration_map.find(std::string(key_value[0])) == + custom_configuration_map.end()) { + // If the key does not exist, insert the new key-value pair + custom_configuration_map[std::string(key_value[0])] = std::string(key_value[1]); + } else { + std::string err = "custom_configuration key already used " + std::string(key_value[0]); + throw std::runtime_error(err); + } + } else { + std::string err = "Invalid custom_configuration key:value pair " + std::string(pair); + throw std::runtime_error(err); } } } diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 66a4d00..07e3c5b 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -91,7 +91,7 @@ DEFINE_double( "Threshold value for likelihood of blanks before detecting end of utterance"); DEFINE_string( custom_configuration, "", - "Add custom configurations to be sent to the custom backends. "); + "Add custom configurations to be sent to the custom backends. "); void signal_handler(int signal_num) diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index 73f39f7..98c4c5e 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -170,7 +170,6 @@ StreamingRecognizeClient::GenerateRequests(std::shared_ptr call) for (auto& it : custom_configuration_map) { (*custom_config)[it.first] = it.second; } - (*custom_config)["test_key"] = "test_value"; config->set_verbatim_transcripts(verbatim_transcripts_); if (model_name_ != "") { config->set_model(model_name_); From b92668c955ea4c61672393f10ad2ddedcf8e5b37 Mon Sep 17 00:00:00 2001 From: sarane Date: Mon, 2 Sep 2024 14:44:35 +0530 Subject: [PATCH 08/18] Updating help message --- riva/clients/asr/riva_asr_client.cc | 4 ++-- riva/clients/asr/riva_streaming_asr_client.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index d0af223..2834339 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -80,7 +80,7 @@ DEFINE_double( "Threshold value for likelihood of blanks before detecting end of utterance"); DEFINE_string( custom_configuration, "", - "Add custom configurations to be sent to the custom backends. "); + "Custom configurations to be sent to the server as key value pairs "); class RecognizeClient { public: @@ -462,7 +462,7 @@ main(int argc, char** argv) str_usage << " --stop_history_eou=" << std::endl; str_usage << " --stop_threshold=" << std::endl; str_usage << " --stop_threshold_eou=" << std::endl; - str_usage << " --custom_configuration=" << std::endl; + str_usage << " --custom_configuration=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 07e3c5b..45af156 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -91,7 +91,7 @@ DEFINE_double( "Threshold value for likelihood of blanks before detecting end of utterance"); DEFINE_string( custom_configuration, "", - "Add custom configurations to be sent to the custom backends. "); + "Custom configurations to be sent to the server as key value pairs "); void signal_handler(int signal_num) @@ -140,7 +140,7 @@ main(int argc, char** argv) str_usage << " --stop_history_eou=" << std::endl; str_usage << " --stop_threshold=" << std::endl; str_usage << " --stop_threshold_eou=" << std::endl; - str_usage << " --custom_configuration=" << std::endl; + str_usage << " --custom_configuration=" << std::endl; gflags::SetUsageMessage(str_usage.str()); gflags::SetVersionString(::riva::utils::kBuildScmRevision); From 6131121fda820fed9f8c875d496551ae58103799 Mon Sep 17 00:00:00 2001 From: sarane Date: Mon, 2 Sep 2024 15:24:19 +0530 Subject: [PATCH 09/18] Updating client test --- riva/clients/asr/streaming_recognize_client_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/riva/clients/asr/streaming_recognize_client_test.cc b/riva/clients/asr/streaming_recognize_client_test.cc index 35d22ef..2ff2af4 100644 --- a/riva/clients/asr/streaming_recognize_client_test.cc +++ b/riva/clients/asr/streaming_recognize_client_test.cc @@ -20,7 +20,7 @@ TEST(StreamingRecognizeClient, num_responses_requests) StreamingRecognizeClient recognize_client( grpc_channel, 1, "en-US", 1, false, false, false, false, false, 800, false, "dummy.txt", - "dummy", true, true, "", 10., 10, 0.98, 10, 8, 0.98, 0.98); + "dummy", true, true, "", 10., 10, 0.98, 10, 8, 0.98, 0.98, "test_key:test_value"); std::shared_ptr call = std::make_shared(1, true); uint32_t num_sends = 10; From 2e0595dac28464339ba069a9e88d91202ab97a40 Mon Sep 17 00:00:00 2001 From: sarane Date: Mon, 9 Sep 2024 11:24:41 +0530 Subject: [PATCH 10/18] Getting VAD states --- riva/clients/asr/client_call.cc | 76 +++++++++++++++++++-------------- 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/riva/clients/asr/client_call.cc b/riva/clients/asr/client_call.cc index 9ceaaf0..46fa6b6 100644 --- a/riva/clients/asr/client_call.cc +++ b/riva/clients/asr/client_call.cc @@ -16,42 +16,56 @@ ClientCall::ClientCall(uint32_t corr_id, bool word_time_offsets) void ClientCall::AppendResult(const nr_asr::StreamingRecognitionResult& result) { - bool is_final = result.is_final(); - if (latest_result_.final_transcripts.size() < 1) { - latest_result_.final_transcripts.resize(1); - latest_result_.final_transcripts[0] = ""; - } - - if (is_final) { - int num_alternatives = result.alternatives_size(); - latest_result_.final_transcripts.resize(num_alternatives); - latest_result_.final_scores.resize(num_alternatives); - latest_result_.final_time_stamps.resize(num_alternatives); - for (int a = 0; a < num_alternatives; ++a) { - // Append to transcript - latest_result_.final_transcripts[a] += result.alternatives(a).transcript(); - latest_result_.final_scores[a] += result.alternatives(a).confidence(); + if (result.has_pipeline_states()) { + if (latest_result_.final_transcripts.size() == 0) { + latest_result_.final_transcripts.resize(1); + latest_result_.final_transcripts[0] = ""; + } + auto pipeline_states = result.pipeline_states(); + int prob_states_count = pipeline_states.vad_probabilities_size(); + std::string log = ""; + for (int i = 0; i < prob_states_count; i++) { + log += std::to_string(pipeline_states.vad_probabilities(i)) + " "; + } + VLOG(1) << "VAD states: " << log; + } else { + bool is_final = result.is_final(); + if (latest_result_.final_transcripts.size() < 1) { + latest_result_.final_transcripts.resize(1); + latest_result_.final_transcripts[0] = ""; } - VLOG(1) << "Final transcript: " << result.alternatives(0).transcript(); - if (word_time_offsets_) { - if (num_alternatives > 0) { - for (int a = 0; a < num_alternatives; ++a) { - for (int w = 0; w < result.alternatives(a).words_size(); ++w) { - latest_result_.final_time_stamps[a].push_back(result.alternatives(a).words(w)); + if (is_final) { + int num_alternatives = result.alternatives_size(); + latest_result_.final_transcripts.resize(num_alternatives); + latest_result_.final_scores.resize(num_alternatives); + latest_result_.final_time_stamps.resize(num_alternatives); + for (int a = 0; a < num_alternatives; ++a) { + // Append to transcript + latest_result_.final_transcripts[a] += result.alternatives(a).transcript(); + latest_result_.final_scores[a] += result.alternatives(a).confidence(); + } + VLOG(1) << "Final transcript: " << result.alternatives(0).transcript(); + + if (word_time_offsets_) { + if (num_alternatives > 0) { + for (int a = 0; a < num_alternatives; ++a) { + for (int w = 0; w < result.alternatives(a).words_size(); ++w) { + latest_result_.final_time_stamps[a].push_back(result.alternatives(a).words(w)); + } } } } - } - } else { - if (result.alternatives_size() > 0) { - if (result.stability() == 1) { - VLOG(1) << "Intermediate transcript: " << result.alternatives(0).transcript(); - } else { - latest_result_.partial_transcript += result.alternatives(0).transcript(); - if (word_time_offsets_) { - for (int w = 0; w < result.alternatives(0).words_size(); ++w) { - latest_result_.partial_time_stamps.emplace_back(result.alternatives(0).words(w)); + } else { + if (result.alternatives_size() > 0) { + if (result.stability() == 1) { + VLOG(1) << "Intermediate transcript: " << result.alternatives(0).transcript(); + } else { + latest_result_.partial_transcript += result.alternatives(0).transcript(); + if (word_time_offsets_) { + for (int w = 0; w < result.alternatives(0).words_size(); ++w) { + latest_result_.partial_time_stamps.emplace_back(result.alternatives(0).words(w)); + } } } } From 033fa2c0339ae2d4087945c44278133771171a9f Mon Sep 17 00:00:00 2001 From: sarane Date: Wed, 23 Oct 2024 21:45:11 +0530 Subject: [PATCH 11/18] Updating common commit ID --- WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/WORKSPACE b/WORKSPACE index d09b312..945c50e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -70,7 +70,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/nvidia-riva/common.git", - commit = "a7d342c223106ccb84a812f189178afa6e68c123" + commit = "9b31412dc43a15740f5f55a97cbd8c3eb5b43d86" ) http_archive( From 75ce170b42c20da7c9322d12c11c8b34a7513815 Mon Sep 17 00:00:00 2001 From: sarane Date: Tue, 29 Oct 2024 07:15:56 +0530 Subject: [PATCH 12/18] Dumping vad prob states to log file --- riva/clients/asr/client_call.cc | 14 +++++++++++--- riva/clients/asr/client_call.h | 2 ++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/riva/clients/asr/client_call.cc b/riva/clients/asr/client_call.cc index 46fa6b6..8edeebc 100644 --- a/riva/clients/asr/client_call.cc +++ b/riva/clients/asr/client_call.cc @@ -13,6 +13,11 @@ ClientCall::ClientCall(uint32_t corr_id, bool word_time_offsets) recv_final_flags.reserve(1000); } +ClientCall::~ClientCall(){ + if (pipeline_states_logs_) + pipeline_states_logs_.close(); +} + void ClientCall::AppendResult(const nr_asr::StreamingRecognitionResult& result) { @@ -23,11 +28,14 @@ ClientCall::AppendResult(const nr_asr::StreamingRecognitionResult& result) } auto pipeline_states = result.pipeline_states(); int prob_states_count = pipeline_states.vad_probabilities_size(); - std::string log = ""; + std::string vad_log = ""; for (int i = 0; i < prob_states_count; i++) { - log += std::to_string(pipeline_states.vad_probabilities(i)) + " "; + vad_log += std::to_string(pipeline_states.vad_probabilities(i)) + " "; + } + if(!pipeline_states_logs_){ + pipeline_states_logs_.open("riva_asr_pipeline_states.log"); } - VLOG(1) << "VAD states: " << log; + pipeline_states_logs_ << "VAD states: " << vad_log << std::endl; } else { bool is_final = result.is_final(); if (latest_result_.final_transcripts.size() < 1) { diff --git a/riva/clients/asr/client_call.h b/riva/clients/asr/client_call.h index 7d9a92f..104a6eb 100644 --- a/riva/clients/asr/client_call.h +++ b/riva/clients/asr/client_call.h @@ -37,6 +37,7 @@ namespace nr_asr = nvidia::riva::asr; class ClientCall { public: ClientCall(uint32_t _corr_id, bool word_time_offsets); + ~ClientCall(); void AppendResult(const nr_asr::StreamingRecognitionResult& result); @@ -66,5 +67,6 @@ class ClientCall { std::vector recv_final_flags; grpc::Status finish_status; + std::ofstream pipeline_states_logs_; }; // ClientCall From 8cc9422d2fe5b959bab6cf3c085bd8a585861931 Mon Sep 17 00:00:00 2001 From: sarane Date: Wed, 30 Oct 2024 01:45:57 +0530 Subject: [PATCH 13/18] minor commit --- riva/clients/asr/client_call.cc | 4 ---- riva/clients/nmt/riva_nmt_t2t_client.cc | 1 + 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/riva/clients/asr/client_call.cc b/riva/clients/asr/client_call.cc index 8edeebc..fdfebf6 100644 --- a/riva/clients/asr/client_call.cc +++ b/riva/clients/asr/client_call.cc @@ -22,10 +22,6 @@ void ClientCall::AppendResult(const nr_asr::StreamingRecognitionResult& result) { if (result.has_pipeline_states()) { - if (latest_result_.final_transcripts.size() == 0) { - latest_result_.final_transcripts.resize(1); - latest_result_.final_transcripts[0] = ""; - } auto pipeline_states = result.pipeline_states(); int prob_states_count = pipeline_states.vad_probabilities_size(); std::string vad_log = ""; diff --git a/riva/clients/nmt/riva_nmt_t2t_client.cc b/riva/clients/nmt/riva_nmt_t2t_client.cc index 1f80054..68ee10c 100644 --- a/riva/clients/nmt/riva_nmt_t2t_client.cc +++ b/riva/clients/nmt/riva_nmt_t2t_client.cc @@ -200,6 +200,7 @@ main(int argc, char** argv) if (FLAGS_text != "") { nr_nmt::TranslateTextRequest request; nr_nmt::TranslateTextResponse response; + VLOG(1) << "Setting up t2t config."; request.set_model(FLAGS_model_name); request.set_source_language(FLAGS_source_language_code); request.set_target_language(FLAGS_target_language_code); From cc2641566669340ab4709f04845931eec765e3b4 Mon Sep 17 00:00:00 2001 From: sarane Date: Mon, 4 Nov 2024 13:00:10 +0530 Subject: [PATCH 14/18] Updating ASR proto --- WORKSPACE | 2 +- riva/clients/asr/client_call.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 945c50e..795504e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -70,7 +70,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/nvidia-riva/common.git", - commit = "9b31412dc43a15740f5f55a97cbd8c3eb5b43d86" + commit = "75afa27124dde29a5e892e6817d9841c6cbc49e7" ) http_archive( diff --git a/riva/clients/asr/client_call.cc b/riva/clients/asr/client_call.cc index fdfebf6..a719f92 100644 --- a/riva/clients/asr/client_call.cc +++ b/riva/clients/asr/client_call.cc @@ -13,7 +13,7 @@ ClientCall::ClientCall(uint32_t corr_id, bool word_time_offsets) recv_final_flags.reserve(1000); } -ClientCall::~ClientCall(){ +ClientCall::~ClientCall() { if (pipeline_states_logs_) pipeline_states_logs_.close(); } From 78ad3a334472bf8957755b699d497d281268c660 Mon Sep 17 00:00:00 2001 From: sarane Date: Fri, 8 Nov 2024 11:22:39 +0530 Subject: [PATCH 15/18] Formatting --- riva/clients/asr/client_call.cc | 5 +++-- riva/clients/asr/riva_asr_client.cc | 19 +++++++++++-------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/riva/clients/asr/client_call.cc b/riva/clients/asr/client_call.cc index a719f92..1411a04 100644 --- a/riva/clients/asr/client_call.cc +++ b/riva/clients/asr/client_call.cc @@ -13,7 +13,8 @@ ClientCall::ClientCall(uint32_t corr_id, bool word_time_offsets) recv_final_flags.reserve(1000); } -ClientCall::~ClientCall() { +ClientCall::~ClientCall() +{ if (pipeline_states_logs_) pipeline_states_logs_.close(); } @@ -28,7 +29,7 @@ ClientCall::AppendResult(const nr_asr::StreamingRecognitionResult& result) for (int i = 0; i < prob_states_count; i++) { vad_log += std::to_string(pipeline_states.vad_probabilities(i)) + " "; } - if(!pipeline_states_logs_){ + if (!pipeline_states_logs_) { pipeline_states_logs_.open("riva_asr_pipeline_states.log"); } pipeline_states_logs_ << "VAD states: " << vad_log << std::endl; diff --git a/riva/clients/asr/riva_asr_client.cc b/riva/clients/asr/riva_asr_client.cc index f10ca38..d6137f9 100644 --- a/riva/clients/asr/riva_asr_client.cc +++ b/riva/clients/asr/riva_asr_client.cc @@ -67,7 +67,9 @@ DEFINE_bool( "Whether to use SSL credentials or not. If ssl_cert is specified, " "this is assumed to be true"); DEFINE_bool(speaker_diarization, false, "Flag that controls if speaker diarization is requested"); -DEFINE_int32(diarization_max_speakers, 3, "Max number of speakers to detect when performing speaker diarization"); +DEFINE_int32( + diarization_max_speakers, 3, + "Max number of speakers to detect when performing speaker diarization"); DEFINE_string(metadata, "", "Comma separated key-value pair(s) of metadata to be sent to server"); DEFINE_int32(start_history, -1, "Value to detect and initiate start of speech utterance"); DEFINE_double( @@ -92,14 +94,15 @@ class RecognizeClient { bool automatic_punctuation, bool separate_recognition_per_channel, bool print_transcripts, std::string output_filename, std::string model_name, bool ctm, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, - bool speaker_diarization, int32_t diarization_max_speakers, int32_t start_history, float start_threshold, int32_t stop_history, - int32_t stop_history_eou, float stop_threshold, float stop_threshold_eou, - std::string custom_configuration) + bool speaker_diarization, int32_t diarization_max_speakers, int32_t start_history, + float start_threshold, int32_t stop_history, int32_t stop_history_eou, float stop_threshold, + float stop_threshold_eou, std::string custom_configuration) : stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), automatic_punctuation_(automatic_punctuation), separate_recognition_per_channel_(separate_recognition_per_channel), - speaker_diarization_(speaker_diarization), diarization_max_speakers_(diarization_max_speakers), print_transcripts_(print_transcripts), + speaker_diarization_(speaker_diarization), + diarization_max_speakers_(diarization_max_speakers), print_transcripts_(print_transcripts), done_sending_(false), num_requests_(0), num_responses_(0), num_failed_requests_(0), total_audio_processed_(0.), model_name_(model_name), output_filename_(output_filename), verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), @@ -533,9 +536,9 @@ main(int argc, char** argv) FLAGS_word_time_offsets, FLAGS_automatic_punctuation, /* separate_recognition_per_channel*/ false, FLAGS_print_transcripts, FLAGS_output_filename, FLAGS_model_name, FLAGS_output_ctm, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, - (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_diarization_max_speakers, FLAGS_start_history, - FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, FLAGS_stop_threshold, - FLAGS_stop_threshold_eou, FLAGS_custom_configuration); + (float)FLAGS_boosted_words_score, FLAGS_speaker_diarization, FLAGS_diarization_max_speakers, + FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, + FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration); // Preload all wav files, sort by size to reduce tail effects std::vector> all_wav; From 2cd2e8d57d1895ad6f94ad71ba47ef99dac5a47c Mon Sep 17 00:00:00 2001 From: sarane Date: Fri, 8 Nov 2024 15:27:39 +0530 Subject: [PATCH 16/18] minor updates --- riva/clients/asr/client_call.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/riva/clients/asr/client_call.cc b/riva/clients/asr/client_call.cc index 1411a04..2a5c8a3 100644 --- a/riva/clients/asr/client_call.cc +++ b/riva/clients/asr/client_call.cc @@ -15,31 +15,31 @@ ClientCall::ClientCall(uint32_t corr_id, bool word_time_offsets) ClientCall::~ClientCall() { - if (pipeline_states_logs_) + if (pipeline_states_logs_.is_open()){ pipeline_states_logs_.close(); + } } void ClientCall::AppendResult(const nr_asr::StreamingRecognitionResult& result) { + if (latest_result_.final_transcripts.size() < 1) { + latest_result_.final_transcripts.resize(1); + latest_result_.final_transcripts[0] = ""; + } if (result.has_pipeline_states()) { auto pipeline_states = result.pipeline_states(); - int prob_states_count = pipeline_states.vad_probabilities_size(); + size_t prob_states_count = pipeline_states.vad_probabilities_size(); std::string vad_log = ""; - for (int i = 0; i < prob_states_count; i++) { + for (size_t i = 0; i < prob_states_count; i++) { vad_log += std::to_string(pipeline_states.vad_probabilities(i)) + " "; } - if (!pipeline_states_logs_) { + if (!pipeline_states_logs_.is_open()) { pipeline_states_logs_.open("riva_asr_pipeline_states.log"); } pipeline_states_logs_ << "VAD states: " << vad_log << std::endl; } else { bool is_final = result.is_final(); - if (latest_result_.final_transcripts.size() < 1) { - latest_result_.final_transcripts.resize(1); - latest_result_.final_transcripts[0] = ""; - } - if (is_final) { int num_alternatives = result.alternatives_size(); latest_result_.final_transcripts.resize(num_alternatives); From a7ebd5f99b8016b3550785e06983d8da81b5481c Mon Sep 17 00:00:00 2001 From: sarane Date: Fri, 8 Nov 2024 15:28:51 +0530 Subject: [PATCH 17/18] code formating --- riva/clients/asr/client_call.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/riva/clients/asr/client_call.cc b/riva/clients/asr/client_call.cc index 2a5c8a3..37b2c93 100644 --- a/riva/clients/asr/client_call.cc +++ b/riva/clients/asr/client_call.cc @@ -15,7 +15,7 @@ ClientCall::ClientCall(uint32_t corr_id, bool word_time_offsets) ClientCall::~ClientCall() { - if (pipeline_states_logs_.is_open()){ + if (pipeline_states_logs_.is_open()) { pipeline_states_logs_.close(); } } From cb3d8c3b7876125e2727f45c9a094537125e9e96 Mon Sep 17 00:00:00 2001 From: sarane Date: Fri, 8 Nov 2024 17:34:11 +0530 Subject: [PATCH 18/18] Updating proto's commit ID --- WORKSPACE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/WORKSPACE b/WORKSPACE index 795504e..5f62bf9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -70,7 +70,7 @@ grpc_extra_deps() git_repository( name = "nvriva_common", remote = "https://github.com/nvidia-riva/common.git", - commit = "75afa27124dde29a5e892e6817d9841c6cbc49e7" + commit = "d73d7c13d3e291aace10f619a8f0e6fc6be78156" ) http_archive(