forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
27a1b5b
commit fb2115d
Showing
6 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
68 changes: 68 additions & 0 deletions
68
python/perf-kernels/tools/profiler/ck-benchmarks/gemm/CMakeLists.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
cmake_minimum_required(VERSION 3.20) | ||
project(ck-gemm-runner) | ||
|
||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
set(CMAKE_CXX_EXTENSIONS OFF) | ||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) | ||
set(CMAKE_BUILD_TYPE "Release") | ||
|
||
option(USE_STREAM_PIPELINE "use stream pipeline" OFF) | ||
option(USE_INTERWAVE "use interwave scheduling" OFF) | ||
|
||
if(NOT DEFINED HIP_PATH) | ||
if(NOT DEFINED ENV{HIP_PATH}) | ||
set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") | ||
else() | ||
set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") | ||
endif() | ||
endif() | ||
|
||
if(NOT DEFINED ROCM_PATH) | ||
if(NOT DEFINED ENV{ROCM_PATH}) | ||
set(ROCM_PATH "/opt/rocm" CACHE PATH "Path to which HIP has been installed") | ||
else() | ||
set(ROCM_PATH $ENV{ROCM_PATH} CACHE PATH "Path to which HIP has been installed") | ||
endif() | ||
endif() | ||
|
||
if(NOT DEFINED CK_PATH) | ||
if(NOT DEFINED ENV{CK_PATH}) | ||
set(CK_PATH "/opt/rocm" CACHE PATH "Path to which CK has been installed") | ||
else() | ||
set(CK_PATH $ENV{CK_PATH} CACHE PATH "Path to which CK has been installed") | ||
endif() | ||
endif() | ||
|
||
set(CMAKE_PREFIX_PATH "${CK_PATH}" "${HIP_PATH}" "${ROCM_PATH}") | ||
|
||
include(FetchContent) | ||
FetchContent_Declare( | ||
cli11 | ||
GIT_REPOSITORY https://github.com/CLIUtils/CLI11 | ||
GIT_TAG v2.2.0 | ||
) | ||
FetchContent_MakeAvailable(cli11) | ||
|
||
find_package(hip REQUIRED) | ||
message(STATUS "Found HIP executable at: ${HIP_BIN_INSTALL_DIR}") | ||
|
||
|
||
find_package(composable_kernel 1.1.0 COMPONENTS device_gemm_operations CONFIG PATHS ${CK_PATH}) | ||
if (NOT composable_kernel_FOUND) | ||
message(FATAL_ERROR "faild to find composable kernels") | ||
else() | ||
message(STATUS "CK PATH: ${CK_PATH}") | ||
endif() | ||
|
||
set(LIBS composable_kernel::device_gemm_operations hip::device) | ||
|
||
add_library(kernel SHARED ${PROJECT_SOURCE_DIR}/kernel.cpp) | ||
target_link_libraries(kernel PUBLIC ${LIBS}) | ||
target_compile_options(kernel PUBLIC --save-temps) | ||
|
||
find_package(Threads REQUIRED) | ||
|
||
add_executable(${CMAKE_PROJECT_NAME} ${PROJECT_SOURCE_DIR}/main.cpp) | ||
target_link_libraries(${CMAKE_PROJECT_NAME} PUBLIC kernel CLI11::CLI11 Threads::Threads) | ||
target_include_directories(${CMAKE_PROJECT_NAME} PUBLIC ${PROJECT_SOURCE_DIR}) |
13 changes: 13 additions & 0 deletions
13
python/perf-kernels/tools/profiler/ck-benchmarks/gemm/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
## Install | ||
|
||
```bash | ||
mkdir build && cd build | ||
CK_PATH=$(realpath <CK install directory>) CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. | ||
make VERBOSE=1 -j4 | ||
``` | ||
|
||
## Example | ||
|
||
```bash | ||
./ck-gemm-runner -m 4864 -n 2048 -k 4160 | ||
``` |
61 changes: 61 additions & 0 deletions
61
python/perf-kernels/tools/profiler/ck-benchmarks/gemm/common.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#pragma once | ||
|
||
#include "ck/stream_config.hpp" | ||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" | ||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" | ||
#include "testcase.hpp" | ||
|
||
using ADataType = ck::half_t; | ||
using BDataType = ck::half_t; | ||
using CDataType = ck::half_t; | ||
|
||
using F16 = ck::half_t; | ||
using F32 = float; | ||
|
||
using AElementOp = ck::tensor_operation::element_wise::PassThrough; | ||
using BElementOp = ck::tensor_operation::element_wise::PassThrough; | ||
using CElementOp = ck::tensor_operation::element_wise::PassThrough; | ||
|
||
template <typename DeviceGemmInstance> struct Driver { | ||
static void launchKernel(real *matA, real *matB, real *matC, | ||
const TestCase::Config &config) { | ||
|
||
auto gemm = DeviceGemmInstance{}; | ||
std::cout << gemm.GetTypeString() << std::endl; | ||
std::cout << std::string(80, '-') << std::endl; | ||
|
||
auto invoker = gemm.MakeInvoker(); | ||
double aveTime = 0.0f; | ||
|
||
size_t strideA = config.transA ? config.m : config.k; | ||
size_t strideB = config.transB ? config.k : config.n; | ||
size_t strideC = config.n; | ||
|
||
auto argument = gemm.MakeArgument( | ||
static_cast<ADataType *>(matA), static_cast<BDataType *>(matB), | ||
static_cast<CDataType *>(matC), config.m, config.n, config.k, strideA, | ||
strideB, strideC, config.kbatch, AElementOp{}, BElementOp{}, | ||
CElementOp{}); | ||
|
||
if (!gemm.IsSupportedArgument(argument)) { | ||
std::cerr << gemm.GetTypeString() << " does not support this problem" | ||
<< std::endl; | ||
return; | ||
} | ||
|
||
StreamConfig streamConfig{ | ||
/*stream_id_=*/nullptr, | ||
/*time_kernel_=*/true, config.logLevel, | ||
config.coldNumIters, config.numRepeat, | ||
config.flushCache, config.rotatingCount, | ||
}; | ||
|
||
// time in milli seconds | ||
aveTime = invoker.Run(argument, streamConfig); | ||
|
||
double flops = 2.0 * (config.m * config.n * config.k); | ||
flops /= (aveTime * 1e9); | ||
|
||
std::cout << "TFLOP/s: " << flops << "\n"; | ||
} | ||
}; |
23 changes: 23 additions & 0 deletions
23
python/perf-kernels/tools/profiler/ck-benchmarks/gemm/kernel.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#include "common.hpp" | ||
#include "testcase.hpp" | ||
|
||
// Insert your GEMM kernel here: using DeviceGemmInstance = ... ; | ||
using DeviceGemmInstance = | ||
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< | ||
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::ColumnMajor, | ||
ck::tensor_layout::gemm::RowMajor, _Float16, _Float16, _Float16, float, | ||
_Float16, ck::tensor_operation::element_wise::PassThrough, | ||
ck::tensor_operation::element_wise::PassThrough, | ||
ck::tensor_operation::element_wise::PassThrough, | ||
ck::tensor_operation::device::GemmSpecialization::Default, 256, 128, | ||
128, 64, 8, 8, 32, 32, 2, 2, ck::Sequence<8, 32, 1>, | ||
ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, 8, 8, false, | ||
ck::Sequence<8, 32, 1>, ck::Sequence<1, 0, 2>, ck::Sequence<1, 0, 2>, 2, | ||
8, 8, false, 1, 1, ck::Sequence<1, 32, 1, 8>, 8, | ||
ck::BlockGemmPipelineScheduler::Interwave, | ||
ck::BlockGemmPipelineVersion::v1>; | ||
|
||
void TestCase::launchKernel(real *matA, real *matB, real *matC, | ||
const TestCase::Config &config) { | ||
Driver<DeviceGemmInstance>::launchKernel(matA, matB, matC, config); | ||
} |
97 changes: 97 additions & 0 deletions
97
python/perf-kernels/tools/profiler/ck-benchmarks/gemm/main.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#include "CLI/CLI.hpp" | ||
#include "testcase.hpp" | ||
#include <iostream> | ||
#include <random> | ||
|
||
#define checkHIPErrors(err) __checkHIPErrors(err, __FILE__, __LINE__) | ||
void __checkHIPErrors(hipError_t err, const char *file, const int line) { | ||
if (hipSuccess != err) { | ||
const char *errorStr = hipGetErrorString(err); | ||
|
||
std::cout << "checkHIPErrors() Driver API error = " << err << "\"" | ||
<< errorStr << "\"" | ||
<< " from file <" << file << "> line " << line << std::endl; | ||
throw std::runtime_error("failed to process a hip command"); | ||
} | ||
} | ||
|
||
void init(std::vector<real> &mat, size_t dim0, size_t dim1) { | ||
std::random_device randomeDev; | ||
std::default_random_engine randomeEngine(randomeDev()); | ||
std::uniform_real_distribution<float> uniformDist(-5.0, 5.0); | ||
|
||
std::array<float, 256> randomNumbers; | ||
for (size_t i = 0; i < randomNumbers.size(); ++i) { | ||
randomNumbers[i] = static_cast<real>(uniformDist(randomeDev)); | ||
} | ||
static size_t startIndex = 0; | ||
startIndex += 4; | ||
startIndex = startIndex > randomNumbers.size() ? 0 : startIndex; | ||
|
||
#pragma omp paralle for collapse(2) | ||
for (size_t j = 0; j < dim0; ++j) { | ||
for (size_t i = 0; i < dim1; ++i) { | ||
const size_t index = j * dim1 + i; | ||
const size_t randomNumberIndex = | ||
(startIndex + index) % randomNumbers.size(); | ||
mat[index] = randomNumbers[randomNumberIndex]; | ||
} | ||
} | ||
} | ||
|
||
void run(const TestCase::Config &config) { | ||
const size_t sizeA = 2 * config.m * config.k; | ||
const size_t sizeB = 2 * config.k * config.n; | ||
const size_t sizeC = 2 * config.m * config.n; | ||
|
||
std::vector<real> hostA(sizeA); | ||
std::vector<real> hostB(sizeB); | ||
std::vector<real> hostC(sizeC); | ||
|
||
init(hostA, config.m, config.k); | ||
init(hostB, config.k, config.n); | ||
init(hostC, config.m, config.n); | ||
|
||
real *devA{nullptr}; | ||
real *devB{nullptr}; | ||
real *devC{nullptr}; | ||
|
||
checkHIPErrors(hipMalloc((void **)&devA, sizeA * sizeof(real))); | ||
checkHIPErrors(hipMalloc((void **)&devB, sizeB * sizeof(real))); | ||
checkHIPErrors(hipMalloc((void **)&devC, sizeC * sizeof(real))); | ||
|
||
checkHIPErrors(hipMemcpy(devA, hostA.data(), sizeA * sizeof(real), | ||
hipMemcpyKind::hipMemcpyHostToDevice)); | ||
checkHIPErrors(hipMemcpy(devB, hostB.data(), sizeB * sizeof(real), | ||
hipMemcpyKind::hipMemcpyHostToDevice)); | ||
checkHIPErrors(hipMemcpy(devC, hostC.data(), sizeC * sizeof(real), | ||
hipMemcpyKind::hipMemcpyHostToDevice)); | ||
|
||
TestCase::launchKernel(devA, devB, devC, config); | ||
|
||
checkHIPErrors(hipFree(devA)); | ||
checkHIPErrors(hipFree(devB)); | ||
checkHIPErrors(hipFree(devC)); | ||
} | ||
|
||
int main(int argc, char *argv[]) { | ||
CLI::App app{"ck gemm examples"}; | ||
TestCase::Config config{}; | ||
|
||
app.add_option("-m", config.m, "M size"); | ||
app.add_option("-n", config.n, "N size"); | ||
app.add_option("-k", config.k, "K size"); | ||
app.add_option("--kbatch", config.kbatch, "kbatch (for split-k)"); | ||
app.add_flag("--trans-a", config.transA, "transpose A"); | ||
app.add_flag("--trans-b", config.transB, "transpose B"); | ||
app.add_option("--log-level", config.logLevel, "CK's log level"); | ||
app.add_option("--cold-num-iters", config.coldNumIters, | ||
"num cold iterations"); | ||
app.add_option("--num-repeat", config.numRepeat, "num repeats"); | ||
app.add_option("--rotating-count", config.rotatingCount, "rotating count"); | ||
app.add_flag("--flush-cache", config.flushCache, "flush cache"); | ||
CLI11_PARSE(app, argc, argv); | ||
|
||
run(config); | ||
return 0; | ||
} |
23 changes: 23 additions & 0 deletions
23
python/perf-kernels/tools/profiler/ck-benchmarks/gemm/testcase.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#pragma once | ||
|
||
#include "ck/ck.hpp" | ||
#include "ck/utility/data_type.hpp" | ||
|
||
using real = ck::half_t; | ||
|
||
struct TestCase { | ||
struct Config { | ||
size_t m{1024}; | ||
size_t n{1024}; | ||
size_t k{1024}; | ||
size_t kbatch{1}; | ||
bool transA{false}; | ||
bool transB{false}; | ||
int logLevel{1}; | ||
int coldNumIters{5}; | ||
int numRepeat{50}; | ||
bool flushCache{false}; | ||
int rotatingCount{1}; | ||
}; | ||
static void launchKernel(real *, real *, real *, const Config &); | ||
}; |