Skip to content

Commit

Permalink
Add whisper parameters handling and update related files for improved…
Browse files Browse the repository at this point in the history
… functionality
  • Loading branch information
royshil committed Nov 20, 2024
1 parent d50639e commit 9cd26c4
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 263 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ target_sources(
src/whisper-utils/whisper-processing.cpp
src/whisper-utils/whisper-utils.cpp
src/whisper-utils/whisper-model-utils.cpp
src/whisper-utils/whisper-params.cpp
src/whisper-utils/silero-vad-onnx.cpp
src/whisper-utils/token-buffer-thread.cpp
src/whisper-utils/vad-processing.cpp
Expand Down
2 changes: 0 additions & 2 deletions src/plugin-support.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ extern "C" {
extern const char *PLUGIN_NAME;
extern const char *PLUGIN_VERSION;

#define MT_ obs_module_text

void obs_log(int log_level, const char *format, ...);

#ifdef __cplusplus
Expand Down
2 changes: 1 addition & 1 deletion src/transcription-filter-callbacks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void send_sentence_to_file(struct transcription_filter_data *gf,
if (!gf->save_srt) {
obs_log(gf->log_level, "Saving sentence '%s' to file %s", str_copy.c_str(),
gf->output_file_path.c_str());
// Write raw sentence to file
// Write raw sentence to text file (non-srt format)
try {
std::ofstream output_file(gf->output_file_path, openmode);
output_file << str_copy << std::endl;
Expand Down
1 change: 1 addition & 0 deletions src/transcription-filter-properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "transcription-filter-utils.h"
#include "whisper-utils/whisper-language.h"
#include "whisper-utils/vad-processing.h"
#include "whisper-utils/whisper-params.h"
#include "model-utils/model-downloader-types.h"
#include "translation/language_codes.h"
#include "ui/filter-replace-dialog.h"
Expand Down
12 changes: 8 additions & 4 deletions src/transcription-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "whisper-utils/whisper-language.h"
#include "whisper-utils/whisper-model-utils.h"
#include "whisper-utils/whisper-utils.h"
#include "whisper-utils/whisper-params.h"
#include "translation/language_codes.h"
#include "translation/translation-utils.h"
#include "translation/translation.h"
Expand Down Expand Up @@ -357,17 +358,20 @@ void transcription_filter_update(void *data, obs_data_t *s)
if (!new_translate || gf->translation_model_index != "whisper-based-translation") {
const char *whisper_language_select =
obs_data_get_string(s, "whisper_language_select");
gf->whisper_params.language = (whisper_language_select != nullptr &&
strlen(whisper_language_select) > 0)
? whisper_language_select
: "auto";
const bool language_selected = whisper_language_select != nullptr &&
strlen(whisper_language_select) > 0;
gf->whisper_params.language = (language_selected) ? whisper_language_select
: "auto";
gf->whisper_params.detect_language = !language_selected;
} else {
// take the language from gf->target_lang
if (language_codes_to_whisper.count(gf->target_lang) > 0) {
gf->whisper_params.language =
language_codes_to_whisper[gf->target_lang].c_str();
gf->whisper_params.detect_language = false;
} else {
gf->whisper_params.language = "auto";
gf->whisper_params.detect_language = true;
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/transcription-filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
extern "C" {
#endif

#define MT_ obs_module_text

void transcription_filter_activate(void *data);
void *transcription_filter_create(obs_data_t *settings, obs_source_t *filter);
void transcription_filter_update(void *data, obs_data_t *s);
Expand Down
35 changes: 25 additions & 10 deletions src/whisper-utils/vad-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,15 @@ vad_state vad_disabled_segmentation(transcription_filter_data *gf, vad_state las
const int ret = get_data_from_buf_and_resample(gf, start_timestamp_offset_ns,
end_timestamp_offset_ns);
if (ret != 0) {
// if there's data on the whisper buffer - run inference as "final" segment
if (gf->whisper_buffer.size > 0) {
obs_log(gf->log_level,
"VAD disabled: no new input but whisper buffer has %lu bytes, run inference",
gf->whisper_buffer.size);
run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms,
last_vad_state.end_ts_offset_ms,
VAD_STATE_WAS_OFF);
}
return last_vad_state;
}

Expand All @@ -141,8 +150,8 @@ vad_state vad_disabled_segmentation(transcription_filter_data *gf, vad_state las
circlebuf_pop_front(&gf->resampled_buffer, nullptr, gf->resampled_buffer.size);

const uint64_t whisper_buf_samples = gf->whisper_buffer.size / sizeof(float);
const bool is_partial_segment = whisper_buf_samples <
gf->segment_duration * WHISPER_SAMPLE_RATE / 1000;
const bool is_partial_segment =
whisper_buf_samples < (uint64_t)(gf->segment_duration * WHISPER_SAMPLE_RATE / 1000);

#ifdef LOCALVOCAL_EXTRA_VERBOSE
obs_log(gf->log_level,
Expand All @@ -151,7 +160,6 @@ vad_state vad_disabled_segmentation(transcription_filter_data *gf, vad_state las
is_partial_segment ? "PARTIAL" : "OFF");
#endif

const uint64_t start_ts_offset_ms = start_timestamp_offset_ns / 1000000;
const uint64_t end_ts_offset_ms = end_timestamp_offset_ns / 1000000;

if (is_partial_segment) {
Expand All @@ -160,13 +168,20 @@ vad_state vad_disabled_segmentation(transcription_filter_data *gf, vad_state las
const uint64_t unprocessed_length_ms =
end_ts_offset_ms - last_vad_state.last_partial_segment_end_ts;
if (unprocessed_length_ms > (uint64_t)gf->partial_latency) {
obs_log(gf->log_level,
"VAD disabled: partial segment with %lu ms unprocessed audio. start %lu, end %lu",
unprocessed_length_ms, last_vad_state.start_ts_offest_ms,
end_ts_offset_ms);
// Send to inference
run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms,
end_ts_offset_ms, VAD_STATE_PARTIAL);
if (gf->partial_transcription) {
obs_log(gf->log_level,
"VAD disabled: partial segment with %lu ms unprocessed audio. start %lu, end %lu",
unprocessed_length_ms, last_vad_state.start_ts_offest_ms,
end_ts_offset_ms);
// Send to inference
run_inference_and_callbacks(gf, last_vad_state.start_ts_offest_ms,
end_ts_offset_ms, VAD_STATE_PARTIAL);
} else {
obs_log(gf->log_level,
"VAD disabled: partial segment with %lu ms unprocessed audio. start %lu, end %lu. Skipping.",
unprocessed_length_ms, last_vad_state.start_ts_offest_ms,
end_ts_offset_ms);
}
// update the last partial segment end timestamp
last_vad_state.last_partial_segment_end_ts = end_ts_offset_ms;
}
Expand Down
192 changes: 192 additions & 0 deletions src/whisper-utils/whisper-params.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#include "whisper-params.h"

#include <obs-module.h>

#define MT_ obs_module_text

void whisper_params_pretty_print(whisper_full_params &params)
{
obs_log(LOG_INFO, "Whisper params:");
obs_log(LOG_INFO, "strategy: %s",
params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH
? "beam_search"
: "greedy");
obs_log(LOG_INFO, "n_threads: %d", params.n_threads);
obs_log(LOG_INFO, "n_max_text_ctx: %d", params.n_max_text_ctx);
obs_log(LOG_INFO, "offset_ms: %d", params.offset_ms);
obs_log(LOG_INFO, "duration_ms: %d", params.duration_ms);
obs_log(LOG_INFO, "translate: %s", params.translate ? "true" : "false");
obs_log(LOG_INFO, "no_context: %s", params.no_context ? "true" : "false");
obs_log(LOG_INFO, "no_timestamps: %s", params.no_timestamps ? "true" : "false");
obs_log(LOG_INFO, "single_segment: %s", params.single_segment ? "true" : "false");
obs_log(LOG_INFO, "print_special: %s", params.print_special ? "true" : "false");
obs_log(LOG_INFO, "print_progress: %s", params.print_progress ? "true" : "false");
obs_log(LOG_INFO, "print_realtime: %s", params.print_realtime ? "true" : "false");
obs_log(LOG_INFO, "print_timestamps: %s", params.print_timestamps ? "true" : "false");
obs_log(LOG_INFO, "token_timestamps: %s", params.token_timestamps ? "true" : "false");
obs_log(LOG_INFO, "thold_pt: %f", params.thold_pt);
obs_log(LOG_INFO, "thold_ptsum: %f", params.thold_ptsum);
obs_log(LOG_INFO, "max_len: %d", params.max_len);
obs_log(LOG_INFO, "split_on_word: %s", params.split_on_word ? "true" : "false");
obs_log(LOG_INFO, "max_tokens: %d", params.max_tokens);
obs_log(LOG_INFO, "debug_mode: %s", params.debug_mode ? "true" : "false");
obs_log(LOG_INFO, "audio_ctx: %d", params.audio_ctx);
obs_log(LOG_INFO, "tdrz_enable: %s", params.tdrz_enable ? "true" : "false");
obs_log(LOG_INFO, "suppress_regex: %s", params.suppress_regex);
obs_log(LOG_INFO, "initial_prompt: %s", params.initial_prompt);
obs_log(LOG_INFO, "language: %s", params.language);
obs_log(LOG_INFO, "detect_language: %s", params.detect_language ? "true" : "false");
obs_log(LOG_INFO, "suppress_blank: %s", params.suppress_blank ? "true" : "false");
obs_log(LOG_INFO, "suppress_non_speech_tokens: %s",
params.suppress_non_speech_tokens ? "true" : "false");
obs_log(LOG_INFO, "temperature: %f", params.temperature);
obs_log(LOG_INFO, "max_initial_ts: %f", params.max_initial_ts);
obs_log(LOG_INFO, "length_penalty: %f", params.length_penalty);
obs_log(LOG_INFO, "temperature_inc: %f", params.temperature_inc);
obs_log(LOG_INFO, "entropy_thold: %f", params.entropy_thold);
obs_log(LOG_INFO, "logprob_thold: %f", params.logprob_thold);
obs_log(LOG_INFO, "no_speech_thold: %f", params.no_speech_thold);
obs_log(LOG_INFO, "greedy.best_of: %d", params.greedy.best_of);
obs_log(LOG_INFO, "beam_search.beam_size: %d", params.beam_search.beam_size);
obs_log(LOG_INFO, "beam_search.patience: %f", params.beam_search.patience);
}

void apply_whisper_params_defaults_on_settings(obs_data_t *s)
{
whisper_full_params whisper_params_tmp = whisper_full_default_params(
whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH);

obs_data_set_default_int(s, "strategy",
whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH);
obs_data_set_default_int(s, "n_threads", whisper_params_tmp.n_threads);
obs_data_set_default_int(s, "n_max_text_ctx", whisper_params_tmp.n_max_text_ctx);
obs_data_set_default_int(s, "offset_ms", whisper_params_tmp.offset_ms);
obs_data_set_default_int(s, "duration_ms", whisper_params_tmp.duration_ms);
obs_data_set_default_bool(s, "whisper_translate", whisper_params_tmp.translate);
obs_data_set_default_bool(s, "no_context", whisper_params_tmp.no_context);
obs_data_set_default_bool(s, "no_timestamps", whisper_params_tmp.no_timestamps);
obs_data_set_default_bool(s, "single_segment", whisper_params_tmp.single_segment);
obs_data_set_default_bool(s, "print_special", false);
obs_data_set_default_bool(s, "print_progress", false);
obs_data_set_default_bool(s, "print_realtime", false);
obs_data_set_default_bool(s, "print_timestamps", false);
obs_data_set_default_bool(s, "token_timestamps", whisper_params_tmp.token_timestamps);
obs_data_set_default_double(s, "thold_pt", whisper_params_tmp.thold_pt);
obs_data_set_default_double(s, "thold_ptsum", whisper_params_tmp.thold_ptsum);
obs_data_set_default_int(s, "max_len", whisper_params_tmp.max_len);
obs_data_set_default_bool(s, "split_on_word", whisper_params_tmp.split_on_word);
obs_data_set_default_int(s, "max_tokens", whisper_params_tmp.max_tokens);
obs_data_set_default_bool(s, "debug_mode", whisper_params_tmp.debug_mode);
obs_data_set_default_int(s, "audio_ctx", whisper_params_tmp.audio_ctx);
obs_data_set_default_bool(s, "tdrz_enable", whisper_params_tmp.tdrz_enable);
obs_data_set_default_string(s, "suppress_regex", whisper_params_tmp.suppress_regex);
obs_data_set_default_string(s, "initial_prompt", whisper_params_tmp.initial_prompt);
// obs_data_set_default_string(s, "language", whisper_params_tmp.language);
obs_data_set_default_bool(s, "detect_language", whisper_params_tmp.detect_language);
obs_data_set_default_bool(s, "suppress_blank", false);
obs_data_set_default_bool(s, "suppress_non_speech_tokens", false);
obs_data_set_default_double(s, "temperature", whisper_params_tmp.temperature);
obs_data_set_default_double(s, "max_initial_ts", whisper_params_tmp.max_initial_ts);
obs_data_set_default_double(s, "length_penalty", whisper_params_tmp.length_penalty);
obs_data_set_default_double(s, "temperature_inc", whisper_params_tmp.temperature_inc);
obs_data_set_default_double(s, "entropy_thold", whisper_params_tmp.entropy_thold);
obs_data_set_default_double(s, "logprob_thold", whisper_params_tmp.logprob_thold);
obs_data_set_default_double(s, "no_speech_thold", whisper_params_tmp.no_speech_thold);
obs_data_set_default_int(s, "greedy.best_of", whisper_params_tmp.greedy.best_of);
obs_data_set_default_int(s, "beam_search.beam_size",
whisper_params_tmp.beam_search.beam_size);
obs_data_set_default_double(s, "beam_search.patience",
whisper_params_tmp.beam_search.patience);
}

void apply_whisper_params_from_settings(whisper_full_params &params, obs_data_t *settings)
{
params = whisper_full_default_params(
(whisper_sampling_strategy)obs_data_get_int(settings, "strategy"));
params.n_threads = obs_data_get_int(settings, "n_threads");
params.n_max_text_ctx = obs_data_get_int(settings, "n_max_text_ctx");
params.offset_ms = obs_data_get_int(settings, "offset_ms");
params.duration_ms = obs_data_get_int(settings, "duration_ms");
params.translate = obs_data_get_bool(settings, "whisper_translate");
params.no_context = obs_data_get_bool(settings, "no_context");
params.no_timestamps = obs_data_get_bool(settings, "no_timestamps");
params.single_segment = obs_data_get_bool(settings, "single_segment");
params.print_special = obs_data_get_bool(settings, "print_special");
params.print_progress = obs_data_get_bool(settings, "print_progress");
params.print_realtime = obs_data_get_bool(settings, "print_realtime");
params.print_timestamps = obs_data_get_bool(settings, "print_timestamps");
params.token_timestamps = obs_data_get_bool(settings, "token_timestamps");
params.thold_pt = obs_data_get_double(settings, "thold_pt");
params.thold_ptsum = obs_data_get_double(settings, "thold_ptsum");
params.max_len = obs_data_get_int(settings, "max_len");
params.split_on_word = obs_data_get_bool(settings, "split_on_word");
params.max_tokens = obs_data_get_int(settings, "max_tokens");
params.debug_mode = obs_data_get_bool(settings, "debug_mode");
params.audio_ctx = obs_data_get_int(settings, "audio_ctx");
params.tdrz_enable = obs_data_get_bool(settings, "tdrz_enable");
params.suppress_regex = obs_data_get_string(settings, "suppress_regex");
params.initial_prompt = obs_data_get_string(settings, "initial_prompt");
// params.language = obs_data_get_string(settings, "language");
params.detect_language = obs_data_get_bool(settings, "detect_language");
params.suppress_blank = obs_data_get_bool(settings, "suppress_blank");
params.suppress_non_speech_tokens =
obs_data_get_bool(settings, "suppress_non_speech_tokens");
params.temperature = obs_data_get_double(settings, "temperature");
params.max_initial_ts = obs_data_get_double(settings, "max_initial_ts");
params.length_penalty = obs_data_get_double(settings, "length_penalty");
params.temperature_inc = obs_data_get_double(settings, "temperature_inc");
params.entropy_thold = obs_data_get_double(settings, "entropy_thold");
params.logprob_thold = obs_data_get_double(settings, "logprob_thold");
params.no_speech_thold = obs_data_get_double(settings, "no_speech_thold");
params.greedy.best_of = obs_data_get_int(settings, "greedy.best_of");
params.beam_search.beam_size = obs_data_get_int(settings, "beam_search.beam_size");
params.beam_search.patience = obs_data_get_double(settings, "beam_search.patience");
}

void add_whisper_params_group_properties(obs_properties_t *ppts)
{
obs_properties_t *g = obs_properties_create();
obs_properties_add_group(ppts, "whisper_params_group", MT_("whisper_parameters"),
OBS_GROUP_NORMAL, g);

obs_properties_add_list(g, "strategy", MT_("whisper_sampling_strategy"),
OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT);
obs_properties_add_int(g, "n_threads", MT_("n_threads"), 1, 8, 1);
obs_properties_add_int(g, "n_max_text_ctx", MT_("n_max_text_ctx"), 1, 100, 1);
obs_properties_add_int(g, "offset_ms", MT_("offset_ms"), 0, 10000, 100);
obs_properties_add_int(g, "duration_ms", MT_("duration_ms"), 0, 30000, 500);
obs_properties_add_bool(g, "whisper_translate", MT_("whisper_translate"));
obs_properties_add_bool(g, "no_context", MT_("no_context"));
obs_properties_add_bool(g, "no_timestamps", MT_("no_timestamps"));
obs_properties_add_bool(g, "single_segment", MT_("single_segment"));
obs_properties_add_bool(g, "print_special", MT_("print_special"));
obs_properties_add_bool(g, "print_progress", MT_("print_progress"));
obs_properties_add_bool(g, "print_realtime", MT_("print_realtime"));
obs_properties_add_bool(g, "print_timestamps", MT_("print_timestamps"));
obs_properties_add_bool(g, "token_timestamps", MT_("token_timestamps"));
obs_properties_add_float(g, "thold_pt", MT_("thold_pt"), 0, 1, 0.05);
obs_properties_add_float(g, "thold_ptsum", MT_("thold_ptsum"), 0, 1, 0.05);
obs_properties_add_int(g, "max_len", MT_("max_len"), 0, 1000, 1);
obs_properties_add_bool(g, "split_on_word", MT_("split_on_word"));
obs_properties_add_int(g, "max_tokens", MT_("max_tokens"), 0, 1000, 1);
obs_properties_add_bool(g, "debug_mode", MT_("debug_mode"));
obs_properties_add_int(g, "audio_ctx", MT_("audio_ctx"), 0, 10, 1);
obs_properties_add_bool(g, "tdrz_enable", MT_("tdrz_enable"));
obs_properties_add_text(g, "suppress_regex", MT_("suppress_regex"), OBS_TEXT_DEFAULT);
obs_properties_add_text(g, "initial_prompt", MT_("initial_prompt"), OBS_TEXT_DEFAULT);
// obs_properties_add_text(g, "language", MT_("language"), OBS_TEXT_DEFAULT);
obs_properties_add_bool(g, "detect_language", MT_("detect_language"));
obs_properties_add_bool(g, "suppress_blank", MT_("suppress_blank"));
obs_properties_add_bool(g, "suppress_non_speech_tokens", MT_("suppress_non_speech_tokens"));
obs_properties_add_float(g, "temperature", MT_("temperature"), 0, 1, 0.05);
obs_properties_add_float(g, "max_initial_ts", MT_("max_initial_ts"), 0, 100, 1);
obs_properties_add_float(g, "length_penalty", MT_("length_penalty"), 0, 1, 0.05);
obs_properties_add_float(g, "temperature_inc", MT_("temperature_inc"), 0, 1, 0.05);
obs_properties_add_float(g, "entropy_thold", MT_("entropy_thold"), 0, 1, 0.05);
obs_properties_add_float(g, "logprob_thold", MT_("logprob_thold"), 0, 1, 0.05);
obs_properties_add_float(g, "no_speech_thold", MT_("no_speech_thold"), 0, 1, 0.05);
obs_properties_add_int(g, "greedy.best_of", MT_("greedy.best_of"), 1, 10, 1);
obs_properties_add_int(g, "beam_search.beam_size", MT_("beam_search.beam_size"), 1, 10, 1);
obs_properties_add_float(g, "beam_search.patience", MT_("beam_search.patience"), 0, 1,
0.05);
}
Loading

0 comments on commit 9cd26c4

Please sign in to comment.