Skip to content

Commit

Permalink
remove cuda sync from cinn_call_cuda_kernel (PaddlePaddle#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Nov 24, 2020
1 parent 828d4e2 commit 98d2ff3
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 5 deletions.
3 changes: 3 additions & 0 deletions cinn/backends/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ endif()
cc_library(__x86_source_fake_lib SRCS _x86_builtin_source.cc)
add_subdirectory(llvm)

if (WITH_CUDA)
nv_test(test_raw_cuda_code SRCS raw_cuda_code_test.cu DEPS core)
endif()

cc_test(test_codegen_c SRCS codegen_c_test.cc DEPS core ARGS ${global_test_args})
cc_test(test_codegen_c_x86 SRCS codegen_c_x86_test.cc DEPS core ARGS ${global_test_args})
Expand Down
29 changes: 26 additions & 3 deletions cinn/backends/compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
#include "cinn/hlir/pe/elementwise.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/runtime/use_extern_funcs.h"
#include "cinn/utils/timer.h"

namespace cinn {
namespace backends {

TEST(Compiler, x86) {
Expr M(10), N(20);
Expr M(1024), N(1024);

auto create_module = [&]() {
Placeholder<float> A("A", {M, N});
Expand Down Expand Up @@ -56,8 +57,21 @@ TEST(Compiler, x86) {
ASSERT_NEAR(Ad[i] + Bd[i], Cd[i], 1e-5);
}
}
}

#ifdef CINN_WITH_CUDA
TEST(Compiler, cuda) {
Expr M(1024), N(1024);

auto create_module = [&]() {
Placeholder<float> A("A", {M, N});
Placeholder<float> B("B", {M, N});

auto C = Compute(
{M, N}, [=](Expr i, Expr j) { return A(i, j) + B(i, j); }, "C");
return std::make_tuple(A, B, C);
};

{ // cuda
auto [A, B, C] = create_module(); // NOLINT
auto stages = CreateStages({C});
Expand Down Expand Up @@ -99,7 +113,16 @@ TEST(Compiler, x86) {
Cbb.memory = reinterpret_cast<uint8_t*>(Cg);

auto args = common::ArgsBuilder().Add(&Abb).Add(&Bbb).Add(&Cbb).Build();
fnp(args.data(), args.size());

utils::Timer timer;
timer.Start();
for (int i = 0; i < 1000; i++) {
fnp(args.data(), args.size());
}

CUDA_CALL(cudaDeviceSynchronize());
float latency = timer.Stop();
LOG(INFO) << "latency: " << latency / 1000;

std::vector<float> ch(M.as_int32() * N.as_int32(), 0.f);
CUDA_CALL(cudaMemcpy(ch.data(), Cg, ch.size() * sizeof(float), cudaMemcpyDeviceToHost));
Expand All @@ -110,8 +133,8 @@ TEST(Compiler, x86) {
ASSERT_NEAR(Ad[i] + Bd[i], ch[i], 1e-5);
}
}
#endif
}
#endif

TEST(Compiler, sqrt) {
Expr N(100);
Expand Down
39 changes: 39 additions & 0 deletions cinn/backends/raw_cuda_code_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "cinn/backends/cuda_util.h"
#include "cinn/utils/timer.h"

__global__ void elementwise_add_kernel(const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C) {
if ((blockIdx.x < 1024)) {
{
if ((threadIdx.x < 1024)) {
{
C[((1024 * blockIdx.x) + threadIdx.x)] =
(A[((1024 * blockIdx.x) + threadIdx.x)] + B[((1024 * blockIdx.x) + threadIdx.x)]);
}
}
}
}
}

TEST(raw_cuda, basic) {
const int M = 1024;
const int N = 1024;
// allocate CUDA buffer
float *Ag, *Bg, *Cg;
const int num_bytes = M * N * sizeof(float);
cudaMalloc(&Ag, num_bytes);
cudaMalloc(&Bg, num_bytes);
cudaMalloc(&Cg, num_bytes);

cinn::utils::Timer timer;
timer.Start();
for (int i = 0; i < 1000; i++) {
elementwise_add_kernel<<<1024, 1024>>>(Ag, Bg, Cg);
}
CUDA_CALL(cudaDeviceSynchronize());
float latency = timer.Stop();
LOG(INFO) << "latency: " << latency / 1000;
}
2 changes: 0 additions & 2 deletions cinn/runtime/cuda/cuda_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ void cinn_call_cuda_kernel(void *kernel_fn,
static_cast<CUstream>(stream),
reinterpret_cast<void **>(arr),
nullptr))

CUDA_CALL(cudaDeviceSynchronize());
}

} // namespace cuda
Expand Down

0 comments on commit 98d2ff3

Please sign in to comment.