diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 4bb54b5614..e752a1a448 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -186,6 +186,9 @@ set( MIOpen_Source include/miopen/reduce_common.hpp include/miopen/sequences.hpp include/miopen/rocm_features.hpp + include/miopen/batched_transpose_sol.hpp + include/miopen/magic_div.hpp + include/miopen/util_sol.hpp md_graph.cpp mdg_expr.cpp conv/invokers/gcn_asm_1x1u.cpp @@ -279,6 +282,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN file(GLOB_RECURSE COMPOSABLE_KERNEL_DYNAMIC_CPP_SOURCE "kernels/dynamic_igemm/*.cpp") file(GLOB_RECURSE GPU_REFERENCE_KERNEL_HIP "kernels/gpu_reference_kernel/*.cpp") file(GLOB_RECURSE GPU_REFERENCE_KERNEL_ASM "kernels/gpu_reference_kernel/*.s") + file(GLOB_RECURSE GPU_BATCHED_TRANSPOSE_KERNEL_HIP "kernels/gpu_batched_transpose_kernel/*.cpp") set(MIOPEN_KERNEL_INCLUDES ${STATIC_COMPOSABLE_KERNEL_INCLUDE} @@ -379,6 +383,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN ${COMPOSABLE_KERNEL_DYNAMIC_CPP_SOURCE} ${GPU_REFERENCE_KERNEL_HIP} ${GPU_REFERENCE_KERNEL_ASM} + ${GPU_BATCHED_TRANSPOSE_KERNEL_HIP} kernels/detect_llvm_amdgcn_buffer_atomic_fadd_f32_float.cpp kernels/MIOpenCheckNumerics.cl kernels/MIOpenBatchNormActivBwdPerAct.cl @@ -503,6 +508,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN ocl/gcn_asm_utils.cpp ocl/rnn_util_ocl.cpp hip/hip_build_utils.cpp + hip/batched_transpose_sol.cpp pooling.cpp ocl/fusionopconvocl.cpp ocl/fusionopbiasbnactivocl.cpp diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index 801d905650..241fa4fd59 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace miopen { @@ -437,6 +438,9 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( int y = conv_problem.GetWeightsHeight(); int x = conv_problem.GetWeightsWidth(); int group = conv_problem.GetGroupCount(); + int c_karg = c / group; + int y_karg = y; + int x_karg = x; uint32_t gemm_m = n * ho * wo; uint32_t gemm_n = k / group; @@ -460,9 +464,9 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( uint32_t s_move_slice_k_y = (config.gemm_k_per_block / (x * (c / group))) % y; uint32_t s_move_slice_k_x = (config.gemm_k_per_block / (c / group)) % x; uint32_t s_move_slice_k_c = config.gemm_k_per_block % (c / group); - y = static_cast((s_move_slice_k_y << 24) | y); - x = static_cast((s_move_slice_k_x << 24) | x); - c = static_cast((s_move_slice_k_c << 24) | c); + y_karg = static_cast((s_move_slice_k_y << 24) | y); + x_karg = static_cast((s_move_slice_k_x << 24) | x); + c_karg = static_cast((s_move_slice_k_c << 24) | (c / group)); } else { @@ -482,7 +486,7 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( opArgs.emplace_back(wi); opArgs.emplace_back(n); opArgs.emplace_back(k / group); - opArgs.emplace_back(c / group); + opArgs.emplace_back(c_karg); opArgs.emplace_back(ho); opArgs.emplace_back(wo); opArgs.emplace_back(stride_h); @@ -491,8 +495,8 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( opArgs.emplace_back(dilation_w); opArgs.emplace_back(pad_h); opArgs.emplace_back(pad_w); - opArgs.emplace_back(y); - opArgs.emplace_back(x); + opArgs.emplace_back(y_karg); + opArgs.emplace_back(x_karg); opArgs.emplace_back(group); opArgs.emplace_back(mdiv_0.magic); opArgs.emplace_back(mdiv_1.magic); @@ -505,10 +509,86 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( opArgs.emplace_back(config.gemm_k_global_split); opArgs.emplace_back(pack0); - const auto& lowp_quant = ctx.conv_problem.GetConv().lowp_quant; + std::vector> opArgsTrans; + + const auto lowp_quant = ctx.conv_problem.GetConv().lowp_quant; const auto isGfx90aFp16altSupport = (ctx.GetStream().GetDeviceName() == "gfx90a") && conv_problem.IsFp16(); + const bool need_cast = [&]() { + if(ctx.conv_problem.GetOut().GetType() == miopenHalf) + return use_fp32_global_split_on_fp16; + if(ctx.conv_problem.GetOut().GetType() == miopenBFloat16) + return need_set_zero; + return false; + }(); + const auto is_nchw = ctx.IsLayoutDefault(); + + size_t trans_input_offset = 0; + size_t trans_input_size = 0; + + size_t trans_weight_offset = 0; + size_t trans_weight_size = 0; + + size_t trans_output_offset = 0; + size_t trans_output_size = 0; + + bool trans_input_skippable = false; + bool trans_weight_skippable = false; + bool trans_output_skippable = false; + + int trans_input_idx = -1; + int trans_weight_idx = -1; + int trans_output_idx = -1; + + if(is_nchw) + { + TransposeSolutionDefault2Nhwc trans_input(ctx, ctx.in_data_type, n, c, hi, wi); + TransposeSolutionDefault2Nhwc trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionNhwc2Default trans_output(ctx, ctx.out_data_type, n, k, ho, wo); + + trans_input_skippable = trans_input.IsSkippable(); + trans_weight_skippable = trans_weight.IsSkippable(); + trans_output_skippable = trans_output.IsSkippable(); + + if(!trans_input_skippable) + opArgsTrans.emplace_back(trans_input.GetKernelArg()); + if(!trans_weight_skippable) + opArgsTrans.emplace_back(trans_weight.GetKernelArg()); + if(!trans_output_skippable) + opArgsTrans.emplace_back(trans_output.GetKernelArg()); + + trans_input_size = trans_input_skippable ? 0 : trans_input.GetSize(); + trans_weight_size = trans_weight_skippable ? 0 : trans_weight.GetSize(); + trans_output_size = trans_output_skippable ? 0 : trans_output.GetSize(); + + trans_weight_offset = trans_input_offset + trans_input_size; + trans_output_offset = trans_weight_offset + trans_weight_size; + + int idx = 0; + if(!trans_input_skippable) + trans_input_idx = idx++; + if(!trans_weight_skippable) + trans_weight_idx = idx++; + if(!trans_output_skippable) + trans_output_idx = idx++; + } + + const size_t cast_offset = is_nchw ? (trans_output_offset + trans_output_size) : 0; + const size_t cast_size = need_cast ? miopen::GetTypeSize(miopenFloat) * n * k * ho * wo : 0; + + const int kID_trans_start = isGfx90aFp16altSupport ? 2 : 1; + + const TensorDescriptor cast_desc(miopenFloat, + ctx.conv_problem.GetOut().GetLengths(), + ctx.conv_problem.GetOut().GetStrides()); + auto null_buf = shared{}; + return [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) data_ctx = primitive_parameters.CastTo(); @@ -517,42 +597,71 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( const auto ker = handle.Run(kernels[(isGfx90aFp16altSupport && data_ctx.gfx90aFp16alt) ? 1 : 0]); float elapsed = 0; - TensorDescriptor workspaceDesc( - miopenFloat, tensors.outDesc.GetLengths(), tensors.outDesc.GetStrides()); - const bool need_cast = [&]() { - if(tensors.outDesc.GetType() == miopenHalf) - return use_fp32_global_split_on_fp16; - if(tensors.outDesc.GetType() == miopenBFloat16) - return need_set_zero; - return false; - }(); - - if(need_cast) - { - opArgs[0] = OpKernelArg(tensors.in); - opArgs[1] = OpKernelArg(tensors.w); - opArgs[2] = OpKernelArg(workSpace); - } - else - { - opArgs[0] = OpKernelArg(tensors.in); - opArgs[1] = OpKernelArg(tensors.w); - opArgs[2] = OpKernelArg(tensors.out); - } + auto trans_input_buf = + trans_input_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, trans_input_offset, trans_input_size); + auto trans_weight_buf = + trans_weight_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, trans_weight_offset, trans_weight_size); + auto trans_output_buf = + trans_output_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, trans_output_offset, trans_output_size); + auto cast_buf = cast_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, cast_offset, cast_size); if(need_set_zero) { + auto zero_buf = need_cast + ? cast_buf.get() + : ((is_nchw && !trans_output_skippable) ? trans_output_buf.get() + : tensors.out); + auto& zero_desc = + need_cast + ? cast_desc + : tensors.outDesc; // use the same desc for NCHW/NHWC for this dense tensor float zero = 0.f; - if(need_cast) - SetTensor(handle, workspaceDesc, workSpace, &zero); - else - SetTensor(handle, tensors.outDesc, tensors.out, &zero); + SetTensor(handle, zero_desc, zero_buf, &zero); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); } + if(is_nchw) + { + if(!trans_input_skippable) + { + auto& karg_input = opArgsTrans[trans_input_idx]; + karg_input[0] = OpKernelArg(trans_input_buf.get()); + karg_input[1] = OpKernelArg(tensors.in); + handle.Run(kernels[kID_trans_start + trans_input_idx])(karg_input); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + if(!trans_weight_skippable) + { + auto& karg_weight = opArgsTrans[trans_weight_idx]; + karg_weight[0] = OpKernelArg(trans_weight_buf.get()); + karg_weight[1] = OpKernelArg(tensors.w); + handle.Run(kernels[kID_trans_start + trans_weight_idx])(karg_weight); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + } + + opArgs[0] = (is_nchw && !trans_input_skippable) ? OpKernelArg(trans_input_buf.get()) + : OpKernelArg(tensors.in); + opArgs[1] = (is_nchw && !trans_weight_skippable) ? OpKernelArg(trans_weight_buf.get()) + : OpKernelArg(tensors.w); + + opArgs[2] = need_cast ? OpKernelArg(cast_buf.get()) + : ((is_nchw && !trans_output_skippable) + ? OpKernelArg(trans_output_buf.get()) + : OpKernelArg(tensors.out)); ker(opArgs); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); @@ -561,16 +670,27 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( { CastTensor(handle, &lowp_quant, - workspaceDesc, - workSpace, + cast_desc, + cast_buf.get(), tensors.outDesc, - tensors.out, + (is_nchw && !trans_output_skippable) ? trans_output_buf.get() + : tensors.out, 0, 0); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); } + if(is_nchw && !trans_output_skippable) + { + auto& karg_output = opArgsTrans[trans_output_idx]; + karg_output[0] = OpKernelArg(tensors.out); + karg_output[1] = OpKernelArg(trans_output_buf.get()); + handle.Run(kernels[kID_trans_start + trans_output_idx])(karg_output); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + if(handle.IsProfilingEnabled()) { handle.ResetKernelTime(); @@ -652,7 +772,6 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( if(y < stride_h || x < stride_w || dilation_h != 1 || dilation_w != 1) need_set_zero = true; need_set_zero |= config.gemm_k_global_split > 0; - bool use_global_split = config.gemm_k_global_split > 0; std::vector opArgs; opArgs.emplace_back(0); // placeholder @@ -698,9 +817,84 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( opArgs.emplace_back(shift_pack_0); opArgs.emplace_back(config.gemm_k_global_split); - const auto& lowp_quant = ctx.conv_problem.GetConv().lowp_quant; + std::vector> opArgsTrans; + + const auto lowp_quant = ctx.conv_problem.GetConv().lowp_quant; const auto isGfx90aFp16altSupport = (ctx.GetStream().GetDeviceName() == "gfx90a") && conv_problem.IsFp16(); + const bool need_cast = [&]() { + if(ctx.conv_problem.GetOut().GetType() == miopenHalf) + return use_fp32_global_split_on_fp16; + if(ctx.conv_problem.GetOut().GetType() == miopenBFloat16) + return need_set_zero; + return false; + }(); + const auto is_nchw = ctx.IsLayoutDefault(); + + size_t trans_input_offset = 0; + size_t trans_input_size = 0; + + size_t trans_weight_offset = 0; + size_t trans_weight_size = 0; + + size_t trans_output_offset = 0; + size_t trans_output_size = 0; + + bool trans_input_skippable = false; + bool trans_weight_skippable = false; + bool trans_output_skippable = false; + + int trans_input_idx = -1; + int trans_weight_idx = -1; + int trans_output_idx = -1; + + if(is_nchw) + { + TransposeSolutionNhwc2Default trans_input(ctx, ctx.out_data_type, n, c, hi, wi); + TransposeSolutionDefault2Nhwc trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionDefault2Nhwc trans_output(ctx, ctx.in_data_type, n, k, ho, wo); + + trans_input_skippable = trans_input.IsSkippable(); + trans_weight_skippable = trans_weight.IsSkippable(); + trans_output_skippable = trans_output.IsSkippable(); + + if(!trans_input_skippable) + opArgsTrans.emplace_back(trans_input.GetKernelArg()); + if(!trans_weight_skippable) + opArgsTrans.emplace_back(trans_weight.GetKernelArg()); + if(!trans_output_skippable) + opArgsTrans.emplace_back(trans_output.GetKernelArg()); + + trans_input_size = trans_input_skippable ? 0 : trans_input.GetSize(); + trans_weight_size = trans_weight_skippable ? 0 : trans_weight.GetSize(); + trans_output_size = trans_output_skippable ? 0 : trans_output.GetSize(); + + trans_weight_offset = trans_input_offset + trans_input_size; + trans_output_offset = trans_weight_offset + trans_weight_size; + + int idx = 0; + if(!trans_input_skippable) + trans_input_idx = idx++; + if(!trans_weight_skippable) + trans_weight_idx = idx++; + if(!trans_output_skippable) + trans_output_idx = idx++; + } + + const size_t cast_offset = is_nchw ? (trans_output_offset + trans_output_size) : 0; + const size_t cast_size = need_cast ? miopen::GetTypeSize(miopenFloat) * n * c * hi * wi : 0; + + const int kID_trans_start = isGfx90aFp16altSupport ? 2 : 1; + + const TensorDescriptor cast_desc(miopenFloat, + ctx.conv_problem.GetOut().GetLengths(), + ctx.conv_problem.GetOut().GetStrides()); + auto null_buf = shared{}; return [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { @@ -710,42 +904,71 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( const auto ker = handle.Run(kernels[(isGfx90aFp16altSupport && data_ctx.gfx90aFp16alt) ? 1 : 0]); float elapsed = 0; - TensorDescriptor workspaceDesc( - miopenFloat, tensors.outDesc.GetLengths(), tensors.outDesc.GetStrides()); - const bool need_cast = [&]() { - if(tensors.outDesc.GetType() == miopenHalf) - return use_fp32_global_split_on_fp16; - if(tensors.outDesc.GetType() == miopenBFloat16) - return use_global_split; - return false; - }(); - - if(need_cast) - { - opArgs[0] = OpKernelArg(workSpace); - opArgs[1] = OpKernelArg(tensors.w); - opArgs[2] = OpKernelArg(tensors.in); - } - else - { - opArgs[0] = OpKernelArg(tensors.out); - opArgs[1] = OpKernelArg(tensors.w); - opArgs[2] = OpKernelArg(tensors.in); - } + auto trans_input_buf = + trans_input_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, trans_input_offset, trans_input_size); + auto trans_weight_buf = + trans_weight_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, trans_weight_offset, trans_weight_size); + auto trans_output_buf = + trans_output_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, trans_output_offset, trans_output_size); + auto cast_buf = cast_size == 0 + ? null_buf + : handle.CreateSubBuffer(workSpace, cast_offset, cast_size); if(need_set_zero) { + auto zero_buf = need_cast + ? cast_buf.get() + : ((is_nchw && !trans_input_skippable) ? trans_input_buf.get() + : tensors.out); + auto& zero_desc = + need_cast + ? cast_desc + : tensors.outDesc; // use the same desc for NCHW/NHWC for this dense tensor float zero = 0.f; - if(need_cast) - SetTensor(handle, workspaceDesc, workSpace, &zero); - else - SetTensor(handle, tensors.outDesc, tensors.out, &zero); + SetTensor(handle, zero_desc, zero_buf, &zero); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); } + if(is_nchw) + { + if(!trans_output_skippable) + { + auto& karg_output = opArgsTrans[trans_output_idx]; + karg_output[0] = OpKernelArg(trans_output_buf.get()); + karg_output[1] = OpKernelArg(tensors.in); + handle.Run(kernels[kID_trans_start + trans_output_idx])(karg_output); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + if(!trans_weight_skippable) + { + auto& karg_weight = opArgsTrans[trans_weight_idx]; + karg_weight[0] = OpKernelArg(trans_weight_buf.get()); + karg_weight[1] = OpKernelArg(tensors.w); + handle.Run(kernels[kID_trans_start + trans_weight_idx])(karg_weight); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + } + + opArgs[0] = need_cast ? OpKernelArg(cast_buf.get()) + : ((is_nchw && !trans_input_skippable) + ? OpKernelArg(trans_input_buf.get()) + : OpKernelArg(tensors.out)); + opArgs[1] = (is_nchw && !trans_weight_skippable) ? OpKernelArg(trans_weight_buf.get()) + : OpKernelArg(tensors.w); + opArgs[2] = (is_nchw && !trans_output_skippable) ? OpKernelArg(trans_output_buf.get()) + : OpKernelArg(tensors.in); + ker(opArgs); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); @@ -754,15 +977,25 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( { CastTensor(handle, &lowp_quant, - workspaceDesc, - workSpace, + cast_desc, + cast_buf.get(), tensors.outDesc, - tensors.out, + (is_nchw && !trans_input_skippable) ? trans_input_buf.get() + : tensors.out, 0, 0); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); } + if((is_nchw && !trans_input_skippable)) + { + auto& karg_input = opArgsTrans[trans_input_idx]; + karg_input[0] = OpKernelArg(tensors.out); + karg_input[1] = OpKernelArg(trans_input_buf.get()); + handle.Run(kernels[kID_trans_start + trans_input_idx])(karg_input); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } if(handle.IsProfilingEnabled()) { diff --git a/src/hip/batched_transpose_sol.cpp b/src/hip/batched_transpose_sol.cpp new file mode 100644 index 0000000000..51a6a99359 --- /dev/null +++ b/src/hip/batched_transpose_sol.cpp @@ -0,0 +1,382 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define BATCHED_TRANSPOSE_BLOCK_SIZE 256 +#define BATCHED_TRANSPOSE_PERSISTENT 0 + +#if BATCHED_TRANSPOSE_PERSISTENT +#define BATCHED_TRANSPOSE_OCCUPANCY 4 +#endif + +namespace miopen { +namespace batched_transpose { + +static inline std::string GetNameTrait(std::size_t type_size) +{ + if(type_size == 1) + return "byte"; + if(type_size == 2) + return "half"; + if(type_size == 4) + return "dword"; + MIOPEN_THROW("data type not supported"); +} + +static inline const std::vector& GetKernelList(std::size_t data_size) +{ + if(data_size == 1) + { + static const std::vector byte_kernel_list{ + // clang-format off + {16, 16, 1, 1, 1, 1}, + {16, 32, 1, 1, 1, 1}, + {32, 16, 1, 1, 1, 1}, + {32, 32, 1, 1, 1, 1}, + + {4, 64, 1, 1, 1, 1}, + {64, 4, 1, 1, 1, 1}, + {4, 128, 1, 1, 1, 1}, + {128, 4, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {256, 4, 1, 1, 1, 1}, + // clang-format on + }; + return byte_kernel_list; + } + if(data_size == 2) + { + static const std::vector half_kernel_list{ + // clang-format off + {16, 16, 1, 1, 1, 1}, + {32, 16, 1, 1, 1, 1}, + {16, 32, 1, 1, 1, 1}, + {32, 32, 1, 1, 1, 1}, + + {4, 64, 1, 1, 1, 1}, + {64, 4, 1, 1, 1, 1}, + {4, 128, 1, 1, 1, 1}, + {128, 4, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {256, 4, 1, 1, 1, 1}, + + {32, 32, 2, 2, 1, 1}, + {32, 32, 2, 2, 1, 2}, + {32, 32, 2, 2, 2, 1}, + {32, 32, 2, 2, 2, 2}, + + {16, 64, 1, 4, 1, 2}, + {64, 16, 4, 1, 2, 1}, + + {32, 64, 2, 4, 1, 2}, + {32, 64, 2, 4, 2, 2}, + {32, 64, 2, 4, 2, 4}, + + {64, 32, 4, 2, 2, 1}, + {64, 32, 4, 2, 2, 2}, + {64, 32, 4, 2, 4, 2}, + + {64, 64, 4, 4, 2, 2}, + {64, 64, 4, 4, 4, 4}, + // clang-format on + }; + return half_kernel_list; + } + if(data_size == 4) + { + static const std::vector dword_kernel_list{ + // clang-format off + {16, 16, 1, 1, 1, 1}, + {16, 32, 1, 1, 1, 1}, + {32, 16, 1, 1, 1, 1}, + {32, 32, 1, 1, 1, 1}, + + {4, 64, 1, 1, 1, 1}, + {64, 4, 1, 1, 1, 1}, + {4, 128, 1, 1, 1, 1}, + {128, 4, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {256, 4, 1, 1, 1, 1}, + // clang-format on + }; + return dword_kernel_list; + } + MIOPEN_THROW("data type not supported"); +} + +static inline bool IsApplicable(uint32_t /* batch */, + uint32_t height, + uint32_t width, + const BatchedTransposeParam* kparam) +{ + return width % kparam->ediv_x == 0 && height % kparam->ediv_y == 0; +} + +static inline bool IsSameSide(uint32_t height, uint32_t width, const BatchedTransposeParam* kparam) +{ + float radio = 0; + if(width > height) + radio = static_cast(kparam->tile_x) / kparam->tile_y; + else + radio = static_cast(kparam->tile_y) / kparam->tile_x; + + // E.g. for cases like width=1000, height=10 + // Allow at least 32x64, 64x64... 16x64 not allowed + return radio >= 0.4; +} + +template +static inline float GetNormalizedRadio(T x, T y) +{ + if(y > x) + return static_cast(y) / x; + return static_cast(x) / y; +} + +static inline std::string GetKernelName(std::size_t data_size, const BatchedTransposeParam* kparam) +{ + std::ostringstream kernel_name; + std::string type_trait = GetNameTrait(data_size); + kernel_name << "batched_transpose_" << kparam->tile_x << "x" << kparam->tile_y << "_"; + if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) + { + kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" + << kparam->ediv_x << "x" << kparam->ediv_y << "_"; + } + kernel_name << type_trait; + return kernel_name.str(); +} + +static inline std::size_t GetExtraPaddingSize(uint32_t /* batch */, + uint32_t height, + uint32_t width, + const BatchedTransposeParam* kparam) +{ + // For simplicity and speed, we ignore batch, only compute h*w + uint32_t padded_h = ((height + kparam->tile_y - 1) / kparam->tile_y) * kparam->tile_y; + uint32_t padded_w = ((width + kparam->tile_x - 1) / kparam->tile_x) * kparam->tile_x; + return static_cast(padded_h) * padded_w - static_cast(height) * width; +} + +static inline BatchedTransposeParam +HeuristicGet(std::size_t data_size, uint32_t batch, uint32_t height, uint32_t width) +{ + /* + * Iterate from big tile size to small tile size, and try match ediv first + * If every kernel is applicable, then will pick up the bigest one + * If need extra padding in h/w (due to tile size), then will pick up kernel that waste the + * samllest. + */ + + const auto& kernel_list = GetKernelList(data_size); + BatchedTransposeParam best_kernel; + std::size_t extra_padding_size = std::numeric_limits::max(); + float hw_radio = GetNormalizedRadio(height, width); + + if(hw_radio >= 12 && (height <= 8 || width <= 8)) + { + // Early heuristic for cases that has very large width, very small height (or vice versa) + if(hw_radio <= 48) + { + return (width <= 8) ? BatchedTransposeParam{4, 64, 1, 1, 1, 1} + : BatchedTransposeParam{64, 4, 1, 1, 1, 1}; + } + else if(hw_radio <= 128) + { + return (width <= 8) ? BatchedTransposeParam{4, 128, 1, 1, 1, 1} + : BatchedTransposeParam{128, 4, 1, 1, 1, 1}; + } + else + { + return (width <= 8) ? BatchedTransposeParam{4, 256, 1, 1, 1, 1} + : BatchedTransposeParam{256, 4, 1, 1, 1, 1}; + } + } + + for(auto it = kernel_list.rbegin(); it != kernel_list.rend(); it++) + { + if(it->tile_x == 4 || it->tile_y == 4) // We don't want such kernel to be selected here, + // they should be used in above cases + continue; + if(!IsApplicable(batch, height, width, &(*it))) + continue; + std::size_t current_padding_size = GetExtraPaddingSize(batch, height, width, &(*it)); + bool replace_current = false; + if(best_kernel.tile_x == 0 && best_kernel.tile_y == 0) + { + // 1st applicable case + replace_current = true; + } + if(hw_radio > 128) + { + // This is for cases that h, w have a great difference + if(!IsSameSide(height, width, &(*it))) + continue; + float prev_radio = GetNormalizedRadio( + GetNormalizedRadio(best_kernel.tile_y, best_kernel.tile_x), hw_radio); + float curr_radio = + GetNormalizedRadio(GetNormalizedRadio(it->tile_y, it->tile_x), hw_radio); + + if(curr_radio * current_padding_size < prev_radio * extra_padding_size) + { + if(curr_radio <= prev_radio) + { + replace_current = true; + } + } + else if(float_equal(curr_radio * current_padding_size, prev_radio * extra_padding_size)) + { + // If width == height, a greate chance is that the kernel performance would be + // almost the same, so ignore this case + if((width > height && it->tile_x > it->tile_y && + best_kernel.tile_x < best_kernel.tile_y) || + (width < height && it->tile_x < it->tile_y && + best_kernel.tile_x > best_kernel.tile_y)) + { + replace_current = true; + } + } + } + else + { + if(current_padding_size < extra_padding_size) + { + replace_current = true; + } + } + + if(replace_current) + { + extra_padding_size = current_padding_size; + best_kernel = *it; + } + } + + assert(extra_padding_size != std::numeric_limits::max()); // Impossible + return best_kernel; +} + +} // namespace batched_transpose + +BatchedTransposeSolution::BatchedTransposeSolution(const ExecutionContext& ctx, + miopenDataType_t data_type_, + uint32_t batch_, + uint32_t height_, + uint32_t width_) + : data_type(data_type_), batch(batch_), height(height_), width(width_) +{ + if(data_type == miopenInt8x4 || data_type == miopenDouble) + MIOPEN_THROW("These data type are not supported"); + num_cu = ctx.GetStream().GetMaxComputeUnits(); + std::size_t data_size = miopen::GetTypeSize(data_type); + kernel_param_heuristic = batched_transpose::HeuristicGet(data_size, batch, height, width); +} + +solver::KernelInfo BatchedTransposeSolution::GetKernel() const +{ + std::size_t block_size = BATCHED_TRANSPOSE_BLOCK_SIZE; +#if BATCHED_TRANSPOSE_PERSISTENT + std::size_t grid_size = num_cu * BATCHED_TRANSPOSE_OCCUPANCY; +#else + uint32_t dim_h = (height + kernel_param_heuristic.tile_y - 1) / kernel_param_heuristic.tile_y; + uint32_t dim_w = (width + kernel_param_heuristic.tile_x - 1) / kernel_param_heuristic.tile_x; + std::size_t grid_size = batch * dim_h * dim_w; +#endif + std::string kernel_name = GetKernelName(); + solver::KernelInfo kernel; + kernel.kernel_file = "batched_transpose.cpp"; + kernel.kernel_name = kernel_name; + kernel.g_wk.clear(); + kernel.g_wk.push_back(grid_size * block_size); + kernel.g_wk.push_back(1); + kernel.g_wk.push_back(1); + kernel.l_wk.clear(); + kernel.l_wk.push_back(block_size); + kernel.l_wk.push_back(1); + kernel.l_wk.push_back(1); + + MIOPEN_LOG_I2("BatchedTransposeSolution use kernel: " + kernel_name); + + return kernel; +} + +std::vector BatchedTransposeSolution::GetKernelArg() const +{ + uint32_t dim_h = (height + kernel_param_heuristic.tile_y - 1) / kernel_param_heuristic.tile_y; + uint32_t dim_w = (width + kernel_param_heuristic.tile_x - 1) / kernel_param_heuristic.tile_x; + uint32_t dim_total = batch * dim_h * dim_w; +#if BATCHED_TRANSPOSE_PERSISTENT + std::size_t grid_size = num_cu * BATCHED_TRANSPOSE_OCCUPANCY; +#else + std::size_t grid_size = batch * dim_h * dim_w; +#endif + + magic_div_u32_t magic_h = magic_div_u32_gen(dim_h); + magic_div_u32_t magic_w = magic_div_u32_gen(dim_w); + + std::vector opArgs; + opArgs.emplace_back(0); // placeholder + opArgs.emplace_back(0); // placeholder + opArgs.emplace_back(height); + opArgs.emplace_back(width); + opArgs.emplace_back(static_cast(grid_size)); + opArgs.emplace_back(dim_total); + opArgs.emplace_back(magic_h.magic); + opArgs.emplace_back(static_cast(magic_h.shift)); + opArgs.emplace_back(magic_w.magic); + opArgs.emplace_back(static_cast(magic_w.shift)); + + return opArgs; +} + +std::string BatchedTransposeSolution::GetKernelName() const +{ + std::size_t data_size = miopen::GetTypeSize(data_type); + return batched_transpose::GetKernelName(data_size, &kernel_param_heuristic); +} + +bool BatchedTransposeSolution::IsSkippable() const +{ + // If height or width is 1, actually no need to do transpose. + // But nonthing prevent you from DO transpose... + return height == 1 || width == 1; +} + +size_t BatchedTransposeSolution::GetSize() const +{ + return miopen::GetTypeSize(data_type) * batch * height * width; +} + +} // namespace miopen diff --git a/src/hip/handlehip.cpp b/src/hip/handlehip.cpp index 7905632ab1..caeca4bce4 100644 --- a/src/hip/handlehip.cpp +++ b/src/hip/handlehip.cpp @@ -572,13 +572,13 @@ std::ostream& Handle::Print(std::ostream& os) const return os; } -shared Handle::CreateSubBuffer(Data_t data, std::size_t offset, std::size_t) +shared Handle::CreateSubBuffer(Data_t data, std::size_t offset, std::size_t) const { auto cdata = reinterpret_cast(data); return {cdata + offset, null_deleter{}}; } -shared Handle::CreateSubBuffer(ConstData_t data, std::size_t offset, std::size_t) +shared Handle::CreateSubBuffer(ConstData_t data, std::size_t offset, std::size_t) const { auto cdata = reinterpret_cast(data); return {cdata + offset, null_deleter{}}; diff --git a/src/include/miopen/batched_transpose_sol.hpp b/src/include/miopen/batched_transpose_sol.hpp new file mode 100644 index 0000000000..c912669c63 --- /dev/null +++ b/src/include/miopen/batched_transpose_sol.hpp @@ -0,0 +1,71 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_BATCHED_TRANSPOSE_SOL_HPP +#define GUARD_MIOPEN_BATCHED_TRANSPOSE_SOL_HPP + +#include +#include +#include +#include +#include + +namespace miopen { + +struct BatchedTransposeParam +{ + int tile_x{0}; + int tile_y{0}; + int pack_x{0}; + int pack_y{0}; + int ediv_x{0}; + int ediv_y{0}; +}; + +struct BatchedTransposeSolution +{ + BatchedTransposeSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t batch_, + uint32_t height_, + uint32_t width_); + solver::KernelInfo GetKernel() const; + std::vector GetKernelArg() const; + std::string GetKernelName() const; + bool IsSkippable() const; + size_t GetSize() const; + + miopenDataType_t data_type; + uint32_t batch; + uint32_t height; + uint32_t width; + int num_cu; + + BatchedTransposeParam kernel_param_heuristic; +}; + +} // namespace miopen + +#endif diff --git a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp index 73ab47580a..f30fbe4cff 100644 --- a/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp +++ b/src/include/miopen/conv/invokers/impl_gemm_dynamic.hpp @@ -34,69 +34,12 @@ #include #include #include +#include #include namespace miopen { namespace conv { -struct magic_div_u32 -{ - uint32_t magic; - uint8_t shift; -}; - -using magic_div_u32_t = magic_div_u32; - -/* - * - * numer / denom = quotient, reminder - * - * use magic number to do integer division of uint32 (acctually INT32_MAX, the 31 bit divisoin) - * most algorithm to compute uint32 need branching if cover all 32 bit of uint32. - * since we compute the magic number on host side, implement the division in gpu side, it is better - * not use branching - * hence add more restriction to numer and denom, to be 1 bit less. hence need less-or-equal than - * INT32_MAX - * - * magic_div_u32_gen() compute from input arg d, to get a magic and a shift. - * to use the value, below is a example host-side code to do this - * - * // host side version - * static inline uint32_t magic_div_mulhi_u32(uint32_t x, uint32_t y) { - * uint64_t xl = x, yl = y; - * uint64_t rl = xl * yl; - * return (uint32_t)(rl >> 32); - * } - * uint32_t magic_div_u32_do(uint32_t numer, const struct magic_div_u32_t *denom) { - * uint32_t tmp = magic_div_mulhi_u32(denom->magic, numer); - * return (tmp + numer) >> denom->shift; - * } - * - */ -static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) -{ - assert(d >= 1 && d <= INT32_MAX); - uint8_t shift; - for(shift = 0; shift < 32; shift++) - if((1U << shift) >= d) - break; - - constexpr uint64_t one = 1; - uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; - assert(magic <= 0xffffffffUL); - - return {static_cast(magic), shift}; -} - -static inline uint32_t magic_div_u32_pack_shift(uint8_t s0, uint8_t s1, uint8_t s2, uint8_t s3) -{ - uint32_t shift_0 = s0; - uint32_t shift_1 = s1; - uint32_t shift_2 = s2; - uint32_t shift_3 = s3; - return (shift_3 << 24) | (shift_2 << 16) | (shift_1 << 8) | shift_0; -} - template inline std::vector ComputeDynamicIGemmForwardKernelArgs(const ProblemDescription& conv_problem, const T& cfg); diff --git a/src/include/miopen/handle.hpp b/src/include/miopen/handle.hpp index e1732874c2..9f81a24216 100644 --- a/src/include/miopen/handle.hpp +++ b/src/include/miopen/handle.hpp @@ -169,9 +169,10 @@ struct Handle : miopenHandle Allocator::ManageDataPtr& WriteTo(const void* data, Allocator::ManageDataPtr& ddata, std::size_t sz) const; void ReadTo(void* data, const Allocator::ManageDataPtr& ddata, std::size_t sz) const; - shared CreateSubBuffer(Data_t data, std::size_t offset, std::size_t size); + shared CreateSubBuffer(Data_t data, std::size_t offset, std::size_t size) const; #if MIOPEN_BACKEND_HIP - shared CreateSubBuffer(ConstData_t data, std::size_t offset, std::size_t size); + shared + CreateSubBuffer(ConstData_t data, std::size_t offset, std::size_t size) const; #endif template diff --git a/src/include/miopen/magic_div.hpp b/src/include/miopen/magic_div.hpp new file mode 100644 index 0000000000..c9353443cf --- /dev/null +++ b/src/include/miopen/magic_div.hpp @@ -0,0 +1,95 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef GUARD_MLOPEN_MAGIC_DIV_HPP_ +#define GUARD_MLOPEN_MAGIC_DIV_HPP_ + +#include +#include + +namespace miopen { + +struct magic_div_u32 +{ + uint32_t magic; + uint8_t shift; +}; + +using magic_div_u32_t = magic_div_u32; + +/* + * + * numer / denom = quotient, reminder + * + * use magic number to do integer division of uint32 (acctually INT32_MAX, the 31 bit divisoin) + * most algorithm to compute uint32 need branching if cover all 32 bit of uint32. + * since we compute the magic number on host side, implement the division in gpu side, it is better + * not use branching + * hence add more restriction to numer and denom, to be 1 bit less. hence need less-or-equal than + * INT32_MAX + * + * magic_div_u32_gen() compute from input arg d, to get a magic and a shift. + * to use the value, below is a example host-side code to do this + * + * // host side version + * static inline uint32_t magic_div_mulhi_u32(uint32_t x, uint32_t y) { + * uint64_t xl = x, yl = y; + * uint64_t rl = xl * yl; + * return (uint32_t)(rl >> 32); + * } + * uint32_t magic_div_u32_do(uint32_t numer, const struct magic_div_u32_t *denom) { + * uint32_t tmp = magic_div_mulhi_u32(denom->magic, numer); + * return (tmp + numer) >> denom->shift; + * } + * + */ +static inline magic_div_u32_t magic_div_u32_gen(uint32_t d) +{ + assert(d >= 1 && d <= INT32_MAX); + uint8_t shift; + for(shift = 0; shift < 32; shift++) + if((1U << shift) >= d) + break; + + constexpr uint64_t one = 1; + uint64_t magic = ((one << 32) * ((one << shift) - d)) / d + 1; + assert(magic <= 0xffffffffUL); + + return {static_cast(magic), shift}; +} + +static inline uint32_t magic_div_u32_pack_shift(uint8_t s0, uint8_t s1, uint8_t s2, uint8_t s3) +{ + uint32_t shift_0 = s0; + uint32_t shift_1 = s1; + uint32_t shift_2 = s2; + uint32_t shift_3 = s3; + return (shift_3 << 24) | (shift_2 << 16) | (shift_1 << 8) | shift_0; +} + +} // namespace miopen + +#endif diff --git a/src/include/miopen/util_sol.hpp b/src/include/miopen/util_sol.hpp new file mode 100644 index 0000000000..c6b5f1e6e7 --- /dev/null +++ b/src/include/miopen/util_sol.hpp @@ -0,0 +1,65 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c_) 202 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_UTIL_SOL_HPP_ +#define MIOPEN_UTIL_SOL_HPP_ + +#include +#include +#include +#include +#include +#include + +namespace miopen { + +struct TransposeSolutionDefault2Nhwc : public BatchedTransposeSolution +{ + TransposeSolutionDefault2Nhwc(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t n_, + uint32_t c_, + uint32_t h_, + uint32_t w_) + : BatchedTransposeSolution(ctx_, data_type_, n_, c_, h_ * w_) + { + } +}; + +struct TransposeSolutionNhwc2Default : public BatchedTransposeSolution +{ + TransposeSolutionNhwc2Default(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t n_, + uint32_t c_, + uint32_t h_, + uint32_t w_) + : BatchedTransposeSolution(ctx_, data_type_, n_, h_ * w_, c_) + { + } +}; +} // namespace miopen + +#endif // MIOPEN_UTIL_SOL_HPP_ diff --git a/src/kernels/gpu_batched_transpose_kernel/batched_transpose.cpp b/src/kernels/gpu_batched_transpose_kernel/batched_transpose.cpp new file mode 100644 index 0000000000..9383ca3c3e --- /dev/null +++ b/src/kernels/gpu_batched_transpose_kernel/batched_transpose.cpp @@ -0,0 +1,2450 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020-2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include + +#ifndef BATCHED_TRANSPOSE_OCCUPANCY +#define BATCHED_TRANSPOSE_OCCUPANCY 4 +#endif + +// Disable specific warnings +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wconditional-uninitialized" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wsometimes-uninitialized" +#endif + +inline __device__ uint32_t magic_div_u32(const uint32_t& numer, + const uint32_t& magic, + const uint32_t& shift) +{ + uint32_t tmp = __umulhi(numer, magic); + return (tmp + numer) >> shift; +} + +inline __device__ void v_pack_b32_f16_00(float& c, const float& a, const float& b) +{ +#if 0 + asm volatile("v_pack_b32_f16 %0, %1, %2\n" + : "=v"(c) + : "v"(a), "v"(b)); +#else + // cppcheck-suppress invalidPointerCast + const uint32_t x = *reinterpret_cast(&a); + // cppcheck-suppress invalidPointerCast + const uint32_t y = *reinterpret_cast(&b); + uint32_t z = (x & 0xffff) | ((y & 0xffff) << 16); + // cppcheck-suppress invalidPointerCast + c = *reinterpret_cast(&z); +#endif +} + +inline __device__ void v_pack_b32_f16_11(float& c, const float& a, const float& b) +{ +#if 0 + asm volatile("v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]\n" + : "=v"(c) + : "v"(a), "v"(b)); +#else + // cppcheck-suppress invalidPointerCast + const uint32_t x = *reinterpret_cast(&a); + // cppcheck-suppress invalidPointerCast + const uint32_t y = *reinterpret_cast(&b); + uint32_t z = ((x & 0xffff0000) >> 16) | (y & 0xffff0000); + // cppcheck-suppress invalidPointerCast + c = *reinterpret_cast(&z); +#endif +} + +inline __device__ void v_pack_b32_f16_2x2(float& y0, float& y1, const float& x0, const float& x1) +{ +#if 0 + asm volatile("\n \ + v_pack_b32_f16 %0, %2, %3\n \ + v_pack_b32_f16 %1, %2, %3 op_sel:[1, 1]\n" + : "=v"(y0), "=v"(y1) + : "v"(x0), "v"(x1), "0"(y0), "1"(y1)); +#else + // cppcheck-suppress invalidPointerCast + const uint32_t a0 = *reinterpret_cast(&x0); + // cppcheck-suppress invalidPointerCast + const uint32_t a1 = *reinterpret_cast(&x1); + uint32_t b0 = (a0 & 0xffff) | ((a1 & 0xffff) << 16); + uint32_t b1 = ((a0 & 0xffff0000) >> 16) | (a1 & 0xffff0000); + // cppcheck-suppress invalidPointerCast + y0 = *reinterpret_cast(&b0); + // cppcheck-suppress invalidPointerCast + y1 = *reinterpret_cast(&b1); +#endif +} + +inline __device__ void v_pack_b32_f16_2x2_half_x0( + float& y0, float& y1, const ushort& x0_lo, const ushort& x0_hi, const float& x1) +{ + // cppcheck-suppress invalidPointerCast + const uint32_t a1 = *reinterpret_cast(&x1); + uint32_t b0 = x0_lo | ((a1 & 0xffff) << 16); + uint32_t b1 = x0_hi | (a1 & 0xffff0000); + // cppcheck-suppress invalidPointerCast + y0 = *reinterpret_cast(&b0); + // cppcheck-suppress invalidPointerCast + y1 = *reinterpret_cast(&b1); +} + +inline __device__ void v_pack_b32_f16_2x2_half_x1( + float& y0, float& y1, const float& x0, const ushort& x1_lo, const ushort& x1_hi) +{ + // cppcheck-suppress invalidPointerCast + const uint32_t a0 = *reinterpret_cast(&x0); + uint32_t b0 = (a0 & 0xffff) | (x1_lo << 16); + uint32_t b1 = ((a0 & 0xffff0000) >> 16) | (x1_hi << 16); + // cppcheck-suppress invalidPointerCast + y0 = *reinterpret_cast(&b0); + // cppcheck-suppress invalidPointerCast + y1 = *reinterpret_cast(&b1); +} + +inline __device__ void v_pack_b32_f16_2x2_half_x0_half_x1(float& y0, + float& y1, + const ushort& x0_lo, + const ushort& x0_hi, + const ushort& x1_lo, + const ushort& x1_hi) +{ + uint32_t b0 = x0_lo | (x1_lo << 16); + uint32_t b1 = x0_hi | (x1_hi << 16); + // cppcheck-suppress invalidPointerCast + y0 = *reinterpret_cast(&b0); + // cppcheck-suppress invalidPointerCast + y1 = *reinterpret_cast(&b1); +} + +template +struct mapped_vector_type +{ +}; + +template <> +struct mapped_vector_type +{ + using type = float4; +}; + +template <> +struct mapped_vector_type +{ + using type = float2; +}; + +template <> +struct mapped_vector_type +{ + using type = float; +}; + +template <> +struct mapped_vector_type +{ + using type = ushort4; +}; + +template <> +struct mapped_vector_type +{ + using type = ushort2; +}; + +template <> +struct mapped_vector_type +{ + using type = ushort; +}; + +template <> +struct mapped_vector_type +{ + using type = uchar4; +}; + +template <> +struct mapped_vector_type +{ + using type = uchar2; +}; + +template <> +struct mapped_vector_type +{ + using type = uchar; +}; + +template +inline __device__ void batched_transpose_16x16(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + /* + * assume src is batch * height * width, dst is batch * width * height + */ + constexpr auto element_byte = sizeof(T); + constexpr auto padding_element = 4 / element_byte; + constexpr auto smem_stride = 16 + padding_element; + __shared__ T smem[16 * smem_stride]; + + uint32_t h_dim = (height + 15) >> 4; + uint32_t w_dim = (width + 15) >> 4; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 4) + i_src_h; + + __syncthreads(); + if(g_src_h < height && g_src_w < width) + { + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + smem[i_src_h * smem_stride + i_src_w] = src[src_index]; + } + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 4) + i_dst_w; + + if(g_dst_h < height && g_dst_w < width) + { + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + dst[dst_index] = smem[i_dst_h * smem_stride + i_dst_w]; + } + } +} + +template +inline __device__ void batched_transpose_32x16(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + /* + * assume src is batch * height * width, dst is batch * width * height + */ + constexpr auto element_byte = sizeof(T); + constexpr auto padding_element = 4 / element_byte; + constexpr auto smem_stride = 16 + padding_element; + __shared__ T smem[32 * smem_stride]; + + uint32_t h_dim = (height + 15) >> 4; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 4) + i_src_h; + + __syncthreads(); + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + T v_src[2]; + if(g_src_h < height && g_src_w < width) + { + v_src[0] = src[src_index]; + } + if(g_src_h < height && (g_src_w + 16) < width) + { + v_src[1] = src[src_index + 16]; + } + smem[i_src_h * smem_stride + i_src_w] = v_src[0]; + smem[i_src_h * smem_stride + i_src_w + 16 * smem_stride] = v_src[1]; + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + i_dst_w; + + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + T v_dst[2]; + v_dst[0] = smem[i_dst_h * smem_stride + i_dst_w]; + v_dst[1] = smem[i_dst_h * smem_stride + i_dst_w + 16 * smem_stride]; + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_dst[0]; + } + if(g_dst_h < height && (g_dst_w + 16) < width) + { + dst[dst_index + 16 * height] = v_dst[1]; + } + } +} + +template +inline __device__ void batched_transpose_16x32(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + /* + * assume src is batch * height * width, dst is batch * width * height + */ + constexpr auto element_byte = sizeof(T); + constexpr auto padding_element = 4 / element_byte; + constexpr auto smem_stride = 16 + padding_element; + __shared__ T smem[32 * smem_stride]; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 15) >> 4; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + i_src_h; + + __syncthreads(); + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + T v_src[2]; + if(g_src_h < height && g_src_w < width) + { + v_src[0] = src[src_index]; + } + if((g_src_h + 16) < height && g_src_w < width) + { + v_src[1] = src[src_index + 16 * width]; + } + smem[i_src_h * smem_stride + i_src_w] = v_src[0]; + smem[i_src_h * smem_stride + i_src_w + 16 * smem_stride] = v_src[1]; + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 4) + i_dst_w; + + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + T v_dst[2]; + v_dst[0] = smem[i_dst_h * smem_stride + i_dst_w]; + v_dst[1] = smem[i_dst_h * smem_stride + i_dst_w + 16 * smem_stride]; + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_dst[0]; + } + if((g_dst_h + 16) < height && g_dst_w < width) + { + dst[dst_index + 16] = v_dst[1]; + } + } +} + +template +inline __device__ void batched_transpose_32x32(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + /* + * assume src is batch * height * width, dst is batch * width * height + */ + constexpr auto smem_stride = 17; + using vec_t = typename mapped_vector_type::type; + __shared__ vec_t smem[16 * smem_stride]; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + i_src_h; + + __syncthreads(); + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + vec_t v_src; + if(g_src_h < height && g_src_w < width) + { + v_src.x = src[src_index]; + } + if(g_src_h < height && (g_src_w + 16) < width) + { + v_src.z = src[src_index + 16]; + } + if((g_src_h + 16) < height && g_src_w < width) + { + v_src.y = src[src_index + 16 * width]; + } + if((g_src_h + 16) < height && (g_src_w + 16) < width) + { + v_src.w = src[src_index + 16 * width + 16]; + } + smem[i_src_h * smem_stride + i_src_w] = v_src; + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + i_dst_w; + + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + vec_t v_dst = smem[i_dst_h * smem_stride + i_dst_w]; + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_dst.x; + } + if((g_dst_h + 16) < height && g_dst_w < width) + { + dst[dst_index + 16] = v_dst.y; + } + if(g_dst_h < height && (g_dst_w + 16) < width) + { + dst[dst_index + 16 * height] = v_dst.z; + } + if((g_dst_h + 16) < height && (g_dst_w + 16) < width) + { + dst[dst_index + 16 * height + 16] = v_dst.w; + } + } +} + +template +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_2x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_2x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float smem[32 * smem_stride]; + + float* p_dst = reinterpret_cast(dst); + float* p_src = reinterpret_cast(src); + + uint32_t height_2 = height >> 1; + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + (i_src_h << 1); + + __syncthreads(); + if(g_src_h < height && g_src_w < width_2) + { + float v_a, v_b, v_a_pack, v_b_pack; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + + static_cast(g_src_w); + v_a = p_src[src_index]; + v_b = p_src[src_index + width_2]; + v_pack_b32_f16_2x2(v_a_pack, v_b_pack, v_a, v_b); + + smem[i_src_w * smem_stride + i_src_h] = v_a_pack; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_b_pack; + } + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + (i_dst_w << 1); + + if(g_dst_h < height_2 && g_dst_w < width) + { + float v_a, v_b; + v_a = smem[i_dst_w * smem_stride + i_dst_h]; + v_b = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + + static_cast(g_dst_h); + p_dst[dst_index] = v_a; + p_dst[dst_index + height_2] = v_b; + } + } +} + +template +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_1x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_1x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float smem[32 * smem_stride]; + + ushort* p_src = src; + float* p_dst = reinterpret_cast(dst); + + uint32_t height_2 = height >> 1; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + (i_src_h << 1); + + ushort v_src[4]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + __syncthreads(); + if(g_src_h < height && g_src_w < width) + { + v_src[0] = p_src[src_index]; + v_src[2] = p_src[src_index + width]; + } + if(g_src_h < height && (g_src_w + 16) < width) + { + v_src[1] = p_src[src_index + 16]; + v_src[3] = p_src[src_index + width + 16]; + } + + float v_pack[2]; + v_pack_b32_f16_2x2_half_x0_half_x1( + v_pack[0], v_pack[1], v_src[0], v_src[1], v_src[2], v_src[3]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack[0]; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_pack[1]; + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + i_dst_w; + + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + static_cast(g_dst_h); + + float v_a, v_b; + v_a = smem[i_dst_w * smem_stride + i_dst_h]; + v_b = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + if(g_dst_h < height_2 && g_dst_w < width) + { + p_dst[dst_index] = v_a; + } + + if(g_dst_h < height_2 && (g_dst_w + 16) < width) + { + p_dst[dst_index + 16 * height_2] = v_b; + } + } +} + +template +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_2x1(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_2x1(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float smem[32 * smem_stride]; + + float* p_src = reinterpret_cast(src); + ushort* p_dst = dst; + + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + i_src_h; + + float v_src[2]; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + static_cast(g_src_w); + __syncthreads(); + + if(g_src_h < height && g_src_w < width_2) + { + v_src[0] = p_src[src_index]; + } + if((g_src_h + 16) < height && g_src_w < width_2) + { + v_src[1] = p_src[src_index + 16 * width_2]; + } + + float v_pack[2]; + v_pack_b32_f16_2x2(v_pack[0], v_pack[1], v_src[0], v_src[1]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack[0]; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_pack[1]; + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + (i_dst_w << 1); + + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + float v_dst[2]; + v_dst[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v_dst[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + + ushort2 v_dst_buf[2]; + v_dst_buf[0] = *reinterpret_cast(&v_dst[0]); + v_dst_buf[1] = *reinterpret_cast(&v_dst[1]); + if(g_dst_h < height && g_dst_w < width) + { + p_dst[dst_index] = v_dst_buf[0].x; + p_dst[dst_index + height] = v_dst_buf[1].x; + } + + if((g_dst_h + 16) < height && g_dst_w < width) + { + p_dst[dst_index + 16] = v_dst_buf[0].y; + p_dst[dst_index + height + 16] = v_dst_buf[1].y; + } + } +} + +template +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_1x1(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_32x32_pack_2x2_ediv_1x1(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float smem[32 * smem_stride]; + + ushort* p_src = src; + ushort* p_dst = dst; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + i_src_h; + + ushort v_src[4]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + __syncthreads(); + /* + * 4x4 -> 4x4 transpose: (0, 1, 2, 3 is in ushort, a, b in float) + * lo hi + * |0|1| lo |a|b| + * |2|3| -> hi |_|_| + */ + + if(g_src_h < height && g_src_w < width) + { + v_src[0] = p_src[src_index]; + } + if(g_src_h < height && (g_src_w + 16) < width) + { + v_src[1] = p_src[src_index + 16]; + } + if((g_src_h + 16) < height && g_src_w < width) + { + v_src[2] = p_src[src_index + 16 * width]; + } + if((g_src_h + 16) < height && (g_src_w + 16) < width) + { + v_src[3] = p_src[src_index + 16 * width + 16]; + } + + float v_pack[2]; + v_pack_b32_f16_2x2_half_x0_half_x1( + v_pack[0], v_pack[1], v_src[0], v_src[1], v_src[2], v_src[3]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack[0]; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_pack[1]; + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + i_dst_w; + + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + float v_dst[2]; + v_dst[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v_dst[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + + ushort2 v_dst_buf[2]; + v_dst_buf[0] = *reinterpret_cast(&v_dst[0]); + v_dst_buf[1] = *reinterpret_cast(&v_dst[1]); + if(g_dst_h < height && g_dst_w < width) + { + p_dst[dst_index] = v_dst_buf[0].x; + } + if((g_dst_h + 16) < height && g_dst_w < width) + { + p_dst[dst_index + 16] = v_dst_buf[0].y; + } + if(g_dst_h < height && (g_dst_w + 16) < width) + { + p_dst[dst_index + 16 * height] = v_dst_buf[1].x; + } + if((g_dst_h + 16) < height && (g_dst_w + 16) < width) + { + p_dst[dst_index + 16 * height + 16] = v_dst_buf[1].y; + } + } +} + +template +inline __device__ void batched_transpose_64x32_pack_4x2_ediv_4x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_64x32_pack_4x2_ediv_4x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float smem[64 * smem_stride]; + //__shared__ float4 smem[16 * smem_stride]; + + float* p_dst = reinterpret_cast(dst); + float2* p_src = reinterpret_cast(src); + + uint32_t height_2 = height >> 1; + uint32_t width_4 = width >> 2; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 63) >> 6; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + (i_src_h << 1); + + __syncthreads(); + if(g_src_h < height && g_src_w < width_4) + { +#if 1 + float2 v_a, v_b; + float v_pack[4]; + size_t src_index = static_cast(dim_in) * height * width_4 + + static_cast(g_src_h) * width_4 + + static_cast(g_src_w); + v_a = p_src[src_index]; + v_b = p_src[src_index + width_4]; + v_pack_b32_f16_2x2(v_pack[0], v_pack[1], v_a.x, v_b.x); + v_pack_b32_f16_2x2(v_pack[2], v_pack[3], v_a.y, v_b.y); + + smem[i_src_w * smem_stride + i_src_h] = v_pack[0]; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_pack[1]; + smem[i_src_w * smem_stride + i_src_h + 32 * smem_stride] = v_pack[2]; + smem[i_src_w * smem_stride + i_src_h + 48 * smem_stride] = v_pack[3]; +#else + float2 v_a, v_b; + float4 v_pack; + size_t src_index = static_cast(dim_in) * height * width_4 + + static_cast(g_src_h) * width_4 + + static_cast(g_src_w); + v_a = p_src[src_index]; + v_b = p_src[src_index + width_4]; + v_pack_b32_f16_2x2(v_pack.x, v_pack.y, v_a.x, v_b.x); + v_pack_b32_f16_2x2(v_pack.z, v_pack.w, v_a.y, v_b.y); + + smem[i_src_w * smem_stride + i_src_h] = v_pack; +#endif + } + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 6) + (i_dst_w << 2); + + if(g_dst_h < height_2 && g_dst_w < width) + { +#if 1 + float v[4]; + v[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + v[2] = smem[i_dst_w * smem_stride + i_dst_h + 32 * smem_stride]; + v[3] = smem[i_dst_w * smem_stride + i_dst_h + 48 * smem_stride]; + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + + static_cast(g_dst_h); + p_dst[dst_index] = v[0]; + p_dst[dst_index + height_2] = v[1]; + p_dst[dst_index + 2 * height_2] = v[2]; + p_dst[dst_index + 3 * height_2] = v[3]; +#else + float4 v; + v = smem[i_dst_w * smem_stride + i_dst_h]; + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + + static_cast(g_dst_h); + p_dst[dst_index] = v.x; + p_dst[dst_index + height_2] = v.y; + p_dst[dst_index + 2 * height_2] = v.z; + p_dst[dst_index + 3 * height_2] = v.w; +#endif + } + } +} + +template +inline __device__ void batched_transpose_64x32_pack_4x2_ediv_2x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_64x32_pack_4x2_ediv_2x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + //__shared__ float smem[64 * smem_stride]; + __shared__ float4 smem[16 * smem_stride]; + + float* p_dst = reinterpret_cast(dst); + float* p_src = reinterpret_cast(src); + + uint32_t height_2 = height >> 1; + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 63) >> 6; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + (i_src_h << 1); + + __syncthreads(); + /* + * 4x2 -> 2x4 transpose + * lo hi + * |0|_|2|_| lo |0|1|2|3| + * |1|_|3|_| -> hi |_|_|_|_| + */ + float v_src[4]; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width_2) + { + v_src[0] = p_src[src_index]; + v_src[1] = p_src[src_index + width_2]; + } + if(g_src_h < height && (g_src_w + 16) < width_2) + { + v_src[2] = p_src[src_index + 16]; + v_src[3] = p_src[src_index + width_2 + 16]; + } + + float4 v_pack; + v_pack_b32_f16_2x2(v_pack.x, v_pack.y, v_src[0], v_src[1]); + v_pack_b32_f16_2x2(v_pack.z, v_pack.w, v_src[2], v_src[3]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack; + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 6) + (i_dst_w << 1); + + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + static_cast(g_dst_h); + + float4 v_dst = smem[i_dst_w * smem_stride + i_dst_h]; + if(g_dst_h < height_2 && g_dst_w < width) + { + p_dst[dst_index] = v_dst.x; + p_dst[dst_index + height_2] = v_dst.y; + } + if(g_dst_h < height_2 && (g_dst_w + 32) < width) + { + p_dst[dst_index + 32 * height_2] = v_dst.z; + p_dst[dst_index + 33 * height_2] = v_dst.w; + } + } +} + +template +inline __device__ void batched_transpose_64x32_pack_4x2_ediv_2x1(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_64x32_pack_4x2_ediv_2x1(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float4 smem[16 * smem_stride]; + + ushort* p_dst = reinterpret_cast(dst); + float* p_src = reinterpret_cast(src); + + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 31) >> 5; + uint32_t w_dim = (width + 63) >> 6; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 5) + i_src_h; + + __syncthreads(); + /* + * 4x2 -> 2x4 transpose + * lo hi + * |0|_|2|_| lo |0|1|2|3| + * |1|_|3|_| -> hi |_|_|_|_| + */ + float v_src[4]; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width_2) + { + v_src[0] = p_src[src_index]; + } + if(g_src_h < height && (g_src_w + 16) < width_2) + { + v_src[2] = p_src[src_index + 16]; + } + if((g_src_h + 16) < height && g_src_w < width_2) + { + v_src[1] = p_src[src_index + 16 * width_2]; + } + if((g_src_h + 16) < height && (g_src_w + 16) < width_2) + { + v_src[3] = p_src[src_index + 16 * width_2 + 16]; + } + + float4 v_pack; + v_pack_b32_f16_2x2(v_pack.x, v_pack.y, v_src[0], v_src[1]); + v_pack_b32_f16_2x2(v_pack.z, v_pack.w, v_src[2], v_src[3]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack; + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 6) + (i_dst_w << 1); + + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + float4 v_dst = smem[i_dst_w * smem_stride + i_dst_h]; + ushort2 v_dst_buf[4]; + v_dst_buf[0] = *reinterpret_cast(&v_dst.x); + v_dst_buf[1] = *reinterpret_cast(&v_dst.y); + v_dst_buf[2] = *reinterpret_cast(&v_dst.z); + v_dst_buf[3] = *reinterpret_cast(&v_dst.w); + if(g_dst_h < height && g_dst_w < width) + { + p_dst[dst_index] = v_dst_buf[0].x; + p_dst[dst_index + height] = v_dst_buf[1].x; + } + if((g_dst_h + 16) < height && g_dst_w < width) + { + p_dst[dst_index + 16] = v_dst_buf[0].y; + p_dst[dst_index + height + 16] = v_dst_buf[1].y; + } + if(g_dst_h < height && (g_dst_w + 32) < width) + { + p_dst[dst_index + 32 * height] = v_dst_buf[2].x; + p_dst[dst_index + 33 * height] = v_dst_buf[3].x; + } + if((g_dst_h + 16) < height && (g_dst_w + 32) < width) + { + p_dst[dst_index + 32 * height + 16] = v_dst_buf[2].y; + p_dst[dst_index + 33 * height + 16] = v_dst_buf[3].y; + } + } +} + +template +inline __device__ void batched_transpose_32x64_pack_2x4_ediv_2x4(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_32x64_pack_2x4_ediv_2x4(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + //__shared__ float smem[64 * smem_stride]; + __shared__ float4 smem[16 * smem_stride]; + + float2* p_dst = reinterpret_cast(dst); + float* p_src = reinterpret_cast(src); + + uint32_t height_4 = height >> 2; + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 63) >> 6; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 6) + (i_src_h << 2); + + __syncthreads(); + /* + * 4x2 -> 2x4 transpose: (0, 1, 2, 3 is in float) + * lo hi + * 0 |_|_| lo |0|2| + * 1 |_|_| -> hi |_|_| + * 2 |_|_| lo |1|3| + * 3 |_|_| hi |_|_| + */ + if(g_src_h < height && g_src_w < width_2) + { +#if 0 + float v[4]; + float v_pack[4]; + size_t src_index = static_cast(dim_in) * height * width_2 + static_cast(g_src_h) * width_2 + static_cast(g_src_w); + v[0] = p_src[src_index]; + v[1] = p_src[src_index + width_2]; + v[2] = p_src[src_index + 2 * width_2]; + v[3] = p_src[src_index + 3 * width_2]; + v_pack_b32_f16_2x2(v_pack[0], v_pack[2], v[0], v[1]); + v_pack_b32_f16_2x2(v_pack[1], v_pack[3], v[2], v[3]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack[0]; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_pack[1]; + smem[i_src_w * smem_stride + i_src_h + 32 * smem_stride] = v_pack[2]; + smem[i_src_w * smem_stride + i_src_h + 48 * smem_stride] = v_pack[3]; +#else + float v[4]; + float4 v_pack; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + + static_cast(g_src_w); + v[0] = p_src[src_index]; + v[1] = p_src[src_index + width_2]; + v[2] = p_src[src_index + 2 * width_2]; + v[3] = p_src[src_index + 3 * width_2]; + v_pack_b32_f16_2x2(v_pack.x, v_pack.z, v[0], v[1]); + v_pack_b32_f16_2x2(v_pack.y, v_pack.w, v[2], v[3]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack; +#endif + } + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + (i_dst_w << 1); + + if(g_dst_h < height_4 && g_dst_w < width) + { +#if 0 + float v[4]; + v[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + v[2] = smem[i_dst_w * smem_stride + i_dst_h + 32 * smem_stride]; + v[3] = smem[i_dst_w * smem_stride + i_dst_h + 48 * smem_stride]; + size_t dst_index = static_cast(dim_in) * width * height_4 + static_cast(g_dst_w) * height_4 + static_cast(g_dst_h); + p_dst[dst_index] = v[0]; + p_dst[dst_index + height_4] = v[1]; + p_dst[dst_index + 2 * height_4] = v[2]; + p_dst[dst_index + 3 * height_4] = v[3]; +#else + float4 v; + v = smem[i_dst_w * smem_stride + i_dst_h]; + size_t dst_index = static_cast(dim_in) * width * height_4 + + static_cast(g_dst_w) * height_4 + + static_cast(g_dst_h); + p_dst[dst_index] = make_float2(v.x, v.y); + p_dst[dst_index + height_4] = make_float2(v.z, v.w); +#endif + } + } +} + +template +inline __device__ void batched_transpose_32x64_pack_2x4_ediv_2x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_32x64_pack_2x4_ediv_2x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + //__shared__ float smem[64 * smem_stride]; + __shared__ float4 smem[16 * smem_stride]; + + float* p_dst = reinterpret_cast(dst); + float* p_src = reinterpret_cast(src); + + uint32_t height_2 = height >> 1; + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 63) >> 6; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 6) + (i_src_h << 1); + + __syncthreads(); + + float v_src[4]; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width_2) + { + v_src[0] = p_src[src_index]; + v_src[1] = p_src[src_index + width_2]; + } + if((g_src_h + 32) < height && g_src_w < width_2) + { + v_src[2] = p_src[src_index + 32 * width_2]; + v_src[3] = p_src[src_index + 33 * width_2]; + } + + float4 v_pack; + v_pack_b32_f16_2x2(v_pack.x, v_pack.z, v_src[0], v_src[1]); + v_pack_b32_f16_2x2(v_pack.y, v_pack.w, v_src[2], v_src[3]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack; + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + (i_dst_w << 1); + + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + static_cast(g_dst_h); + + float4 v_dst = smem[i_dst_w * smem_stride + i_dst_h]; + if(g_dst_h < height_2 && g_dst_w < width) + { + p_dst[dst_index] = v_dst.x; + p_dst[dst_index + height_2] = v_dst.z; + } + if((g_dst_h + 16) < height_2 && g_dst_w < width) + { + p_dst[dst_index + 16] = v_dst.y; + p_dst[dst_index + height_2 + 16] = v_dst.w; + } + } +} + +template +inline __device__ void batched_transpose_32x64_pack_2x4_ediv_1x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_32x64_pack_2x4_ediv_1x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + //__shared__ float smem[64 * smem_stride]; + __shared__ float4 smem[16 * smem_stride]; + + float* p_dst = reinterpret_cast(dst); + ushort* p_src = src; + + uint32_t height_2 = height >> 1; + + uint32_t h_dim = (height + 63) >> 6; + uint32_t w_dim = (width + 31) >> 5; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 6) + (i_src_h << 1); + + __syncthreads(); + /* + * 4x2 -> 2x4 transpose: + * lo hi + * |0|1| lo |0|2| + * |2|3| -> hi |_|_| + * |4|5| lo |1|3| + * |6|7| hi |_|_| + */ + + ushort v_src[8]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_src[0] = p_src[src_index]; + v_src[2] = p_src[src_index + width]; + } + if(g_src_h < height && (g_src_w + 16) < width) + { + v_src[1] = p_src[src_index + 16]; + v_src[3] = p_src[src_index + width + 16]; + } + if((g_src_h + 32) < height && g_src_w < width) + { + v_src[4] = p_src[src_index + 32 * width]; + v_src[6] = p_src[src_index + 33 * width]; + } + if((g_src_h + 32) < height && (g_src_w + 16) < width) + { + v_src[5] = p_src[src_index + 32 * width + 16]; + v_src[7] = p_src[src_index + 33 * width + 16]; + } + + float4 v_pack; + v_pack_b32_f16_2x2_half_x0_half_x1( + v_pack.x, v_pack.z, v_src[0], v_src[1], v_src[2], v_src[3]); + v_pack_b32_f16_2x2_half_x0_half_x1( + v_pack.y, v_pack.w, v_src[4], v_src[5], v_src[6], v_src[7]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack; + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 5) + i_dst_w; + + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + static_cast(g_dst_h); + + float4 v_dst = smem[i_dst_w * smem_stride + i_dst_h]; + if(g_dst_h < height_2 && g_dst_w < width) + { + p_dst[dst_index] = v_dst.x; + } + if((g_dst_h + 16) < height_2 && g_dst_w < width) + { + p_dst[dst_index + 16] = v_dst.y; + } + if(g_dst_h < height_2 && (g_dst_w + 16) < width) + { + p_dst[dst_index + 16 * height_2] = v_dst.z; + } + if((g_dst_h + 16) < height_2 && (g_dst_w + 16) < width) + { + p_dst[dst_index + 16 * height_2 + 16] = v_dst.w; + } + } +} + +template +inline __device__ void batched_transpose_16x64_pack_1x4_ediv_1x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_16x64_pack_1x4_ediv_1x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float smem[32 * smem_stride]; + + float* p_dst = reinterpret_cast(dst); + ushort* p_src = src; + + uint32_t height_2 = height >> 1; + + uint32_t h_dim = (height + 63) >> 6; + uint32_t w_dim = (width + 15) >> 4; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 6) + (i_src_h << 1); + + __syncthreads(); + /* + * 4x1 -> 1x4 transpose: + * lo hi + * |0| lo |0| + * |1| -> hi |_| + * |2| lo |1| + * |3| hi |_| + */ + + ushort v_src[4]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_src[0] = p_src[src_index]; + v_src[1] = p_src[src_index + width]; + } + if((g_src_h + 32) < height && g_src_w < width) + { + v_src[2] = p_src[src_index + 32 * width]; + v_src[3] = p_src[src_index + 33 * width]; + } + + ushort2 v_pack_tmp[2]; + v_pack_tmp[0] = make_ushort2(v_src[0], v_src[1]); + v_pack_tmp[1] = make_ushort2(v_src[2], v_src[3]); + + float v_pack[2]; + v_pack[0] = *reinterpret_cast(&v_pack_tmp[0]); + v_pack[1] = *reinterpret_cast(&v_pack_tmp[1]); + + smem[i_src_w * smem_stride + i_src_h] = v_pack[0]; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_pack[1]; + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 4) + i_dst_w; + + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + static_cast(g_dst_h); + + float v_dst[2]; + v_dst[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v_dst[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + if(g_dst_h < height_2 && g_dst_w < width) + { + p_dst[dst_index] = v_dst[0]; + } + if((g_dst_h + 16) < height_2 && g_dst_w < width) + { + p_dst[dst_index + 16] = v_dst[1]; + } + } +} + +template +inline __device__ void batched_transpose_64x16_pack_4x1_ediv_2x1(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_64x16_pack_4x1_ediv_2x1(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float smem[32 * smem_stride]; + + ushort* p_dst = dst; + float* p_src = reinterpret_cast(src); + + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 15) >> 4; + uint32_t w_dim = (width + 63) >> 6; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 4) + i_src_h; + + __syncthreads(); + float v_src[2]; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width_2) + { + v_src[0] = p_src[src_index]; + } + if(g_src_h < height && (g_src_w + 16) < width_2) + { + v_src[1] = p_src[src_index + 16]; + } + + smem[i_src_w * smem_stride + i_src_h] = v_src[0]; + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = v_src[1]; + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 6) + (i_dst_w << 1); + + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + float v_dst[2]; + v_dst[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v_dst[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + + ushort2 v_dst_buf[2]; + v_dst_buf[0] = *reinterpret_cast(&v_dst[0]); + v_dst_buf[1] = *reinterpret_cast(&v_dst[1]); + if(g_dst_h < height && g_dst_w < width) + { + p_dst[dst_index] = v_dst_buf[0].x; + p_dst[dst_index + height] = v_dst_buf[0].y; + } + if(g_dst_h < height && (g_dst_w + 32) < width) + { + p_dst[dst_index + 32 * height] = v_dst_buf[1].x; + p_dst[dst_index + 33 * height] = v_dst_buf[1].y; + } + } +} + +template +inline __device__ void batched_transpose_64x64_pack_4x4_ediv_4x4(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_64x64_pack_4x4_ediv_4x4(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + //__shared__ float smem[64 * smem_stride]; + __shared__ float4 smem[32 * smem_stride]; + + float2* p_dst = reinterpret_cast(dst); + float2* p_src = reinterpret_cast(src); + + uint32_t height_4 = height >> 2; + uint32_t width_4 = width >> 2; + + uint32_t h_dim = (height + 63) >> 6; + uint32_t w_dim = (width + 63) >> 6; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 4) + i_src_w; + uint32_t g_src_h = (dim_ih << 6) + (i_src_h << 2); + + __syncthreads(); + /* + * 4x2 -> 2x4 transpose: (0, 1, 2, 3 is in float2) + * lo hi + * 0 |_|_|_|_| lo |0|1|2|3| + * 1 |_|_|_|_| -> hi |_|_|_|_| + * 2 |_|_|_|_| lo | | | | | + * 3 |_|_|_|_| hi |_|_|_|_| + */ + + float2 v_src[4]; + size_t src_index = static_cast(dim_in) * height * width_4 + + static_cast(g_src_h) * width_4 + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width_4) + { + v_src[0] = p_src[src_index]; + v_src[1] = p_src[src_index + width_4]; + v_src[2] = p_src[src_index + 2 * width_4]; + v_src[3] = p_src[src_index + 3 * width_4]; + } + + float2 v_pack[4]; + v_pack_b32_f16_2x2(v_pack[0].x, v_pack[1].x, v_src[0].x, v_src[1].x); + v_pack_b32_f16_2x2(v_pack[2].x, v_pack[3].x, v_src[0].y, v_src[1].y); + v_pack_b32_f16_2x2(v_pack[0].y, v_pack[1].y, v_src[2].x, v_src[3].x); + v_pack_b32_f16_2x2(v_pack[2].y, v_pack[3].y, v_src[2].y, v_src[3].y); + + smem[i_src_w * smem_stride + i_src_h] = + make_float4(v_pack[0].x, v_pack[0].y, v_pack[1].x, v_pack[1].y); + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = + make_float4(v_pack[2].x, v_pack[2].y, v_pack[3].x, v_pack[3].y); + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 4) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 6) + (i_dst_w << 2); + + size_t dst_index = static_cast(dim_in) * width * height_4 + + static_cast(g_dst_w) * height_4 + static_cast(g_dst_h); + + float4 v_dst[2]; + v_dst[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v_dst[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + if(g_dst_h < height_4 && g_dst_w < width) + { + p_dst[dst_index] = make_float2(v_dst[0].x, v_dst[0].y); + p_dst[dst_index + height_4] = make_float2(v_dst[0].z, v_dst[0].w); + p_dst[dst_index + 2 * height_4] = make_float2(v_dst[1].x, v_dst[1].y); + p_dst[dst_index + 3 * height_4] = make_float2(v_dst[1].z, v_dst[1].w); + } + } +} + +template +inline __device__ void batched_transpose_64x64_pack_4x4_ediv_2x2(T* /*dst*/, + T* /*src*/, + uint32_t /*height*/, + uint32_t /*width*/, + uint32_t /*dim_stride*/, + uint32_t /*dim_total*/, + uint32_t /*magic_h*/, + uint32_t /*shift_h*/, + uint32_t /*magic_w*/, + uint32_t /*shift_w*/) +{ +} + +template <> +inline __device__ void batched_transpose_64x64_pack_4x4_ediv_2x2(ushort* dst, + ushort* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + constexpr auto smem_stride = 17; + __shared__ float4 smem[32 * smem_stride]; + + float* p_dst = reinterpret_cast(dst); + float* p_src = reinterpret_cast(src); + + uint32_t height_2 = height >> 1; + uint32_t width_2 = width >> 1; + + uint32_t h_dim = (height + 63) >> 6; + uint32_t w_dim = (width + 63) >> 6; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t i_src_w = threadIdx.x & 15; + uint32_t i_src_h = threadIdx.x >> 4; + uint32_t g_src_w = (dim_iw << 5) + i_src_w; + uint32_t g_src_h = (dim_ih << 6) + (i_src_h << 1); + + __syncthreads(); + /* + * 4x4 -> 4x4 transpose: (0, 1, 2, 3 is in float, a, b, c, d is in float2) + * lo hi + * |0|_|1|_| lo |a|b|c|d| + * |2|_|3|_| -> hi |_|_|_|_| + * |4|_|5|_| lo | | | | | + * |6|_|7|_| hi |_|_|_|_| + */ + + float v_src[8]; + size_t src_index = static_cast(dim_in) * height * width_2 + + static_cast(g_src_h) * width_2 + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width_2) + { + v_src[0] = p_src[src_index]; + v_src[2] = p_src[src_index + width_2]; + } + if(g_src_h < height && (g_src_w + 16) < width_2) + { + v_src[1] = p_src[src_index + 16]; + v_src[3] = p_src[src_index + width_2 + 16]; + } + if((g_src_h + 32) < height && g_src_w < width_2) + { + v_src[4] = p_src[src_index + 32 * width_2]; + v_src[6] = p_src[src_index + 33 * width_2]; + } + if((g_src_h + 32) < height && (g_src_w + 16) < width_2) + { + v_src[5] = p_src[src_index + 32 * width_2 + 16]; + v_src[7] = p_src[src_index + 33 * width_2 + 16]; + } + + float2 v_pack[4]; + v_pack_b32_f16_2x2(v_pack[0].x, v_pack[1].x, v_src[0], v_src[2]); + v_pack_b32_f16_2x2(v_pack[2].x, v_pack[3].x, v_src[1], v_src[3]); + v_pack_b32_f16_2x2(v_pack[0].y, v_pack[1].y, v_src[4], v_src[6]); + v_pack_b32_f16_2x2(v_pack[2].y, v_pack[3].y, v_src[5], v_src[7]); + + smem[i_src_w * smem_stride + i_src_h] = + make_float4(v_pack[0].x, v_pack[0].y, v_pack[1].x, v_pack[1].y); + smem[i_src_w * smem_stride + i_src_h + 16 * smem_stride] = + make_float4(v_pack[2].x, v_pack[2].y, v_pack[3].x, v_pack[3].y); + + __syncthreads(); + + uint32_t i_dst_h = threadIdx.x & 15; + uint32_t i_dst_w = threadIdx.x >> 4; + uint32_t g_dst_h = (dim_ih << 5) + i_dst_h; + uint32_t g_dst_w = (dim_iw << 6) + (i_dst_w << 1); + + size_t dst_index = static_cast(dim_in) * width * height_2 + + static_cast(g_dst_w) * height_2 + static_cast(g_dst_h); + + float4 v_dst[2]; + v_dst[0] = smem[i_dst_w * smem_stride + i_dst_h]; + v_dst[1] = smem[i_dst_w * smem_stride + i_dst_h + 16 * smem_stride]; + if(g_dst_h < height_2 && g_dst_w < width) + { + p_dst[dst_index] = v_dst[0].x; + p_dst[dst_index + height_2] = v_dst[0].z; + } + if((g_dst_h + 16) < height_2 && g_dst_w < width) + { + p_dst[dst_index + 16] = v_dst[0].y; + p_dst[dst_index + height_2 + 16] = v_dst[0].w; + } + if(g_dst_h < height_2 && (g_dst_w + 32) < width) + { + p_dst[dst_index + 32 * height_2] = v_dst[1].x; + p_dst[dst_index + 33 * height_2] = v_dst[1].z; + } + if((g_dst_h + 16) < height_2 && (g_dst_w + 32) < width) + { + p_dst[dst_index + 32 * height_2 + 16] = v_dst[1].y; + p_dst[dst_index + 33 * height_2 + 16] = v_dst[1].w; + } + } +} + +template +inline __device__ void batched_transpose_4x256(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + uint32_t h_dim = (height + 255) >> 8; + uint32_t w_dim = (width + 3) >> 2; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t g_src_w = (dim_iw << 2); + uint32_t g_src_h = (dim_ih << 8) + threadIdx.x; + + T v_buf[4]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_buf[0] = src[src_index]; + } + if(g_src_h < height && (g_src_w + 1) < width) + { + v_buf[1] = src[src_index + 1]; + } + if(g_src_h < height && (g_src_w + 2) < width) + { + v_buf[2] = src[src_index + 2]; + } + if(g_src_h < height && (g_src_w + 3) < width) + { + v_buf[3] = src[src_index + 3]; + } + + uint32_t g_dst_h = (dim_ih << 8) + threadIdx.x; + uint32_t g_dst_w = (dim_iw << 2); + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_buf[0]; + } + if(g_dst_h < height && (g_dst_w + 1) < width) + { + dst[dst_index + height] = v_buf[1]; + } + if(g_dst_h < height && (g_dst_w + 2) < width) + { + dst[dst_index + 2 * height] = v_buf[2]; + } + if(g_dst_h < height && (g_dst_w + 3) < width) + { + dst[dst_index + 3 * height] = v_buf[3]; + } + } +} + +template +inline __device__ void batched_transpose_256x4(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + uint32_t h_dim = (height + 3) >> 2; + uint32_t w_dim = (width + 255) >> 8; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t g_src_w = (dim_iw << 8) + threadIdx.x; + uint32_t g_src_h = (dim_ih << 2); + + T v_buf[4]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_buf[0] = src[src_index]; + } + if((g_src_h + 1) < height && g_src_w < width) + { + v_buf[1] = src[src_index + width]; + } + if((g_src_h + 2) < height && g_src_w < width) + { + v_buf[2] = src[src_index + 2 * width]; + } + if((g_src_h + 3) < height && g_src_w < width) + { + v_buf[3] = src[src_index + 3 * width]; + } + + uint32_t g_dst_h = (dim_ih << 2); + uint32_t g_dst_w = (dim_iw << 8) + threadIdx.x; + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_buf[0]; + } + if((g_dst_h + 1) < height && g_dst_w < width) + { + dst[dst_index + 1] = v_buf[1]; + } + if((g_dst_h + 2) < height && g_dst_w < width) + { + dst[dst_index + 2] = v_buf[2]; + } + if((g_dst_h + 3) < height && g_dst_w < width) + { + dst[dst_index + 3] = v_buf[3]; + } + } +} + +template +inline __device__ void batched_transpose_4x128(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + uint32_t h_dim = (height + 127) >> 7; + uint32_t w_dim = (width + 3) >> 2; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t g_src_w = (dim_iw << 2) + (threadIdx.x >> 7); + uint32_t g_src_h = (dim_ih << 7) + (threadIdx.x & 127); + + T v_buf[2]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_buf[0] = src[src_index]; + } + if(g_src_h < height && (g_src_w + 2) < width) + { + v_buf[1] = src[src_index + 2]; + } + + uint32_t g_dst_h = (dim_ih << 7) + (threadIdx.x & 127); + uint32_t g_dst_w = (dim_iw << 2) + (threadIdx.x >> 7); + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_buf[0]; + } + if(g_dst_h < height && (g_dst_w + 2) < width) + { + dst[dst_index + 2 * height] = v_buf[1]; + } + } +} + +template +inline __device__ void batched_transpose_128x4(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + uint32_t h_dim = (height + 3) >> 2; + uint32_t w_dim = (width + 127) >> 7; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t g_src_w = (dim_iw << 7) + (threadIdx.x & 127); + uint32_t g_src_h = (dim_ih << 2) + (threadIdx.x >> 7); + + T v_buf[2]; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_buf[0] = src[src_index]; + } + if((g_src_h + 2) < height && g_src_w < width) + { + v_buf[1] = src[src_index + 2 * width]; + } + + uint32_t g_dst_h = (dim_ih << 2) + (threadIdx.x >> 7); + uint32_t g_dst_w = (dim_iw << 7) + (threadIdx.x & 127); + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_buf[0]; + } + if((g_dst_h + 2) < height && g_dst_w < width) + { + dst[dst_index + 2] = v_buf[1]; + } + } +} + +template +inline __device__ void batched_transpose_4x64(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + uint32_t h_dim = (height + 63) >> 6; + uint32_t w_dim = (width + 3) >> 2; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t g_src_w = (dim_iw << 2) + (threadIdx.x >> 6); + uint32_t g_src_h = (dim_ih << 6) + (threadIdx.x & 63); + + T v_buf; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_buf = src[src_index]; + } + + uint32_t g_dst_h = (dim_ih << 6) + (threadIdx.x & 63); + uint32_t g_dst_w = (dim_iw << 2) + (threadIdx.x >> 6); + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_buf; + } + } +} + +template +inline __device__ void batched_transpose_64x4(T* dst, + T* src, + uint32_t height, + uint32_t width, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_h, + uint32_t shift_h, + uint32_t magic_w, + uint32_t shift_w) +{ + uint32_t h_dim = (height + 3) >> 2; + uint32_t w_dim = (width + 63) >> 6; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + uint32_t dim_ih_tmp = magic_div_u32(dim_id, magic_w, shift_w); + uint32_t dim_iw = dim_id - dim_ih_tmp * w_dim; + uint32_t dim_in = magic_div_u32(dim_ih_tmp, magic_h, shift_h); + uint32_t dim_ih = dim_ih_tmp - dim_in * h_dim; + + uint32_t g_src_w = (dim_iw << 6) + (threadIdx.x & 63); + uint32_t g_src_h = (dim_ih << 2) + (threadIdx.x >> 6); + + T v_buf; + size_t src_index = static_cast(dim_in) * height * width + + static_cast(g_src_h) * width + static_cast(g_src_w); + if(g_src_h < height && g_src_w < width) + { + v_buf = src[src_index]; + } + + uint32_t g_dst_h = (dim_ih << 2) + (threadIdx.x >> 6); + uint32_t g_dst_w = (dim_iw << 6) + (threadIdx.x & 63); + size_t dst_index = static_cast(dim_in) * width * height + + static_cast(g_dst_w) * height + static_cast(g_dst_h); + + if(g_dst_h < height && g_dst_w < width) + { + dst[dst_index] = v_buf; + } + } +} + +#define DEFINE_BATCHED_TRANSPOSE_KERNEL( \ + tile_trait, accept_data_type, cast_data_type, lb_threads_per_block, lb_blocks_per_cu) \ + extern "C" __global__ void __launch_bounds__(lb_threads_per_block, lb_blocks_per_cu) \ + batched_transpose_##tile_trait##_##accept_data_type(void* dst, \ + void* src, \ + uint32_t height, \ + uint32_t width, \ + uint32_t dim_stride, \ + uint32_t dim_total, \ + uint32_t magic_h, \ + uint32_t shift_h, \ + uint32_t magic_w, \ + uint32_t shift_w) \ + { \ + batched_transpose_##tile_trait(reinterpret_cast(dst), \ + reinterpret_cast(src), \ + height, \ + width, \ + dim_stride, \ + dim_total, \ + magic_h, \ + shift_h, \ + magic_w, \ + shift_w); \ + } + +DEFINE_BATCHED_TRANSPOSE_KERNEL(16x16, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(16x16, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(16x16, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(32x16, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(32x16, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(32x16, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(16x32, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(16x32, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(16x32, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(32x32, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(32x32, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(32x32, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x256, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x256, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x256, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(256x4, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(256x4, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(256x4, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x128, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x128, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x128, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(128x4, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(128x4, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(128x4, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x64, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x64, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(4x64, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL(64x4, dword, float, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(64x4, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL(64x4, byte, uchar, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 32x32_pack_2x2_ediv_2x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 32x32_pack_2x2_ediv_1x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 32x32_pack_2x2_ediv_2x1, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 32x32_pack_2x2_ediv_1x1, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 64x32_pack_4x2_ediv_4x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 64x32_pack_4x2_ediv_2x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 64x32_pack_4x2_ediv_2x1, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 32x64_pack_2x4_ediv_2x4, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 32x64_pack_2x4_ediv_2x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 32x64_pack_2x4_ediv_1x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 16x64_pack_1x4_ediv_1x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 64x16_pack_4x1_ediv_2x1, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) + +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 64x64_pack_4x4_ediv_4x4, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) +DEFINE_BATCHED_TRANSPOSE_KERNEL( + 64x64_pack_4x4_ediv_2x2, half, ushort, 256, BATCHED_TRANSPOSE_OCCUPANCY) diff --git a/src/nogpu/handle.cpp b/src/nogpu/handle.cpp index 4faed48dd2..f86dc405ba 100644 --- a/src/nogpu/handle.cpp +++ b/src/nogpu/handle.cpp @@ -275,13 +275,13 @@ std::ostream& Handle::Print(std::ostream& os) const return os; } -shared Handle::CreateSubBuffer(Data_t data, std::size_t offset, std::size_t) +shared Handle::CreateSubBuffer(Data_t data, std::size_t offset, std::size_t) const { auto cdata = reinterpret_cast(data); return {cdata + offset, null_deleter{}}; } -shared Handle::CreateSubBuffer(ConstData_t data, std::size_t offset, std::size_t) +shared Handle::CreateSubBuffer(ConstData_t data, std::size_t offset, std::size_t) const { auto cdata = reinterpret_cast(data); return {cdata + offset, null_deleter{}}; diff --git a/src/ocl/handleocl.cpp b/src/ocl/handleocl.cpp index 6678e60c10..902bc560c7 100644 --- a/src/ocl/handleocl.cpp +++ b/src/ocl/handleocl.cpp @@ -550,7 +550,7 @@ void Handle::Copy(ConstData_t src, Data_t dest, std::size_t size) const } } -shared Handle::CreateSubBuffer(Data_t data, std::size_t offset, std::size_t size) +shared Handle::CreateSubBuffer(Data_t data, std::size_t offset, std::size_t size) const { MIOPEN_HANDLE_LOCK struct region diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index b64f127dc0..881fd4ac05 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -30,6 +30,7 @@ #include #include #include +#include MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS_NHWC) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) @@ -880,9 +881,6 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable(const ConvolutionC if(!ctx.rmv.IsV3()) return false; - if(!ctx.IsLayoutNHWC()) - return false; - const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) return false; // NOLINT (readability-simplify-boolean-expr) @@ -893,16 +891,42 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable(const ConvolutionC size_t ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetWorkspaceSize(const ConvolutionContext& ctx) const { - if(ctx.IsFp32()) - return 0; - - // FP16 or BF16 - const auto& hi = ctx.out_height; - const auto& wi = ctx.out_width; - const auto& n = ctx.batch_sz; - const auto& c = ctx.n_outputs; - return miopen::GetTypeSize(miopenFloat) // The intermediate output of the 1st kernel is FP32. - * n * c * hi * wi; + const auto& hi = ctx.out_height; + const auto& wi = ctx.out_width; + const auto& n = ctx.batch_sz; + const auto& k = ctx.n_inputs; + const auto& c = ctx.n_outputs; + const auto& ho = ctx.in_height; + const auto& wo = ctx.in_width; + const auto& y = ctx.kernel_size_h; + const auto& x = ctx.kernel_size_w; + const auto& group = ctx.group_counts; + const auto is_nchw = ctx.IsLayoutDefault(); + size_t workspace_size = 0; + if(is_nchw) + { + TransposeSolutionNhwc2Default trans_input(ctx, ctx.out_data_type, n, c, hi, wi); + TransposeSolutionDefault2Nhwc trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionDefault2Nhwc trans_output(ctx, ctx.in_data_type, n, k, ho, wo); + if(!trans_input.IsSkippable()) + workspace_size += trans_input.GetSize(); + if(!trans_weight.IsSkippable()) + workspace_size += trans_weight.GetSize(); + if(!trans_output.IsSkippable()) + workspace_size += trans_output.GetSize(); + } + + if(!ctx.IsFp32()) + workspace_size += miopen::GetTypeSize(miopenFloat) // The intermediate output of the 1st + // kernel is FP32, when using FP32 atomic + * n * c * hi * wi; + + return workspace_size; } ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( @@ -938,6 +962,8 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( const auto isGfx90aFp16altSupport = (ctx.GetStream().GetDeviceName() == "gfx90a") && ctx.conv_problem.IsFp16(); + const auto is_nchw = ctx.IsLayoutDefault(); + result.construction_params.push_back(kernel); std::ostringstream options; GenerateClangDefsym(options, "ROCM_METADATA_VERSION", ctx.rmv.UseV3() ? 5 : 4); @@ -946,6 +972,7 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( if(isGfx90aFp16altSupport) GenerateClangDefsym(opts_0, "igemm_bwd_fp16_alt_impl", 0); result.construction_params[0].comp_options = opts_0.str(); + std::ostringstream msg; if(isGfx90aFp16altSupport) { @@ -953,9 +980,53 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( std::ostringstream opts_1(options.str(), std::ios_base::ate); GenerateClangDefsym(opts_1, "igemm_bwd_fp16_alt_impl", 1); result.construction_params[1].comp_options = opts_1.str(); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", fp16_alt:" << ctx.conv_problem.GetConv().attribute.gfx90aFp16alt.GetBwd(); + } + + if(is_nchw) + { + const auto& hi = ctx.out_height; + const auto& wi = ctx.out_width; + const auto& n = ctx.batch_sz; + const auto& k = ctx.n_inputs; + const auto& c = ctx.n_outputs; + const auto& ho = ctx.in_height; + const auto& wo = ctx.in_width; + const auto& y = ctx.kernel_size_h; + const auto& x = ctx.kernel_size_w; + const auto& group = ctx.group_counts; + + TransposeSolutionNhwc2Default trans_input(ctx, ctx.out_data_type, n, c, hi, wi); + TransposeSolutionDefault2Nhwc trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionDefault2Nhwc trans_output(ctx, ctx.in_data_type, n, k, ho, wo); + + if(!trans_input.IsSkippable()) + { + result.construction_params.push_back(trans_input.GetKernel()); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", inp trans:" << trans_input.GetKernelName(); + } + if(!trans_weight.IsSkippable()) + { + result.construction_params.push_back(trans_weight.GetKernel()); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", wei trans:" << trans_weight.GetKernelName(); + } + if(!trans_output.IsSkippable()) + { + result.construction_params.push_back(trans_output.GetKernel()); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", out trans:" << trans_output.GetKernelName(); + } } - MIOPEN_LOG_I2("ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC: " + config.ToString()); + MIOPEN_LOG_I2(SolverDbId(*this) << ": " << config.ToString() << msg.str()); result.invoker_factory = conv::MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(ctx, config); diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 54add6b943..266055c8ee 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -30,6 +30,7 @@ #include #include #include +#include MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_FWD_GTC_XDLOPS_NHWC) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) @@ -738,16 +739,44 @@ ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::Search(const ConvolutionContext& ctx size_t ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetWorkspaceSize(const ConvolutionContext& ctx) const { - if(ctx.IsFp32()) - return 0; - - // FP16 or BF16 - const auto& n = ctx.batch_sz; - const auto& k = ctx.n_outputs; - const auto& ho = ctx.out_height; - const auto& wo = ctx.out_width; - return miopen::GetTypeSize(miopenFloat) // The intermediate output of the 1st kernel is FP32. - * n * k * ho * wo; + const auto& hi = ctx.in_height; + const auto& wi = ctx.in_width; + const auto& n = ctx.batch_sz; + const auto& k = ctx.n_outputs; + const auto& c = ctx.n_inputs; + const auto& ho = ctx.out_height; + const auto& wo = ctx.out_width; + const auto& y = ctx.kernel_size_h; + const auto& x = ctx.kernel_size_w; + const auto& group = ctx.group_counts; + const auto is_nchw = ctx.IsLayoutDefault(); + size_t workspace_size = 0; + if(is_nchw) + { + + TransposeSolutionDefault2Nhwc trans_input(ctx, ctx.in_data_type, n, c, hi, wi); + TransposeSolutionDefault2Nhwc trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionNhwc2Default trans_output(ctx, ctx.out_data_type, n, k, ho, wo); + + if(!trans_input.IsSkippable()) + workspace_size += trans_input.GetSize(); + if(!trans_weight.IsSkippable()) + workspace_size += trans_weight.GetSize(); + if(!trans_output.IsSkippable()) + workspace_size += trans_output.GetSize(); + } + + if(!ctx.IsFp32()) + workspace_size += miopen::GetTypeSize(miopenFloat) // The intermediate output of the 1st + // kernel is FP32, when using FP32 atomic + * n * k * ho * wo; + + return workspace_size; } bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable(const ConvolutionContext& ctx) const @@ -774,9 +803,6 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable(const ConvolutionC if(!ctx.rmv.IsV3()) return false; - if(!ctx.IsLayoutNHWC()) - return false; - const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) return false; // NOLINT (readability-simplify-boolean-expr) @@ -816,6 +842,8 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution( const auto isGfx90aFp16altSupport = (ctx.GetStream().GetDeviceName() == "gfx90a") && ctx.conv_problem.IsFp16(); + const auto is_nchw = ctx.IsLayoutDefault(); + result.construction_params.push_back(kernel); std::ostringstream options; GenerateClangDefsym(options, "ROCM_METADATA_VERSION", ctx.rmv.UseV3() ? 5 : 4); @@ -825,15 +853,61 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution( GenerateClangDefsym(opts_0, "igemm_fwd_fp16_alt_impl", 0); result.construction_params[0].comp_options = opts_0.str(); + std::ostringstream msg; + if(isGfx90aFp16altSupport) { result.construction_params.push_back(kernel); std::ostringstream opts_1(options.str(), std::ios_base::ate); GenerateClangDefsym(opts_1, "igemm_fwd_fp16_alt_impl", 1); result.construction_params[1].comp_options = opts_1.str(); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", fp16_alt:" << ctx.conv_problem.GetConv().attribute.gfx90aFp16alt.GetFwd(); + } + + if(is_nchw) + { + const auto& hi = ctx.in_height; + const auto& wi = ctx.in_width; + const auto& n = ctx.batch_sz; + const auto& k = ctx.n_outputs; + const auto& c = ctx.n_inputs; + const auto& ho = ctx.out_height; + const auto& wo = ctx.out_width; + const auto& y = ctx.kernel_size_h; + const auto& x = ctx.kernel_size_w; + const auto& group = ctx.group_counts; + + TransposeSolutionDefault2Nhwc trans_input(ctx, ctx.in_data_type, n, c, hi, wi); + TransposeSolutionDefault2Nhwc trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionNhwc2Default trans_output(ctx, ctx.out_data_type, n, k, ho, wo); + + if(!trans_input.IsSkippable()) + { + result.construction_params.push_back(trans_input.GetKernel()); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", inp trans:" << trans_input.GetKernelName(); + } + if(!trans_weight.IsSkippable()) + { + result.construction_params.push_back(trans_weight.GetKernel()); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", wei trans:" << trans_weight.GetKernelName(); + } + if(!trans_output.IsSkippable()) + { + result.construction_params.push_back(trans_output.GetKernel()); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", out trans:" << trans_output.GetKernelName(); + } } - MIOPEN_LOG_I2("ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC: " + config.ToString()); + MIOPEN_LOG_I2(SolverDbId(*this) << ": " << config.ToString() << msg.str()); result.invoker_factory = conv::MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory(ctx, config); return result; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index e27fb1ed90..7b73d811df 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -33,6 +33,7 @@ #include #include #include +#include MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_WRW_GTC_XDLOPS_NHWC) MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_PK_ATOMIC_ADD_FP16) @@ -804,9 +805,6 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable(const ConvolutionC if(!ctx.rmv.IsV3()) return false; - if(!ctx.IsLayoutNHWC()) - return false; - const auto target = ctx.GetStream().GetTargetProperties(); if(target.Xnack() && *target.Xnack()) return false; // NOLINT (readability-simplify-boolean-expr) @@ -865,19 +863,41 @@ ComputeDynamicIGemmWrwKernelArgsNHWC(const conv::ProblemDescription& conv_proble size_t ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetWorkspaceSize(const ConvolutionContext& ctx) const { - if(ctx.IsFp32()) - return 0; - else + const auto& hi = ctx.out_height; + const auto& wi = ctx.out_width; + const auto& n = ctx.batch_sz; + const auto& k = ctx.n_inputs; + const auto& c = ctx.n_outputs; + const auto& ho = ctx.in_height; + const auto& wo = ctx.in_width; + const auto& y = ctx.kernel_size_h; + const auto& x = ctx.kernel_size_w; + const auto& group = ctx.group_counts; + const auto is_nchw = ctx.IsLayoutDefault(); + size_t workspace_size = 0; + if(is_nchw) { - const auto k = ctx.n_inputs; - const auto c = ctx.n_outputs; - const auto y = ctx.kernel_size_h; - const auto x = ctx.kernel_size_w; - const auto ngroups = ctx.group_counts; - - return static_cast(ngroups) * (k / ngroups) * (c / ngroups) * y * x * - miopen::GetTypeSize(miopenFloat); + TransposeSolutionDefault2Nhwc trans_input(ctx, ctx.out_data_type, n, c, hi, wi); + TransposeSolutionNhwc2Default trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionDefault2Nhwc trans_output(ctx, ctx.in_data_type, n, k, ho, wo); + if(!trans_input.IsSkippable()) + workspace_size += trans_input.GetSize(); + if(!trans_weight.IsSkippable()) + workspace_size += trans_weight.GetSize(); + if(!trans_output.IsSkippable()) + workspace_size += trans_output.GetSize(); } + + if(!ctx.IsFp32()) + workspace_size += miopen::GetTypeSize(miopenFloat) // The intermediate output of the 1st + // kernel is FP32, when using FP32 atomic + * (k / group) * c * y * x; + return workspace_size; } ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( @@ -932,6 +952,19 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( const auto& conv_problem = ctx.conv_problem; const auto isFp16 = conv_problem.IsFp16(); const auto isGfx90aFp16altSupport = (ctx.GetStream().GetDeviceName() == "gfx90a") && isFp16; + const bool need_cast = (conv_problem.IsBfp16() && gemm_k_global_splits >= 1) || (isFp16 && gemm_k_global_splits >= 1 && (config.tensor_b_thread_lengths[3] == 1 || config.vector_store == 1)); + + const auto& hi = ctx.out_height; + const auto& wi = ctx.out_width; + const auto& n = ctx.batch_sz; + const auto& k = ctx.n_inputs; + const auto& c = ctx.n_outputs; + const auto& ho = ctx.in_height; + const auto& wo = ctx.in_width; + const auto& y = ctx.kernel_size_h; + const auto& x = ctx.kernel_size_w; + const auto& group = ctx.group_counts; + const auto is_nchw = ctx.IsLayoutDefault(); result.construction_params.push_back(kernel); // Intentionally without options. std::ostringstream options; // Common options for both kernels. @@ -941,6 +974,7 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( if(isGfx90aFp16altSupport) GenerateClangDefsym(opts_0, "igemm_wrw_fp16_alt_impl", 0); result.construction_params[0].comp_options = opts_0.str(); + std::ostringstream msg; if(isGfx90aFp16altSupport) { @@ -948,26 +982,100 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( std::ostringstream opts_1(options.str(), std::ios_base::ate); // Options for alt kernel. GenerateClangDefsym(opts_1, "igemm_wrw_fp16_alt_impl", 1); result.construction_params[1].comp_options = opts_1.str(); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", fp16_alt:" <> opArgsTrans; + size_t trans_input_offset = 0; + size_t trans_input_size = 0; + + size_t trans_weight_offset = 0; + size_t trans_weight_size = 0; - if((conv_problem.IsBfp16() && gemm_k_global_splits >= 1) || (isFp16 && gemm_k_global_splits >= 1 && (config.tensor_b_thread_lengths[3] == 1 || config.vector_store == 1))) + size_t trans_output_offset = 0; + size_t trans_output_size = 0; + + bool trans_input_skippable = false; + bool trans_weight_skippable = false; + bool trans_output_skippable = false; + + int trans_input_idx = -1; + int trans_weight_idx = -1; + int trans_output_idx = -1; + if(is_nchw) + { + TransposeSolutionDefault2Nhwc trans_input(ctx, ctx.out_data_type, n, c, hi, wi); + TransposeSolutionNhwc2Default trans_weight(ctx, + ctx.weights_data_type, + k, + c / group, + y, + x); // group * k_per_group as batch for weight + TransposeSolutionDefault2Nhwc trans_output(ctx, ctx.in_data_type, n, k, ho, wo); + + trans_input_skippable = trans_input.IsSkippable(); + trans_weight_skippable = trans_weight.IsSkippable(); + trans_output_skippable = trans_output.IsSkippable(); + + if(!trans_input_skippable){ + result.construction_params.push_back(trans_input.GetKernel()); + opArgsTrans.emplace_back(trans_input.GetKernelArg()); + if(miopen::IsLogging(LoggingLevel::Info2)) + msg << ", inp trans:"< {}; + + if(need_cast) { - TensorDescriptor workspaceDesc(miopenFloat, - conv_problem.GetWeights().GetLengths(), - conv_problem.GetWeights().GetStrides()); result.invoker_factory = [=](const std::vector& kernels) mutable { return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) mutable { decltype(auto) wrw_invoke_params = primitive_parameters.CastTo(); const auto& tensors = wrw_invoke_params.tensors; - const auto k = handle.Run( + const auto ker = handle.Run( kernels[(isGfx90aFp16altSupport && wrw_invoke_params.gfx90aFp16alt) ? 1 : 0]); const auto& workSpace = wrw_invoke_params.workSpace; const auto& workSpaceSize = wrw_invoke_params.workSpaceSize; @@ -978,28 +1086,66 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( MIOPEN_THROW("Not enough workspace has been provided for " "ConvAsmImplicitGemmGTCDynamicWrwXdlops with fp16 and atomic " "add."); - - SetTensor(handle, workspaceDesc, workSpace, &zero); + auto trans_input_buf = trans_input_size== 0 ?null_buf : handle.CreateSubBuffer( + workSpace, trans_input_offset, trans_input_size); + auto trans_weight_buf = trans_weight_size==0 ? null_buf : handle.CreateSubBuffer( + workSpace, trans_weight_offset, trans_weight_size); + auto trans_output_buf = trans_output_size ==0 ? null_buf : handle.CreateSubBuffer( + workSpace, trans_output_offset, trans_output_size); + auto cast_buf = cast_size == 0 ? null_buf : handle.CreateSubBuffer( + workSpace, cast_offset, cast_size); + + SetTensor(handle, cast_desc, cast_buf.get(), &zero); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); - opArgs[0] = OpKernelArg(tensors.x); - opArgs[1] = OpKernelArg(workSpace); - opArgs[2] = OpKernelArg(tensors.dy); + if(is_nchw) + { + if(!trans_input_skippable){ + auto& karg_input = opArgsTrans[trans_input_idx]; + karg_input[0] = OpKernelArg(trans_input_buf.get()); + karg_input[1] = OpKernelArg(tensors.x); + handle.Run(kernels[kID_trans_start + trans_input_idx])(karg_input); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + if(!trans_output_skippable){ + auto& karg_output = opArgsTrans[trans_output_idx]; + karg_output[0] = OpKernelArg(trans_output_buf.get()); + karg_output[1] = OpKernelArg(tensors.dy); + handle.Run(kernels[kID_trans_start + trans_output_idx])(karg_output); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + } + + opArgs[0] = (is_nchw && !trans_input_skippable) ? OpKernelArg(trans_input_buf.get()) : OpKernelArg(tensors.x); + opArgs[1] = OpKernelArg(cast_buf.get()); + opArgs[2] = (is_nchw && !trans_output_skippable) ? OpKernelArg(trans_output_buf.get()) : OpKernelArg(tensors.dy); - k(opArgs); + ker(opArgs); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); CastTensor(handle, &lowp_quant, - workspaceDesc, - workSpace, + cast_desc, + cast_buf.get(), tensors.dwDesc, - tensors.dw, + (is_nchw && !trans_weight_skippable) ? trans_weight_buf.get() : tensors.dw, 0, 0); + if(is_nchw && !trans_weight_skippable) + { + auto& karg_weight = opArgsTrans[trans_weight_idx]; + karg_weight[0] = OpKernelArg(tensors.dw); + karg_weight[1] = OpKernelArg(trans_weight_buf.get()); + handle.Run(kernels[kID_trans_start + trans_weight_idx])(karg_weight); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); @@ -1018,23 +1164,65 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( decltype(auto) wrw_invoke_params = primitive_parameters.CastTo(); const auto& tensors = wrw_invoke_params.tensors; - const auto k = handle.Run( + const auto ker = handle.Run( kernels[(isGfx90aFp16altSupport && wrw_invoke_params.gfx90aFp16alt) ? 1 : 0]); + const auto& workSpace = wrw_invoke_params.workSpace; float elapsed = 0; float zero = 0.f; - opArgs[0] = OpKernelArg(tensors.x); - opArgs[1] = OpKernelArg(tensors.dw); - opArgs[2] = OpKernelArg(tensors.dy); + auto trans_input_buf = trans_input_size== 0 ?null_buf : handle.CreateSubBuffer( + workSpace, trans_input_offset, trans_input_size); + auto trans_weight_buf = trans_weight_size==0 ? null_buf : handle.CreateSubBuffer( + workSpace, trans_weight_offset, trans_weight_size); + auto trans_output_buf = trans_output_size ==0 ? null_buf : handle.CreateSubBuffer( + workSpace, trans_output_offset, trans_output_size); + auto cast_buf = cast_size == 0 ? null_buf : handle.CreateSubBuffer( + workSpace, cast_offset, cast_size); - SetTensor(handle, tensors.dwDesc, tensors.dw, &zero); + opArgs[0] = (is_nchw && !trans_input_skippable) ? OpKernelArg(trans_input_buf.get()) : OpKernelArg(tensors.x); + opArgs[1] = (is_nchw && !trans_weight_skippable)? OpKernelArg(trans_weight_buf.get()) : OpKernelArg(tensors.dw); + opArgs[2] = (is_nchw && !trans_output_skippable) ? OpKernelArg(trans_output_buf.get()) : OpKernelArg(tensors.dy); + + SetTensor(handle, tensors.dwDesc, (is_nchw && !trans_weight_skippable) ? trans_weight_buf.get() : tensors.dw, &zero); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); - k(opArgs); + if(is_nchw) + { + if(!trans_input_skippable){ + + auto& karg_input = opArgsTrans[trans_input_idx]; + karg_input[0] = OpKernelArg(trans_input_buf.get()); + karg_input[1] = OpKernelArg(tensors.x); + handle.Run(kernels[kID_trans_start + trans_input_idx])(karg_input); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + if(!trans_output_skippable){ + + auto& karg_output = opArgsTrans[trans_output_idx]; + karg_output[0] = OpKernelArg(trans_output_buf.get()); + karg_output[1] = OpKernelArg(tensors.dy); + handle.Run(kernels[kID_trans_start + trans_output_idx])(karg_output); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + } + + ker(opArgs); if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); + if(is_nchw && !trans_weight_skippable) + { + auto& karg_weight = opArgsTrans[trans_weight_idx]; + karg_weight[0] = OpKernelArg(tensors.dw); + karg_weight[1] = OpKernelArg(trans_weight_buf.get()); + handle.Run(kernels[kID_trans_start + trans_weight_idx])(karg_weight); + if(handle.IsProfilingEnabled()) + elapsed += handle.GetKernelTime(); + } + if(handle.IsProfilingEnabled()) { handle.ResetKernelTime(); diff --git a/test/gpu_nchw_nhwc_transpose.cpp b/test/gpu_nchw_nhwc_transpose.cpp new file mode 100644 index 0000000000..43004f8163 --- /dev/null +++ b/test/gpu_nchw_nhwc_transpose.cpp @@ -0,0 +1,435 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "test.hpp" +#include "driver.hpp" +#include "random.hpp" + +template <> +struct miopen_type : std::integral_constant +{ +}; + +template <> +struct miopen_type : std::integral_constant +{ +}; + +template +void cpu_nchw2nhwc(T* dst, T* src, uint64_t N, uint64_t C, uint64_t H, uint64_t W) +{ + for(uint64_t i_n = 0; i_n < N; i_n++) + { + for(uint64_t i_h = 0; i_h < H; i_h++) + { + for(uint64_t i_w = 0; i_w < W; i_w++) + { + for(uint64_t i_c = 0; i_c < C; i_c++) + { + uint64_t idx_nhwc = i_n * H * W * C + i_h * W * C + i_w * C + i_c; + uint64_t idx_nchw = i_n * C * H * W + i_c * H * W + i_h * W + i_w; + dst[idx_nhwc] = src[idx_nchw]; + } + } + } + } +} + +template +void cpu_nhwc2nchw(T* dst, T* src, uint64_t N, uint64_t C, uint64_t H, uint64_t W) +{ + for(uint64_t i_n = 0; i_n < N; i_n++) + { + for(uint64_t i_c = 0; i_c < C; i_c++) + { + for(uint64_t i_h = 0; i_h < H; i_h++) + { + for(uint64_t i_w = 0; i_w < W; i_w++) + { + uint64_t idx_nhwc = i_n * H * W * C + i_h * W * C + i_w * C + i_c; + uint64_t idx_nchw = i_n * C * H * W + i_c * H * W + i_h * W + i_w; + dst[idx_nchw] = src[idx_nhwc]; + } + } + } + } +} + +template +struct cpu_transpose +{ +}; + +template +struct cpu_transpose +{ + static void run(T* dst, T* src, uint64_t N, uint64_t C, uint64_t H, uint64_t W) + { + cpu_nchw2nhwc(dst, src, N, C, H, W); + } +}; + +template +struct cpu_transpose +{ + static void run(T* dst, T* src, uint64_t N, uint64_t C, uint64_t H, uint64_t W) + { + cpu_nhwc2nchw(dst, src, N, C, H, W); + } +}; + +template +struct transpose_str +{ +}; + +template <> +struct transpose_str +{ + static std::string get() { return "nchw2nhwc"; } +}; + +template <> +struct transpose_str +{ + static std::string get() { return "nhwc2nchw"; } +}; + +enum tensor_layout_t +{ + miopen_tensor_layout_nchw, + miopen_tensor_layout_ncdhw, + miopen_tensor_layout_nhwc, + miopen_tensor_layout_ndhwc, +}; + +std::string tensor_layout_to_string(tensor_layout_t layout) +{ + std::string layout_string("N/A"); + if(layout == miopen_tensor_layout_nchw) + layout_string = "NCHW"; + else if(layout == miopen_tensor_layout_ncdhw) + layout_string = "NCDHW"; + else if(layout == miopen_tensor_layout_nhwc) + layout_string = "NHWC"; + else if(layout == miopen_tensor_layout_ndhwc) + layout_string = "NDHWC"; + else + MIOPEN_THROW("Unsupported tensor layout"); + return layout_string; +} + +template +struct to_miopen_data_type +{ +}; + +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenFloat; } +}; + +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenHalf; } // we actually didn't calculate 16bit float +}; + +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenInt8; } +}; + +#define RAND_INTEGER_MAX 120 +#define RAND_INTEGER_MIN -88 + +static int gen_rand_integer() +{ + // NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) + static int inited = 0; + if(inited == 0) + { + std::srand(std::time(nullptr)); + inited = 1; + } + return GET_RAND(); +} + +template +void rand_tensor_integer(tensor& t, int max = RAND_INTEGER_MAX, int min = RAND_INTEGER_MIN) +{ + // use integer to random. + for(int i = 0; i < t.data.size(); i++) + t[i] = static_cast(gen_rand_integer() % (max - min) + min); +} + +template +bool compare_equal(T r1, T r2) +{ + return r1 == r2; +} + +template <> +bool compare_equal(float r1, float r2) +{ + return miopen::float_equal(r1, r2); +} + +template +bool verify_tensor(tensor& t_gpu, tensor& t_cpu) +{ + if(t_gpu.data.size() != t_cpu.data.size()) + { + MIOPEN_LOG_E("size not equal, should not happen"); + return false; + } + auto idx = miopen::mismatch_idx(t_gpu.data, t_cpu.data, compare_equal); + bool valid_result = idx >= miopen::range_distance(t_cpu); + + if(!valid_result) + { + std::cout << "diff at:" << idx << ", gpu:" << t_gpu[idx] << ", cpu:" << t_cpu[idx] + << std::endl; + } + return valid_result; +} + +struct transpose_base +{ + miopenHandle_t handle{}; +#if MIOPEN_BACKEND_OPENCL + cl_command_queue q{}; +#endif + + transpose_base() + { + miopenCreate(&handle); +#if MIOPEN_BACKEND_OPENCL + miopenGetStream(handle, &q); +#endif + } + ~transpose_base() { miopenDestroy(handle); } + + static std::vector get_image_size() { return {1, 9, 14}; } + + static std::vector get_channel_size() { return {3, 8, 14}; } + + static std::vector get_batch_size() { return {1, 2}; } + + template + void iterate_transpose(F f) + { + std::vector channel_list = get_channel_size(); + std::vector image_list = get_image_size(); + std::vector batch_list = get_batch_size(); + channel_list.push_back(gen_rand_integer() % 13 + 29); + image_list.push_back(gen_rand_integer() % 13 + 15); + batch_list.push_back(gen_rand_integer() % 4 + 3); + + for(uint32_t c : channel_list) + { + for(uint32_t h : image_list) + { + for(uint32_t w : image_list) + { + for(uint32_t n : batch_list) + { + f(n, c, h, w); + } + } + } + } + } +}; + +struct transpose_invoke_param : public miopen::InvokeParams +{ + ConstData_t src = nullptr; + Data_t dst = nullptr; + + transpose_invoke_param(ConstData_t src_, Data_t dst_) : src(src_), dst(dst_) {} + transpose_invoke_param(miopen::InvokeType type_, ConstData_t src_, Data_t dst_) + : InvokeParams{type_}, src(src_), dst(dst_) + { + } +}; + +template +struct transpose_test : transpose_base +{ + void run() + { + auto run_transpose = [this](uint32_t n, uint32_t c, uint32_t h, uint32_t w) { + int tensor_sz = n * c * h * w; + std::vector tensor_len({static_cast(n), + static_cast(c), + static_cast(h), + static_cast(w)}); + + std::vector tensor_strides; + + std::string layout_default = miopen::tensor_layout_get_default(4); + std::string layout_string = tensor_layout_to_string(miopen_tensor_layout_nchw); + + miopen::tensor_layout_to_strides( + tensor_len, layout_default, layout_string, tensor_strides); + + tensor t_src(tensor_len, tensor_strides); + tensor t_dst(tensor_len, tensor_strides); + tensor t_dst_gpu(tensor_len, tensor_strides); + rand_tensor_integer(t_src); +#if MIOPEN_BACKEND_OPENCL + cl_context cl_ctx; + clGetCommandQueueInfo(q, CL_QUEUE_CONTEXT, sizeof(cl_context), &cl_ctx, nullptr); + cl_int status = CL_SUCCESS; + cl_mem src_dev = + clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, &status); + cl_mem dst_dev = + clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, nullptr); + status |= clEnqueueWriteBuffer(q, + src_dev, + CL_TRUE, + 0, + sizeof(T) * tensor_sz, + t_src.data.data(), + 0, + nullptr, + nullptr); + EXPECT(status == CL_SUCCESS); +#elif MIOPEN_BACKEND_HIP + void* src_dev; + void* dst_dev; + EXPECT(hipMalloc(&src_dev, sizeof(T) * tensor_sz) == hipSuccess); + EXPECT(hipMalloc(&dst_dev, sizeof(T) * tensor_sz) == hipSuccess); + EXPECT(hipMemcpy( + src_dev, t_src.data.data(), sizeof(T) * tensor_sz, hipMemcpyHostToDevice) == + hipSuccess); +#endif + + const auto invoke_param = transpose_invoke_param{ + DataCast(static_cast(src_dev)), DataCast(dst_dev)}; + + miopen::ExecutionContext ctx; + ctx.SetStream(&miopen::deref(this->handle)); + ctx.DetectRocm(); + // ctx.SetupFloats(); + + TRANSPOSE_SOL transpose_sol(ctx, to_miopen_data_type::get(), n, c, h, w); + + std::vector opArgs = transpose_sol.GetKernelArg(); + + boost::optional invoker_factory( + [=](const std::vector& kernels) mutable { + return [=](const miopen::Handle& handle, + const miopen::AnyInvokeParams& primitive_param) mutable { + decltype(auto) invoke_params = + primitive_param.CastTo(); + + const auto k = handle.Run(kernels[0]); + + opArgs[0] = OpKernelArg(invoke_params.dst); + opArgs[1] = OpKernelArg(invoke_params.src); + + k(opArgs); + }; + }); + + std::vector construction_params{transpose_sol.GetKernel()}; + + const auto invoker = + miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); + + // run gpu + invoker(miopen::deref(this->handle), invoke_param); + + // run cpu + cpu_transpose::run(t_dst.data.data(), t_src.data.data(), n, c, h, w); + +#if MIOPEN_BACKEND_OPENCL + status = clEnqueueReadBuffer(q, + dst_dev, + CL_TRUE, + 0, + sizeof(T) * tensor_sz, + t_dst_gpu.data.data(), + 0, + nullptr, + nullptr); + EXPECT(status == CL_SUCCESS); +#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(t_dst_gpu.data.data(), + dst_dev, + sizeof(T) * tensor_sz, + hipMemcpyDeviceToHost) == hipSuccess); +#endif + + // we expect excact match, since use integer + bool valid_result = verify_tensor(t_dst_gpu, t_dst); + + std::cout << "[" << transpose_str::get() << ", b" << (sizeof(T) * 8) + << " ] " + << "n:" << n << ", c:" << c << ", h:" << h << ", w:" << w + << ", valid:" << valid_result << std::endl; + + EXPECT(valid_result == true); + +#if MIOPEN_BACKEND_OPENCL + clReleaseMemObject(src_dev); + clReleaseMemObject(dst_dev); +#elif MIOPEN_BACKEND_HIP + hipFree(src_dev); + hipFree(dst_dev); +#endif + }; + + iterate_transpose(run_transpose); + } +}; + +int main() +{ + run_test>(); + run_test>(); + run_test>(); + + run_test>(); + run_test>(); + run_test>(); +}