diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c0772c33f6e5d..2e9be26fb9920 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -294,7 +294,7 @@ if (onnxruntime_USE_ROCM) endif() if (NOT CMAKE_HIP_ARCHITECTURES) - set(CMAKE_HIP_ARCHITECTURES "gfx906;gfx908;gfx90a;gfx1030;gfx1100;gfx1101") + set(CMAKE_HIP_ARCHITECTURES "gfx908;gfx90a;gfx1030;gfx1100;gfx1101;gfx940;gfx941;gfx942;gfx1200;gfx1201") endif() file(GLOB rocm_cmake_components ${onnxruntime_ROCM_HOME}/lib/cmake/*) diff --git a/cmake/external/composable_kernel.cmake b/cmake/external/composable_kernel.cmake index 4230eb8f4259b..b388a01209f4e 100644 --- a/cmake/external/composable_kernel.cmake +++ b/cmake/external/composable_kernel.cmake @@ -1,10 +1,12 @@ -set(PATCH ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) +set(PATCH_CLANG ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch) +set(PATCH_GFX12X ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Add_gfx12x_support.patch) include(FetchContent) FetchContent_Declare(composable_kernel URL ${DEP_URL_composable_kernel} URL_HASH SHA1=${DEP_SHA1_composable_kernel} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_CLANG} && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX12X} ) FetchContent_GetProperties(composable_kernel) diff --git a/cmake/hip_fatbin_insert b/cmake/hip_fatbin_insert new file mode 100644 index 0000000000000..7d834cbf569f0 --- /dev/null +++ b/cmake/hip_fatbin_insert @@ -0,0 +1,7 @@ +SECTIONS { + .hipFatBinSegment : { *(.hipFatBinSegment) } +} INSERT AFTER .bss + +SECTIONS { + .hip_fatbin : { *(.hip_fatbin) } +} INSERT AFTER .hipFatBinSegment diff --git a/cmake/onnxruntime_providers_rocm.cmake b/cmake/onnxruntime_providers_rocm.cmake index 559204bd0df88..e003cc2f8b724 100644 --- a/cmake/onnxruntime_providers_rocm.cmake +++ b/cmake/onnxruntime_providers_rocm.cmake @@ -116,6 +116,7 @@ auto_set_source_files_hip_language(${onnxruntime_providers_rocm_src}) onnxruntime_add_shared_library_module(onnxruntime_providers_rocm ${onnxruntime_providers_rocm_src}) target_compile_options(onnxruntime_providers_rocm PRIVATE -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1) + target_link_options(onnxruntime_providers_rocm PRIVATE -T ${REPO_ROOT}/cmake/hip_fatbin_insert) if(NOT MSVC) target_compile_options(onnxruntime_providers_rocm PRIVATE -Wno-sign-compare) diff --git a/cmake/patches/composable_kernel/Add_gfx12x_support.patch b/cmake/patches/composable_kernel/Add_gfx12x_support.patch new file mode 100644 index 0000000000000..ef529184d2ed8 --- /dev/null +++ b/cmake/patches/composable_kernel/Add_gfx12x_support.patch @@ -0,0 +1,2280 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index bc326c8b5..db5ad5052 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -117,7 +117,7 @@ else() + add_definitions(-DPROFILER_ONLY) + set(GPU_TARGETS "" CACHE STRING "" FORCE) + if(GPU_TARGETS) +- message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") ++ message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12") + endif() + if(GPU_ARCH MATCHES "gfx90") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a") +@@ -127,8 +127,10 @@ else() + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") + elseif(GPU_ARCH MATCHES "gfx11") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") ++ elseif(GPU_ARCH MATCHES "gfx12") ++ rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201") + else() +- message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") ++ message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12") + endif() + set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) + endif() +diff --git a/Jenkinsfile b/Jenkinsfile +index 75800bfc9..b72e2ca4e 100644 +--- a/Jenkinsfile ++++ b/Jenkinsfile +@@ -493,6 +493,7 @@ def Build_CK(Map conf=[:]){ + + def variant = env.STAGE_NAME + def retimage ++ + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + try { + (retimage, image) = getDockerImage(conf) +@@ -660,9 +661,6 @@ CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCM + + pipeline { + agent none +- triggers { +- parameterizedCron(CRON_SETTINGS) +- } + options { + parallelsAlwaysFailFast() + } +diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake +index 8654170b3..42070051b 100644 +--- a/cmake/EnableCompilerWarnings.cmake ++++ b/cmake/EnableCompilerWarnings.cmake +@@ -66,7 +66,7 @@ else() + -Wunreachable-code + -Wunused + -Wno-reserved-identifier +- -Werror ++ -Werror + -Wno-option-ignored + -Wsign-compare + -Wno-extra-semi-stmt +diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp +index 8c52e4f7d..f8afe8d6d 100644 +--- a/example/01_gemm/gemm_wmma_fp16.cpp ++++ b/example/01_gemm/gemm_wmma_fp16.cpp +@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa + + // clang-format off + using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle +- < ALayout, +- BLayout, +- CLayout, +- ADataType, ++ < ALayout, ++ BLayout, ++ CLayout, ++ ADataType, + BDataType, +- CDataType, +- AccDataType, +- CShuffleDataType, +- AElementOp, +- BElementOp, +- CElementOp, +- GemmDefault, ++ CDataType, ++ AccDataType, ++ CShuffleDataType, ++ AElementOp, ++ BElementOp, ++ CElementOp, ++ GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock +- 8, // K1 ++ 2, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave +- S<4, 32, 1>, +- S<1, 0, 2>, +- S<1, 0, 2>, +- 2, +- 8, +- 8, +- true, +- S<4, 32, 1>, +- S<1, 0, 2>, +- S<1, 0, 2>, +- 2, +- 8, +- 8, +- true, ++ S<4, 32, 1>, ++ S<1, 0, 2>, ++ S<1, 0, 2>, ++ 2, ++ 2, ++ 2, ++ true, ++ S<4, 32, 1>, ++ S<1, 0, 2>, ++ S<1, 0, 2>, ++ 2, ++ 2, ++ 2, ++ true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store +- S<1, 32, 1, 4>, ++ S<1, 32, 1, 4>, + 8>; + // clang-format on + +diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc +index b04e4e53a..cb15186c3 100644 +--- a/example/01_gemm/run_gemm_example.inc ++++ b/example/01_gemm/run_gemm_example.inc +@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: +- ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); ++ ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: +diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt +index ab19f819e..be47665a2 100644 +--- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt ++++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt +@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32) + set(target 1) + endif() +-endforeach() +\ No newline at end of file ++endforeach() +diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +index 2bbf430c4..f556be887 100644 +--- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp ++++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN = + 2, + 4, + 4, +- true, ++ false, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 4, + 4, +- true, ++ false, + 1, + 1, + S<1, 64, 1, 2>, +diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +index 4c92c5497..fac19f8b5 100644 +--- a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp ++++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp +@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial + #define CK_MHA_USE_WAVE_1 + #define CK_MHA_USE_WAVE_2 + #define CK_MHA_USE_WAVE_4 +-#define CK_MHA_USE_WAVE_8 ++//#define CK_MHA_USE_WAVE_8 + using DeviceMHAFactory = + std::tuple< + #ifdef CK_MHA_USE_WAVE_1 +@@ -277,10 +277,10 @@ using DeviceMHAFactory = + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, +- MaskingSpec>, ++ MaskingSpec> + #endif + #ifdef CK_MHA_USE_WAVE_8 +- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ++ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, +diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +index 8e037272b..d463cc871 100644 +--- a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp ++++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp +@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial + #define CK_MHA_USE_WAVE_1 + #define CK_MHA_USE_WAVE_2 + #define CK_MHA_USE_WAVE_4 +-#define CK_MHA_USE_WAVE_8 ++//#define CK_MHA_USE_WAVE_8 + using DeviceMHAFactory = + std::tuple< + #ifdef CK_MHA_USE_WAVE_1 +@@ -277,10 +277,10 @@ using DeviceMHAFactory = + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, +- MaskingSpec>, ++ MaskingSpec> + #endif + #ifdef CK_MHA_USE_WAVE_8 +- ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< ++ ,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, +diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt +index 5465adb77..7534bff3b 100644 +--- a/example/CMakeLists.txt ++++ b/example/CMakeLists.txt +@@ -60,7 +60,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() +@@ -134,7 +134,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() +diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp +index 55f562061..69a7abf62 100644 +--- a/include/ck/ck.hpp ++++ b/include/ck/ck.hpp +@@ -69,6 +69,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) + #define __gfx11__ + #endif ++#if defined(__gfx1200__) || defined(__gfx1201__) ++#define __gfx12__ ++#endif + + // buffer resource + #ifndef __HIP_DEVICE_COMPILE__ // for host code +@@ -77,7 +80,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 + #elif defined(__gfx103__) + #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +-#elif defined(__gfx11__) ++#elif defined(__gfx11__) || defined(__gfx12__) + #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 + #endif + +@@ -89,7 +92,7 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #define CK_USE_AMD_V_FMAC_F32 + #define CK_USE_AMD_V_DOT2_F32_F16 + #define CK_USE_AMD_V_DOT4_I32_I8 +-#elif defined(__gfx11__) ++#elif defined(__gfx11__) || defined(__gfx12__) + #define CK_USE_AMD_V_FMAC_F32 + #define CK_USE_AMD_V_DOT2_F32_F16 + #define CK_USE_AMD_V_DOT4_I32_I8_GFX11 +@@ -110,13 +113,6 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + #define CK_USE_AMD_MFMA_GFX940 + #endif + +-// WMMA instruction +-#ifndef __HIP_DEVICE_COMPILE__ // for host code +-#define CK_USE_AMD_WMMA +-#elif defined(__gfx11__) // for GPU code +-#define CK_USE_AMD_WMMA +-#endif +- + // buffer load + #define CK_USE_AMD_BUFFER_LOAD 1 + +diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp +index 116bb3ea0..83af2efe8 100644 +--- a/include/ck/host_utility/device_prop.hpp ++++ b/include/ck/host_utility/device_prop.hpp +@@ -84,4 +84,9 @@ inline bool is_gfx11_supported() + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103"; + } + ++inline bool is_gfx12_supported() ++{ ++ return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; ++} ++ + } // namespace ck +diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +index f8ee283c6..7eb7d42eb 100644 +--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +@@ -13,6 +13,504 @@ + + namespace ck { + ++#ifdef __gfx12__ ++template ++/* Option: Read from LDS, big buffer hold all threads required data ++ * Source ++ * A: K0PerBlock x MPerBlock x K1 ++ * B: K0PerBlock x NPerBlock x K1 ++ * Destination ++ * C, non-transpose ++ * thread level: MRepeat x NRepeat x MAccVgprs ++ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs ++ * KPACK == WMMA_K = 16 ++ * ++ * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) ++ * Source: ++ * A(if skip LDS): MRepeat x KPack ++ * B(if skip LDS): NRepeat x KPack ++ * Destination ++ * C, non-transpose ++ * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs ++ */ ++struct BlockwiseGemmWMMA ++{ ++ static constexpr auto I0 = Number<0>{}; ++ static constexpr auto I1 = Number<1>{}; ++ static constexpr auto I2 = Number<2>{}; ++ static constexpr auto I3 = Number<3>{}; ++ static constexpr auto I4 = Number<4>{}; ++ static constexpr auto I5 = Number<5>{}; ++ static constexpr auto WmmaK = Number<16>{}; ++ ++ using ThisThreadBlock = ThisThreadBlock; ++ ++ // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. ++ static constexpr index_t WaveSize = 32; ++ ++ // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer ++ // When not use LDS, each Row read half of whole data from source buffer, exchange the data via ++ // permutation ++ static constexpr index_t A_KRow = 2; ++ static constexpr index_t B_KRow = 2; ++ ++ static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); ++ static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); ++ ++ static constexpr auto wmma_gemm = ++ WmmaGemm{}; ++ ++ static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); ++ static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); ++ ++ StaticBufferTupleOfVector ++ c_thread_buf_; ++ ++ __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } ++ ++ __device__ static auto GetWaveIdx() ++ { ++ const index_t thread_id = ThisThreadBlock::GetThreadId(); ++ ++ constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( ++ make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), ++ make_tuple(Sequence<0, 1, 2>{}), ++ make_tuple(Sequence<0>{})); ++ ++ return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); ++ } ++ ++ // Default, Block buffer in LDS, thread level offset enabled ++ __device__ static auto CalculateAThreadOriginDataIndex() ++ { ++ if constexpr(AEnableLds) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ const auto waveId_m = wave_idx[I0]; ++ const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); ++ ++ // |KRepeat |MRepeat|MWave |KRow |MLane |KPack ++ return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); ++ } ++ else ++ { ++ return make_tuple(0, 0, 0, 0, 0, 0); ++ } ++ } ++ ++ __device__ static auto CalculateBThreadOriginDataIndex() ++ { ++ if constexpr(BEnableLds) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ const auto waveId_n = wave_idx[I1]; ++ const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); ++ ++ // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack ++ return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); ++ } ++ else ++ { ++ return make_tuple(0, 0, 0, 0, 0, 0); ++ } ++ } ++ ++ template ++ __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ ++ const auto waveId_m = wave_idx[I0]; ++ const auto waveId_n = wave_idx[I1]; ++ ++ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); ++ ++ constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( ++ make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), ++ make_tuple(Sequence<0>{}), ++ make_tuple(Sequence<0, 1, 2>{})); ++ ++ constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( ++ make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), ++ make_tuple(Sequence<0>{}), ++ make_tuple(Sequence<0, 1, 2>{})); ++ ++ const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( ++ make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; ++ const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( ++ make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; ++ ++ return make_tuple(c_thread_m, c_thread_n); ++ } ++ ++ template ++ __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) ++ { ++ const auto wave_idx = GetWaveIdx(); ++ ++ const auto waveId_m = wave_idx[I0]; ++ const auto waveId_n = wave_idx[I1]; ++ ++ const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); ++ ++ return make_tuple( ++ Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); ++ } ++ ++ using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); ++ __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), ++ Tuple6 b_origin = CalculateBThreadOriginDataIndex()) ++ : a_thread_copy_(a_origin), b_thread_copy_(b_origin) ++ { ++ static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), ++ "wrong! Desc should be known at compile-time"); ++ ++ static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, ++ "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); ++ ++ static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && ++ NPerBlock % (NPerWMMA * NRepeat) == 0, ++ "wrong!"); ++ } ++ ++ // transposed WMMA output C' = B' * A' ++ __host__ __device__ static constexpr auto ++ GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() ++ { ++ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = ++ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); ++ ++ constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; ++ ++ return make_naive_tensor_descriptor_packed( ++ // |MRepeat |MWave |MSubGroup |NRepeat |NWave ++ // |NThreadPerSubGroup |MAccVgprs ++ make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); ++ } ++ ++ // Thread level, register decriptor. Vector-write ++ __host__ __device__ static constexpr auto ++ GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() ++ { ++ constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = ++ wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); ++ ++ constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; ++ constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; ++ return make_naive_tensor_descriptor( ++ // |MRepeat |MWave |MSubGroup |NRepeat |NWave ++ // |NThreadPerSubGroup |MAccVgprs ++ make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), ++ make_tuple(Number{} * MAccVgprs * AccStride, ++ Number{} * MAccVgprs * AccStride, ++ Number{} * MAccVgprs * AccStride, ++ MAccVgprs * AccStride, ++ MAccVgprs * AccStride, ++ MAccVgprs * AccStride, ++ AccStride)); ++ } ++ ++ template ++ __host__ __device__ static constexpr auto ++ MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( ++ const CGridDesc_M_N& c_grid_desc_m_n) ++ { ++ const auto M = c_grid_desc_m_n.GetLength(I0); ++ const auto N = c_grid_desc_m_n.GetLength(I1); ++ ++ const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = ++ transform_tensor_descriptor( ++ c_grid_desc_m_n, ++ make_tuple( ++ make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), ++ make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), ++ make_tuple(Sequence<0>{}, Sequence<1>{}), ++ make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); ++ ++ return wmma_gemm ++ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( ++ c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); ++ } ++ ++ // transposed WMMA output C' = B' * A' ++ __host__ __device__ static constexpr auto ++ GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() ++ { ++ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = ++ make_naive_tensor_descriptor_packed(make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{})); ++ ++ return wmma_gemm ++ .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( ++ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); ++ } ++ ++ // Provide dimension size ++ __host__ __device__ static constexpr auto ++ GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() ++ { ++ constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = ++ make_naive_tensor_descriptor_packed(make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{})); ++ ++ return wmma_gemm ++ .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( ++ c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); ++ } ++ ++ // Describe how data allocated in thread copy src buffer ++ // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma ++ static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; ++ static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; ++ ++ template ++ __device__ void Run(const ABlockBuffer& a_block_buf, ++ const BBlockBuffer& b_block_buf, ++ CThreadBuffer& c_thread_buf) const ++ { ++ auto a_thread_buf = make_static_buffer( ++ a_thread_desc_.GetElementSpaceSize()); ++ auto b_thread_buf = make_static_buffer( ++ b_thread_desc_.GetElementSpaceSize()); ++ ++ static_assert(KPack % (A_K1 * A_KRow) == 0, ""); ++ static_assert(KPack % (B_K1 * B_KRow) == 0, ""); ++ ++ // basic intrinsic to determine loopover direction ++ if constexpr(MRepeat < NRepeat) ++ { ++ static_for<0, KPerBlock / KPack, 1>{}( ++ [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... ++ static_for<0, MRepeat, 1>{}([&](auto m0) { ++ // read A ++ a_thread_copy_.Run( ++ a_block_desc_k0_m0_m1_m2_k1, ++ make_tuple(Number{}, m0, I0, I0, I0, I0), ++ a_block_buf, ++ a_thread_desc_, ++ make_tuple(I0, m0, I0, I0, I0, I0), ++ a_thread_buf); ++ ++ static_for<0, NRepeat, 1>{}([&](auto n0) { ++ // read B ++ b_thread_copy_.Run( ++ b_block_desc_k0_n0_n1_n2_k1, ++ make_tuple(Number{}, n0, I0, I0, I0, I0), ++ b_block_buf, ++ b_thread_desc_, ++ make_tuple(I0, n0, I0, I0, I0, I0), ++ b_thread_buf); ++ ++ vector_type a_thread_vec; ++ vector_type b_thread_vec; ++ ++ static_for<0, KPack / A_KRow, 1>{}([&](auto i) { ++ a_thread_vec.template AsType()(i) = ++ a_thread_buf[Number{}]; ++ }); ++ ++ static_for<0, KPack / B_KRow, 1>{}([&](auto i) { ++ b_thread_vec.template AsType()(i) = ++ b_thread_buf[Number{}]; ++ }); ++ ++ using wmma_input_type_a = ++ typename vector_type::type; ++ using wmma_input_type_b = ++ typename vector_type::type; ++ ++ constexpr index_t c_offset = ++ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); ++ ++ wmma_gemm.template Run( ++ a_thread_vec.template AsType(), ++ b_thread_vec.template AsType(), ++ c_thread_buf.GetVectorTypeReference(Number{})); ++ }); ++ }); ++ }); ++ } ++ else ++ { ++ static_for<0, NRepeat, 1>{}([&](auto n0) { ++ static_for<0, MRepeat, 1>{}([&](auto m0) { ++ static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of ++ // k=0,kpack*1, .. ++ // read B ++ b_thread_copy_.Run( ++ b_block_desc_k0_n0_n1_n2_k1, ++ make_tuple(Number{}, n0, I0, I0, I0, I0), ++ b_block_buf, ++ b_thread_desc_, ++ make_tuple(I0, n0, I0, I0, I0, I0), ++ b_thread_buf); ++ // read A ++ a_thread_copy_.Run( ++ a_block_desc_k0_m0_m1_m2_k1, ++ make_tuple(Number{}, m0, I0, I0, I0, I0), ++ a_block_buf, ++ a_thread_desc_, ++ make_tuple(I0, m0, I0, I0, I0, I0), ++ a_thread_buf); ++ ++ vector_type a_thread_vec; ++ vector_type b_thread_vec; ++ ++ static_for<0, KPack / A_KRow, 1>{}([&](auto i) { ++ a_thread_vec.template AsType()(i) = ++ a_thread_buf[Number{}]; ++ }); ++ ++ static_for<0, KPack / B_KRow, 1>{}([&](auto i) { ++ b_thread_vec.template AsType()(i) = ++ b_thread_buf[Number{}]; ++ }); ++ ++ using wmma_input_type_a = ++ typename vector_type::type; ++ using wmma_input_type_b = ++ typename vector_type::type; ++ ++ constexpr index_t c_offset = ++ c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); ++ ++ wmma_gemm.template Run( ++ a_thread_vec.template AsType(), ++ b_thread_vec.template AsType(), ++ c_thread_buf.GetVectorTypeReference(Number{})); ++ }); ++ }); ++ }); ++ } ++ } ++ ++ protected: ++ static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( ++ make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), ++ make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number<1>{})); ++ ++ static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( ++ make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), ++ make_tuple(Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number{}, ++ Number<1>{})); ++ ++ // C[M, N, NumRegWMMA] ++ static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( ++ make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); ++ ++ template ++ struct AThreadCopySelector; ++ ++ template <> ++ struct AThreadCopySelector ++ { ++ using type = ++ ThreadwiseTensorSliceTransfer_v4, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ A_K1, ++ A_K1>; ++ }; ++ ++ template <> ++ struct AThreadCopySelector ++ { ++ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< ++ FloatA, ++ FloatA, ++ decltype(a_block_desc_k0_m0_m1_m2_k1), ++ decltype(a_thread_desc_), ++ tensor_operation::element_wise::PassThrough, ++ Sequence, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ A_K1, ++ false>; ++ }; ++ ++ template ++ struct BThreadCopySelector; ++ ++ template <> ++ struct BThreadCopySelector ++ { ++ using type = ++ ThreadwiseTensorSliceTransfer_v4, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ B_K1, ++ B_K1>; ++ }; ++ ++ template <> ++ struct BThreadCopySelector ++ { ++ using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< ++ FloatB, ++ FloatB, ++ decltype(b_block_desc_k0_n0_n1_n2_k1), ++ decltype(b_thread_desc_), ++ tensor_operation::element_wise::PassThrough, ++ Sequence, ++ Sequence<0, 1, 2, 3, 4, 5>, ++ 5, ++ B_K1, ++ false>; ++ }; ++ ++ typename AThreadCopySelector::type a_thread_copy_; ++ typename BThreadCopySelector::type b_thread_copy_; ++}; ++#else + template ::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; + }; ++#endif + + } // namespace ck +diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +index e5e6245cb..1f7d50429 100644 +--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp ++++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +@@ -488,7 +488,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + // sync point. + if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) + { ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("s_barrier" ::); ++#endif + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +index a15759559..ab3f3856a 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + +- static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; +- static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; ++ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; ++ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; ++ ++ static constexpr auto AEnableLds_auto = ++ (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; ++ static constexpr auto BEnableLds_auto = ++ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; +@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + + static bool IsSupportedArgument(const Argument& arg) + { +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + } + else + { +- if(!(arg.a_kz_stride_ == 1 && +- arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) ++ if(!(arg.a_kz_stride_ == 1)) + { +- printf("DeviceOp: Vector Access A-k check failure\n"); +- return false; ++ index_t LastK = ++ AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); ++ if(LastK % ABlockTransferSrcScalarPerVector == 0) ++ { ++ printf("DeviceOp: Vector Access A-k check failure\n"); ++ return false; ++ } + } + } + +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +index 8fd14afc0..1b487502f 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +@@ -70,8 +70,9 @@ __global__ void + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ +- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ ++ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +index 9d5b74be6..017d28641 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +@@ -601,9 +601,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle + return false; + } + +- if(ck::get_device_name() != "gfx90a" && ck::get_device_name() != "gfx940" && +- ck::get_device_name() != "gfx941" && ck::get_device_name() != "gfx942" && +- std::is_same::value) ++ if(!ck::is_lds_direct_load_supported() && std::is_same::value) + { + return false; + } +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +index b84e18130..1edae33be 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp +@@ -1393,7 +1393,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl + { + // check device + if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || +- ck::is_gfx11_supported())) ++ ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return false; + } +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +index bf96324d0..553143e28 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB || is_same_v || + is_same_v)) +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +index b1784b385..eb0fb55f5 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +index 93ab8a7e1..a7cc546f5 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; + +- static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); +- static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); +- static constexpr auto WmmaK = K1 == 16 ? 32 : 16; +- +- static constexpr auto AEnableLds_auto = +- (NWaves == 1 && is_same::value) ? false : true; ++ static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); ++ static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); ++ static constexpr auto WmmaK = K1 == 16 ? 32 : 16; ++ static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; ++ static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; ++ ++ static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && ++ is_same::value) ++ ? false ++ : true; + static constexpr auto BEnableLds_auto = +- (MWaves == 1 && is_same::value) ? false : true; ++ (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && ++ is_same::value) ++ ? false ++ : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; +@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v || + is_same_v)) +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +index 6f74838fb..6bb5d431c 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + static bool IsSupportedArgument(const Argument& arg) + { + // check device +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +index bd264a3c8..7047e1bda 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +@@ -48,8 +48,9 @@ __global__ void + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ +- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ ++ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +index 211185dfb..5738be0fb 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle + static bool IsSupportedArgument(const Argument& arg) + { + // check device +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +index 7cfbd8a8f..5d5a9de7d 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +@@ -90,8 +90,9 @@ __global__ void + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ +- defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ ++ defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -666,7 +667,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + + // check device + if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || +- ck::is_gfx103_supported() || ck::is_gfx11_supported())) ++ ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return false; + } +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +index 6a4d97d7d..c65370b51 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +@@ -107,7 +107,7 @@ __global__ void + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { + #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ +- defined(__gfx11__)) ++ defined(__gfx11__) || defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -602,7 +602,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +index ac392cddc..060a16d1e 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +@@ -39,8 +39,9 @@ __global__ void + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ +- defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ ++ defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ ++ defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); +@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +index 4e14ed3a5..cc88c1a10 100644 +--- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +@@ -60,7 +60,7 @@ __global__ void + bool input_permute, + bool output_permute) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off + // *************************************************** +@@ -165,6 +165,7 @@ __global__ void + ignore = O; + ignore = G0; + ignore = G1; ++ ignore = alpha; + ignore = input_permute; + ignore = output_permute; + #endif // end of if (defined(__gfx11__)) +@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma + + static bool IsSupportedArgument(const RawArg& arg) + { +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +index 16717ff81..1754e07e6 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma + if constexpr(B0EnableLds) + { + // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 +- constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + B0BlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma + if constexpr(B1EnableLds) + { + // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 +- constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); +- constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); ++ constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); ++ constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_LRow = I2; ++#else + constexpr auto B_LRow = I1; ++#endif + return transform_tensor_descriptor( + B1BlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +index 499eb7eb0..21dac6f9e 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +@@ -50,7 +50,7 @@ __global__ void + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, +@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 +- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); +- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); ++ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto A_KRow = I2; ++#else + constexpr auto A_KRow = I1; ++#endif + return transform_tensor_descriptor( + ABlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 +- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + BBlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +index 82d010a99..fdda649ef 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +@@ -54,7 +54,7 @@ __global__ void + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); +@@ -147,7 +147,7 @@ __global__ void + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + // printf("entry kernel launch"); + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; + +@@ -237,7 +237,7 @@ __global__ void + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; + + GridwiseOp::template Run(p_a_grid, +@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma + } + else + { ++ constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma + } + else + { ++ constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 +- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); +- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); ++ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto A_KRow = I2; ++#else + constexpr auto A_KRow = I1; ++#endif + return transform_tensor_descriptor( + ABlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 +- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + BBlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { +- constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); +- constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); +- + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, +- Number{}, ++ Number{}, + I1, +- Number{})); ++ Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } +@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + +- const auto MBlock = M / MPerBlock; +- const auto NBlock = N / NPerBlock; ++ const auto MBlock = M / MPerBlock; ++ const auto NBlock = N / NPerBlock; ++ + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +index 8e4117593..4458b9356 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +@@ -45,7 +45,7 @@ __global__ void + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, +@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma + } + else + { ++ constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma + } + else + { ++ ++ constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; +- constexpr auto K0PerWmma = WmmaK / 2 / K1; ++ constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, +@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 +- constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); +- constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++ constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); ++ constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto A_KRow = I2; ++#else + constexpr auto A_KRow = I1; ++#endif ++ + return transform_tensor_descriptor( + ABlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 +- constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); +- constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++ constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); ++ constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); ++#ifdef __gfx12__ ++ constexpr auto B_KRow = I2; ++#else + constexpr auto B_KRow = I1; ++#endif + return transform_tensor_descriptor( + BBlockDesc_{}, +- make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), ++ make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), +@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma + c_grid_desc_m_n); + } + +- using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = +- remove_cvref_t; +- using DefaultBlock2CTileMap = +- remove_cvref_t; +- + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment +@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma + b_block_space_size_aligned * sizeof(BDataType)); + }; + ++ using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = ++ remove_cvref_t; ++ using DefaultBlock2CTileMap = ++ remove_cvref_t; ++ + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, +diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +index 6772524e0..174074990 100644 +--- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp ++++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +@@ -35,8 +35,9 @@ __global__ void + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + { +-#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ +- defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__)) ++#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ ++ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ ++ defined(__gfx12__)) + GridwiseTensorRearrangeKernel::Run(in_grid_desc, + p_in_global, + out_grid_desc, +diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +index bcce930fc..d7a6a3624 100644 +--- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp ++++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic + ElementwiseOperation element_op_; + }; + +-// Specilized for WMMA ++// Specilized for WMMA-Navi3 + // A single Wave32 is composed by double row + // Data exchange allowed between these two rows + // This RowLane Dst buf will be filled from two Src buf +@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow + ElementwiseOperation element_op_{}; + }; + ++// Specilized for WMMA-Navi4 ++template ::type = false> ++struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow ++{ ++ static constexpr index_t nDim = SliceLengths::Size(); ++ ++ using Index = MultiIndex; ++ ++ __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) ++ { ++ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), ++ "wrong! Desc need to known at compile-time"); ++ ++ static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, ++ "wrong! Not divisible"); ++ ignore = src_idx; ++ } ++ ++ template ++ __device__ void Run(const SrcDesc&, ++ const SrcSliceOriginIdx&, ++ const SrcBuffer& src_buf, ++ const DstDesc&, ++ const DstSliceOriginIdx&, ++ DstBuffer& dst_buf) const ++ { ++ static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), ++ "wrong! Desc need to known at compile-time"); ++ ++ static_assert(is_known_at_compile_time>::value && ++ is_known_at_compile_time>::value, ++ "wrong! SliceOrigin need to known at compile-time"); ++ ++ static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), ++ "wrong! Buffer need to be StaticBuffer"); ++ ++ // SrcDesc and src_slice_origin_idx are known at compile-time ++ constexpr auto src_desc = remove_cvref_t{}; ++ constexpr auto dst_desc = remove_cvref_t{}; ++ constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); ++ constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); ++ ++ // scalar per access on each dim ++ constexpr auto dst_scalar_per_access = generate_sequence( ++ detail::lambda_scalar_per_access{}, Number{}); ++ ++ constexpr auto dst_scalar_step_in_vector = ++ generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); ++ ++ using SpaceFillingCurve = SpaceFillingCurve>; ++ ++ static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, ++ "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); ++ ++ constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); ++ ++ static_for<0, num_access, 1>{}([&](auto idx_1d) { ++ constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); ++ ++ // copy data from src_buf into dst_vector ++ static_for<0, DstScalarPerVector, 1>{}([&](auto i) { ++ // src_desc error, non constexpr, caused by merge transform ++ constexpr index_t src_offset = src_desc.CalculateOffset( ++ src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); ++ ++ constexpr index_t dst_offset = dst_desc.CalculateOffset( ++ dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); ++ ++ SrcData v_this_row; ++ // int type temp value due to intrinsic requirement ++ int temp = 0; ++ ++ // apply element-wise operation ++ element_op_(v_this_row, src_buf[Number{}]); ++ ++ // apply intra-row permute. ++ if constexpr(IntraRowSwizzlePerm) ++ { ++ temp = __builtin_amdgcn_permlane16( ++ temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); ++ v_this_row = type_convert_sp(temp); ++ } ++ ++ // apply type convert ++ dst_buf(Number{}) = type_convert_sp(v_this_row); ++ }); ++ }); ++ } ++ ElementwiseOperation element_op_{}; ++}; ++ + } // namespace ck +diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +index 565195f53..9a9ebf559 100644 +--- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp ++++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +@@ -11,12 +11,17 @@ namespace ck { + + enum struct WmmaInstr + { ++ // gfx11 + wmma_f32_16x16x16_f16 = 0, + wmma_f32_16x16x16_bf16, + wmma_f16_16x16x16_f16, + wmma_bf16_16x16x16_bf16, + wmma_i32_16x16x16_iu8, +- wmma_i32_16x16x16_iu4 ++ wmma_i32_16x16x16_iu4, ++ // gfx12 ++ wmma_f32_16x16x16_f16_gfx12, ++ wmma_f32_16x16x16_bf16_gfx12, ++ wmma_i32_16x16x16_iu8_gfx12, + }; + + /* +@@ -279,6 +284,122 @@ struct wmma_type ++struct wmma_type> ++{ ++ // Absolute fixing property ++ // * Data Pixel ++ static constexpr index_t m_per_wmma = 16; ++ static constexpr index_t n_per_wmma = 16; ++ static constexpr index_t k_per_wmma = 16; ++ // static constexpr index_t src_a_data_size = 2; ++ // static constexpr index_t src_b_data_size = 2; ++ // static constexpr index_t acc_data_size = 4; ++ // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction ++ static constexpr index_t acc_data_size = 4; ++ static constexpr index_t acc_pack_number = 1; ++ static constexpr index_t num_thread_per_subgroups = n_per_wmma; ++ ++ // Wave mode dependent propety ++ static constexpr index_t wave_size = Number{}; ++ // * Fixed in Navi3x, Will be wave mode dependent on Navi4x ++ // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; ++ // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; ++ // * num_acc_vgprs_per_wave alone M direction ++ // * num_subgroups alone M direction ++ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; ++ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; ++ ++ template ++ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const ++ { ++ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); ++ if constexpr(wave_size == 32) ++ { ++ intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); ++ } ++ } ++}; ++ ++template ++struct wmma_type> ++{ ++ // Absolute fixing property ++ static constexpr index_t m_per_wmma = 16; ++ static constexpr index_t n_per_wmma = 16; ++ static constexpr index_t k_per_wmma = 16; ++ // static constexpr index_t src_a_data_size = 2; ++ // static constexpr index_t src_b_data_size = 2; ++ static constexpr index_t acc_data_size = 4; ++ static constexpr index_t acc_pack_number = 1; ++ static constexpr index_t num_thread_per_subgroups = n_per_wmma; ++ ++ // Wave mode dependent propety ++ static constexpr index_t wave_size = Number{}; ++ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; ++ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; ++ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; ++ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; ++ ++ template ++ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const ++ { ++ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); ++ if constexpr(wave_size == 32) ++ { ++ intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); ++ } ++ } ++}; ++ ++template ++struct wmma_type> ++{ ++ // Absolute fixing property ++ static constexpr index_t m_per_wmma = 16; ++ static constexpr index_t n_per_wmma = 16; ++ static constexpr index_t k_per_wmma = 16; ++ // static constexpr index_t src_a_data_size = 2; ++ // static constexpr index_t src_b_data_size = 2; ++ static constexpr index_t acc_data_size = 4; ++ static constexpr index_t acc_pack_number = 1; ++ static constexpr index_t num_thread_per_subgroups = n_per_wmma; ++ ++ // Wave mode dependent propety ++ static constexpr index_t wave_size = Number{}; ++ // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; ++ // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; ++ static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; ++ static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; ++ ++ template ++ __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const ++ { ++ static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); ++ if constexpr(wave_size == 32) ++ { ++ intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( ++ a, b, reg_c); ++ } ++ } ++}; ++ + template + static constexpr auto GetWmma() + { ++#ifdef __gfx12__ ++ return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; ++#else + return WmmaInstr::wmma_f32_16x16x16_f16; ++#endif + } + + template <> + static constexpr auto GetWmma() + { ++#ifdef __gfx12__ ++ return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; ++#else + return WmmaInstr::wmma_f32_16x16x16_bf16; ++#endif + } + + template <> +@@ -320,8 +449,13 @@ struct WmmaSelector + template <> + static constexpr auto GetWmma() + { ++#ifdef __gfx12__ ++ return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; ++#else + return WmmaInstr::wmma_i32_16x16x16_iu8; ++#endif + } ++ + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + static constexpr auto GetWmma() +@@ -502,6 +636,9 @@ struct WmmaGemm + + __device__ static auto GetSubGroupId() + { ++ static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == ++ wmma_instr.wave_size, ++ ""); + return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; + } + +@@ -516,12 +653,20 @@ struct WmmaGemm + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { ++#ifdef __gfx12__ ++ return GetLaneIdUnderSubGroup(); ++#else + return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); ++#endif + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { ++#ifdef __gfx12__ ++ return GetLaneIdUnderSubGroup(); ++#else + return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); ++#endif + } + + __device__ static CIndex GetBeginOfThreadBlk() +diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp +index 1bb0140f3..322a0f94b 100644 +--- a/include/ck/utility/amd_wmma.hpp ++++ b/include/ck/utility/amd_wmma.hpp +@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> + } + }; + ++// gfx12 ++/********************************WAVE32 MODE***********************************************/ ++ ++#if defined(__gfx1200__) || defined(__gfx1201__) ++#define __gfx12__ ++#endif ++ ++// src: fp16, dst: fp32 ++template ++struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; ++ ++template <> ++struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> ++{ ++ template ++ __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) ++ { ++ // * Inline assembly need to elimate the duplicated data load, compiler won't help you ++ // delete them. ++ // amd_assembly_wmma_f32_16x16x16_f16_w32( ++ // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); ++#if defined(__gfx12__) ++ reg_c.template AsType()(Number<0>{}) = ++ __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( ++ reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); ++#else ++ ignore = reg_a; ++ ignore = reg_b; ++ ignore = reg_c; ++#endif ++ } ++}; ++ ++// src: bf16, dst: fp32 ++template ++struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; ++ ++template <> ++struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> ++{ ++ template ++ __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) ++ { ++#if defined(__gfx12__) ++ reg_c.template AsType()(Number<0>{}) = ++ __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( ++ reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); ++#else ++ ignore = reg_a; ++ ignore = reg_b; ++ ignore = reg_c; ++#endif ++ } ++}; ++ ++// src: iu8, dst: i32 ++template ++struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; ++ ++template ++struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> ++{ ++ template ++ __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) ++ { ++#if defined(__gfx12__) ++ reg_c.template AsType()(Number<0>{}) = ++ __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( ++ neg_a, ++ bit_cast(reg_a), ++ neg_b, ++ bit_cast(reg_b), ++ reg_c.template AsType()[Number<0>{}], ++ clamp); ++#else ++ ignore = reg_a; ++ ignore = reg_b; ++ ignore = reg_c; ++#endif ++ } ++}; ++ + } // namespace ck + #endif +diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp +index 93a1edefb..4df14c621 100644 +--- a/include/ck/utility/data_type.hpp ++++ b/include/ck/utility/data_type.hpp +@@ -203,7 +203,7 @@ struct vector_type + } + }; + +-int static err = 0; ++__device__ int static err = 0; + template + struct vector_type + { +diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp +index 4fe5e3950..d6b6eac26 100644 +--- a/include/ck/utility/synchronization.hpp ++++ b/include/ck/utility/synchronization.hpp +@@ -10,12 +10,20 @@ namespace ck { + __device__ void block_sync_lds() + { + #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); ++#endif + #else + __syncthreads(); + #endif +@@ -23,11 +31,20 @@ __device__ void block_sync_lds() + + __device__ void block_sync_lds_direct_load() + { ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_vmcnt 0x0 \n \ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("\ + s_waitcnt vmcnt(0) \n \ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); ++#endif + } + + __device__ void s_nop() +diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp +index 601aad19b..9dc2b072a 100644 +--- a/include/ck_tile/core/config.hpp ++++ b/include/ck_tile/core/config.hpp +@@ -17,6 +17,9 @@ + #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) + #define __gfx11__ + #endif ++#if defined(__gfx1200__) || defined(__gfx1201__) ++#define __gfx12__ ++#endif + + #ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS + #include "hip/hip_runtime.h" +@@ -155,7 +158,7 @@ + #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 + #elif defined(__gfx103__) // for GPU code + #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +-#elif defined(__gfx11__) // for GPU code ++#elif defined(__gfx11__) || defined(__gfx12__) // for GPU code + #define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 + #endif + +diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt +index 8c5f36d2e..89c9d6dc6 100644 +--- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt ++++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt +@@ -52,7 +52,7 @@ function(add_instance_library INSTANCE_NAME) + endforeach() + # Do not build WMMA instances if gfx11 targets are not on the target list + foreach(source IN LISTS ARGN) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + message("removing wmma instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() +@@ -149,7 +149,7 @@ FOREACH(subdir_path ${dir_list}) + message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") + set(add_inst 0) + endif() +- if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) ++ if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12")) + message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") + set(add_inst 0) + endif() +@@ -157,11 +157,11 @@ FOREACH(subdir_path ${dir_list}) + message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() +- if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) ++ if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") + set(add_inst 0) + endif() +- if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) ++ if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx12") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") + set(add_inst 0) + endif() +diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt +index 1cfcbfff6..a9557a9b9 100644 +--- a/profiler/src/CMakeLists.txt ++++ b/profiler/src/CMakeLists.txt +@@ -58,7 +58,7 @@ if(GPU_TARGETS MATCHES "gfx9") + + endif() + +-if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") ++if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + endif() +@@ -133,7 +133,7 @@ if(GPU_TARGETS MATCHES "gfx9") + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + endif() + +-if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") ++if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) + endif() +diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt +index 25c63ac7f..2a7c52b58 100644 +--- a/test/CMakeLists.txt ++++ b/test/CMakeLists.txt +@@ -53,7 +53,7 @@ function(add_test_executable TEST_NAME) + endif() + endforeach() + foreach(source IN LISTS ARGN) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() +@@ -118,7 +118,7 @@ function(add_gtest_executable TEST_NAME) + endif() + endforeach() + foreach(source IN LISTS ARGN) +- if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") ++ if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() +diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +index 1c8082645..21f49ec0f 100644 +--- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp ++++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +@@ -55,7 +55,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test + } + } + +- if(ck::is_gfx11_supported()) ++ if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + // on gfx11 only support for 3d is implemented + if constexpr(NDimSpatial{} != 3) +diff --git a/test/wmma_op/wmma_op_util.hpp b/test/wmma_op/wmma_op_util.hpp +index 49782bce6..d9ec94771 100644 +--- a/test/wmma_op/wmma_op_util.hpp ++++ b/test/wmma_op/wmma_op_util.hpp +@@ -140,10 +140,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) + p_shared[8 * 16 * lane_hi + 8 * lane_lo + ele + 16 * 16] = b_temp[ele]; + } + ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); ++#endif + + for(int ele = 0; ele < 16; ++ele) + { +@@ -155,10 +163,18 @@ __global__ void matmul(const src_t* a, const src_t* b, dst_t* c) + a_frag[ele] = p_shared[(ele / 8) * 16 * 8 + 8 * lane + ele % 8]; + } + ++#ifdef __gfx12__ ++ asm volatile("\ ++ s_wait_dscnt 0x0 \n \ ++ s_barrier_signal -1 \n \ ++ s_barrier_wait -1 \ ++ " ::); ++#else + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); ++#endif + + // sync threads, similar to mma_sync + // __syncthreads();