Skip to content

Commit

Permalink
Add Cuda tests for SplitWithMask (pytorch#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-xq authored and Mikhail Zolotukhin committed Feb 7, 2020
1 parent d42abd0 commit 75e9993
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 9 deletions.
63 changes: 57 additions & 6 deletions test/cpp/tensorexpr/test_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,65 @@ void testCudaTestVectorAdd01() {
cudaFree(b_dev);
cudaFree(c_dev);
}
} // namespace jit
} // namespace torch

#else // USE_CUDA
namespace torch {
namespace jit {
void testCudaTestVectorAdd01() { }
static void testCudaTestVectorAdd02_impl(int N, int block_size) {
Buffer a_buf("a", kFloat32, {N});
Buffer b_buf("b", kFloat32, {N});
Tensor c = Compute(
"c",
{
{N, "N"},
},
[&](const Var& n) { return a_buf(n) + b_buf(n); });
Schedule sch({c});
const Var& n = c.arg(0);
Var n_outer;
Var n_inner;
c.SplitWithMask(n, block_size, true, &n_outer, &n_inner);
c.GPUExecConfig({n_outer}, {n_inner});
Stmt stmt = sch.Lower();
CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf);
PaddedBuffer<float> a_v(N);
PaddedBuffer<float> b_v(N);
PaddedBuffer<float> c_v(N);
PaddedBuffer<float> c_ref(N);

for (int i = 0; i < N; i++) {
a_v(i) = i;
b_v(i) = i * 3 + 7;
c_ref(i) = a_v(i) + b_v(i);
}

// TODO: move gpu support into PaddedBuffer
float* a_dev = nullptr;
cudaMalloc(&a_dev, N * sizeof(float));
float* b_dev = nullptr;
cudaMalloc(&b_dev, N * sizeof(float));
float* c_dev = nullptr;
cudaMalloc(&c_dev, N * sizeof(float));
cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice);
cudaDeviceSynchronize();

cuda_cg(c_dev, a_dev, b_dev);

cudaDeviceSynchronize();
cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost);
cudaDeviceSynchronize();

ExpectAllNear(c_v, c_ref, 1e-5);

cudaFree(a_dev);
cudaFree(b_dev);
cudaFree(c_dev);
}

void testCudaTestVectorAdd02() {
testCudaTestVectorAdd02_impl(1024, 128);
testCudaTestVectorAdd02_impl(1030, 128);
}
} // namespace jit
} // namespace torch

#endif
5 changes: 3 additions & 2 deletions test/cpp/tensorexpr/tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ namespace jit {
_(LLVMBroadcastAdd) \
_(LLVMDynamicShapeAdd) \
_(LLVMBindDynamicShapeAdd) \
_(CudaTestVectorAdd01) \
_(Cond01) \
_(ATen_cast_Float) \
_(ATennegInt) \
Expand Down Expand Up @@ -110,7 +109,9 @@ namespace jit {
_(ATenleInt) \
_(ATenltInt)

#define TH_FORALL_TESTS_CUDA(_)
#define TH_FORALL_TESTS_CUDA(_) \
_(CudaTestVectorAdd01) \
_(CudaTestVectorAdd02)

#define DECLARE_TENSOREXPR_TEST(name) void test##name();
TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST)
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/tensorexpr/cuda_codegen.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include "torch/csrc/jit/tensorexpr/cuda_codegen.h"

#define DEBUG_PRINT 0
#define DEBUG_PRINT 1

namespace torch {
namespace jit {
Expand Down

0 comments on commit 75e9993

Please sign in to comment.