Skip to content
This repository has been archived by the owner on Dec 18, 2024. It is now read-only.

Add CUDA matrix strategy and CUDA toolkit setup #9

Merged
merged 8 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .github/scripts/Package-Windows.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,14 @@ function Package {
$ProductName = $BuildSpec.name
$ProductVersion = $BuildSpec.version

$OutputName = "${ProductName}-${ProductVersion}-windows-${Target}"
# check the CPU_OR_CUDA env variable to determine the target
if ( $Env:CPU_OR_CUDA -eq 'cpu' ) {
$cudaName = 'cpu'
} else {
$cudaName = "cuda${Env:CPU_OR_CUDA}"
}

$OutputName = "${ProductName}-${ProductVersion}-windows-${Target}-${cudaName}"

if ( ! $SkipDeps ) {
Install-BuildDependencies -WingetFile "${ScriptHome}/.Wingetfile"
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/build-project.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ jobs:
name: Build for Windows 🪟
runs-on: windows-2022
needs: check-event
strategy:
matrix:
cuda: [ 'cpu', '12.2.0', '11.8.0' ]
defaults:
run:
shell: pwsh
Expand Down Expand Up @@ -248,13 +251,17 @@ jobs:
with:
target: x64
config: ${{ needs.check-event.outputs.config }}
env:
CPU_OR_CUDA: ${{ matrix.cuda }}

- name: Package Plugin 📀
uses: ./.github/actions/package-plugin
with:
target: x64
config: ${{ needs.check-event.outputs.config }}
package: ${{ fromJSON(needs.check-event.outputs.package) }}
env:
CPU_OR_CUDA: ${{ matrix.cuda }}

- name: Upload Artifacts 📡
uses: actions/upload-artifact@v3
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ jobs:
commit_hash="${GITHUB_SHA:0:9}"

variants=(
'windows-x64;zip|exe'
'windows-x64-cpu;zip|exe'
'windows-x64-cuda12.2.0;zip|exe'
'windows-x64-cuda11.8.0;zip|exe'
'macos-universal;tar.xz|pkg'
'ubuntu-22.04-x86_64;tar.xz|deb|ddeb'
'sources;tar.xz'
Expand Down
22 changes: 11 additions & 11 deletions buildspec.json
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
{
"dependencies": {
"obs-studio": {
"version": "29.1.2",
"version": "29.1.3",
"baseUrl": "https://github.com/obsproject/obs-studio/archive/refs/tags",
"label": "OBS sources",
"hashes": {
"macos": "215f1fa5772c5dd9f3d6e35b0cb573912b00320149666a77864f9d305525504b",
"windows-x64": "46d451f3f42b9d2c59339ec268165849c7b7904cdf1cc2a8d44c015815a9e37d"
"macos": "9D9CFBDBDD255F48A23FEEEFB60089769A65F52BBCA24FA31D74125F3BBB0E90",
"windows-x64": "965334470E447DC164801F8812D583260761521E6E3C5EBEE1DA7CD8F6EC4A95"
}
},
"prebuilt": {
"version": "2023-04-12",
"version": "2023-06-22",
"baseUrl": "https://github.com/obsproject/obs-deps/releases/download",
"label": "Pre-Built obs-deps",
"hashes": {
"macos": "9535c6e1ad96f7d49960251e85a245774088d48da1d602bb82f734b10219125a",
"windows-x64": "c13a14a1acc4224b21304d97b63da4121de1ed6981297e50496fbc474abc0503"
"macos": "a0d2e03f0ea79681634c31627430a220d9b62113d6ff58174d0bdab6fafdd32b",
"windows-x64": "1b12e86e2d62a97a889866d66b95fe47ddc6f7fa9b13e88aedfab4ea9e298ea2"
}
},
"qt6": {
"version": "2023-04-12",
"version": "2023-06-22",
"baseUrl": "https://github.com/obsproject/obs-deps/releases/download",
"label": "Pre-Built Qt6",
"hashes": {
"macos": "eb7614544ab4f3d2c6052c797635602280ca5b028a6b987523d8484222ce45d1",
"windows-x64": "4d39364b8a8dee5aa24fcebd8440d5c22bb4551c6b440ffeacce7d61f2ed1add"
"macos": "f890d258a1afa7ba409b79c8ee55d53155e5c72105b8b18a3f52047ee70fc0aa",
"windows-x64": "1907fbcbcef69527154b29316c425b0885afb77ad69a9a2af7a1471d79512195"
},
"debugSymbols": {
"windows-x64": "f34ee5067be19ed370268b15c53684b7b8aaa867dc800b68931df905d679e31f"
"windows-x64": "b461a7ade0c099505baea857fa5b98c4f8e9b702681be019ea354735d062e065"
}
}
},
Expand All @@ -45,7 +45,7 @@
}
},
"name": "obs-polyglot",
"version": "0.0.2",
"version": "0.0.3",
"author": "Roy Shilkrot",
"website": "https://github.com/occ-ai/obs-polyglot",
"email": "[email protected]",
Expand Down
39 changes: 32 additions & 7 deletions cmake/BuildCTranslate2.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,45 @@ if(APPLE)

elseif(WIN32)

FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.1.1/libctranslate2-windows-4.1.1-Release.zip
URL_HASH SHA256=aa87073e663e4dfbbc9b9360d83bed847212309bea433c3b1888a33a51cb0db0)
# check CPU_OR_CUDA environment variable
if(NOT DEFINED ENV{CPU_OR_CUDA})
message(FATAL_ERROR "Please set the CPU_OR_CUDA environment variable to either CPU or CUDA")
endif()

if($ENV{CPU_OR_CUDA} STREQUAL "cpu")
FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cpu.zip
URL_HASH SHA256=30ff8b2499b8d3b5a6c4d6f7f8ddbc89e745ff06e0050b645e3b7c9b369451a3)
else()
# add compile definitions for CUDA
add_compile_definitions(POLYGLOT_WITH_CUDA)
add_compile_definitions(POLYGLOT_CUDA_VERSION=$ENV{CPU_OR_CUDA})

