Skip to content

Commit

Permalink
Merge pull request #24 from haruhi55/data_transfer_s2f
Browse files Browse the repository at this point in the history
feat(unittest): Implement basic unittest for transferring 2D data tiles between global and shared memory
  • Loading branch information
haruhi55 authored Apr 30, 2024
2 parents 93f053c + 30c7a8c commit 8205e7c
Show file tree
Hide file tree
Showing 33 changed files with 485 additions and 100 deletions.
1 change: 0 additions & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@
[submodule "3rd-party/googletest"]
path = 3rd-party/googletest
url = [email protected]:google/googletest.git
branch = main
2 changes: 1 addition & 1 deletion 3rd-party/googletest
Submodule googletest updated 73 files
+43 −0 .github/workflows/gtest-ci.yml
+0 −1 .gitignore
+0 −17 BUILD.bazel
+1 −10 CMakeLists.txt
+4 −4 CONTRIBUTING.md
+0 −1 CONTRIBUTORS
+0 −61 MODULE.bazel
+8 −8 README.md
+12 −14 WORKSPACE
+0 −35 WORKSPACE.bzlmod
+4 −6 ci/linux-presubmit.sh
+1 −2 ci/macos-presubmit.sh
+2 −2 ci/windows-presubmit.bat
+12 −22 docs/advanced.md
+39 −18 docs/faq.md
+0 −6 docs/gmock_cook_book.md
+2 −4 docs/gmock_for_dummies.md
+20 −19 docs/primer.md
+1 −1 docs/reference/assertions.md
+2 −1 docs/reference/mocking.md
+4 −8 docs/reference/testing.md
+0 −33 fake_fuchsia_sdk.bzl
+13 −14 googlemock/CMakeLists.txt
+3 −3 googlemock/README.md
+10 −34 googlemock/include/gmock/gmock-actions.h
+4 −5 googlemock/include/gmock/gmock-function-mocker.h
+87 −88 googlemock/include/gmock/gmock-matchers.h
+3 −4 googlemock/include/gmock/gmock-more-actions.h
+7 −8 googlemock/include/gmock/gmock.h
+6 −8 googlemock/include/gmock/internal/gmock-internal-utils.h
+4 −4 googlemock/include/gmock/internal/gmock-port.h
+2 −3 googlemock/src/gmock-internal-utils.cc
+1 −1 googlemock/src/gmock-matchers.cc
+1 −2 googlemock/src/gmock-spec-builders.cc
+0 −9 googlemock/test/gmock-matchers-comparisons_test.cc
+1 −39 googlemock/test/gmock-more-actions_test.cc
+1 −1 googlemock/test/gmock-spec-builders_test.cc
+0 −9 googlemock/test/gmock_link_test.h
+14 −14 googletest/CMakeLists.txt
+2 −2 googletest/README.md
+0 −4 googletest/cmake/Config.cmake.in
+20 −22 googletest/cmake/internal_utils.cmake
+1 −1 googletest/include/gtest/gtest-assertion-result.h
+10 −9 googletest/include/gtest/gtest-message.h
+3 −3 googletest/include/gtest/gtest-param-test.h
+27 −63 googletest/include/gtest/gtest-printers.h
+61 −65 googletest/include/gtest/gtest-typed-test.h
+19 −36 googletest/include/gtest/gtest.h
+4 −2 googletest/include/gtest/internal/gtest-death-test-internal.h
+1 −7 googletest/include/gtest/internal/gtest-filepath.h
+69 −30 googletest/include/gtest/internal/gtest-internal.h
+75 −79 googletest/include/gtest/internal/gtest-param-util.h
+0 −2 googletest/include/gtest/internal/gtest-port-arch.h
+41 −95 googletest/include/gtest/internal/gtest-port.h
+17 −19 googletest/src/gtest-death-test.cc
+1 −1 googletest/src/gtest-filepath.cc
+17 −29 googletest/src/gtest-internal-inl.h
+28 −70 googletest/src/gtest-port.cc
+112 −163 googletest/src/gtest.cc
+0 −1 googletest/test/googletest-color-test.py
+37 −39 googletest/test/googletest-death-test-test.cc
+0 −15 googletest/test/googletest-json-output-unittest.py
+1 −4 googletest/test/googletest-options-test.cc
+5 −0 googletest/test/googletest-output-test-golden-lin.txt
+3 −3 googletest/test/googletest-port-test.cc
+0 −16 googletest/test/googletest-printers-test.cc
+44 −52 googletest/test/gtest_environment_test.cc
+41 −23 googletest/test/gtest_help_test.py
+0 −3 googletest/test/gtest_json_test_utils.py
+3 −1 googletest/test/gtest_repeat_test.cc
+27 −19 googletest/test/gtest_unittest.cc
+6 −9 googletest/test/gtest_xml_output_unittest.py
+8 −14 googletest_deps.bzl
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR})
message(FATAL_ERROR "In-source build are not supported")
endif()

