Skip to content

Commit

Permalink
Update the build system to support Cuda. (#75)
Browse files Browse the repository at this point in the history
* update the build system to support cuda.

* add a google colab example for Cuda test.

* enable CI for the cuda_draft branch.

* resolve some comments.
  • Loading branch information
csukuangfj authored Jul 29, 2020
1 parent ae3da3e commit fda3b83
Show file tree
Hide file tree
Showing 20 changed files with 492 additions and 98 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@ on:
push:
branches:
- master
- cuda
pull_request:
branches:
- master
- cuda

env:
BUILD_TYPE: Debug

jobs:
build:
# disable CI now since GitHub action does not support CUDA
# and it always fails
if: false
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down Expand Up @@ -55,4 +60,4 @@ jobs:
- name: Test
shell: bash
working-directory: ${{runner.workspace}}/build
run: ctest --verbose --build-config $BUILD_TYPE
run: ctest --verbose --exclude-regex Cuda --build-config $BUILD_TYPE
36 changes: 34 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ to build this project"
)
endif()

cmake_minimum_required(VERSION 3.5 FATAL_ERROR)
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)

project(k2)
project(k2 CUDA CXX)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

Expand All @@ -26,19 +26,51 @@ if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING
"Set the build type. Available values are: Debug Release RelWithDebInfo MinSizeRel"
FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS
Debug Release RelWithDebInfo MinSizeRel
)
endif()

if(WIN32 AND BUILD_SHARED_LIBS)
message(STATUS "Set BUILD_SHARED_LIBS to OFF for Windows")
set(BUILD_SHARED_LIBS OFF CACHE BOOL "" FORCE)
endif()

# the following settings are modified from cub/CMakeLists.txt
#[[ start settings for CUB ]]

set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
set(CMAKE_CXX_EXTENSIONS OFF)

message(STATUS "C++ Standard version: ${CMAKE_CXX_STANDARD}")

# Force CUDA C++ standard to be the same as the C++ standard used.
#
# Now, CMake is unaligned with reality on standard versions: https://gitlab.kitware.com/cmake/cmake/issues/18597
# which means that using standard CMake methods, it's impossible to actually sync the CXX and CUDA versions for pre-11
# versions of C++; CUDA accepts 98 but translates that to 03, while CXX doesn't accept 03 (and doesn't translate that to 03).
# In case this gives You, dear user, any trouble, please escalate the above CMake bug, so we can support reality properly.
if(DEFINED CMAKE_CUDA_STANDARD)
message(WARNING "You've set CMAKE_CUDA_STANDARD; please note that this variable is ignored, and CMAKE_CXX_STANDARD"
" is used as the C++ standard version for both C++ and CUDA.")
endif()
unset(CMAKE_CUDA_STANDARD CACHE)
set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD})

set(K2_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72 75)
foreach(COMPUTE_ARCH IN LISTS K2_COMPUTE_ARCHS)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode arch=compute_${COMPUTE_ARCH},code=sm_${COMPUTE_ARCH}")
endforeach()

#[[ end settings for CUB ]]

enable_testing()

list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(cpplint)
include(glog)
include(googletest)
include(pybind11)
include(cub)

add_subdirectory(k2)
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,9 @@

# k2
FSA/FST algorithms, intended to (eventually) be interoperable with PyTorch and similar.

## Quick start

Want to try it out without installing anything? We have setup a [Google Colab][1].

[1]: https://colab.research.google.com/drive/1qbHUhNZUX7AYEpqnZyf29Lrz2IPHBGlX?usp=sharing
30 changes: 30 additions & 0 deletions cmake/cub.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2020 Fangjun Kuang ([email protected])
# See ../LICENSE for clarification regarding multiple authors

function(download_cub)
if(CMAKE_VERSION VERSION_LESS 3.11)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()

include(FetchContent)

set(cub_URL "https://github.com/NVlabs/cub/archive/1.9.10.tar.gz")
set(cub_HASH "SHA256=2bd7077a3d9741f0689e6c1eb58c6278fc96eccc27d964168bc8be1bc3a9040f")

FetchContent_Declare(cub
URL ${cub_URL}
URL_HASH ${cub_HASH}
)

FetchContent_GetProperties(cub)
if(NOT cub)
message(STATUS "Downloading cub")
FetchContent_Populate(cub)
endif()
message(STATUS "cub is downloaded to ${cub_SOURCE_DIR}")
add_library(cub INTERFACE)
target_include_directories(cub INTERFACE ${cub_SOURCE_DIR})

endfunction()

