Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Commit

Permalink
Merge pull request #9 from occ-ai/roy.cuda_build_windows
Browse files Browse the repository at this point in the history
Add CUDA matrix strategy and CUDA toolkit setup
  • Loading branch information
royshil authored Mar 26, 2024
2 parents 716e53f + 5a47bc9 commit 1cb4708
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 22 deletions.
9 changes: 8 additions & 1 deletion .github/scripts/Package-Windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ function Package {
$ProductName = $BuildSpec.name
$ProductVersion = $BuildSpec.version

$OutputName = "${ProductName}-${ProductVersion}-windows-${Target}"
# check the CPU_OR_CUDA env variable to determine the target
if ( $Env:CPU_OR_CUDA -eq 'cpu' ) {
$cudaName = 'cpu'
} else {
$cudaName = "cuda${Env:CPU_OR_CUDA}"
}

$OutputName = "${ProductName}-${ProductVersion}-windows-${Target}-${cudaName}"

if ( ! $SkipDeps ) {
Install-BuildDependencies -WingetFile "${ScriptHome}/.Wingetfile"
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/build-project.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ jobs:
name: Build for Windows 🪟
runs-on: windows-2022
needs: check-event
strategy:
matrix:
cuda: [ 'cpu', '12.2.0', '11.8.0' ]
defaults:
run:
shell: pwsh
Expand Down Expand Up @@ -248,13 +251,17 @@ jobs:
with:
target: x64
config: ${{ needs.check-event.outputs.config }}
env:
CPU_OR_CUDA: ${{ matrix.cuda }}

- name: Package Plugin 📀
uses: ./.github/actions/package-plugin
with:
target: x64
config: ${{ needs.check-event.outputs.config }}
package: ${{ fromJSON(needs.check-event.outputs.package) }}
env:
CPU_OR_CUDA: ${{ matrix.cuda }}

- name: Upload Artifacts 📡
uses: actions/upload-artifact@v3
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ jobs:
commit_hash="${GITHUB_SHA:0:9}"
variants=(
'windows-x64;zip|exe'
'windows-x64-cpu;zip|exe'
'windows-x64-cuda12.2.0;zip|exe'
'windows-x64-cuda11.8.0;zip|exe'
'macos-universal;tar.xz|pkg'
'ubuntu-22.04-x86_64;tar.xz|deb|ddeb'
'sources;tar.xz'
Expand Down
22 changes: 11 additions & 11 deletions buildspec.json
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
{
"dependencies": {
"obs-studio": {
"version": "29.1.2",
"version": "29.1.3",
"baseUrl": "https://github.com/obsproject/obs-studio/archive/refs/tags",
"label": "OBS sources",
"hashes": {
"macos": "215f1fa5772c5dd9f3d6e35b0cb573912b00320149666a77864f9d305525504b",
"windows-x64": "46d451f3f42b9d2c59339ec268165849c7b7904cdf1cc2a8d44c015815a9e37d"
"macos": "9D9CFBDBDD255F48A23FEEEFB60089769A65F52BBCA24FA31D74125F3BBB0E90",
"windows-x64": "965334470E447DC164801F8812D583260761521E6E3C5EBEE1DA7CD8F6EC4A95"
}
},
"prebuilt": {
"version": "2023-04-12",
"version": "2023-06-22",
"baseUrl": "https://github.com/obsproject/obs-deps/releases/download",
"label": "Pre-Built obs-deps",
"hashes": {
"macos": "9535c6e1ad96f7d49960251e85a245774088d48da1d602bb82f734b10219125a",
"windows-x64": "c13a14a1acc4224b21304d97b63da4121de1ed6981297e50496fbc474abc0503"
"macos": "a0d2e03f0ea79681634c31627430a220d9b62113d6ff58174d0bdab6fafdd32b",
"windows-x64": "1b12e86e2d62a97a889866d66b95fe47ddc6f7fa9b13e88aedfab4ea9e298ea2"
}
},
"qt6": {
"version": "2023-04-12",
"version": "2023-06-22",
"baseUrl": "https://github.com/obsproject/obs-deps/releases/download",
"label": "Pre-Built Qt6",
"hashes": {
"macos": "eb7614544ab4f3d2c6052c797635602280ca5b028a6b987523d8484222ce45d1",
"windows-x64": "4d39364b8a8dee5aa24fcebd8440d5c22bb4551c6b440ffeacce7d61f2ed1add"
"macos": "f890d258a1afa7ba409b79c8ee55d53155e5c72105b8b18a3f52047ee70fc0aa",
"windows-x64": "1907fbcbcef69527154b29316c425b0885afb77ad69a9a2af7a1471d79512195"
},
"debugSymbols": {
"windows-x64": "f34ee5067be19ed370268b15c53684b7b8aaa867dc800b68931df905d679e31f"
"windows-x64": "b461a7ade0c099505baea857fa5b98c4f8e9b702681be019ea354735d062e065"
}
}
},
Expand All @@ -45,7 +45,7 @@
}
},
"name": "obs-polyglot",
"version": "0.0.2",
"version": "0.0.3",
"author": "Roy Shilkrot",
"website": "https://github.com/occ-ai/obs-polyglot",
"email": "[email protected]",
Expand Down
39 changes: 32 additions & 7 deletions cmake/BuildCTranslate2.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,45 @@ if(APPLE)

elseif(WIN32)

FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libctranslate2-windows-4.1.1-Release.zip
URL_HASH SHA256=aa87073e663e4dfbbc9b9360d83bed847212309bea433c3b1888a33a51cb0db0)
# check CPU_OR_CUDA environment variable
if(NOT DEFINED ENV{CPU_OR_CUDA})
message(FATAL_ERROR "Please set the CPU_OR_CUDA environment variable to either CPU or CUDA")
endif()

if($ENV{CPU_OR_CUDA} STREQUAL "cpu")
FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cpu.zip
URL_HASH SHA256=30ff8b2499b8d3b5a6c4d6f7f8ddbc89e745ff06e0050b645e3b7c9b369451a3)
else()
# add compile definitions for CUDA
add_compile_definitions(POLYGLOT_WITH_CUDA)
add_compile_definitions(POLYGLOT_CUDA_VERSION=$ENV{CPU_OR_CUDA})

