From 94bafcdcedee1f721cd8dde6a5c35459a611672e Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 18 Mar 2024 16:44:03 -0400 Subject: [PATCH 1/7] Add status label to TranslateDockWidget --- .github/scripts/.build.zsh | 4 +- .github/scripts/Package-Windows.ps1 | 19 +- CMakeLists.txt | 10 +- cmake/BuildCTranslate2.cmake | 83 +++----- cmake/BuildMyCurl.cmake | 73 +++++++ cmake/BuildSentencepiece.cmake | 15 +- cmake/common/compiler_common.cmake | 2 +- cmake/macos/xcode.cmake | 2 +- src/model-utils/model-downloader-types.h | 23 ++ src/model-utils/model-downloader-ui.cpp | 257 +++++++++++++++++++++++ src/model-utils/model-downloader-ui.h | 63 ++++++ src/model-utils/model-downloader.cpp | 86 ++++++++ src/model-utils/model-downloader.h | 15 ++ src/plugin-main.c | 4 +- src/translation-service/httpserver.cpp | 5 + src/translation-service/translation.cpp | 10 +- src/ui/registerDock.cpp | 4 + src/ui/settingsdialog.cpp | 72 ++++++- src/ui/settingsdialog.ui | 63 +++++- src/ui/translatedockwidget.cpp | 10 + src/ui/translatedockwidget.h | 1 + src/ui/translatedockwidget.ui | 7 + src/utils/config-data.cpp | 13 +- src/utils/config-data.h | 6 + 24 files changed, 756 insertions(+), 91 deletions(-) create mode 100644 cmake/BuildMyCurl.cmake create mode 100644 src/model-utils/model-downloader-types.h create mode 100644 src/model-utils/model-downloader-ui.cpp create mode 100644 src/model-utils/model-downloader-ui.h create mode 100644 src/model-utils/model-downloader.cpp create mode 100644 src/model-utils/model-downloader.h diff --git a/.github/scripts/.build.zsh b/.github/scripts/.build.zsh index c6fe52a..708fdcf 100755 --- a/.github/scripts/.build.zsh +++ b/.github/scripts/.build.zsh @@ -230,8 +230,7 @@ ${_usage_host:-}" -DCODESIGN_IDENTITY=${CODESIGN_IDENT:--} ) - # TODO: enable -arch arm64 - cmake_build_args+=(--preset ${_preset} --parallel --config ${config} -- ONLY_ACTIVE_ARCH=NO -arch x86_64 -arch arm64) + cmake_build_args+=(--preset ${_preset} --parallel --config ${config} -- ONLY_ACTIVE_ARCH=NO -arch arm64 -arch x86_64) cmake_install_args+=(build_macos --config ${config} --prefix "${project_root}/release/${config}") local -a xcbeautify_opts=() @@ -243,6 +242,7 @@ ${_usage_host:-}" -G "${generator}" -DQT_VERSION=${QT_VERSION:-6} -DCMAKE_BUILD_TYPE=${config} + -DCMAKE_INSTALL_PREFIX=/usr ) local cmake_version diff --git a/.github/scripts/Package-Windows.ps1 b/.github/scripts/Package-Windows.ps1 index 819fe54..52317be 100644 --- a/.github/scripts/Package-Windows.ps1 +++ b/.github/scripts/Package-Windows.ps1 @@ -79,17 +79,18 @@ function Package { Invoke-External iscc ${IsccFile} /O"${ProjectRoot}/release" /F"${OutputName}-Installer" Remove-Item -Path Package -Recurse Pop-Location -Stack BuildTemp - } else { - Log-Group "Archiving ${ProductName}..." - $CompressArgs = @{ - Path = (Get-ChildItem -Path "${ProjectRoot}/release/${Configuration}" -Exclude "${OutputName}*.*") - CompressionLevel = 'Optimal' - DestinationPath = "${ProjectRoot}/release/${OutputName}.zip" - Verbose = ($Env:CI -ne $null) - } + } - Compress-Archive -Force @CompressArgs + Log-Group "Archiving ${ProductName}..." + $CompressArgs = @{ + Path = (Get-ChildItem -Path "${ProjectRoot}/release/${Configuration}" -Exclude "${OutputName}*.*") + CompressionLevel = 'Optimal' + DestinationPath = "${ProjectRoot}/release/${OutputName}.zip" + Verbose = ($Env:CI -ne $null) } + + Compress-Archive -Force @CompressArgs + Log-Group } diff --git a/CMakeLists.txt b/CMakeLists.txt index afd099a..650929b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -19,7 +19,8 @@ target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE OBS::libobs) include(cmake/BuildCTranslate2.cmake) include(cmake/BuildSentencepiece.cmake) include(cmake/BuildCppHTTPLib.cmake) -target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ct2 sentencepiece cpphttplib) +include(cmake/BuildMyCurl.cmake) +target_link_libraries(${CMAKE_PROJECT_NAME} PRIVATE ct2 sentencepiece cpphttplib libcurl) if(ENABLE_FRONTEND_API) find_package(obs-frontend-api REQUIRED) @@ -40,7 +41,10 @@ if(ENABLE_QT) endif() add_subdirectory(src/ui) -target_sources(${CMAKE_PROJECT_NAME} PRIVATE src/translation-service/translation.cpp src/plugin-main.c - src/utils/config-data.cpp src/translation-service/httpserver.cpp) +target_sources( + ${CMAKE_PROJECT_NAME} + PRIVATE src/translation-service/translation.cpp src/plugin-main.c src/utils/config-data.cpp + src/translation-service/httpserver.cpp src/model-utils/model-downloader.cpp + src/model-utils/model-downloader-ui.cpp) set_target_properties_plugin(${CMAKE_PROJECT_NAME} PROPERTIES OUTPUT_NAME ${_name}) diff --git a/cmake/BuildCTranslate2.cmake b/cmake/BuildCTranslate2.cmake index 6ac33cb..df37efb 100644 --- a/cmake/BuildCTranslate2.cmake +++ b/cmake/BuildCTranslate2.cmake @@ -7,8 +7,8 @@ if(APPLE) FetchContent_Declare( ctranslate2_fetch - URL https://github.com/obs-ai/obs-ai-ctranslate2-dep/releases/download/1.0.0/libctranslate2-macos-Release-1.0.0.tar.gz - URL_HASH SHA256=8e55a6ed4fb17ac556ad0e020ddab619584e3ceb4c9497a816f819bd8fd36443) + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/libctranslate2-macos-Release-1.1.0.tar.gz + URL_HASH SHA256=dba2eaa1b3f4e9eb1e8999e668d515aa94b115af07565e2b6797b9eda6f2f845) FetchContent_MakeAvailable(ctranslate2_fetch) add_library(ct2 INTERFACE) @@ -17,45 +17,26 @@ if(APPLE) set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include) target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32) -else() - set(CT2_VERSION "3.20.0") - set(CT2_URL "https://github.com/OpenNMT/CTranslate2.git") +elseif(WIN32) - if(WIN32) - # Build with OpenBLAS - - set(OpenBLAS_URL "https://github.com/xianyi/OpenBLAS/releases/download/v0.3.24/OpenBLAS-0.3.24-x64.zip") - set(OpenBLAS_SHA256 "6335128ee7117ea2dd2f5f96f76dafc17256c85992637189a2d5f6da0c608163") - FetchContent_Declare( - openblas_fetch - URL ${OpenBLAS_URL} - URL_HASH SHA256=${OpenBLAS_SHA256}) - FetchContent_MakeAvailable(openblas_fetch) - set(OpenBLAS_DIR ${openblas_fetch_SOURCE_DIR}) - set(OPENBLAS_INCLUDE_DIR ${OpenBLAS_DIR}/include) + FetchContent_Declare( + ctranslate2_fetch + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/libctranslate2-windows-4.1.1-Release.zip + URL_HASH SHA256=683023e9c76ac6d54e54d14c32d86c020d5486ba289b60e5336f7cc86b984d03) + FetchContent_MakeAvailable(ctranslate2_fetch) - add_library(openblas STATIC IMPORTED) - set_target_properties(openblas PROPERTIES IMPORTED_LOCATION ${OpenBLAS_DIR}/lib/libopenblas.dll.a) - install(FILES ${OpenBLAS_DIR}/bin/libopenblas.dll DESTINATION "obs-plugins/64bit") + add_library(ct2 INTERFACE) + target_link_libraries(ct2 INTERFACE "-framework Accelerate" ${ctranslate2_fetch_SOURCE_DIR}/lib/libctranslate2.a + ${ctranslate2_fetch_SOURCE_DIR}/lib/libcpu_features.a) + set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include) + target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32) - set(CT2_OPENBLAS_CMAKE_ARGS -DOPENBLAS_INCLUDE_DIR=${OPENBLAS_INCLUDE_DIR} - -DOPENBLAS_LIBRARY=${OpenBLAS_DIR}/lib/libopenblas.dll.a -DWITH_OPENBLAS=ON) - else() - set(CT2_OPENBLAS_CMAKE_ARGS -DWITH_OPENBLAS=OFF) - endif() +else() + set(CT2_VERSION "4.1.1") + set(CT2_URL "https://github.com/OpenNMT/CTranslate2.git") + set(CT2_OPENBLAS_CMAKE_ARGS -DWITH_OPENBLAS=OFF) - if(UNIX) - if(APPLE) - set(CT2_CMAKE_PLATFORM_OPTIONS -DCMAKE_OSX_DEPLOYMENT_TARGET=10.13 -DBUILD_SHARED_LIBS=OFF -DWITH_ACCELERATE=ON - -DOPENMP_RUNTIME=NONE -DCMAKE_OSX_ARCHITECTURES=arm64) - else() - set(CT2_CMAKE_PLATFORM_OPTIONS -DBUILD_SHARED_LIBS=OFF -DOPENMP_RUNTIME=NONE -DCMAKE_POSITION_INDEPENDENT_CODE=ON) - endif() - set(CT2_LIB_INSTALL_LOCATION lib/${CMAKE_SHARED_LIBRARY_PREFIX}ctranslate2${CMAKE_STATIC_LIBRARY_SUFFIX}) - else() - set(CT2_CMAKE_PLATFORM_OPTIONS -DBUILD_SHARED_LIBS=ON -DOPENMP_RUNTIME=COMP) - set(CT2_LIB_INSTALL_LOCATION bin/${CMAKE_SHARED_LIBRARY_PREFIX}ctranslate2${CMAKE_SHARED_LIBRARY_SUFFIX}) - endif() + set(CT2_CMAKE_PLATFORM_OPTIONS -DBUILD_SHARED_LIBS=OFF -DOPENMP_RUNTIME=NONE -DCMAKE_POSITION_INDEPENDENT_CODE=ON) ExternalProject_Add( ct2_build @@ -84,33 +65,17 @@ else() ${CT2_CMAKE_PLATFORM_OPTIONS}) ExternalProject_Get_Property(ct2_build INSTALL_DIR) - if(UNIX) - # Get cpu_features from the CTranslate2 build - only for x86_64 builds if(APPLE) - # ExternalProject_Get_Property(ct2_build BINARY_DIR) add_library(ct2::cpu_features STATIC IMPORTED GLOBAL) - # set_target_properties( ct2::cpu_features PROPERTIES IMPORTED_LOCATION - # ${BINARY_DIR}/third_party/cpu_features/RelWithDebInfo/libcpu_features.a) endif() + # Get cpu_features from the CTranslate2 build - only for x86_64 builds if(APPLE) + # ExternalProject_Get_Property(ct2_build BINARY_DIR) add_library(ct2::cpu_features STATIC IMPORTED GLOBAL) + # set_target_properties( ct2::cpu_features PROPERTIES IMPORTED_LOCATION + # ${BINARY_DIR}/third_party/cpu_features/RelWithDebInfo/libcpu_features.a) endif() - add_library(ct2::ct2 STATIC IMPORTED GLOBAL) - else() - add_library(ct2::ct2 SHARED IMPORTED GLOBAL) - set_target_properties( - ct2::ct2 PROPERTIES IMPORTED_IMPLIB - ${INSTALL_DIR}/lib/${CMAKE_STATIC_LIBRARY_PREFIX}ctranslate2${CMAKE_STATIC_LIBRARY_SUFFIX}) - install(FILES ${INSTALL_DIR}/${CT2_LIB_INSTALL_LOCATION} DESTINATION "obs-plugins/64bit") - endif() + add_library(ct2::ct2 STATIC IMPORTED GLOBAL) add_dependencies(ct2::ct2 ct2_build) set_target_properties(ct2::ct2 PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/${CT2_LIB_INSTALL_LOCATION}) add_library(ct2 INTERFACE) - if(APPLE) - target_link_libraries( - ct2 - INTERFACE - "-framework Accelerate /Users/roy_shilkrot/Downloads/obs-ai-ctranslate2-dep/CTranslate2-3.20.0/release/universal/RelWithDebInfo/lib/libctranslate2.a" - ) - else() - target_link_libraries(ct2 INTERFACE ct2::ct2) - endif() + target_link_libraries(ct2 INTERFACE ct2::ct2) set_target_properties(ct2::ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${INSTALL_DIR}/include) endif() diff --git a/cmake/BuildMyCurl.cmake b/cmake/BuildMyCurl.cmake new file mode 100644 index 0000000..10d3e05 --- /dev/null +++ b/cmake/BuildMyCurl.cmake @@ -0,0 +1,73 @@ +include(FetchContent) + +set(LibCurl_VERSION "8.4.0-3") +set(LibCurl_BASEURL "https://github.com/occ-ai/obs-ai-libcurl-dep/releases/download/${LibCurl_VERSION}") + +if(${CMAKE_BUILD_TYPE} STREQUAL Release OR ${CMAKE_BUILD_TYPE} STREQUAL RelWithDebInfo) + set(LibCurl_BUILD_TYPE Release) +else() + set(LibCurl_BUILD_TYPE Debug) +endif() + +if(APPLE) + if(LibCurl_BUILD_TYPE STREQUAL Release) + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-macos-${LibCurl_VERSION}-Release.tar.gz") + set(LibCurl_HASH SHA256=5ef7bfed2c2bca17ba562aede6a3c3eb465b8d7516cff86ca0f0d0337de951e1) + else() + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-macos-${LibCurl_VERSION}-Debug.tar.gz") + set(LibCurl_HASH SHA256=da0801168eac5103e6b27bfd0f56f82e0617f85e4e6c69f476071dbba273403b) + endif() +elseif(MSVC) + if(LibCurl_BUILD_TYPE STREQUAL Release) + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-windows-${LibCurl_VERSION}-Release.zip") + set(LibCurl_HASH SHA256=bf4d4cd7d741712a2913df0994258d11aabe22c9a305c9f336ed59e76f351adf) + else() + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-windows-${LibCurl_VERSION}-Debug.zip") + set(LibCurl_HASH SHA256=9fe20e677ffb0d7dd927b978d532e23574cdb1923e2d2ca7c5e42f1fff2ec529) + endif() +else() + if(LibCurl_BUILD_TYPE STREQUAL Release) + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-linux-${LibCurl_VERSION}-Release.tar.gz") + set(LibCurl_HASH SHA256=f2cd80b7d3288fe5b4c90833bcbf0bde7c9574bc60eddb13015df19c5a09f56b) + else() + set(LibCurl_URL "${LibCurl_BASEURL}/libcurl-linux-${LibCurl_VERSION}-Debug.tar.gz") + set(LibCurl_HASH SHA256=6a41d3daef98acc3172b3702118dcf1cccbde923f3836ed2f4f3ed7301e47b8b) + endif() +endif() + +FetchContent_Declare( + libcurl_fetch + URL ${LibCurl_URL} + URL_HASH ${LibCurl_HASH}) +FetchContent_MakeAvailable(libcurl_fetch) + +if(MSVC) + set(libcurl_fetch_lib_location "${libcurl_fetch_SOURCE_DIR}/lib/libcurl.lib") + set(libcurl_fetch_link_libs "\$;\$;\$;\$") +else() + find_package(ZLIB REQUIRED) + set(libcurl_fetch_lib_location "${libcurl_fetch_SOURCE_DIR}/lib/libcurl.a") + if(UNIX AND NOT APPLE) + find_package(OpenSSL REQUIRED) + set(libcurl_fetch_link_libs "\$;\$;\$") + else() + set(libcurl_fetch_link_libs + "-framework SystemConfiguration;-framework Security;-framework CoreFoundation;-framework CoreServices;ZLIB::ZLIB" + ) + endif() +endif() + +# Create imported target +add_library(libcurl STATIC IMPORTED) + +set_target_properties( + libcurl + PROPERTIES INTERFACE_COMPILE_DEFINITIONS "CURL_STATICLIB" + INTERFACE_INCLUDE_DIRECTORIES "${libcurl_fetch_SOURCE_DIR}/include" + INTERFACE_LINK_LIBRARIES "${libcurl_fetch_link_libs}") +set_property( + TARGET libcurl + APPEND + PROPERTY IMPORTED_CONFIGURATIONS RELEASE) +set_target_properties(libcurl PROPERTIES IMPORTED_LINK_INTERFACE_LANGUAGES_RELEASE "C" IMPORTED_LOCATION_RELEASE + ${libcurl_fetch_lib_location}) diff --git a/cmake/BuildSentencepiece.cmake b/cmake/BuildSentencepiece.cmake index 761b809..16bd724 100644 --- a/cmake/BuildSentencepiece.cmake +++ b/cmake/BuildSentencepiece.cmake @@ -6,8 +6,19 @@ if(APPLE) FetchContent_Declare( sentencepiece_fetch - URL https://github.com/obs-ai/obs-ai-ctranslate2-dep/releases/download/1.0.0/libsentencepiece-macos-Release-1.0.0.tar.gz - URL_HASH SHA256=67f58a8e97c14db1bc69becd507ffe69326948f371bf874fe919157d7d65aff4) + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/libsentencepiece-macos-Release-1.1.0.tar.gz + URL_HASH SHA256=168c9eead7ea77010c6e7867555da1b39433e5c002dc994f44abe96df7c71a66) + FetchContent_MakeAvailable(sentencepiece_fetch) + add_library(sentencepiece INTERFACE) + target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/libsentencepiece.a) + set_target_properties(sentencepiece PROPERTIES INTERFACE_INCLUDE_DIRECTORIES + ${sentencepiece_fetch_SOURCE_DIR}/include) +elseif(WIN32) + + FetchContent_Declare( + sentencepiece_fetch + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/sentencepiece-windows-0.2.0-Release.zip + URL_HASH SHA256=f45109b75929d1e35780d1b4ce4218002a3872352f494659b79488214daa987c) FetchContent_MakeAvailable(sentencepiece_fetch) add_library(sentencepiece INTERFACE) target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/libsentencepiece.a) diff --git a/cmake/common/compiler_common.cmake b/cmake/common/compiler_common.cmake index 8ac423f..1e54992 100644 --- a/cmake/common/compiler_common.cmake +++ b/cmake/common/compiler_common.cmake @@ -34,8 +34,8 @@ set(_obs_clang_c_options -Wfour-char-constants -Winfinite-recursion -Wint-conversion - -Wnewline-eof -Wno-conversion + -Wno-error=newline-eof -Wno-float-conversion -Wno-implicit-fallthrough -Wno-missing-braces diff --git a/cmake/macos/xcode.cmake b/cmake/macos/xcode.cmake index 3b4184c..6a43920 100644 --- a/cmake/macos/xcode.cmake +++ b/cmake/macos/xcode.cmake @@ -144,7 +144,7 @@ set(CMAKE_XCODE_ATTRIBUTE_CLANG_WARN__DUPLICATE_METHOD_MATCH YES) set(CMAKE_XCODE_ATTRIBUTE_GCC_NO_COMMON_BLOCKS YES) set(CMAKE_XCODE_ATTRIBUTE_GCC_WARN_64_TO_32_BIT_CONVERSION YES) set(CMAKE_XCODE_ATTRIBUTE_GCC_WARN_ABOUT_MISSING_FIELD_INITIALIZERS NO) -set(CMAKE_XCODE_ATTRIBUTE_GCC_WARN_ABOUT_MISSING_NEWLINE YES) +set(CMAKE_XCODE_ATTRIBUTE_GCC_WARN_ABOUT_MISSING_NEWLINE NO) set(CMAKE_XCODE_ATTRIBUTE_GCC_WARN_ABOUT_RETURN_TYPE YES_ERROR) set(CMAKE_XCODE_ATTRIBUTE_GCC_WARN_CHECK_SWITCH_STATEMENTS YES) set(CMAKE_XCODE_ATTRIBUTE_GCC_WARN_FOUR_CHARACTER_CONSTANTS YES) diff --git a/src/model-utils/model-downloader-types.h b/src/model-utils/model-downloader-types.h new file mode 100644 index 0000000..e7fcefe --- /dev/null +++ b/src/model-utils/model-downloader-types.h @@ -0,0 +1,23 @@ +#ifndef MODEL_DOWNLOADER_TYPES_H +#define MODEL_DOWNLOADER_TYPES_H + +#include +#include +#include + +// Information about a model +struct ModelInfo { + std::string name; + std::vector urls; + std::string localPath; + std::string spmUrl; + std::string localSpmPath; +}; + +// Callback for when the download is finished +typedef std::function + download_finished_callback_t; + +extern std::map models_info; + +#endif // MODEL_DOWNLOADER_TYPES_H diff --git a/src/model-utils/model-downloader-ui.cpp b/src/model-utils/model-downloader-ui.cpp new file mode 100644 index 0000000..aeea442 --- /dev/null +++ b/src/model-utils/model-downloader-ui.cpp @@ -0,0 +1,257 @@ +#include "model-downloader-ui.h" +#include "plugin-support.h" + +#include + +#include +#include +#include + +size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream) +{ + size_t written = fwrite(ptr, size, nmemb, stream); + return written; +} + +ModelDownloader::ModelDownloader(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback_, + QWidget *parent) + : QDialog(parent), + download_finished_callback(download_finished_callback_), + model_info(model_info) +{ + this->setWindowTitle("Downloading model..."); + this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint); + this->setFixedSize(300, 100); + // Bring the dialog to the front + this->activateWindow(); + this->raise(); + + this->layout = new QVBoxLayout(this); + + // Add a label for the model name + QLabel *model_name_label = new QLabel(this); + model_name_label->setText(QString::fromStdString(model_info.name)); + model_name_label->setAlignment(Qt::AlignCenter); + this->layout->addWidget(model_name_label); + + this->progress_bar = new QProgressBar(this); + this->progress_bar->setRange(0, 100); + this->progress_bar->setValue(0); + this->progress_bar->setAlignment(Qt::AlignCenter); + // Show progress as a percentage + this->progress_bar->setFormat("%p%"); + this->layout->addWidget(this->progress_bar); + + this->download_thread = new QThread(); + this->download_worker = new ModelDownloadWorker(model_info); + this->download_worker->moveToThread(this->download_thread); + + connect(this->download_thread, &QThread::started, this->download_worker, + &ModelDownloadWorker::download_model); + connect(this->download_worker, &ModelDownloadWorker::download_progress, this, + &ModelDownloader::update_progress); + connect(this->download_worker, &ModelDownloadWorker::download_finished, this, + &ModelDownloader::download_finished); + connect(this->download_worker, &ModelDownloadWorker::download_finished, + this->download_thread, &QThread::quit); + connect(this->download_worker, &ModelDownloadWorker::download_finished, + this->download_worker, &ModelDownloadWorker::deleteLater); + connect(this->download_worker, &ModelDownloadWorker::download_error, this, + &ModelDownloader::show_error); + connect(this->download_thread, &QThread::finished, this->download_thread, + &QThread::deleteLater); + + this->download_thread->start(); +} + +void ModelDownloader::closeEvent(QCloseEvent *e) +{ + if (!this->mPrepareToClose) + e->ignore(); + else + QDialog::closeEvent(e); +} + +void ModelDownloader::close() +{ + this->mPrepareToClose = true; + + QDialog::close(); +} + +void ModelDownloader::update_progress(int progress) +{ + this->progress_bar->setValue(progress); +} + +void ModelDownloader::download_finished(const ModelInfo &info) +{ + this->model_info.localPath = info.localPath; + this->model_info.localSpmPath = info.localSpmPath; + + // Call the callback with the path to the downloaded model + this->download_finished_callback(0, this->model_info); + // Close the dialog + this->close(); +} + +void ModelDownloader::show_error(const std::string &reason) +{ + this->setWindowTitle("Download failed!"); + this->progress_bar->setFormat("Download failed!"); + this->progress_bar->setAlignment(Qt::AlignCenter); + this->progress_bar->setStyleSheet("QProgressBar::chunk { background-color: #FF0000; }"); + // Add a label to show the error + QLabel *error_label = new QLabel(this); + error_label->setText(QString::fromStdString(reason)); + error_label->setAlignment(Qt::AlignCenter); + // Color red + error_label->setStyleSheet("QLabel { color : red; }"); + this->layout->addWidget(error_label); + // Add a button to close the dialog + QPushButton *close_button = new QPushButton("Close", this); + this->layout->addWidget(close_button); + connect(close_button, &QPushButton::clicked, this, &ModelDownloader::close); + this->download_finished_callback(1, {}); +} + +ModelDownloadWorker::ModelDownloadWorker(const ModelInfo &model_info_) +{ + this->model_info = model_info_; +} + +std::string ModelDownloadWorker::download_file(CURL *curl, const std::string &url, + const std::string &path) +{ + // Check if the file already exists + if (std::filesystem::exists(path)) { + obs_log(LOG_INFO, "File already exists: %s", path.c_str()); + return ""; + } + + FILE *fp = fopen(path.c_str(), "wb"); + if (fp == nullptr) { + obs_log(LOG_ERROR, "Failed to open file %s.", path.c_str()); + return "Failed to open file."; + } + curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, fp); + curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, ModelDownloadWorker::progress_callback); + curl_easy_setopt(curl, CURLOPT_XFERINFODATA, this); + // Follow redirects + curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + obs_log(LOG_ERROR, "Failed to download model %s.", this->model_info.name.c_str()); + return "Failed to download model."; + } + fclose(fp); + return ""; +} + +void ModelDownloadWorker::download_model() +{ + char *module_config_path = obs_module_get_config_path(obs_current_module(), "models"); + // Check if the config folder exists + if (!std::filesystem::exists(module_config_path)) { + obs_log(LOG_WARNING, "Config folder does not exist: %s", module_config_path); + // Create the config folder + if (!std::filesystem::create_directories(module_config_path)) { + obs_log(LOG_ERROR, "Failed to create config folder: %s", + module_config_path); + emit download_error("Failed to create config folder."); + return; + } + } + + char *model_save_path_str = obs_module_get_config_path(obs_current_module(), "models"); + std::string model_save_path(model_save_path_str); + bfree(model_save_path_str); + obs_log(LOG_INFO, "Model save path: %s", model_save_path.c_str()); + + CURL *curl = curl_easy_init(); + if (curl) { + const std::string model_file_save_base_path = + model_save_path + "/" + this->model_info.name; + // Create if doesn't exist + if (!std::filesystem::exists(model_file_save_base_path)) { + if (!std::filesystem::create_directories(model_file_save_base_path)) { + obs_log(LOG_ERROR, "Failed to create model folder: %s", + model_file_save_base_path.c_str()); + emit download_error("Failed to create model folder."); + return; + } + } + + // Download the model files + for (const std::string &url : this->model_info.urls) { + // extract filename from URL that has querystring + std::string filename = url.substr(url.find_last_of("/\\") + 1); + filename = filename.substr(0, filename.find("?")); + + const std::string model_file_save_path = + model_file_save_base_path + "/" + filename; + std::string error = download_file(curl, url, model_file_save_path); + if (!error.empty()) { + emit download_error(error); + return; + } + } + obs_log(LOG_INFO, "Model downloaded to %s", model_file_save_base_path.c_str()); + + // Download the sentencepiece model + const std::string sp_model_file_save_path = + model_save_path + "/" + this->model_info.name + ".sp.model"; + std::string error = download_file(curl, model_info.spmUrl, sp_model_file_save_path); + if (!error.empty()) { + emit download_error(error); + return; + } + obs_log(LOG_INFO, "SPM downloaded to %s", sp_model_file_save_path.c_str()); + + // Save the model path to the model_info + this->model_info.localPath = model_file_save_base_path; + this->model_info.localSpmPath = sp_model_file_save_path; + + emit download_finished(this->model_info); + } else { + obs_log(LOG_ERROR, "Failed to initialize curl."); + emit download_error("Failed to initialize curl."); + } +} + +int ModelDownloadWorker::progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, + curl_off_t, curl_off_t) +{ + if (dltotal == 0) { + return 0; // Unknown progress + } + ModelDownloadWorker *worker = (ModelDownloadWorker *)clientp; + if (worker == nullptr) { + obs_log(LOG_ERROR, "Worker is null."); + return 1; + } + int progress = (int)(dlnow * 100l / dltotal); + emit worker->download_progress(progress); + return 0; +} + +ModelDownloader::~ModelDownloader() +{ + if (this->download_thread != nullptr) { + if (this->download_thread->isRunning()) { + this->download_thread->quit(); + this->download_thread->wait(); + } + delete this->download_thread; + } + delete this->download_worker; +} + +ModelDownloadWorker::~ModelDownloadWorker() +{ + // Do nothing +} diff --git a/src/model-utils/model-downloader-ui.h b/src/model-utils/model-downloader-ui.h new file mode 100644 index 0000000..7593603 --- /dev/null +++ b/src/model-utils/model-downloader-ui.h @@ -0,0 +1,63 @@ +#ifndef MODEL_DOWNLOADER_UI_H +#define MODEL_DOWNLOADER_UI_H + +#include +#include + +#include +#include + +#include + +#include "model-downloader-types.h" + +class ModelDownloadWorker : public QObject { + Q_OBJECT +public: + ModelDownloadWorker(const ModelInfo &model_info); + ~ModelDownloadWorker(); + +public slots: + void download_model(); + +signals: + void download_progress(int progress); + void download_finished(const ModelInfo &info); + void download_error(const std::string &reason); + +private: + static int progress_callback(void *clientp, curl_off_t dltotal, curl_off_t dlnow, + curl_off_t ultotal, curl_off_t ulnow); + std::string download_file(CURL *curl, const std::string &url, const std::string &path); + ModelInfo model_info; +}; + +class ModelDownloader : public QDialog { + Q_OBJECT +public: + ModelDownloader(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback, + QWidget *parent = nullptr); + ~ModelDownloader(); + +public slots: + void update_progress(int progress); + void download_finished(const ModelInfo &info); + void show_error(const std::string &reason); + +protected: + void closeEvent(QCloseEvent *e) override; + +private: + QVBoxLayout *layout; + QProgressBar *progress_bar; + QThread *download_thread; + ModelDownloadWorker *download_worker; + // Callback for when the download is finished + download_finished_callback_t download_finished_callback; + bool mPrepareToClose; + void close(); + ModelInfo model_info; +}; + +#endif // MODEL_DOWNLOADER_UI_H diff --git a/src/model-utils/model-downloader.cpp b/src/model-utils/model-downloader.cpp new file mode 100644 index 0000000..d652d4c --- /dev/null +++ b/src/model-utils/model-downloader.cpp @@ -0,0 +1,86 @@ +#include "model-downloader.h" +#include "plugin-support.h" +#include "model-downloader-ui.h" + +#include +#include + +#include +#include +#include +#include + +#include + +std::map models_info = { + {"M2M100 418M (495Mb)", + { + "m2m100-418m", + {"https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/model.bin?download=true", + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/config.json?download=true", + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/generation_config.json?download=true", + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/shared_vocabulary.json?download=true", + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/tokenizer_config.json?download=true", + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/vocab.json?download=true"}, + "", + "https://huggingface.co/jncraton/m2m100_418M-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "", + }}, + {"NLLB 200 Distilled 600M (625Mb)", + { + "nllb-200-distilled-600m", + {"https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/model.bin?download=true", + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/config.json?download=true", + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/shared_vocabulary.txt?download=true", + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/special_tokens_map.json?download=true", + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/tokenizer.json?download=true", + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/tokenizer_config.json?download=true"}, + "", + "https://huggingface.co/JustFrederik/nllb-200-distilled-600M-ct2-int8/resolve/main/sentencepiece.bpe.model?download=true", + "", + }}}; + +std::string find_model_file(const std::string &model_name) +{ + const char *model_name_cstr = model_name.c_str(); + obs_log(LOG_INFO, "Checking if model %s exists in data...", model_name_cstr); + + char *model_file_path = obs_module_file(model_name_cstr); + if (model_file_path == nullptr) { + obs_log(LOG_INFO, "Model %s not found in data.", model_name_cstr); + } else { + std::string model_file_path_str(model_file_path); + bfree(model_file_path); + if (!std::filesystem::exists(model_file_path_str)) { + obs_log(LOG_INFO, "Model not found in data: %s", + model_file_path_str.c_str()); + } else { + obs_log(LOG_INFO, "Model found in data: %s", model_file_path_str.c_str()); + return model_file_path_str; + } + } + + // Check if model exists in the config folder + char *model_config_path_str = + obs_module_get_config_path(obs_current_module(), model_name_cstr); + std::string model_config_path(model_config_path_str); + bfree(model_config_path_str); + obs_log(LOG_INFO, "Model path in config: %s", model_config_path.c_str()); + if (std::filesystem::exists(model_config_path)) { + obs_log(LOG_INFO, "Model exists in config folder: %s", model_config_path.c_str()); + return model_config_path; + } + + obs_log(LOG_INFO, "Model %s not found.", model_name_cstr); + return ""; +} + +void download_model_with_ui_dialog(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback) +{ + // Start the model downloader UI + ModelDownloader *model_downloader = new ModelDownloader( + model_info, download_finished_callback, (QWidget *)obs_frontend_get_main_window()); + model_downloader->show(); +} diff --git a/src/model-utils/model-downloader.h b/src/model-utils/model-downloader.h new file mode 100644 index 0000000..8de0173 --- /dev/null +++ b/src/model-utils/model-downloader.h @@ -0,0 +1,15 @@ +#ifndef MODEL_DOWNLOADER_H +#define MODEL_DOWNLOADER_H + +#include +#include + +#include "model-downloader-types.h" + +std::string find_model_file(const std::string &model_name); + +// Start the model downloader UI dialog with a callback for when the download is finished +void download_model_with_ui_dialog(const ModelInfo &model_info, + download_finished_callback_t download_finished_callback); + +#endif // MODEL_DOWNLOADER_H diff --git a/src/plugin-main.c b/src/plugin-main.c index 50e5f21..9c84c9b 100644 --- a/src/plugin-main.c +++ b/src/plugin-main.c @@ -30,6 +30,7 @@ bool obs_module_load(void) obs_log(LOG_INFO, "plugin loaded successfully (version %s)", PLUGIN_VERSION); resetContext(); + registerDock(); // load plugin settings from config if (loadConfig() == OBS_POLYGLOT_CONFIG_SUCCESS) { @@ -40,9 +41,10 @@ bool obs_module_load(void) // build the translation context if (build_translation_context() != OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS) { obs_log(LOG_ERROR, "Failed to build translation context"); + } else { + obs_log(LOG_INFO, "Built translation context"); } - registerDock(); return true; } diff --git a/src/translation-service/httpserver.cpp b/src/translation-service/httpserver.cpp index 34d3454..68a97be 100644 --- a/src/translation-service/httpserver.cpp +++ b/src/translation-service/httpserver.cpp @@ -81,6 +81,10 @@ void start_http_server() } obs_log(LOG_INFO, "Polyglot Http server stopped."); }).detach(); + + global_context.status_callback("Ready for requests at http://localhost:" + + std::to_string(global_config.http_server_port) + + "/translate"); } // stop the http server @@ -93,5 +97,6 @@ void stop_http_server() global_context.svr->stop(); delete global_context.svr; global_context.svr = nullptr; + global_context.status_callback(""); } } diff --git a/src/translation-service/translation.cpp b/src/translation-service/translation.cpp index d9eaffb..446e9f6 100644 --- a/src/translation-service/translation.cpp +++ b/src/translation-service/translation.cpp @@ -10,12 +10,17 @@ int build_translation_context() { obs_log(LOG_INFO, "Building translation context..."); + if (global_config.model_selection == 0) { + obs_log(LOG_INFO, "No model selected"); + global_context.error_callback("No model selected"); + return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; + } try { obs_log(LOG_INFO, "Loading SPM from %s", global_config.local_spm_path.c_str()); global_context.processor = new sentencepiece::SentencePieceProcessor(); const auto status = global_context.processor->Load(global_config.local_spm_path); if (!status.ok()) { - obs_log(LOG_ERROR, status.ToString().c_str()); + obs_log(LOG_ERROR, "Failed to load SPM: %s", status.ToString().c_str()); global_context.error_callback("Failed to load SPM. " + status.ToString()); return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; } @@ -44,11 +49,12 @@ int build_translation_context() global_context.options->use_vmap = true; global_context.options->return_scores = false; } catch (std::exception &e) { - obs_log(LOG_ERROR, "Error: %s", e.what()); + obs_log(LOG_ERROR, "Failed to load CT2 model: %s", e.what()); global_context.error_callback("Failed to load CT2 model. " + std::string(e.what())); return OBS_POLYGLOT_TRANSLATION_INIT_FAIL; } global_context.error_callback(""); // Clear any errors + global_context.status_callback("Translation engine ready"); return OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS; } diff --git a/src/ui/registerDock.cpp b/src/ui/registerDock.cpp index f94f977..28d9d4b 100644 --- a/src/ui/registerDock.cpp +++ b/src/ui/registerDock.cpp @@ -23,6 +23,10 @@ void registerDock() global_context.error_message = error_message; dock->updateErrorLabel(error_message); }; + global_context.status_callback = [=](const std::string &message) { + global_context.status_message = message; + dock->updateStatusLabel(message); + }; // Register the dock obs_frontend_add_dock(dock); diff --git a/src/ui/settingsdialog.cpp b/src/ui/settingsdialog.cpp index fb16817..4a7f577 100644 --- a/src/ui/settingsdialog.cpp +++ b/src/ui/settingsdialog.cpp @@ -3,6 +3,7 @@ #include "utils/config-data.h" #include "plugin-support.h" #include "translation-service/translation.h" +#include "model-utils/model-downloader.h" #include #include @@ -22,11 +23,77 @@ std::string getHomeDir() SettingsDialog::SettingsDialog(QWidget *parent) : QDialog(parent), ui(new Ui::SettingsDialog) { ui->setupUi(this); + + // add model selection from model infos + for (const auto &model : models_info) { + this->ui->comboBox_modelSelection->addItem(QString::fromStdString(model.first)); + } + // populate the UI with the current settings - this->ui->modelFile->setText(QString::fromStdString(global_config.local_model_path)); - this->ui->spmFile->setText(QString::fromStdString(global_config.local_spm_path)); + this->ui->comboBox_modelSelection->setCurrentIndex(global_config.model_selection); + this->ui->modelFile->setEnabled(false); + this->ui->modelFileBtn->setEnabled(false); + this->ui->spmFile->setEnabled(false); + this->ui->spmFileBtn->setEnabled(false); + + if (global_config.model_selection == 1) { + this->ui->modelFile->setEnabled(true); + this->ui->modelFileBtn->setEnabled(true); + this->ui->spmFile->setEnabled(true); + this->ui->spmFileBtn->setEnabled(true); + } + if (global_config.model_selection >= 1) { + this->ui->modelFile->setText( + QString::fromStdString(global_config.local_model_path)); + this->ui->spmFile->setText(QString::fromStdString(global_config.local_spm_path)); + } + this->ui->httpPort->setText(QString::number(global_config.http_server_port)); + connect(this->ui->comboBox_modelSelection, + QOverload::of(&QComboBox::currentIndexChanged), [=](int index) { + if (index == 0) { + // None + this->ui->modelFile->setEnabled(false); + this->ui->modelFileBtn->setEnabled(false); + this->ui->spmFile->setEnabled(false); + this->ui->spmFileBtn->setEnabled(false); + this->ui->modelFile->setText(""); + this->ui->spmFile->setText(""); + } else if (index == 1) { + // Custom + this->ui->modelFile->setEnabled(true); + this->ui->modelFileBtn->setEnabled(true); + this->ui->spmFile->setEnabled(true); + this->ui->spmFileBtn->setEnabled(true); + } else { + this->ui->modelFile->setEnabled(false); + this->ui->modelFileBtn->setEnabled(false); + this->ui->spmFile->setEnabled(false); + this->ui->spmFileBtn->setEnabled(false); + + // launch the model downloader + download_model_with_ui_dialog( + models_info[this->ui->comboBox_modelSelection->currentText() + .toStdString()], + [=](int download_status, const ModelInfo &modelInfo) { + if (download_status == 0) { + this->ui->modelFile->setText( + QString::fromStdString( + modelInfo.localPath)); + this->ui->spmFile->setText( + QString::fromStdString( + modelInfo.localSpmPath)); + } else { + obs_log(LOG_ERROR, + "Failed to download model"); + this->ui->comboBox_modelSelection + ->setCurrentIndex(0); + } + }); + } + }); + // Model folder selection dialog connect(this->ui->modelFileBtn, &QPushButton::clicked, this, [=]() { // Allow selection of a folder only @@ -57,6 +124,7 @@ SettingsDialog::SettingsDialog(QWidget *parent) : QDialog(parent), ui(new Ui::Se // connect to the dialog Save action to save the settings this->connect(this->ui->buttonBox, &QDialogButtonBox::accepted, this, [=]() { // get settings from UI into config struct + global_config.model_selection = this->ui->comboBox_modelSelection->currentIndex(); global_config.local_model_path = this->ui->modelFile->text().toStdString(); global_config.local_spm_path = this->ui->spmFile->text().toStdString(); global_config.http_server_port = this->ui->httpPort->text().toUShort(); diff --git a/src/ui/settingsdialog.ui b/src/ui/settingsdialog.ui index 7abd4fb..b980cbc 100644 --- a/src/ui/settingsdialog.ui +++ b/src/ui/settingsdialog.ui @@ -47,14 +47,17 @@ 3 - + + + false + - Model + Model File - + @@ -73,10 +76,17 @@ 0 - + + + false + + + + false + 0 @@ -91,14 +101,17 @@ - + + + false + - SPM + SPM File - + @@ -117,10 +130,17 @@ 0 - + + + false + + + + false + ... @@ -129,6 +149,33 @@ + + + + + 0 + 0 + + + + + Select a model + + + + + Custom (Files) + + + + + + + + Model + + + diff --git a/src/ui/translatedockwidget.cpp b/src/ui/translatedockwidget.cpp index 4bc3eda..febf759 100644 --- a/src/ui/translatedockwidget.cpp +++ b/src/ui/translatedockwidget.cpp @@ -10,6 +10,7 @@ TranslateDockWidget::TranslateDockWidget(QWidget *parent) ui->setupUi(this); ui->errorLabel->hide(); + ui->label_status->hide(); this->updateErrorLabel(global_context.error_message); // connect the settings button to the settings dialog @@ -39,6 +40,15 @@ void TranslateDockWidget::openSettingsDialog() settingsDialog->show(); } +void TranslateDockWidget::updateStatusLabel(const std::string &message) +{ + if (message.empty()) { + ui->label_status->hide(); + return; + } + ui->label_status->setText(QString::fromStdString(message)); +} + void TranslateDockWidget::updateErrorLabel(const std::string &error_message) { // if there is an error message, show the error label diff --git a/src/ui/translatedockwidget.h b/src/ui/translatedockwidget.h index 9550b5d..a3791d1 100644 --- a/src/ui/translatedockwidget.h +++ b/src/ui/translatedockwidget.h @@ -14,6 +14,7 @@ class TranslateDockWidget : public QDockWidget { explicit TranslateDockWidget(QWidget *parent = nullptr); ~TranslateDockWidget(); void updateErrorLabel(const std::string &error_message); + void updateStatusLabel(const std::string &message); private slots: void openSettingsDialog(); diff --git a/src/ui/translatedockwidget.ui b/src/ui/translatedockwidget.ui index f1e591f..62222e9 100644 --- a/src/ui/translatedockwidget.ui +++ b/src/ui/translatedockwidget.ui @@ -32,6 +32,13 @@ + + + + Status Label + + + diff --git a/src/utils/config-data.cpp b/src/utils/config-data.cpp index 67e6401..1423539 100644 --- a/src/utils/config-data.cpp +++ b/src/utils/config-data.cpp @@ -14,6 +14,7 @@ polyglot_global_context global_context; void config_defaults() { + global_config.model_selection = 0; global_config.local = true; global_config.local_model_path = ""; global_config.local_spm_path = ""; @@ -108,6 +109,7 @@ int loadConfig() std::string config_data_to_json(const polyglot_config_data &data) { nlohmann::json j; + j["model_selection"] = data.model_selection; j["local"] = data.local; j["local_model_path"] = data.local_model_path; j["local_spm_path"] = data.local_spm_path; @@ -124,6 +126,7 @@ polyglot_config_data config_data_from_json(const std::string &json) nlohmann::json j = nlohmann::json::parse(json); polyglot_config_data data; try { + data.model_selection = j["model_selection"]; data.local = j["local"]; data.local_model_path = j["local_model_path"]; data.local_spm_path = j["local_spm_path"]; @@ -146,7 +149,15 @@ void resetContext() global_context.svr = nullptr; global_context.error_callback = [](const std::string &error_message) { global_context.error_message = error_message; - obs_log(LOG_ERROR, "Error (callback): %s", error_message.c_str()); + if (!error_message.empty()) { + obs_log(LOG_ERROR, "Error (callback): %s", error_message.c_str()); + } + }; + global_context.status_callback = [](const std::string &message) { + global_context.status_message = message; + if (!message.empty()) { + obs_log(LOG_INFO, "Status (callback): %s", message.c_str()); + } }; global_context.tokenizer = [](const std::string &) { return std::vector(); }; global_context.detokenizer = [](const std::vector &) { return std::string(); }; diff --git a/src/utils/config-data.h b/src/utils/config-data.h index 5b4f90d..e8697a6 100644 --- a/src/utils/config-data.h +++ b/src/utils/config-data.h @@ -8,6 +8,8 @@ #include struct polyglot_config_data { + // model selection (0: none, 1: custom, 2+ preset models) + int model_selection; // local model path std::string local_model_path; // local spm path @@ -43,6 +45,8 @@ class Server; struct polyglot_global_context { // error message std::string error_message; + // status message + std::string status_message; // ctranslate2 options ctranslate2::TranslationOptions *options; // ctranslate2 translator @@ -55,6 +59,8 @@ struct polyglot_global_context { std::function &)> detokenizer; // error callback std::function error_callback; + // status callback + std::function status_callback; // http server httplib::Server *svr; }; From efcff5a1cd417d54ec2aab679271e67725c553ec Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 18 Mar 2024 23:48:36 -0400 Subject: [PATCH 2/7] Update function signatures and enable start/stop HTTP server button --- cmake/BuildCTranslate2.cmake | 21 +++++++--------- cmake/BuildSentencepiece.cmake | 10 ++++---- src/plugin-main.c | 4 +-- src/translation-service/httpserver.cpp | 17 +------------ src/translation-service/translation.cpp | 20 +++++++++++++++ src/translation-service/translation.h | 1 + src/ui/settingsdialog.cpp | 2 +- src/ui/translatedockwidget.cpp | 2 ++ src/utils/config-data.cpp | 33 ++++++++++++++----------- src/utils/config-data.h | 4 +-- 10 files changed, 61 insertions(+), 53 deletions(-) diff --git a/cmake/BuildCTranslate2.cmake b/cmake/BuildCTranslate2.cmake index df37efb..36f1071 100644 --- a/cmake/BuildCTranslate2.cmake +++ b/cmake/BuildCTranslate2.cmake @@ -7,8 +7,8 @@ if(APPLE) FetchContent_Declare( ctranslate2_fetch - URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/libctranslate2-macos-Release-1.1.0.tar.gz - URL_HASH SHA256=dba2eaa1b3f4e9eb1e8999e668d515aa94b115af07565e2b6797b9eda6f2f845) + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libctranslate2-macos-Release-1.1.1.tar.gz + URL_HASH SHA256=da04d88ecc1ea105f8ee672e4eab33af96e50c999c5cc8170e105e110392182b) FetchContent_MakeAvailable(ctranslate2_fetch) add_library(ct2 INTERFACE) @@ -21,15 +21,17 @@ elseif(WIN32) FetchContent_Declare( ctranslate2_fetch - URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/libctranslate2-windows-4.1.1-Release.zip - URL_HASH SHA256=683023e9c76ac6d54e54d14c32d86c020d5486ba289b60e5336f7cc86b984d03) + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libctranslate2-windows-4.1.1-Release.zip + URL_HASH SHA256=aa87073e663e4dfbbc9b9360d83bed847212309bea433c3b1888a33a51cb0db0) FetchContent_MakeAvailable(ctranslate2_fetch) add_library(ct2 INTERFACE) - target_link_libraries(ct2 INTERFACE "-framework Accelerate" ${ctranslate2_fetch_SOURCE_DIR}/lib/libctranslate2.a - ${ctranslate2_fetch_SOURCE_DIR}/lib/libcpu_features.a) + target_link_libraries(ct2 INTERFACE ${ctranslate2_fetch_SOURCE_DIR}/lib/ctranslate2.lib) set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include) - target_compile_options(ct2 INTERFACE -Wno-shorten-64-to-32) + target_compile_options(ct2 INTERFACE /wd4267 /wd4244 /wd4305 /wd4996 /wd4099) + + install(FILES ${ctranslate2_fetch_SOURCE_DIR}/bin/ctranslate2.dll ${ctranslate2_fetch_SOURCE_DIR}/bin/libopenblas.dll + DESTINATION "obs-plugins/64bit") else() set(CT2_VERSION "4.1.1") @@ -65,11 +67,6 @@ else() ${CT2_CMAKE_PLATFORM_OPTIONS}) ExternalProject_Get_Property(ct2_build INSTALL_DIR) - # Get cpu_features from the CTranslate2 build - only for x86_64 builds if(APPLE) - # ExternalProject_Get_Property(ct2_build BINARY_DIR) add_library(ct2::cpu_features STATIC IMPORTED GLOBAL) - # set_target_properties( ct2::cpu_features PROPERTIES IMPORTED_LOCATION - # ${BINARY_DIR}/third_party/cpu_features/RelWithDebInfo/libcpu_features.a) endif() - add_library(ct2::ct2 STATIC IMPORTED GLOBAL) add_dependencies(ct2::ct2 ct2_build) set_target_properties(ct2::ct2 PROPERTIES IMPORTED_LOCATION ${INSTALL_DIR}/${CT2_LIB_INSTALL_LOCATION}) diff --git a/cmake/BuildSentencepiece.cmake b/cmake/BuildSentencepiece.cmake index 16bd724..024283e 100644 --- a/cmake/BuildSentencepiece.cmake +++ b/cmake/BuildSentencepiece.cmake @@ -6,8 +6,8 @@ if(APPLE) FetchContent_Declare( sentencepiece_fetch - URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/libsentencepiece-macos-Release-1.1.0.tar.gz - URL_HASH SHA256=168c9eead7ea77010c6e7867555da1b39433e5c002dc994f44abe96df7c71a66) + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libsentencepiece-macos-Release-1.1.1.tar.gz + URL_HASH SHA256=c911f1e84ea94925a8bc3fd3257185b2e18395075509c8659cc7003a979e0b32) FetchContent_MakeAvailable(sentencepiece_fetch) add_library(sentencepiece INTERFACE) target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/libsentencepiece.a) @@ -17,11 +17,11 @@ elseif(WIN32) FetchContent_Declare( sentencepiece_fetch - URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.0/sentencepiece-windows-0.2.0-Release.zip - URL_HASH SHA256=f45109b75929d1e35780d1b4ce4218002a3872352f494659b79488214daa987c) + URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/sentencepiece-windows-0.2.0-Release.zip + URL_HASH SHA256=846699c7fa1e8918b71ed7f2bd5cd60e47e51105e1d84e3192919b4f0f10fdeb) FetchContent_MakeAvailable(sentencepiece_fetch) add_library(sentencepiece INTERFACE) - target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/libsentencepiece.a) + target_link_libraries(sentencepiece INTERFACE ${sentencepiece_fetch_SOURCE_DIR}/lib/sentencepiece.lib) set_target_properties(sentencepiece PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${sentencepiece_fetch_SOURCE_DIR}/include) diff --git a/src/plugin-main.c b/src/plugin-main.c index 9c84c9b..5931821 100644 --- a/src/plugin-main.c +++ b/src/plugin-main.c @@ -29,7 +29,7 @@ bool obs_module_load(void) { obs_log(LOG_INFO, "plugin loaded successfully (version %s)", PLUGIN_VERSION); - resetContext(); + resetContext(true); registerDock(); // load plugin settings from config @@ -50,6 +50,6 @@ bool obs_module_load(void) void obs_module_unload(void) { - freeContext(); + freeContext(true); obs_log(LOG_INFO, "plugin unloaded"); } diff --git a/src/translation-service/httpserver.cpp b/src/translation-service/httpserver.cpp index 68a97be..473243f 100644 --- a/src/translation-service/httpserver.cpp +++ b/src/translation-service/httpserver.cpp @@ -45,24 +45,9 @@ void start_http_server() body.append(data, data_length); return true; }); - std::string input_text; - std::string source_lang; - std::string target_lang; - // parse body json - try { - nlohmann::json j = nlohmann::json::parse(body); - input_text = j["text"]; - source_lang = j["source_lang"]; - target_lang = j["target_lang"]; - } catch (std::exception &e) { - obs_log(LOG_ERROR, "Error: %s", e.what()); - res.set_content("Error parsing json", "text/plain"); - res.status = 500; - return; - } std::string result; - int ret = translate(input_text, source_lang, target_lang, result); + int ret = translate_from_json(body, result); if (ret == OBS_POLYGLOT_TRANSLATION_SUCCESS) { res.set_content(result, "text/plain"); } else { diff --git a/src/translation-service/translation.cpp b/src/translation-service/translation.cpp index 446e9f6..3cd422c 100644 --- a/src/translation-service/translation.cpp +++ b/src/translation-service/translation.cpp @@ -58,6 +58,26 @@ int build_translation_context() return OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS; } +int translate_from_json(const std::string &body, std::string &result) +{ + std::string input_text; + std::string source_lang; + std::string target_lang; + // parse body json + try { + nlohmann::json j = nlohmann::json::parse(body); + input_text = j["text"]; + source_lang = j["source_lang"]; + target_lang = j["target_lang"]; + } catch (std::exception &e) { + obs_log(LOG_ERROR, "Error: %s", e.what()); + result = "Error parsing json"; + return OBS_POLYGLOT_TRANSLATION_FAIL; + } + + return translate(input_text, source_lang, target_lang, result); +} + int translate(const std::string &text, const std::string &source_lang, const std::string &target_lang, std::string &result) { diff --git a/src/translation-service/translation.h b/src/translation-service/translation.h index 8a5f637..82393ef 100644 --- a/src/translation-service/translation.h +++ b/src/translation-service/translation.h @@ -12,6 +12,7 @@ int build_translation_context(); #include int translate(const std::string &text, const std::string &source_lang, const std::string &target_lang, std::string &result); +int translate_from_json(const std::string &body, std::string &result); #endif #define OBS_POLYGLOT_TRANSLATION_INIT_FAIL -1 diff --git a/src/ui/settingsdialog.cpp b/src/ui/settingsdialog.cpp index 4a7f577..7e5c0ae 100644 --- a/src/ui/settingsdialog.cpp +++ b/src/ui/settingsdialog.cpp @@ -134,7 +134,7 @@ SettingsDialog::SettingsDialog(QWidget *parent) : QDialog(parent), ui(new Ui::Se obs_log(LOG_INFO, "Saved settings"); // update the plugin - freeContext(); + freeContext(false); if (build_translation_context() == OBS_POLYGLOT_TRANSLATION_INIT_SUCCESS) { obs_log(LOG_INFO, "Translation context updated"); } else { diff --git a/src/ui/translatedockwidget.cpp b/src/ui/translatedockwidget.cpp index febf759..743d38a 100644 --- a/src/ui/translatedockwidget.cpp +++ b/src/ui/translatedockwidget.cpp @@ -62,5 +62,7 @@ void TranslateDockWidget::updateErrorLabel(const std::string &error_message) ui->startStopHTTPServer->setEnabled(false); } else { ui->errorLabel->hide(); + // enable the start/stop http server button + ui->startStopHTTPServer->setEnabled(true); } } diff --git a/src/utils/config-data.cpp b/src/utils/config-data.cpp index 1423539..83728ab 100644 --- a/src/utils/config-data.cpp +++ b/src/utils/config-data.cpp @@ -140,30 +140,33 @@ polyglot_config_data config_data_from_json(const std::string &json) return data; } -void resetContext() +void resetContext(bool resetCallbacks) { global_context.error_message = ""; global_context.options = nullptr; global_context.translator = nullptr; global_context.processor = nullptr; global_context.svr = nullptr; - global_context.error_callback = [](const std::string &error_message) { - global_context.error_message = error_message; - if (!error_message.empty()) { - obs_log(LOG_ERROR, "Error (callback): %s", error_message.c_str()); - } - }; - global_context.status_callback = [](const std::string &message) { - global_context.status_message = message; - if (!message.empty()) { - obs_log(LOG_INFO, "Status (callback): %s", message.c_str()); - } - }; + if (resetCallbacks) { + global_context.error_callback = [](const std::string &error_message) { + global_context.error_message = error_message; + if (!error_message.empty()) { + obs_log(LOG_ERROR, "Error (vanilla callback): %s", + error_message.c_str()); + } + }; + global_context.status_callback = [](const std::string &message) { + global_context.status_message = message; + if (!message.empty()) { + obs_log(LOG_INFO, "Status (vanilla callback): %s", message.c_str()); + } + }; + } global_context.tokenizer = [](const std::string &) { return std::vector(); }; global_context.detokenizer = [](const std::vector &) { return std::string(); }; } -void freeContext() +void freeContext(bool resetCallbacks) { if (global_context.options != nullptr) { delete global_context.options; @@ -181,5 +184,5 @@ void freeContext() delete global_context.svr; global_context.svr = nullptr; } - resetContext(); + resetContext(resetCallbacks); } diff --git a/src/utils/config-data.h b/src/utils/config-data.h index e8697a6..c4f8589 100644 --- a/src/utils/config-data.h +++ b/src/utils/config-data.h @@ -77,8 +77,8 @@ extern polyglot_global_context global_context; extern "C" { #endif -void resetContext(); -void freeContext(); +void resetContext(bool resetCallbacks); +void freeContext(bool resetCallbacks); int saveConfig(bool create_if_not_exist); int loadConfig(); From 26a1a27c56e7c87ff2407e0ca5f01d03c27229fc Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Mon, 18 Mar 2024 23:56:47 -0400 Subject: [PATCH 3/7] Update brew bundle command and fix variable naming in ModelDownloader constructor --- .github/scripts/utils.zsh/check_macos | 2 +- src/model-utils/model-downloader-ui.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/scripts/utils.zsh/check_macos b/.github/scripts/utils.zsh/check_macos index 54b5fbf..9c63496 100644 --- a/.github/scripts/utils.zsh/check_macos +++ b/.github/scripts/utils.zsh/check_macos @@ -17,6 +17,6 @@ if (( ! ${+commands[brew]} )) { return 2 } -brew bundle --file ${SCRIPT_HOME}/.Brewfile +brew bundle --no-upgrade --file ${SCRIPT_HOME}/.Brewfile rehash log_group diff --git a/src/model-utils/model-downloader-ui.cpp b/src/model-utils/model-downloader-ui.cpp index aeea442..7ed7121 100644 --- a/src/model-utils/model-downloader-ui.cpp +++ b/src/model-utils/model-downloader-ui.cpp @@ -13,12 +13,12 @@ size_t write_data(void *ptr, size_t size, size_t nmemb, FILE *stream) return written; } -ModelDownloader::ModelDownloader(const ModelInfo &model_info, +ModelDownloader::ModelDownloader(const ModelInfo &model_info_, download_finished_callback_t download_finished_callback_, QWidget *parent) : QDialog(parent), download_finished_callback(download_finished_callback_), - model_info(model_info) + model_info(model_info_) { this->setWindowTitle("Downloading model..."); this->setWindowFlags(Qt::Dialog | Qt::WindowTitleHint | Qt::CustomizeWindowHint); From 52d6345a7ccc5918cda54d1268e13885df04345c Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 19 Mar 2024 00:03:13 -0400 Subject: [PATCH 4/7] Remove deprecated declaration warning flag in compilerconfig.cmake --- cmake/linux/compilerconfig.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/linux/compilerconfig.cmake b/cmake/linux/compilerconfig.cmake index 67d2b74..bb77721 100644 --- a/cmake/linux/compilerconfig.cmake +++ b/cmake/linux/compilerconfig.cmake @@ -13,7 +13,6 @@ set(_obs_gcc_c_options # cmake-format: sortable -fno-strict-aliasing -fopenmp-simd - -Wdeprecated-declarations -Wempty-body -Wenum-conversion -Werror=return-type @@ -21,6 +20,7 @@ set(_obs_gcc_c_options -Wformat -Wformat-security -Wno-conversion + -Wno-error=deprecated-declarations -Wno-float-conversion -Wno-implicit-fallthrough -Wno-missing-braces From 3e32cb614d537302bd122f3dbad287a26ebd3d38 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 19 Mar 2024 00:14:02 -0400 Subject: [PATCH 5/7] Add -Wdeprecated-declarations flag to gcc options --- cmake/linux/compilerconfig.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/linux/compilerconfig.cmake b/cmake/linux/compilerconfig.cmake index bb77721..3d1c95f 100644 --- a/cmake/linux/compilerconfig.cmake +++ b/cmake/linux/compilerconfig.cmake @@ -13,6 +13,7 @@ set(_obs_gcc_c_options # cmake-format: sortable -fno-strict-aliasing -fopenmp-simd + -Wdeprecated-declarations -Wempty-body -Wenum-conversion -Werror=return-type From a5f7819397bb291a0603ba7c010caa19450ff972 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 19 Mar 2024 10:22:03 -0400 Subject: [PATCH 6/7] Remove deprecated declaration warning flag in compilerconfig.cmake --- cmake/linux/compilerconfig.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/linux/compilerconfig.cmake b/cmake/linux/compilerconfig.cmake index 3d1c95f..647c4b3 100644 --- a/cmake/linux/compilerconfig.cmake +++ b/cmake/linux/compilerconfig.cmake @@ -13,7 +13,6 @@ set(_obs_gcc_c_options # cmake-format: sortable -fno-strict-aliasing -fopenmp-simd - -Wdeprecated-declarations -Wempty-body -Wenum-conversion -Werror=return-type @@ -21,6 +20,7 @@ set(_obs_gcc_c_options -Wformat -Wformat-security -Wno-conversion + -Wno-deprecated-declarations -Wno-error=deprecated-declarations -Wno-float-conversion -Wno-implicit-fallthrough From 91776d60a9e41995a57520200fa59892692ba31e Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Tue, 19 Mar 2024 10:33:53 -0400 Subject: [PATCH 7/7] Add CT2_LIB_INSTALL_LOCATION to cmake options --- cmake/BuildCTranslate2.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/BuildCTranslate2.cmake b/cmake/BuildCTranslate2.cmake index 36f1071..9ce13df 100644 --- a/cmake/BuildCTranslate2.cmake +++ b/cmake/BuildCTranslate2.cmake @@ -39,6 +39,7 @@ else() set(CT2_OPENBLAS_CMAKE_ARGS -DWITH_OPENBLAS=OFF) set(CT2_CMAKE_PLATFORM_OPTIONS -DBUILD_SHARED_LIBS=OFF -DOPENMP_RUNTIME=NONE -DCMAKE_POSITION_INDEPENDENT_CODE=ON) + set(CT2_LIB_INSTALL_LOCATION lib/${CMAKE_SHARED_LIBRARY_PREFIX}ctranslate2${CMAKE_STATIC_LIBRARY_SUFFIX}) ExternalProject_Add( ct2_build