-
Notifications
You must be signed in to change notification settings - Fork 217
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update the build system to support Cuda. (#75)
* update the build system to support cuda. * add a google colab example for Cuda test. * enable CI for the cuda_draft branch. * resolve some comments.
- Loading branch information
1 parent
ae3da3e
commit fda3b83
Showing
20 changed files
with
492 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) 2020 Fangjun Kuang ([email protected]) | ||
# See ../LICENSE for clarification regarding multiple authors | ||
|
||
function(download_cub) | ||
if(CMAKE_VERSION VERSION_LESS 3.11) | ||
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) | ||
endif() | ||
|
||
include(FetchContent) | ||
|
||
set(cub_URL "https://github.com/NVlabs/cub/archive/1.9.10.tar.gz") | ||
set(cub_HASH "SHA256=2bd7077a3d9741f0689e6c1eb58c6278fc96eccc27d964168bc8be1bc3a9040f") | ||
|
||
FetchContent_Declare(cub | ||
URL ${cub_URL} | ||
URL_HASH ${cub_HASH} | ||
) | ||
|
||
FetchContent_GetProperties(cub) | ||
if(NOT cub) | ||
message(STATUS "Downloading cub") | ||
FetchContent_Populate(cub) | ||
endif() | ||
message(STATUS "cub is downloaded to ${cub_SOURCE_DIR}") | ||
add_library(cub INTERFACE) | ||
target_include_directories(cub INTERFACE ${cub_SOURCE_DIR}) | ||
|
||
endfunction() | ||
|
||
download_cub() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,3 +61,5 @@ set(fsa_tests | |
foreach(name IN LISTS fsa_tests) | ||
k2_add_fsa_test(${name}) | ||
endforeach() | ||
|
||
add_subdirectory(cuda) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
|
||
add_library(context context.cu) | ||
target_include_directories(context PUBLIC ${CMAKE_SOURCE_DIR}) | ||
target_link_libraries(context PUBLIC cub) | ||
|
||
function(k2_add_cuda_test name) | ||
add_executable(${name} "${name}.cu") | ||
target_link_libraries(${name} | ||
PRIVATE | ||
context | ||
gtest | ||
gtest_main | ||
) | ||
add_test(NAME "Test.Cuda.${name}" | ||
COMMAND | ||
$<TARGET_FILE:${name}> | ||
) | ||
endfunction() | ||
|
||
# please sort the source files alphabetically | ||
set(cuda_tests | ||
utils_test | ||
) | ||
|
||
foreach(name IN LISTS cuda_tests) | ||
k2_add_cuda_test(${name}) | ||
endforeach() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// k2/csrc/cuda/context.cu | ||
|
||
// Copyright (c) 2020 Fangjun Kuang ([email protected]) | ||
|
||
// See ../../LICENSE for clarification regarding multiple authors | ||
|
||
// WARNING(fangjun): this is a naive implementation to test the build system | ||
#include "k2/csrc/cuda/context.h" | ||
|
||
#include <cstdlib> | ||
|
||
static constexpr size_t kAlignment = 64; | ||
|
||
namespace k2 { | ||
|
||
class CpuContext : public Context { | ||
public: | ||
ContextPtr Duplicate() override { return nullptr; } | ||
|
||
DeviceType GetDeviceType() const override { return kCpu; } | ||
|
||
void *Allocate(size_t bytes) override { | ||
void *p = nullptr; | ||
if (bytes) { | ||
int ret = posix_memalign(&p, kAlignment, bytes); | ||
// check the return code | ||
} | ||
return p; | ||
} | ||
|
||
bool IsSame(const Context & /*other*/) const override { return true; } | ||
|
||
void Deallocate(void *data) override { free(data); } | ||
}; | ||
|
||
class CudaContext : public Context { | ||
public: | ||
ContextPtr Duplicate() override { return nullptr; } | ||
|
||
DeviceType GetDeviceType() const override { return kCuda; } | ||
|
||
void *Allocate(size_t bytes) override { | ||
void *p = nullptr; | ||
if (bytes) { | ||
cudaError_t ret = cudaMalloc(&p, bytes); | ||
// check the return code | ||
} | ||
return p; | ||
} | ||
|
||
bool IsSame(const Context & /*other*/) const override { | ||
// TODO: change this | ||
return true; | ||
} | ||
|
||
void Deallocate(void *data) override { cudaFree(data); } | ||
}; | ||
|
||
ContextPtr GetCpuContext() { return std::make_shared<CpuContext>(); } | ||
|
||
ContextPtr GetCudaContext(int gpu_id /*= -1*/) { | ||
// TODO: select a gpu | ||
return std::make_shared<CudaContext>(); | ||
} | ||
|
||
} // namespace k2 |
Oops, something went wrong.