download_cub()
2 changes: 2 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,5 @@ set(fsa_tests
foreach(name IN LISTS fsa_tests)
k2_add_fsa_test(${name})
endforeach()

add_subdirectory(cuda)
27 changes: 27 additions & 0 deletions k2/csrc/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

add_library(context context.cu)
target_include_directories(context PUBLIC ${CMAKE_SOURCE_DIR})
target_link_libraries(context PUBLIC cub)

function(k2_add_cuda_test name)
add_executable(${name} "${name}.cu")
target_link_libraries(${name}
PRIVATE
context
gtest
gtest_main
)
add_test(NAME "Test.Cuda.${name}"
COMMAND
$<TARGET_FILE:${name}>
)
endfunction()

# please sort the source files alphabetically
set(cuda_tests
utils_test
)

foreach(name IN LISTS cuda_tests)
k2_add_cuda_test(${name})
endforeach()
2 changes: 1 addition & 1 deletion k2/csrc/cuda/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

So far this directory just contains some notes on implementation; all the code
is just a VERY EARLY DRAFT. The goal here is to show *in principle* how we parallelize
things, building up from low-level primitives, but without acutally creating any
things, building up from low-level primitives, but without actually creating any
CUDA code.

Actually we probably shouldn't separate this into a separate directory from the CPU code,
Expand Down
25 changes: 19 additions & 6 deletions k2/csrc/cuda/array.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
// k2/csrc/cuda/array.h

// Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey)

// See ../../LICENSE for clarification regarding multiple authors

#ifndef K2_CSRC_CUDA_ARRAY_H_
#define K2_CSRC_CUDA_ARRAY_H_

#include "k2/csrc/cuda/context.h"

namespace k2 {

/*
Array1* is a 1-dimensional contiguous array (that doesn't support a stride).
Expand All @@ -12,16 +22,16 @@ template <typename T> class Array1 {



// generally L will be some kind of lambda or function object; it should be
// generally Callable will be some kind of lambda or function object; it should be
// possible to evaluate it on the CUDA device (if we're compiling with CUDA)
// and also on the CPU. We'll do src(i) to evaluate element i.
// NOTE: we assume this thread is already set to use the device associated with the
// context in 'ctx', if it's a CUDA context.
template <typename L>
Array1(ContextPtr ctx, int size, L lambda) {
template <typename Callable>
Array1(ContextPtr ctx, int size, Callable &&callable) {
Init(ctx, size);

Eval(ctx->DeviceType(), data, size, lambda);
Eval(ctx->DeviceType(), data(), size, std::forward<Callable>(callable));
}

/* Return sub-part of this array
Expand All @@ -30,7 +40,7 @@ template <typename T> class Array1 {
*/
Array1 Range(int start, int size);

DeviceType Device() { return region->device; }
DeviceType Device() const { return region->device; }

// Resizes, copying old contents if we could not re-use the same memory location.
// It will always at least double the allocated size if it has to reallocate.
Expand All @@ -53,5 +63,8 @@ template <typename T> class Array1 {
void Init(DeviceType d, int size) {
// .. takes care of allocation etc.
}

};

} // namespace k2

#endif // K2_CSRC_CUDA_ARRAY_H_
18 changes: 11 additions & 7 deletions k2/csrc/cuda/compose.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
#include <compose.h>
// k2/csrc/cuda/compose.cc

// Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey)

// See ../../LICENSE for clarification regarding multiple authors

#include "k2/csrc/cuda/compose.h"

namespace k2 {

// Caution: this is really a .cu file. It contains mixed host and device code.

Expand Down Expand Up @@ -500,10 +508,6 @@ void IntersectDensePruned(FsaVec &a_fsas,
Array1<HashKeyType> state_repr_hash; // hash-value of corresponding elements of a_fsas and b_fsas

Hash<HashKeyType, int, Hasher> repr_hash_to_id; // Maps from (fsa_index, hash of state_repr) to






}

} // namespace k2
16 changes: 16 additions & 0 deletions k2/csrc/cuda/compose.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
// k2/csrc/cuda/compose.h

// Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey)

// See ../../LICENSE for clarification regarding multiple authors

#ifndef K2_CSRC_CUDA_COMPOSE_H_
#define K2_CSRC_CUDA_COMPOSE_H_

#include "k2/csrc/cuda/array.h"

namespace k2 {


// Note: b is FsaVec<Arc>.
Expand All @@ -19,3 +31,7 @@ void IntersectDensePruned(Array3<Arc> &a_fsas,
FsaVec *ofsa,
Array<int> *arc_map_a,
Array<int> *arc_map_b);

} // namespace k2

#endif // K2_CSRC_CUDA_COMPOSE_H_
66 changes: 66 additions & 0 deletions k2/csrc/cuda/context.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// k2/csrc/cuda/context.cu

// Copyright (c) 2020 Fangjun Kuang ([email protected])

// See ../../LICENSE for clarification regarding multiple authors

// WARNING(fangjun): this is a naive implementation to test the build system
#include "k2/csrc/cuda/context.h"

#include <cstdlib>

static constexpr size_t kAlignment = 64;

namespace k2 {

class CpuContext : public Context {
public:
ContextPtr Duplicate() override { return nullptr; }

DeviceType GetDeviceType() const override { return kCpu; }

void *Allocate(size_t bytes) override {
void *p = nullptr;
if (bytes) {
int ret = posix_memalign(&p, kAlignment, bytes);
// check the return code
}
return p;
}

bool IsSame(const Context & /*other*/) const override { return true; }

void Deallocate(void *data) override { free(data); }
};

class CudaContext : public Context {
public:
ContextPtr Duplicate() override { return nullptr; }

DeviceType GetDeviceType() const override { return kCuda; }

void *Allocate(size_t bytes) override {
void *p = nullptr;
if (bytes) {
cudaError_t ret = cudaMalloc(&p, bytes);
// check the return code
}
return p;
}

bool IsSame(const Context & /*other*/) const override {
// TODO: change this
return true;
}

void Deallocate(void *data) override { cudaFree(data); }
};

ContextPtr GetCpuContext() { return std::make_shared<CpuContext>(); }

ContextPtr GetCudaContext(int gpu_id /*= -1*/) {
// TODO: select a gpu
return std::make_shared<CudaContext>();
}

} // namespace k2
Loading

0 comments on commit fda3b83

Please sign in to comment.