Skip to content

Commit

Permalink
Added wrapped C cuda code and runable examples (#1)
Browse files Browse the repository at this point in the history
initial
  • Loading branch information
ArrogantGao authored Jul 5, 2023
1 parent 7f35b0b commit 895cfa5
Show file tree
Hide file tree
Showing 13 changed files with 1,508 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*.jl.cov
*.jl.mem
/Manifest.toml
lib
.vscode
6 changes: 6 additions & 0 deletions Artifacts.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[CUDA_lib]
git-tree-sha1 = "2918fba865582556e219191a7f393c47c2e822e0"

[[CUDA_lib.download]]
sha256 = "751bf9d1f2921d4176ffb8ed1ddbd59bb60d6a517e6784bb71d61b62357c0007"
url = "https://gist.github.com/ArrogantGao/c38791f143d36d4b2481ac7e4aa4ecce/raw/2918fba865582556e219191a7f393c47c2e822e0.tar.gz"
695 changes: 674 additions & 21 deletions LICENSE

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ uuid = "c2b282c3-c9c2-431d-80f7-a1a0561ebe55"
authors = ["Xuanzhao Gao <[email protected]> and contributors"]
version = "1.0.0-DEV"

[deps]
ArtifactUtils = "8b73e784-e7d8-4ea5-973d-377fed4e3bce"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"

[compat]
julia = "1"

Expand Down
20 changes: 20 additions & 0 deletions benchmark/benchmark_CUDA_mapreduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using TropicalNumbers, CUDA, BenchmarkTools

function map_reduce_benchmark(m::T, n::T, k::T) where{T}
A = Tropical.(CUDA.randn(Float32, (m, k)))
B = Tropical.(CUDA.randn(Float32, (k, n)))
C = Tropical.(CUDA.randn(Float32, (k, n)))

elapsed_time = @belapsed CUDA.@sync begin
$C = $A * $B
end

work_load = 2 * m * n * k
flops = work_load / elapsed_time / 1e9
@show m, n, k, elapsed_time, flops
return nothing
end

map_reduce_benchmark(2560, 2048, 2048)
map_reduce_benchmark(2 * 2560, 2 * 2048, 2 * 2048)
map_reduce_benchmark(4 * 2560, 4 * 2048, 4 * 2048)
29 changes: 29 additions & 0 deletions benchmark/benchmark_CuTropicalGemm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
using CUDA
using BenchmarkTools
using CuTropicalGEMM

function benchmakr_CuTropicalGemmFP32(m::T, n::T, k::T) where{T}
A = rand(Float32, (m * k))
B = rand(Float32, (k * n))
C = rand(Float32, (m * n))

CuA = CuArray(A)
CuB = CuArray(B)
CuC = CuArray(C)

# I found hat @belapsed and CUDA.@sync can not properly benchmark the function from .so lib
elapsed_time = @belapsed CUDA.@sync begin
1 + 1
CuTropicalGemmMatmulFP32!($m, $n, $k, $CuA, $CuB, $CuC)
1 + 1
end

work_load = 2 * m * n * k
flops = work_load / elapsed_time / 1e9
@show m, n, k, elapsed_time, flops
return nothing
end

benchmakr_CuTropicalGemmFP32(2560, 2048, 2048)
benchmakr_CuTropicalGemmFP32(2 * 2560, 2 * 2048, 2 * 2048)
benchmakr_CuTropicalGemmFP32(4 * 2560, 4 * 2048, 4 * 2048)
6 changes: 6 additions & 0 deletions benchmark/benchmark_CuTropicalGemm_C.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#! /bin/bash

nvcc -arch=sm_80 ../src/TropicalSGemmFP32.cu
./a.out

rm a.out
83 changes: 83 additions & 0 deletions benchmark/benchmark_GemmKerenls_Tropical.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# notice: this code is used to benchmark the Tropical matmul in GemmKernels.jl, which is not yet released in the latest version and only supported in [email protected]
# to run the code, you need to manually download the latest version repo of GemmKernels.jl and activate the enviroment

import CUDA
import InteractiveUtils

using CUDA
using GemmKernels
using LinearAlgebra
using BenchmarkTools
using Test

CUDA.allowscalar(false)

function try_tropical(M, N, K)
for (A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128)],
transpose_a = [true, false],
transpose_b = [true, false],
(OP_M, OP_N, OP_K) in [(8, 16, 2)]

a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
c_h = rand(CD_type, (M, N))
d_h = similar(c_h)


# Transpose input if necessary
a_h = transpose_a ? transpose(a_h) : a_h
b_h = transpose_b ? transpose(b_h) : b_h

a = CuArray(a_h)
b = CuArray(b_h)
c = CuArray(c_h)
d = similar(c)

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
block_shape = (M = 64, N = 64, K = 32),
operator = Operator.TropicalFPUOp{OP_M, OP_N, OP_K, CD_type, A_type},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type},

global_c_layout = Layout.AlignedColMajor{CD_type},
global_d_layout = Layout.AlignedColMajor{CD_type},

is_a_col_major = !transpose_a,
is_b_col_major = !transpose_b,
)

n_iter = 1
elapsed_time = @belapsed CUDA.@sync begin
GemmKernels.matmul($a, $b, $c, $d, $conf; kernel = Kernel.matmul_pipelined)
end
TFlops = (n_iter * M * N * K * 2 / elapsed_time) / 1e9
@show TFlops, elapsed_time, transpose_a, transpose_b, M, N, K


d_c = Array(d)

# random 1600 points took to test
if transpose_a == transpose_b == false
@testset begin
for _ in 1 : 40
for _ in 1 : 40
i = rand(1:M)
j = rand(1:N)
d_h[i, j] = c_h[i, j]
for k in 1 : K
d_h[i, j] = max(a_h[i, k] + b_h[k, j], d_h[i, j])
end
@test isapprox(d_h[i, j], d_c[i, j]; rtol = sqrt(eps(A_type)))
end
end
end
end
end
return nothing
end


try_tropical(2560, 2048, 2048)
try_tropical(2 * 2560, 2 * 2048, 2 * 2048)
try_tropical(4 * 2560, 4 * 2048, 4 * 2048)
7 changes: 6 additions & 1 deletion src/CuTropicalGEMM.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
module CuTropicalGEMM

# Write your package code here.
export CuTropicalGemmMatmulFP32!

using CUDA
using Artifacts

include("TropicalGemm_Cuda_wrapper.jl")

end
Loading

0 comments on commit 895cfa5

Please sign in to comment.