Skip to content

Commit

Permalink
Passing VAD states at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
sarane22 committed Oct 25, 2024
1 parent 033fa2c commit b8b9bc1
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 6 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ grpc_extra_deps()

git_repository(
name = "nvriva_common",
remote = "https://github.com/nvidia-riva/common.git",
commit = "9b31412dc43a15740f5f55a97cbd8c3eb5b43d86"
remote = "https://github.com/sarane22/common.git",
commit = "c15d5f02e8aee4b01daebc0d3a09f71aa9aa40c0"
)

http_archive(
Expand Down
20 changes: 19 additions & 1 deletion riva/clients/asr/riva_streaming_asr_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ DEFINE_int32(
DEFINE_double(
stop_threshold_eou, -1.,
"Threshold value for likelihood of blanks before detecting end of utterance");
DEFINE_double(
offset, -1.,
"VAD offset");
DEFINE_double(
onset, -1.,
"VAD onset");
DEFINE_double(
pad_offset, -1.,
"VAD pad_offset");
DEFINE_double(
pad_onset, -1.,
"VAD pad_onset");
DEFINE_double(
min_duration_off, -1.,
"VAD min_duration_off");
DEFINE_double(
min_duration_on, -1.,
"VAD min_duration_on");
DEFINE_string(
custom_configuration, "",
"Custom configurations to be sent to the server as key value pairs <key:value,key:value,...>");
Expand Down Expand Up @@ -210,7 +228,7 @@ main(int argc, char** argv)
FLAGS_interim_results, FLAGS_output_filename, FLAGS_model_name, FLAGS_simulate_realtime,
FLAGS_verbatim_transcripts, FLAGS_boosted_words_file, FLAGS_boosted_words_score,
FLAGS_start_history, FLAGS_start_threshold, FLAGS_stop_history, FLAGS_stop_history_eou,
FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration);
FLAGS_stop_threshold, FLAGS_stop_threshold_eou, FLAGS_custom_configuration, FLAGS_offset, FLAGS_onset, FLAGS_pad_offset, FLAGS_pad_onset, FLAGS_min_duration_off, FLAGS_min_duration_on);

if (FLAGS_audio_file.size()) {
return recognize_client.DoStreamingFromFile(
Expand Down
38 changes: 36 additions & 2 deletions riva/clients/asr/streaming_recognize_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ StreamingRecognizeClient::StreamingRecognizeClient(
std::string output_filename, std::string model_name, bool simulate_realtime,
bool verbatim_transcripts, const std::string& boosted_phrases_file, float boosted_phrases_score,
int32_t start_history, float start_threshold, int32_t stop_history, int32_t stop_history_eou,
float stop_threshold, float stop_threshold_eou, std::string custom_configuration)
float stop_threshold, float stop_threshold_eou, std::string custom_configuration, float offset, float onset, float pad_offset, float pad_onset, float min_duration_off, float min_duration_on)
: print_latency_stats_(true), stub_(nr_asr::RivaSpeechRecognition::NewStub(channel)),
language_code_(language_code), max_alternatives_(max_alternatives),
profanity_filter_(profanity_filter), word_time_offsets_(word_time_offsets),
Expand All @@ -71,7 +71,7 @@ StreamingRecognizeClient::StreamingRecognizeClient(
verbatim_transcripts_(verbatim_transcripts), boosted_phrases_score_(boosted_phrases_score),
start_history_(start_history), start_threshold_(start_threshold), stop_history_(stop_history),
stop_history_eou_(stop_history_eou), stop_threshold_(stop_threshold),
stop_threshold_eou_(stop_threshold_eou), custom_configuration_(custom_configuration)
stop_threshold_eou_(stop_threshold_eou), custom_configuration_(custom_configuration), onset_(onset), offset_(offset), pad_onset_(pad_onset), pad_offset_(pad_offset), min_duration_off_(min_duration_off), min_duration_on_(min_duration_on)
{
num_active_streams_.store(0);
num_streams_finished_.store(0);
Expand Down Expand Up @@ -141,6 +141,39 @@ StreamingRecognizeClient::UpdateEndpointingConfig(nr_asr::RecognitionConfig* con
}
}


void
StreamingRecognizeClient::UpdateVADConfig(nr_asr::RecognitionConfig* config)
{
if (!(offset_ > 0 || onset_ > 0 || pad_onset_ > 0 || pad_offset_ > 0 ||
min_duration_off_ > 0 || min_duration_on_ > 0)) {
return;
}
// Set the endpoint parameters
// Get a mutable reference to the Endpointing config message
auto* vad_config = config->mutable_vad_config();

if (offset_ > 0) {
vad_config->set_offset(offset_);
}
if (onset_ > 0) {
vad_config->set_onset(onset_);
}
if (pad_onset_ > 0) {
vad_config->set_pad_onset(pad_onset_);
}
if (pad_offset_ > 0) {
vad_config->set_pad_offset(pad_offset_);
}
if (min_duration_off_ > 0) {
vad_config->set_min_duration_off(min_duration_off_);
}
if (min_duration_on_ > 0) {
vad_config->set_min_duration_on(min_duration_on_);
}
}


void
StreamingRecognizeClient::GenerateRequests(std::shared_ptr<ClientCall> call)
{
Expand Down Expand Up @@ -181,6 +214,7 @@ StreamingRecognizeClient::GenerateRequests(std::shared_ptr<ClientCall> call)

// Set the endpoint parameters
UpdateEndpointingConfig(config);
UpdateVADConfig(config);

call->streamer->Write(request);
first_write = false;
Expand Down
9 changes: 8 additions & 1 deletion riva/clients/asr/streaming_recognize_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class StreamingRecognizeClient {
bool verbatim_transcripts, const std::string& boosted_phrases_file,
float boosted_phrases_score, int32_t start_history, float start_threshold,
int32_t stop_history, int32_t stop_history_eou, float stop_threshold,
float stop_threshold_eou, std::string custom_configuration);
float stop_threshold_eou, std::string custom_configuration, float offset, float onset, float pad_offset, float pad_onset, float min_duration_off, float min_duration_on);

~StreamingRecognizeClient();

Expand All @@ -62,6 +62,7 @@ class StreamingRecognizeClient {
void StartNewStream(std::unique_ptr<Stream> stream);

void UpdateEndpointingConfig(nr_asr::RecognitionConfig* config);
void UpdateVADConfig(nr_asr::RecognitionConfig* config);

void GenerateRequests(std::shared_ptr<ClientCall> call);

Expand Down Expand Up @@ -126,4 +127,10 @@ class StreamingRecognizeClient {
float stop_threshold_;
float stop_threshold_eou_;
std::string custom_configuration_;
float offset_;
float onset_;
float pad_offset_;
float pad_onset_;
float min_duration_off_;
float min_duration_on_;
};

0 comments on commit b8b9bc1

Please sign in to comment.