diff --git a/src/transcription-filter.cpp b/src/transcription-filter.cpp index d8c8ef2..d2b2d44 100644 --- a/src/transcription-filter.cpp +++ b/src/transcription-filter.cpp @@ -193,6 +193,40 @@ void acquire_weak_text_source_ref(struct transcription_filter_data *gf) } } +#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) +#define is_trail_byte(c) (((c)&0xc0) == 0x80) + +inline int lead_byte_length(const uint8_t c) +{ + if ((c & 0xe0) == 0xc0) { + return 2; + } else if ((c & 0xf0) == 0xe0) { + return 3; + } else if ((c & 0xf8) == 0xf0) { + return 4; + } else { + return 1; + } +} + +inline bool is_valid_lead_byte(const uint8_t *c) +{ + const int length = lead_byte_length(c[0]); + if (length == 1) { + return true; + } + if (length == 2 && is_trail_byte(c[1])) { + return true; + } + if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { + return true; + } + if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { + return true; + } + return false; +} + void set_text_callback(struct transcription_filter_data *gf, const DetectionResultWithText &resultIn) { @@ -212,18 +246,39 @@ void set_text_callback(struct transcription_filter_data *gf, #ifdef _WIN32 // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs - // 0xf?, and 0xc? becomes 0xe?, so we need to replace it. - std::string str_copy = result.text; - for (size_t i = 0; i < str_copy.size(); ++i) { - // if the char MSBs starts with 0xf replace the MSBs with 0xd - if ((str_copy.c_str()[i] & 0xf0) == 0xf0) { - str_copy[i] = (str_copy.c_str()[i] & 0x0f) | 0xd0; - } - // if the char MSBs starts with 0xe replace the char with 0xc - if ((str_copy.c_str()[i] & 0xf0) == 0xe0) { - str_copy[i] = (str_copy.c_str()[i] & 0x0f) | 0xc0; + // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. + std::stringstream ss; + uint8_t *c_str = (uint8_t *)result.text.c_str(); + for (size_t i = 0; i < result.text.size(); ++i) { + if (is_lead_byte(c_str[i])) { + // this is a unicode leading byte + // if the next char is 0xff - it's a bug char, replace it with 0x9f + if (c_str[i + 1] == 0xff) { + c_str[i + 1] = 0x9f; + } + if (!is_valid_lead_byte(c_str + i)) { + // This is a bug lead byte, because it's length 3 and the i+2 byte is also + // a lead byte + c_str[i] = c_str[i] - 0x20; + } + } else { + if (c_str[i] >= 0xf8) { + // this may be a malformed lead byte. + // lets see if it becomes a valid lead byte if we "fix" it + uint8_t buf_[4]; + buf_[0] = c_str[i] - 0x20; + buf_[1] = c_str[i + 1]; + buf_[2] = c_str[i + 2]; + buf_[3] = c_str[i + 3]; + if (is_valid_lead_byte(buf_)) { + // this is a malformed lead byte, fix it + c_str[i] = c_str[i] - 0x20; + } + } } } + + std::string str_copy = (char *)c_str; #else std::string str_copy = result.text; #endif diff --git a/src/whisper-processing.cpp b/src/whisper-processing.cpp index 63eb357..844aff5 100644 --- a/src/whisper-processing.cpp +++ b/src/whisper-processing.cpp @@ -9,6 +9,11 @@ #include #include +#ifdef _WIN32 +#include +#include +#endif + #define VAD_THOLD 0.0001f #define FREQ_THOLD 100.0f @@ -73,7 +78,34 @@ bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float v struct whisper_context *init_whisper_context(const std::string &model_path) { obs_log(LOG_INFO, "Loading whisper model from %s", model_path.c_str()); + +#ifdef _WIN32 + // convert model path UTF8 to wstring (wchar_t) for whisper + int count = MultiByteToWideChar(CP_UTF8, 0, model_path.c_str(), (int)model_path.length(), + NULL, 0); + std::wstring model_path_ws(count, 0); + MultiByteToWideChar(CP_UTF8, 0, model_path.c_str(), (int)model_path.length(), + &model_path_ws[0], count); + + // Read model into buffer + std::ifstream modelFile(model_path_ws, std::ios::binary); + if (!modelFile.is_open()) { + obs_log(LOG_ERROR, "Failed to open whisper model file %s", model_path.c_str()); + return nullptr; + } + modelFile.seekg(0, std::ios::end); + const size_t modelFileSize = modelFile.tellg(); + modelFile.seekg(0, std::ios::beg); + std::vector modelBuffer(modelFileSize); + modelFile.read(modelBuffer.data(), modelFileSize); + modelFile.close(); + + // Initialize whisper + struct whisper_context *ctx = whisper_init_from_buffer(modelBuffer.data(), modelFileSize); +#else struct whisper_context *ctx = whisper_init_from_file(model_path.c_str()); +#endif + if (ctx == nullptr) { obs_log(LOG_ERROR, "Failed to load whisper model"); return nullptr;