diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu index dc87fff3..cede1373 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu @@ -57,11 +57,11 @@ efficiently. // The code section below describes datatype for input, output matrices and computation between // elements in input matrices. -using ElementAccumulator = int32_t; // <- data type of accumulator +using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A -using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B -using ElementOutput = int32_t; // <- data type of elements in output matrix D +using ElementInputA = float; // <- data type of elements in input matrix A +using ElementInputB = float; // <- data type of elements in input matrix B +using ElementOutput = float; // <- data type of elements in output matrix D // The code section below describes matrix layout of input and output matrices. Row Major for // Matrix A, Column Major for Matrix B and Row Major for Matrix C @@ -77,11 +77,11 @@ using SmArch = cutlass::arch::Sm80; // This code section describes the tile size a thread block will compute using ShapeMMAThreadBlock = - cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256 + cutlass::gemm::GemmShape<128, 64, 32>; // <- threadblock tile M = 128, N = 128, K = 256 // This code section describes tile size a warp will compute -using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256 +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 256 // This code section describes the size of MMA op -using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 16, N = 8, K = 128 // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? @@ -130,9 +130,9 @@ constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; int run() { - const int length_m = 512; - const int length_n = 512; - const int length_k = 1024; + const int length_m = 32; + const int length_n = 192; + const int length_k = 512; // Create a tuple of problem size for matrix multiplication cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); @@ -174,12 +174,8 @@ int run() { ElementInputB(2), ElementInputB(-2), 0); // <- Fill matrix B on host with uniform-distribution random data - cutlass::reference::host::TensorFillRandomUniform( - tensor_c.host_view(), - 1, - ElementOutput(2), - ElementOutput(-2), - 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFill( + tensor_c.host_view()); // <- Fill matrix C on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomSparseMeta( tensor_e.host_view(), 1,