if($ENV{CPU_OR_CUDA} STREQUAL "12.2.0")
FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cuda12.2.0.zip
URL_HASH SHA256=131724d510f9f2829970953a1bc9e4e8fb7b4cbc8218e32270dcfe6172a51558)
elseif($ENV{CPU_OR_CUDA} STREQUAL "11.8.0")
FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cuda11.8.0.zip
URL_HASH SHA256=a120bee82f821df35a4646add30ac18b5c23e4e16b56fa7ba338eeae336e0d81)
else()
message(FATAL_ERROR "Unsupported CUDA version: $ENV{CPU_OR_CUDA}")
endif()
endif()

FetchContent_MakeAvailable(ctranslate2_fetch)

add_library(ct2 INTERFACE)
target_link_libraries(ct2 INTERFACE ${ctranslate2_fetch_SOURCE_DIR}/lib/ctranslate2.lib)
set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include)
target_compile_options(ct2 INTERFACE /wd4267 /wd4244 /wd4305 /wd4996 /wd4099)

install(FILES ${ctranslate2_fetch_SOURCE_DIR}/bin/ctranslate2.dll ${ctranslate2_fetch_SOURCE_DIR}/bin/libopenblas.dll
DESTINATION "obs-plugins/64bit")

file(GLOB CT2_DLLS ${ctranslate2_fetch_SOURCE_DIR}/bin/*.dll)
install(FILES ${CT2_DLLS} DESTINATION "obs-plugins/64bit")
else()
set(CT2_VERSION "4.1.1")
set(CT2_URL "https://github.com/OpenNMT/CTranslate2.git")
Expand Down
9 changes: 9 additions & 0 deletions src/translation-service/httpserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ void start_http_server()
}
global_context.svr = new httplib::Server();

global_context.svr->set_pre_routing_handler([](const httplib::Request &,
httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
res.set_header("Access-Control-Allow-Headers",
"Content-Type, Authorization");
return httplib::Server::HandlerResponse::Unhandled;
});

// set an echo handler
global_context.svr->Post("/echo", [](const httplib::Request &req,
httplib::Response &res,
Expand Down
12 changes: 10 additions & 2 deletions src/translation-service/translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ int build_translation_context()

obs_log(LOG_INFO, "Loading CT2 model from %s",
global_config.local_model_path.c_str());

#ifdef POLYGLOT_WITH_CUDA
ctranslate2::Device device = ctranslate2::Device::CUDA;
obs_log(LOG_INFO, "Using CUDA");
#else
ctranslate2::Device device = ctranslate2::Device::CPU;
obs_log(LOG_INFO, "Using CPU");
#endif

global_context.translator = new ctranslate2::Translator(
global_config.local_model_path, ctranslate2::Device::CPU,
ctranslate2::ComputeType::AUTO);
global_config.local_model_path, device, ctranslate2::ComputeType::AUTO);
obs_log(LOG_INFO, "CT2 Model loaded");

global_context.options = new ctranslate2::TranslationOptions;
Expand Down

0 comments on commit 1cb4708

Please sign in to comment.