From b8b9bc11097dfc63b143a8859510bdeac92b3e7b Mon Sep 17 00:00:00 2001 From: sarane Date: Fri, 25 Oct 2024 19:04:23 +0530 Subject: [PATCH] Passing VAD states at runtime --- WORKSPACE | 4 +- riva/clients/asr/riva_streaming_asr_client.cc | 20 +++++++++- .../clients/asr/streaming_recognize_client.cc | 38 ++++++++++++++++++- riva/clients/asr/streaming_recognize_client.h | 9 ++++- 4 files changed, 65 insertions(+), 6 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 945c50e..c7399cb 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -69,8 +69,8 @@ grpc_extra_deps() git_repository( name = "nvriva_common", - remote = "https://github.com/nvidia-riva/common.git", - commit = "9b31412dc43a15740f5f55a97cbd8c3eb5b43d86" + remote = "https://github.com/sarane22/common.git", + commit = "c15d5f02e8aee4b01daebc0d3a09f71aa9aa40c0" ) http_archive( diff --git a/riva/clients/asr/riva_streaming_asr_client.cc b/riva/clients/asr/riva_streaming_asr_client.cc index 4630270..9751cd4 100644 --- a/riva/clients/asr/riva_streaming_asr_client.cc +++ b/riva/clients/asr/riva_streaming_asr_client.cc @@ -90,6 +90,24 @@ DEFINE_int32( DEFINE_double( stop_threshold_eou, -1., "Threshold value for likelihood of blanks before detecting end of utterance"); +DEFINE_double( + offset, -1., + "VAD offset"); +DEFINE_double( + onset, -1., + "VAD onset"); +DEFINE_double( + pad_offset, -1., + "VAD pad_offset"); +DEFINE_double( + pad_onset, -1., + "VAD pad_onset"); +DEFINE_double( + min_duration_off, -1., + "VAD min_duration_off"); +DEFINE_double( + min_duration_on, -1., + "VAD min_duration_on"); DEFINE_string( custom_configuration, "", "Custom configurations to be sent to the server as key value pairs "); @@ -210,7 +228,7 @@ main(int argc, char** argv) FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime, FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score, FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou, - FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration); + FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration, FLAGS_offset, FLAGS_onset, FLAGS_pad_offset, FLAGS_pad_onset, FLAGS_min_duration_off, FLAGS_min_duration_on); if (FLAGS_audio_file.size()) { return recognize_client.DoStreamingFromFile( diff --git a/riva/clients/asr/streaming_recognize_client.cc b/riva/clients/asr/streaming_recognize_client.cc index 1680714..6502208 100644 --- a/riva/clients/asr/streaming_recognize_client.cc +++ b/riva/clients/asr/streaming_recognize_client.cc @@ -59,7 +59,7 @@ StreamingRecognizeClient::StreamingRecognizeClient( std::string output_filename, std::string model_name, bool simulate_realtime, bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou, - float stop_threshold, float stop_threshold_eou, std::string custom_configuration) + float stop_threshold, float stop_threshold_eou, std::string custom_configuration, float offset, float onset, float pad_offset, float pad_onset, float min_duration_off, float min_duration_on) : print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)), language_code_(language_code), max_alternatives_(max_alternatives), profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets), @@ -71,7 +71,7 @@ StreamingRecognizeClient::StreamingRecognizeClient( verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score), start_history_(start_history), start_threshold_(start_threshold), stop_history_(stop_history), stop_history_eou_(stop_history_eou), stop_threshold_(stop_threshold), - stop_threshold_eou_(stop_threshold_eou), custom_configuration_(custom_configuration) + stop_threshold_eou_(stop_threshold_eou), custom_configuration_(custom_configuration), onset_(onset), offset_(offset), pad_onset_(pad_onset), pad_offset_(pad_offset), min_duration_off_(min_duration_off), min_duration_on_(min_duration_on) { num_active_streams_.store(0); num_streams_finished_.store(0); @@ -141,6 +141,39 @@ StreamingRecognizeClient::UpdateEndpointingConfig(nr_asr::RecognitionConfig* con } } + +void +StreamingRecognizeClient::UpdateVADConfig(nr_asr::RecognitionConfig* config) +{ + if (!(offset_ > 0 || onset_ > 0 || pad_onset_ > 0 || pad_offset_ > 0 || + min_duration_off_ > 0 || min_duration_on_ > 0)) { + return; + } + // Set the endpoint parameters + // Get a mutable reference to the Endpointing config message + auto* vad_config = config->mutable_vad_config(); + + if (offset_ > 0) { + vad_config->set_offset(offset_); + } + if (onset_ > 0) { + vad_config->set_onset(onset_); + } + if (pad_onset_ > 0) { + vad_config->set_pad_onset(pad_onset_); + } + if (pad_offset_ > 0) { + vad_config->set_pad_offset(pad_offset_); + } + if (min_duration_off_ > 0) { + vad_config->set_min_duration_off(min_duration_off_); + } + if (min_duration_on_ > 0) { + vad_config->set_min_duration_on(min_duration_on_); + } +} + + void StreamingRecognizeClient::GenerateRequests(std::shared_ptr call) { @@ -181,6 +214,7 @@ StreamingRecognizeClient::GenerateRequests(std::shared_ptr call) // Set the endpoint parameters UpdateEndpointingConfig(config); + UpdateVADConfig(config); call->streamer->Write(request); first_write = false; diff --git a/riva/clients/asr/streaming_recognize_client.h b/riva/clients/asr/streaming_recognize_client.h index 56d66bc..6c6575f 100644 --- a/riva/clients/asr/streaming_recognize_client.h +++ b/riva/clients/asr/streaming_recognize_client.h @@ -49,7 +49,7 @@ class StreamingRecognizeClient { bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score, int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou, float stop_threshold, - float stop_threshold_eou, std::string custom_configuration); + float stop_threshold_eou, std::string custom_configuration, float offset, float onset, float pad_offset, float pad_onset, float min_duration_off, float min_duration_on); ~StreamingRecognizeClient(); @@ -62,6 +62,7 @@ class StreamingRecognizeClient { void StartNewStream(std::unique_ptr stream); void UpdateEndpointingConfig(nr_asr::RecognitionConfig* config); + void UpdateVADConfig(nr_asr::RecognitionConfig* config); void GenerateRequests(std::shared_ptr call); @@ -126,4 +127,10 @@ class StreamingRecognizeClient { float stop_threshold_; float stop_threshold_eou_; std::string custom_configuration_; + float offset_; + float onset_; + float pad_offset_; + float pad_onset_; + float min_duration_off_; + float min_duration_on_; }; \ No newline at end of file