-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added wrapped C cuda code and runable examples (#1)
initial
- Loading branch information
1 parent
7f35b0b
commit 895cfa5
Showing
13 changed files
with
1,508 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,5 @@ | |
*.jl.cov | ||
*.jl.mem | ||
/Manifest.toml | ||
lib | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
[CUDA_lib] | ||
git-tree-sha1 = "2918fba865582556e219191a7f393c47c2e822e0" | ||
|
||
[[CUDA_lib.download]] | ||
sha256 = "751bf9d1f2921d4176ffb8ed1ddbd59bb60d6a517e6784bb71d61b62357c0007" | ||
url = "https://gist.github.com/ArrogantGao/c38791f143d36d4b2481ac7e4aa4ecce/raw/2918fba865582556e219191a7f393c47c2e822e0.tar.gz" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,12 @@ uuid = "c2b282c3-c9c2-431d-80f7-a1a0561ebe55" | |
authors = ["Xuanzhao Gao <[email protected]> and contributors"] | ||
version = "1.0.0-DEV" | ||
|
||
[deps] | ||
ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce" | ||
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" | ||
|
||
[compat] | ||
julia = "1" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
using TropicalNumbers, CUDA, BenchmarkTools | ||
|
||
function map_reduce_benchmark(m::T, n::T, k::T) where{T} | ||
A = Tropical.(CUDA.randn(Float32, (m, k))) | ||
B = Tropical.(CUDA.randn(Float32, (k, n))) | ||
C = Tropical.(CUDA.randn(Float32, (k, n))) | ||
|
||
elapsed_time = @belapsed CUDA.@sync begin | ||
$C = $A * $B | ||
end | ||
|
||
work_load = 2 * m * n * k | ||
flops = work_load / elapsed_time / 1e9 | ||
@show m, n, k, elapsed_time, flops | ||
return nothing | ||
end | ||
|
||
map_reduce_benchmark(2560, 2048, 2048) | ||
map_reduce_benchmark(2 * 2560, 2 * 2048, 2 * 2048) | ||
map_reduce_benchmark(4 * 2560, 4 * 2048, 4 * 2048) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
using CUDA | ||
using BenchmarkTools | ||
using CuTropicalGEMM | ||
|
||
function benchmakr_CuTropicalGemmFP32(m::T, n::T, k::T) where{T} | ||
A = rand(Float32, (m * k)) | ||
B = rand(Float32, (k * n)) | ||
C = rand(Float32, (m * n)) | ||
|
||
CuA = CuArray(A) | ||
CuB = CuArray(B) | ||
CuC = CuArray(C) | ||
|
||
# I found hat @belapsed and CUDA.@sync can not properly benchmark the function from .so lib | ||
elapsed_time = @belapsed CUDA.@sync begin | ||
1 + 1 | ||
CuTropicalGemmMatmulFP32!($m, $n, $k, $CuA, $CuB, $CuC) | ||
1 + 1 | ||
end | ||
|
||
work_load = 2 * m * n * k | ||
flops = work_load / elapsed_time / 1e9 | ||
@show m, n, k, elapsed_time, flops | ||
return nothing | ||
end | ||
|
||
benchmakr_CuTropicalGemmFP32(2560, 2048, 2048) | ||
benchmakr_CuTropicalGemmFP32(2 * 2560, 2 * 2048, 2 * 2048) | ||
benchmakr_CuTropicalGemmFP32(4 * 2560, 4 * 2048, 4 * 2048) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#! /bin/bash | ||
|
||
nvcc -arch=sm_80 ../src/TropicalSGemmFP32.cu | ||
./a.out | ||
|
||
rm a.out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# notice: this code is used to benchmark the Tropical matmul in GemmKernels.jl, which is not yet released in the latest version and only supported in [email protected] | ||
# to run the code, you need to manually download the latest version repo of GemmKernels.jl and activate the enviroment | ||
|
||
import CUDA | ||
import InteractiveUtils | ||
|
||
using CUDA | ||
using GemmKernels | ||
using LinearAlgebra | ||
using BenchmarkTools | ||
using Test | ||
|
||
CUDA.allowscalar(false) | ||
|
||
function try_tropical(M, N, K) | ||
for (A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128)], | ||
transpose_a = [true, false], | ||
transpose_b = [true, false], | ||
(OP_M, OP_N, OP_K) in [(8, 16, 2)] | ||
|
||
a_h = rand(A_type, (M, K)) / sqrt(A_type(K)) | ||
b_h = rand(B_type, (K, N)) / sqrt(B_type(K)) | ||
c_h = rand(CD_type, (M, N)) | ||
d_h = similar(c_h) | ||
|
||
|
||
# Transpose input if necessary | ||
a_h = transpose_a ? transpose(a_h) : a_h | ||
b_h = transpose_b ? transpose(b_h) : b_h | ||
|
||
a = CuArray(a_h) | ||
b = CuArray(b_h) | ||
c = CuArray(c_h) | ||
d = similar(c) | ||
|
||
conf = GemmKernels.get_config( | ||
gemm_shape = (M = M, N = N, K = K), | ||
block_shape = (M = 64, N = 64, K = 32), | ||
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, CD_type, A_type}, | ||
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type}, | ||
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type}, | ||
|
||
global_c_layout = Layout.AlignedColMajor{CD_type}, | ||
global_d_layout = Layout.AlignedColMajor{CD_type}, | ||
|
||
is_a_col_major = !transpose_a, | ||
is_b_col_major = !transpose_b, | ||
) | ||
|
||
n_iter = 1 | ||
elapsed_time = @belapsed CUDA.@sync begin | ||
GemmKernels.matmul($a, $b, $c, $d, $conf; kernel = Kernel.matmul_pipelined) | ||
end | ||
TFlops = (n_iter * M * N * K * 2 / elapsed_time) / 1e9 | ||
@show TFlops, elapsed_time, transpose_a, transpose_b, M, N, K | ||
|
||
|
||
d_c = Array(d) | ||
|
||
# random 1600 points took to test | ||
if transpose_a == transpose_b == false | ||
@testset begin | ||
for _ in 1 : 40 | ||
for _ in 1 : 40 | ||
i = rand(1:M) | ||
j = rand(1:N) | ||
d_h[i, j] = c_h[i, j] | ||
for k in 1 : K | ||
d_h[i, j] = max(a_h[i, k] + b_h[k, j], d_h[i, j]) | ||
end | ||
@test isapprox(d_h[i, j], d_c[i, j]; rtol = sqrt(eps(A_type))) | ||
end | ||
end | ||
end | ||
end | ||
end | ||
return nothing | ||
end | ||
|
||
|
||
try_tropical(2560, 2048, 2048) | ||
try_tropical(2 * 2560, 2 * 2048, 2 * 2048) | ||
try_tropical(4 * 2560, 4 * 2048, 4 * 2048) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,10 @@ | ||
module CuTropicalGEMM | ||
|
||
# Write your package code here. | ||
export CuTropicalGemmMatmulFP32! | ||
|
||
using CUDA | ||
using Artifacts | ||
|
||
include("TropicalGemm_Cuda_wrapper.jl") | ||
|
||
end |
Oops, something went wrong.