option(WITH_TESTING "Build with CTests" ON)
if(WITH_TESTING)
enable_testing()
endif()

find_package(CUDA REQUIRED)
find_package(Torch REQUIRED)

Expand Down
12 changes: 11 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
EXAMPLE_DIR := examples
TEST_DIR := tests/python
UNIT_TEST ?= test_lstm_cell
CPP_UT ?= test_copy
CPP_UTS := scripts/unittests/run_all_cpp_tests.sh

EXAMPLE ?= $(EXAMPLE_DIR)/scatter_nd.py
UNIT ?= $(TEST_DIR)/$(UNIT_TEST).py

WITH_TEST ?= ON

BUILD_DIR := build
DYNAMIC_LIB := $(BUILD_DIR)/libtiledcuda.so

.PHONY: build example unit_test clean

build:
@mkdir -p build
@cd build && cmake .. && make -j$(proc)
@cd build && cmake -DWITH_TESTING=$(WITH_TEST) .. && make -j$(proc)

$(DYNAMIC_LIB): build

Expand All @@ -22,5 +26,11 @@ example: $(DYNAMIC_LIB)
unit_test: $(DYNAMIC_LIB)
@python3 $(UNIT)

unit_test_cpp: $(DYNAMIC_LIB)
@cd $(BUILD_DIR) && ctest -R $(CPP_UT) -V

unit_test_cpps: $(DYNAMIC_LIB)
@sh $(CPP_UTS)

clean:
@rm -rf build
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

## Introduction

**TiledCUDA** is an efficient kernel template library written in **CuTe**, which provides a wrapper for cutlass CuTe and enables more efficient fusion.
**TiledCUDA** is a kernel template library that is designed to be highly efficient. It provides a wrapper for cutlass **CuTe** to simplifly the process of implementing complex fused kernels that utilize tensor core GEMM.

TiledCUDA uses **PyTorch** as the runtime and leverages the **Tensor** class of PyTorch for convenient testing.
TiledCUDA utilizes **PyTorch** as its runtime environment and leverages the **Tensor** class of PyTorch for convenient testing.

## Quick Start

Expand All @@ -17,12 +17,14 @@ cd TiledCUDA && git submodule update --init --recursive

### Unit Test

- **Run single unit test**: `make unit_test UNIT_TEST=test_scatter_nd.py`
- **Run a single unit test**: `make unit_test UNIT_TEST=test_scatter_nd.py`
- **Run all unit tests**: `./scripts/unittests/python.sh`
- **Run a single cpp unit test**: `make unit_test_cpp CPP_UT=test_copy`
- **Run all cpp unit tests**: `make unit_test_cpps`

## Features