if($ENV{CPU_OR_CUDA} STREQUAL "12.2.0")
FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cuda12.2.0.zip
URL_HASH SHA256=131724d510f9f2829970953a1bc9e4e8fb7b4cbc8218e32270dcfe6172a51558)
elseif($ENV{CPU_OR_CUDA} STREQUAL "11.8.0")
FetchContent_Declare(
ctranslate2_fetch
URL https://github.com/occ-ai/obs-ai-ctranslate2-dep/releases/download/1.2.0/libctranslate2-windows-4.1.1-Release-cuda11.8.0.zip
URL_HASH SHA256=a120bee82f821df35a4646add30ac18b5c23e4e16b56fa7ba338eeae336e0d81)
else()
message(FATAL_ERROR "Unsupported CUDA version: $ENV{CPU_OR_CUDA}")
endif()
endif()

FetchContent_MakeAvailable(ctranslate2_fetch)

add_library(ct2 INTERFACE)
target_link_libraries(ct2 INTERFACE ${ctranslate2_fetch_SOURCE_DIR}/lib/ctranslate2.lib)
set_target_properties(ct2 PROPERTIES INTERFACE_INCLUDE_DIRECTORIES ${ctranslate2_fetch_SOURCE_DIR}/include)
target_compile_options(ct2 INTERFACE /wd4267 /wd4244 /wd4305 /wd4996 /wd4099)

install(FILES ${ctranslate2_fetch_SOURCE_DIR}/bin/ctranslate2.dll ${ctranslate2_fetch_SOURCE_DIR}/bin/libopenblas.dll
DESTINATION "obs-plugins/64bit")

file(GLOB CT2_DLLS ${ctranslate2_fetch_SOURCE_DIR}/bin/*.dll)
install(FILES ${CT2_DLLS} DESTINATION "obs-plugins/64bit")
else()
set(CT2_VERSION "4.1.1")
set(CT2_URL "https://github.com/OpenNMT/CTranslate2.git")
Expand Down
9 changes: 9 additions & 0 deletions src/translation-service/httpserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ void start_http_server()
}
global_context.svr = new httplib::Server();

global_context.svr->set_pre_routing_handler([](const httplib::Request &,
httplib::Response &res) {
res.set_header("Access-Control-Allow-Origin", "*");
res.set_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS");
res.set_header("Access-Control-Allow-Headers",
"Content-Type, Authorization");
return httplib::Server::HandlerResponse::Unhandled;
});

// set an echo handler
global_context.svr->Post("/echo", [](const httplib::Request &req,
httplib::Response &res,
Expand Down
12 changes: 10 additions & 2 deletions src/translation-service/translation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,17 @@ int build_translation_context()

obs_log(LOG_INFO, "Loading CT2 model from %s",
global_config.local_model_path.c_str());

#ifdef POLYGLOT_WITH_CUDA
ctranslate2::Device device = ctranslate2::Device::CUDA;
obs_log(LOG_INFO, "Using CUDA");
#else
ctranslate2::Device device = ctranslate2::Device::CPU;
obs_log(LOG_INFO, "Using CPU");
#endif

global_context.translator = new ctranslate2::Translator(
global_config.local_model_path, ctranslate2::Device::CPU,
ctranslate2::ComputeType::AUTO);
global_config.local_model_path, device, ctranslate2::ComputeType::AUTO);
obs_log(LOG_INFO, "CT2 Model loaded");

global_context.options = new ctranslate2::TranslationOptions;
Expand Down
Loading