- Implemented `__device__` function wrapper that enables **static/dynamic** copying between different memory hierarchy.
- Implemented `__device__` function wrapper for CUDA **micro kernels**, such as `copy_async` and tensor core operations.
- Implemented template wrapper for **CuTe** to make it easier to use.
- Implemented template wrapper for **CuTe** to simplify its usage.
- Implemented fused kernels such as **GEMM**, **Back2Back GEMM**, **Batched GEMM**, **Lstm Cell**, etc.
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[TBD]
2 changes: 1 addition & 1 deletion include/cell/copy/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ __forceinline__ __device__ auto make_s2rA(const Element* data, int tid,
// partition register
auto thr_mma = tiled_mma.get_thread_slice(tid);
auto dst = thr_mma.partition_fragment_A(tensor);
auto dst_view = thrd_copy.retile_S(dst);
auto dst_view = thrd_copy.retile_D(dst);

Shm2RegLoad loader(tiled_copy, src, dst, dst_view);
return loader;
Expand Down
28 changes: 14 additions & 14 deletions include/cell/copy/dyn_copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@

#include "cuda_utils.hpp"

#include <cute/tensor.hpp>

namespace tiledcuda::cell::copy {

// Copy a tensor from global memory to shared memory
using namespace cute;

// Copy a 2d data tile from global memory to shared memory
template <typename Element, typename SrcLayout, typename DstLayout,
typename TiledCopy>
__forceinline__ __device__ void copy_tensor_g2s(const Element* src_data,
Element* dst_data,
SrcLayout src_layout,
DstLayout dst_layout,
TiledCopy tiled_copy, int tid) {
DEVICE void copy_2d_tile_g2s(const Element* src_data, Element* dst_data,
SrcLayout src_layout, DstLayout dst_layout,
TiledCopy tiled_copy, int tid) {
auto gtile = make_tensor(make_gmem_ptr(src_data), src_layout);
auto stile = make_tensor(make_smem_ptr(dst_data), dst_layout);

Expand All @@ -30,11 +32,9 @@ __forceinline__ __device__ void copy_tensor_g2s(const Element* src_data,
// Copy a tensor from shared memory to global memory
template <typename Element, typename SrcLayout, typename DstLayout,
typename TiledCopy>
__forceinline__ __device__ void copy_tensor_s2g(const Element* src_data,
Element* dst_data,
SrcLayout src_layout,
DstLayout dst_layout,
TiledCopy tiled_copy, int tid) {
DEVICE void copy_2d_tile_s2g(const Element* src_data, Element* dst_data,
SrcLayout src_layout, DstLayout dst_layout,
TiledCopy tiled_copy, int tid) {
auto stile = make_tensor(make_smem_ptr(src_data), src_layout);
auto gtile = make_tensor(make_gmem_ptr(dst_data), dst_layout);

Expand All @@ -50,7 +50,7 @@ __forceinline__ __device__ void copy_tensor_s2g(const Element* src_data,
cute::copy(tiled_copy, src(_, i, j), dst(_, i, j));
}

__forceinline__ __device__ void copy_tensor_s2r() {}
DEVICE void copy_2d_tile_s2r() {}

__forceinline__ __device__ void copy_tensor_r2s() {}
} // namespace tiledcuda::cell::copy
DEVICE void copy_2d_tile_r2s() {}
} // namespace tiledcuda::cell::copy
6 changes: 3 additions & 3 deletions include/cell/copy/static_copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ struct R2SCopy2D {
template <typename Engine, typename Layout>
__forceinline__ __device__ void copy(
cute::Tensor<Engine, Layout> const& acc, Element* dst_data, int tid) {
// FIXME(ying): This implementation is specifically designed
// for TCU WMMA and assumes that the ACC value has a
// FIXME(haruhi): This implementation is specifically designed
// for tcu WMMA and assumes that the ACC value has a
// floating-point precision. The code converts the ACC value
// to half-precision.
auto src_tensor = convert_type<Element>(acc);
Expand Down Expand Up @@ -49,4 +49,4 @@ struct R2SCopy2D {
}
};

} // namespace tiledcuda::cell::copy
} // namespace tiledcuda::cell::copy
2 changes: 1 addition & 1 deletion include/cell/traits/b2b_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct DynBack2BackGemmTraits : public Base {
CopyInst{}, ThreadLayout{},
Layout<Shape<_1, Int<Base::kNumPerAccess>>>{}));

// TODO(ying): The current implementation uses ldmatrix.x4
// TODO(haruhi): The current implementation uses ldmatrix.x4
// instruction which requires the TileMMA configuration to be
// fixed as follows. Make it able to be tuned by policy in
// future implementation.
Expand Down
8 changes: 1 addition & 7 deletions include/cell/traits/base.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
#pragma once

#include <cutlass/numeric_size.h>

namespace tiledcuda::cell::traits {

// FIXME(haruhi). The swizzle function requires a data tile with a minimal
// shape of <8, 32> for the <2, 3, 3> case, and a minimal shape of <8, 64> for
// the <3, 3, 3> case. Here requires some check to ensure that the data tile
// meets these requirements before using this function.
template <const int N>
static constexpr int kSwizzle = (N == 32 ? 2 : 3);

template <typename Element>
struct TraitsBase {
static constexpr int kAccessInBits = 128; // 128 bits
Expand Down
92 changes: 92 additions & 0 deletions include/cell/traits/copy.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#pragma once

#include "cell/traits/base.hpp"
#include "layout.hpp"

namespace tiledcuda::cell::traits {

/// @brief Configurations for transfering a single 2D data tile from global
/// memory to shared memory, which include configurating the layout of data tile
/// and thread tile.
/// @tparam Element_: the element type
/// @tparam kThreads: number of threads in a thread block
template <typename Element_, const int kRows_, const int kCols_,
const int kShmRows_, const int kShmCols_, const int kThreads,
typename Base = TraitsBase<Element_>>
struct G2S2DCopyTraits : public Base {
using Element = Element_;

static constexpr int kRows = kRows_;
static constexpr int kCols = kCols_;

static constexpr int kShmRows = kShmRows_;
static constexpr int kShmCols = kShmCols_;

using SrcLayout = RowMajor<kRows, kCols, kCols>;

// To avoid bank conflicts, the shared memory requires a swizzled layout
static constexpr int kSwizzleMode = kShmCols % 32 ? 1 : 0;
using Swizzled =
SwizzledRowMajor<Element, kShmRows, kShmCols, kSwizzleMode>;
using DstLayout = typename Swizzled::SmemLayout;

// threads in a thread block are laid out as a 2D tile
// that has a shape of kThreadsRows x kThreadsCols.
static constexpr int kThreadsCols = kShmCols / Base::kNumPerAccess;
static constexpr int kThreadsRows = kThreads / kThreadsCols;
using ThreadLayout = RowMajor<kThreadsRows, kThreadsCols, kThreadsCols>;

using ValueLayout = Layout<Shape<_1, Int<Base::kNumPerAccess>>>;

#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
using CopyInst =
Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, Element>;
#else
using CopyInst = Copy_Atom<DefaultCopy, Element>;
#endif

using TiledCopy =
decltype(make_tiled_copy(CopyInst{}, ThreadLayout{}, ValueLayout{}));
};

/// @brief Configurations for transfering a single 2D data tile from shared
/// memory to global memory, which include configurating the layout of data tile
/// and thread tile.
/// @tparam Element_: the element type
/// @tparam kThreads: number of threads in a thread block
template <typename Element_, const int kRows_, const int kCols_,
const int kShmRows_, const int kShmCols_, const int kThreads,
typename Base = TraitsBase<Element_>>
struct S2G2DCopyTraits : public Base {
using Element = Element_;

static constexpr int kRows = kRows_;
static constexpr int kCols = kCols_;

static constexpr int kShmRows = kShmRows_;
static constexpr int kShmCols = kShmCols_;

static constexpr int kSwizzleMode = kShmCols % 32 ? 1 : 0;
using Swizzled =
SwizzledRowMajor<Element, kShmRows, kShmCols, kSwizzleMode>;
using SrcLayout = typename Swizzled::SmemLayout;

// To avoid bank conflicts, the shared memory requires a swizzled layout
using DstLayout = RowMajor<kRows, kCols, kCols>;

// threads in a thread block are laid out as a 2D tile
// that has a shape of kThreadsRows x kThreadsCols.
static constexpr int kThreadsCols = kShmCols / Base::kNumPerAccess;
static constexpr int kThreadsRows = kThreads / kThreadsCols;
using ThreadLayout = RowMajor<kThreadsRows, kThreadsCols, kThreadsCols>;

using ValueLayout = Layout<Shape<_1, Int<Base::kNumPerAccess>>>;

// transfer data from global memory to shared memory has cp.async,
// while transfer data from shared memory to global memory does not have.
// for the latter case, the copy instruction should be the default one.
using TiledCopy = decltype(make_tiled_copy(
Copy_Atom<DefaultCopy, Element>{}, ThreadLayout{}, ValueLayout{}));
};

} // namespace tiledcuda::cell::traits
1 change: 1 addition & 0 deletions include/cell/traits/mod.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#pragma once

#include "cell/traits/b2b_gemm.hpp"
#include "cell/traits/base.hpp"
#include "cell/traits/batched_gemm.hpp"
Expand Down
15 changes: 15 additions & 0 deletions include/config.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#if defined(__CUDA_ARCH__)
#define HOST_DEVICE __forceinline__ __host__ __device__
#define DEVICE __forceinline__ __device__
#define HOST __forceinline__ __host__
#else
#define HOST_DEVICE inline
#define DEVICE inline
#define HOST inline
#endif

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
#define CUTE_ARCH_CP_ASYNC_SM80_ENABLED
#endif
2 changes: 2 additions & 0 deletions include/cuda_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "config.hpp"

#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
Expand Down
57 changes: 51 additions & 6 deletions include/layout.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include "config.hpp"

#include <cute/layout.hpp>

using namespace cute;
Expand All @@ -17,15 +19,58 @@ template <const int row, const int col, const int stride = row>
using ColMajor =
cute::Layout<Shape<Int<row>, Int<col>>, Stride<_1, Int<stride>>>;

__forceinline__ __device__ auto make_row_major_layout(const int row,
const int col,
const int stride) {
HOST_DEVICE auto make_row_major_layout(const int row, const int col,
const int stride) {
return cute::make_layout(make_shape(row, col), make_stride(stride, 1));
}

__forceinline__ __device__ auto make_col_major_layout(const int row,
const int col,
const int stride) {
HOST_DEVICE auto make_col_major_layout(const int row, const int col,
const int stride) {
return cute::make_layout(make_shape(row, col), make_stride(1, stride));
}

// CuTe's swizzle functions, swizzle(B, M, S), permute elements in a
// 2D coordinate space. This 2D coordinate space has 2^B rows and 2^S columns,
// and each coordinate position has 2^M elements. Therefore, to apply a swizzle
// function to a 2D data tile, the data tile should have a shape that is a
// multiple of 2^B x 2^S x 2^M.
/// @tparam Element: element type
/// @tparam kRows: number of rows
/// @tparam kCols: number of columns
/// @tparam kSwizzleMode: The value should be either 0 or 1, indicating whether
/// the size of the contiguous dimension is divisible by 32 or not.
template <typename Element, const int kRows, const int kCols,
const int kSwizzleMode>
struct SwizzledRowMajor {};

// FIXME(haruhi): This implementation is very inflexible and is almost
// equivalent to a hard-coded swizzle function for 2D data tiles that have a
// shape that is a multiple of 8x32.
// Improve the implementation to make it more general.
template <const int kRows, const int kCols>
struct SwizzledRowMajor<cutlass::half_t, kRows, kCols, 0> {
static_assert(kRows % 8 == 0,
"The number of rows must be a multiple of 8.");

using SmemLayoutAtom = decltype(composition(
Swizzle<2, 3, 3>{}, cute::Layout<Shape<_8, _32>, Stride<_32, _1>>{}));
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{},
Shape<Int<kRows>, Int<kCols>>{}));
};

// FIXME(haruhi): This implementation is very inflexible and is almost
// equivalent to a hard-coded swizzle function for 2D data tiles that have a
// shape that is a multiple of 8x64.
// Improve the implementation to make it more general.
template <const int kRows, const int kCols>
struct SwizzledRowMajor<cutlass::half_t, kRows, kCols, 1> {
static_assert(kRows % 8 == 0,
"The number of rows must be a multiple of 8.");

using SmemLayoutAtom = decltype(composition(
Swizzle<3, 3, 3>{}, cute::Layout<Shape<_8, _64>, Stride<_64, _1>>{}));
using SmemLayout = decltype(tile_to_shape(SmemLayoutAtom{},
Shape<Int<kRows>, Int<kCols>>{}));
};

} // namespace tiledcuda
Loading

0 comments on commit 8205e7c

Please sign in to comment.