From c52547acbef190253bd02d46edd9e2279bd63248 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 24 Jan 2022 17:40:28 +0800 Subject: [PATCH 01/77] test_file commit --- test/tensor_reorder.cpp | 542 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 542 insertions(+) create mode 100644 test/tensor_reorder.cpp diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp new file mode 100644 index 0000000000..f2e5aeaac8 --- /dev/null +++ b/test/tensor_reorder.cpp @@ -0,0 +1,542 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "gpu_tensor_reorder.h" +#include "sequence.hpp" + + +#ifndef HIP_CALL +#define HIP_CALL(call) \ + do { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + printf("[hiperror](%d) fail to call %s,(%s)\n", (int)err, #call, \ + hipGetErrorString(err)); \ + exit(1); \ + } \ + } while (0) +#endif + +static inline int env_get_int(const char *var_name, int default_int) { + char *v = getenv(var_name); + int r = default_int; + if (v) + r = atoi(v); + return r; +} + +static int gen_rand_integer() +{ + static int inited = 0; + if(inited == 0) + { + std::srand(std::time(nullptr)); + inited = 1; + } + return std::rand(); +} + + +static inline char *env_get_str(char *var_name, char* default_str) { + char *v = getenv(var_name); + if (v) + return v; + return default_str; +} + +template +struct distribution_t{ +}; + +template <> +struct distribution_t{ + distribution_t(int min, int max) : distribution(min, max) {} + template + int8_t operator()(URNG & rng){ + int value = distribution(rng); + return *reinterpret_cast(&value); + //return 0xf; + } + std::uniform_int_distribution distribution; +}; +template <> +struct distribution_t{ + distribution_t(int min, int max) : distribution(min, max) {} + template + int operator()(URNG & rng){ return distribution(rng);} + std::uniform_int_distribution distribution; +}; +template <> +struct distribution_t{ + distribution_t(float min, float max) : distribution(min, max) {} + template + float operator()(URNG & rng){ return distribution(rng);} + std::uniform_real_distribution distribution; +}; + +template +void block_wise_rand_generator(Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale) +{ + std::mt19937 rng(std::chrono::system_clock::now() + .time_since_epoch() + .count() + + std::hash()(std::this_thread::get_id())); + distribution_t distribution(min,max); + for (int i = tid; i < total_size; i += block_size) { + p[i] = static_cast(scale * distribution(rng)); + } +} + +template +void gen_rand_vector(Dst_T *vec, size_t vec_size, Src_T fmin, Src_T fmax, Src_T scale = 1) { + int num_threads = std::thread::hardware_concurrency(); + if (num_threads < 4) + num_threads = 4; + // printf("total threads:%d\n",num_threads); + std::vector threads; + for (int t = 0; t < num_threads; t++) { + threads.push_back(std::thread(block_wise_rand_generator, + vec, t, num_threads, vec_size, fmin, fmax, scale)); + } + for (auto &th : threads) + th.join(); +} + +static inline bool valid_float(float p) +{ + return !(std::isnan(p) || std::isinf(p)); +} +#ifndef ABS +#define ABS(b) ((b) > 0 ? (b) : -1 * (b)) +#endif +static inline bool valid_vector(const float *ref, const float *pred, int n, + double nrms = 1.5e-6) { + double s0 = 0.0; + double s1 = 0.0; + int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); + int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); + int pp_err = 0; + + for (int i = 0; i < n; ++i) { + if(!(valid_float(ref[i]) && valid_float(pred[i]))){ + printf(" invalid float at %d, ref:%f, pred:%f\n", i, ref[i], pred[i]); + return -1; + } + double ri = (double)ref[i]; + double pi = (double)pred[i]; + double d = ri - pi; + double dd = d * d; + double rr = 2.0 * ri * ri; + s0 += dd; + s1 += rr; + if(igemm_per_pixel_check){ + double delta = ABS(ABS(ri - pi) / ri); + printf("[%d] ref:%lf, pred:%lf(0x%08x) [%s]\n", i, ri, pi, ((uint32_t *)pred)[i], delta > 3e-5? "N":"Y"); + if (delta > 3e-5) { + if(igemm_per_pixel_check_print){ + if (pp_err < 100) + printf("diff at %d, ref:%lf, pred:%lf(0x%08x), d:%lf\n", i, ri, + pi, ((uint32_t *)pred)[i], delta); + } + pp_err++; + } + + } + } + // printf("\nnrms:%lf, s0:%lf, s1:%lf, expected_nrms is %1f\n",sqrt(s0/s1),s0,s1,nrms); + fflush(stdout); + return (sqrt(s0 / s1) < nrms) +#ifdef PER_PIXEL_CHECK + && (pp_err == 0) +#endif + ; +} + +static inline bool valid_vector_binary(int8_t *ref, int8_t *pred, size_t bytes) { + int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); + size_t err = 0; + for(int i = 0; i < bytes ; i++){ + // { + // uint32_t r = 0; + // uint32_t p = 0; + // memcpy(reinterpret_cast(&r), reinterpret_cast(&ref[i]), 1); + // memcpy(reinterpret_cast(&p), reinterpret_cast(&pred[i]), 1); + // printf("%7d, ref:0x%x, pred:0x%x, %s\n", i, r, p, r==p?"y":"n"); + // } + if(ref[i] != pred[i]){ + err ++; + if(igemm_per_pixel_check){ + uint32_t r = 0; + uint32_t p = 0; + memcpy(reinterpret_cast(&r), reinterpret_cast(&ref[i]), 1); + memcpy(reinterpret_cast(&p), reinterpret_cast(&pred[i]), 1); + printf("fail at %d, ref:0x%x, pred:0x%x\n", i, r, p); + } + } + } + return err == 0; +} + +template +void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64_t dim_2, uint64_t dim_3) +{ + constexpr auto dorder = dst_order{}; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + + const uint64_t src_stride[4] ={src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint64_t itr_src_dim[4] = {0, 0, 0, 0}; + uint64_t itr_dst_dim[4] = {0, 0, 0, 0}; + + for(itr_src_dim[0] = 0; itr_src_dim[0] < src_dim[0]; itr_src_dim[0]++){ + for(itr_src_dim[1] = 0; itr_src_dim[1] < src_dim[1]; itr_src_dim[1]++){ + for(itr_src_dim[2] = 0; itr_src_dim[2] < src_dim[2]; itr_src_dim[2]++){ + for(itr_src_dim[3] = 0; itr_src_dim[3] < src_dim[3]; itr_src_dim[3]++){ + itr_dst_dim[0] = itr_src_dim[dorder.at(0)]; + itr_dst_dim[1] = itr_src_dim[dorder.at(1)]; + itr_dst_dim[2] = itr_src_dim[dorder.at(2)]; + itr_dst_dim[3] = itr_src_dim[dorder.at(3)]; + + uint64_t idx_src = itr_src_dim[0] * src_stride[0] + + itr_src_dim[1] * src_stride[1] + + itr_src_dim[2] * src_stride[2] + + itr_src_dim[3] * src_stride[3] ; + uint64_t idx_dst = itr_dst_dim[0] * dst_stride[0] + + itr_dst_dim[1] * dst_stride[1] + + itr_dst_dim[2] * dst_stride[2] + + itr_dst_dim[3] * dst_stride[3] ; + + dst[idx_dst] = src[idx_src]; + } + } + } + } +} + +//compile time for_loop +namespace detail { + + template + constexpr void loop(std::integer_sequence, F&& f) { + (f(std::integral_constant{}), ...);// C++17 fold expression + } + +} + +template +constexpr void loop(F&& f) { + detail::loop(std::make_integer_sequence{}, std::forward(f)); +} + +#define WARMUP 3 +#define REPEAT 7 +#define BATCHED_TRANSPOSE_HSACO "out/batched_transpose.hsaco" +#define GENERAL_TENSOR_REORDER_HSACO "out/general_tensor_reorder.hsaco" + +int main(int argc, char ** argv){ + if(argc < 5){ + printf("%s Please input tensor size in order of: DIM0, DIM1, DIM2, DIM3\n", argv[0]); + return -1; + } + if(argc > 5){ + printf("Too many argument\n"); + return -1; + } + int warmup = env_get_int("IGEMM_WARMUP", WARMUP); + int repeat = env_get_int("IGEMM_REPEAT", REPEAT); + const uint64_t dim_0 = std::stoull(std::string(argv[1])); + const uint64_t dim_1 = std::stoull(std::string(argv[2])); + const uint64_t dim_2 = std::stoull(std::string(argv[3])); + const uint64_t dim_3 = std::stoull(std::string(argv[4])); + + size_t size_byte = 4; + const char* fp = env_get_str("FP", "32"); + std::string fp_str(fp); + if(fp_str == "32") + size_byte = 4; + else if(fp_str == "16") + size_byte = 2; + else if(fp_str == "8") + size_byte = 1; + else{ + printf("error FP:%s\n", fp); + return -1; + } + + bool batched = false; + bool is_kernel_valid = false; + const char* hsaco; + void * src_cpu = malloc(dim_0*dim_1*dim_2*dim_3*size_byte); + void * dst_cpu = malloc(dim_0*dim_1*dim_2*dim_3*size_byte); + void * dst_gpu_valid = malloc(dim_0*dim_1*dim_2*dim_3*size_byte); + + void * src_gpu; + void * dst_gpu; + + HIP_CALL(hipMalloc(&src_gpu, dim_0*dim_1*dim_2*dim_3*size_byte)); + HIP_CALL(hipMalloc(&dst_gpu, dim_0*dim_1*dim_2*dim_3*size_byte)); + + gen_rand_vector(reinterpret_cast(src_cpu), dim_0*dim_1*dim_2*dim_3*size_byte, -116, 121); + HIP_CALL(hipMemcpy(src_gpu, src_cpu, dim_0*dim_1*dim_2*dim_3*size_byte, hipMemcpyHostToDevice)); + +loop([&](auto i) { + constexpr int all_possible_sequence[23][4] = { + {0, 1, 3, 2}, {2, 3, 0, 1}, {3, 0, 1, 2}, {0, 2, 3, 1}, {0, 3, 1, 2}, //BATCHED TRANSPOSE + {0, 2, 1, 3}, {0, 3, 2, 1}, + {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, + {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 1, 0}, + {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; + using dst_order = sequence; + std::cout <<" Tensor reorder to ("<< dst_order::at(0)<<","<< dst_order::at(1)<<","<< dst_order::at(2)<<","<< dst_order::at(3)<<")" << std::endl; + + //TODO: an API with more privacy + auto launch_gpu_init = [&](){ + if((dst_order::at(0)==0 && dst_order::at(1)==1 && dst_order::at(2)==3 && dst_order::at(3)==2) || + (dst_order::at(0)==0 && dst_order::at(1)==2 && dst_order::at(2)==3 && dst_order::at(3)==1) || + (dst_order::at(0)==0 && dst_order::at(1)==3 && dst_order::at(2)==1 && dst_order::at(3)==2) || + (dst_order::at(0)==3 && dst_order::at(1)==0 && dst_order::at(2)==1 && dst_order::at(3)==2) || + (dst_order::at(0)==2 && dst_order::at(1)==3 && dst_order::at(2)==0 && dst_order::at(3)==1) + ){ + printf("choose batched transpose kernel\n"); + batched = true; + //batched transpose. NCHW <----> NHWC, (NC)cHW <----> (NC)HWc + hsaco = env_get_str("BATCHED_TRANSPOSE", BATCHED_TRANSPOSE_HSACO); + gpu_nhwc_nchw_transpose_init(hsaco); + } + else { + printf("choose general tensor reorder kernel\n"); + hsaco = env_get_str("GENERAL_TENSOR_REORDER_HSACO", GENERAL_TENSOR_REORDER_HSACO); + gpu_tensor_reorder_init(hsaco); + } + }; + + auto launch_gpu_tensor_reorder = [&](const transpose_kernel_param_t * kparam){ + if(fp_str == "32") + gpu_tensor_reorder(reinterpret_cast (dst_gpu), reinterpret_cast (src_gpu), dim_0, dim_1, dim_2, dim_3, kparam); + else if(fp_str == "16") + gpu_tensor_reorder(reinterpret_cast(dst_gpu), reinterpret_cast(src_gpu), dim_0, dim_1, dim_2, dim_3, kparam); + else if(fp_str == "8") + gpu_tensor_reorder(reinterpret_cast(dst_gpu), reinterpret_cast(src_gpu), dim_0, dim_1, dim_2, dim_3, kparam); + }; + + auto launch_cpu_tensor_reorder = [&](){ + if(fp_str == "32") + cpu_tensor_reorder(reinterpret_cast (dst_cpu), reinterpret_cast (src_cpu), dim_0, dim_1, dim_2, dim_3); + else if(fp_str == "16") + cpu_tensor_reorder(reinterpret_cast(dst_cpu), reinterpret_cast(src_cpu), dim_0, dim_1, dim_2, dim_3); + else if(fp_str == "8") + cpu_tensor_reorder(reinterpret_cast(dst_cpu), reinterpret_cast(src_cpu), dim_0, dim_1, dim_2, dim_3); + }; + + auto test_batched_transpose = [&](const transpose_kernel_param_t *transpose_kparam){ + float kernel_time = 0; + bool valid = false; + bool is_kernel_valid = false; + + if(dst_order::at(0)==0 && dst_order::at(1)==2 && dst_order::at(2)==3 && dst_order::at(3)==1){ + is_kernel_valid = transpose_kernel_is_valid(dim_0, dim_1, dim_2 * dim_3, transpose_kparam); + } + else if(dst_order::at(0)==0 && dst_order::at(1)==1 && dst_order::at(2)==3 && dst_order::at(3)==2){ + is_kernel_valid = transpose_kernel_is_valid(dim_0 * dim_1, dim_2, dim_3, transpose_kparam); + } + else if(dst_order::at(0)==0 && dst_order::at(1)==3 && dst_order::at(2)==1 && dst_order::at(3)==2){ + is_kernel_valid = transpose_kernel_is_valid(dim_0, dim_1 * dim_2, dim_3, transpose_kparam); + } + else if(dst_order::at(0)==3 && dst_order::at(1)==0 && dst_order::at(2)==1 && dst_order::at(3)==2){ + is_kernel_valid = transpose_kernel_is_valid(1, dim_0 * dim_1 * dim_2, dim_3, transpose_kparam); + } + //dst_order::at(0)==2 && dst_order::at(1)==3 && dst_order::at(2)==0 && dst_order::at(3)==1 + else{ + is_kernel_valid = transpose_kernel_is_valid(1, dim_0 * dim_1, dim_2 * dim_3, transpose_kparam); + } + if(is_kernel_valid){ + hipEvent_t start, stop; + HIP_CALL(hipMemset(dst_gpu, 0, dim_0*dim_1*dim_2*dim_3*size_byte)); + + for(int i=0; i< warmup; i++){ + launch_gpu_tensor_reorder(transpose_kparam); + } + + HIP_CALL(hipEventCreate(&start)); + HIP_CALL(hipEventCreate(&stop)); + HIP_CALL(hipDeviceSynchronize()); + HIP_CALL(hipEventRecord(start, 0) ); + + for(int i=0; i< repeat; i++){ + launch_gpu_tensor_reorder(transpose_kparam); + } + HIP_CALL(hipEventRecord(stop, 0) ); + HIP_CALL(hipEventSynchronize(stop) ); + HIP_CALL(hipEventElapsedTime(&kernel_time, start, stop) ); + HIP_CALL(hipEventDestroy(start) ); + HIP_CALL(hipEventDestroy(stop) ); + kernel_time = kernel_time / repeat; + + launch_cpu_tensor_reorder(); + + HIP_CALL(hipMemcpy(dst_gpu_valid, dst_gpu, dim_0*dim_1*dim_2*dim_3*size_byte, hipMemcpyDeviceToHost)); + + valid = valid_vector_binary(reinterpret_cast(dst_cpu), reinterpret_cast(dst_gpu_valid), dim_0*dim_1*dim_2*dim_3*size_byte); + } + + double flop_cnt = 2 * dim_0*dim_1*dim_2*dim_3*size_byte; + double bw = is_kernel_valid ? flop_cnt / kernel_time / 1e6 : 0; + + printf("[tensor_reorder fp%s] tensor_size:(%lu, %lu, %lu, %lu), flop:%.0f, time:%fms, bw:%.4fGB/s, valid:%s (%dx%d, %dx%d, %dx%d)\n", + fp_str.c_str(), dim_0, dim_1, dim_2, dim_3, flop_cnt, kernel_time, bw, is_kernel_valid ? (valid ? "y" : "n") : "x", + transpose_kparam->tile_x, transpose_kparam->tile_y, transpose_kparam->pack_x, transpose_kparam->pack_y, transpose_kparam->ediv_x, transpose_kparam->ediv_y); + fflush(stdout); + + return valid && is_kernel_valid ? kernel_time : FLT_MAX; + }; + + auto test_general_tensor_reorder = [&](const transpose_kernel_param_t *transpose_kparam){ + float kernel_time = 0; + bool valid = false; + + bool is_kernel_valid = true; + if(is_kernel_valid){ + hipEvent_t start, stop; + HIP_CALL(hipMemset(dst_gpu, 0, dim_0*dim_1*dim_2*dim_3*size_byte)); + + for(int i=0; i< warmup; i++){ + launch_gpu_tensor_reorder(transpose_kparam); + } + + HIP_CALL(hipEventCreate(&start)); + HIP_CALL(hipEventCreate(&stop)); + HIP_CALL(hipDeviceSynchronize()); + HIP_CALL(hipEventRecord(start, 0) ); + + for(int i=0; i< repeat; i++){ + launch_gpu_tensor_reorder(transpose_kparam); + } + HIP_CALL(hipEventRecord(stop, 0) ); + HIP_CALL(hipEventSynchronize(stop) ); + HIP_CALL(hipEventElapsedTime(&kernel_time, start, stop) ); + HIP_CALL(hipEventDestroy(start) ); + HIP_CALL(hipEventDestroy(stop) ); + kernel_time = kernel_time / repeat; + + launch_cpu_tensor_reorder(); + + HIP_CALL(hipMemcpy(dst_gpu_valid, dst_gpu, dim_0*dim_1*dim_2*dim_3*size_byte, hipMemcpyDeviceToHost)); + + valid = valid_vector_binary(reinterpret_cast(dst_cpu), reinterpret_cast(dst_gpu_valid), dim_0*dim_1*dim_2*dim_3*size_byte); + } + + double flop_cnt = 2 * dim_0*dim_1*dim_2*dim_3*size_byte; + double bw = is_kernel_valid ? flop_cnt / kernel_time / 1e6 : 0; + + printf("[tensor_reorder fp%s] tensor_size:(%lu, %lu, %lu, %lu), flop:%.0f, time:%fms, bw:%.4fGB/s, valid:%s (256x%d)\n", + fp_str.c_str(), dim_0, dim_1, dim_2, dim_3, flop_cnt, kernel_time, bw, is_kernel_valid ? (valid ? "y" : "n") : "x", + transpose_kparam->tile_x); + fflush(stdout); + + return valid && is_kernel_valid ? kernel_time : FLT_MAX; + }; + + auto get_transpose_all_kernel = [&](){ + if(fp_str == "32") + return transpose_kernel_get_all_param_t<4>::get(); + else if(fp_str == "16") + return transpose_kernel_get_all_param_t<2>::get(); + else if(fp_str == "8") + return transpose_kernel_get_all_param_t<1>::get(); + else + assert(false); + }; + + auto get_tensor_reorder_all_kernel = [&](){ + if(fp_str == "32") + return tensor_reorder_kernel_get_all_param_t<4>::get(); + else if(fp_str == "16") + return tensor_reorder_kernel_get_all_param_t<2>::get(); + else if(fp_str == "8") + return tensor_reorder_kernel_get_all_param_t<1>::get(); + else + assert(false); + }; + + batched = false; + launch_gpu_init(); + float min_tensor_reorder_time = FLT_MAX; + transpose_kernel_param_t min_tensor_reorder_kparam; + if(batched){ + for(auto kparam : get_transpose_all_kernel()){ + float current_time = test_batched_transpose(&kparam); + if(current_time < min_tensor_reorder_time){ + min_tensor_reorder_time = current_time; + min_tensor_reorder_kparam = kparam; + } + } + printf("-> min time:%fms, kparam: %dx%d, %dx%d, %dx%d\n", min_tensor_reorder_time, + min_tensor_reorder_kparam.tile_x, min_tensor_reorder_kparam.tile_y, min_tensor_reorder_kparam.pack_x, min_tensor_reorder_kparam.pack_y, min_tensor_reorder_kparam.ediv_x, min_tensor_reorder_kparam.ediv_y); + fflush(stdout); + printf("-------------------------\n"); + } + else{ + for(auto kparam : get_tensor_reorder_all_kernel()){ + float current_time = test_general_tensor_reorder(&kparam); + if(current_time < min_tensor_reorder_time){ + min_tensor_reorder_time = current_time; + min_tensor_reorder_kparam = kparam; + } + } + printf("-> min time:%fms, kparam: 256x%d\n", min_tensor_reorder_time, min_tensor_reorder_kparam.tile_x); + fflush(stdout); + printf("-------------------------\n"); + } +}); + + free(src_cpu); + free(dst_cpu); + free(dst_gpu_valid); +} From 60d45644fa59fac9c7f08e0746cab74f34e4f3cc Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 26 Jan 2022 23:39:30 +0800 Subject: [PATCH 02/77] add all files --- src/hip/tensor_reorder_sol.cpp | 287 +++++++ src/include/miopen/tensor_reorder_sol.hpp | 74 ++ .../general_tensor_reorder.cpp | 745 +++++++++++++++++ src/kernels/gpu_tensor_reorder/sequence.hpp | 46 ++ test/tensor_reorder.cpp | 762 ++++++++---------- 5 files changed, 1484 insertions(+), 430 deletions(-) create mode 100644 src/hip/tensor_reorder_sol.cpp create mode 100644 src/include/miopen/tensor_reorder_sol.hpp create mode 100644 src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp create mode 100644 src/kernels/gpu_tensor_reorder/sequence.hpp diff --git a/src/hip/tensor_reorder_sol.cpp b/src/hip/tensor_reorder_sol.cpp new file mode 100644 index 0000000000..abf6854b98 --- /dev/null +++ b/src/hip/tensor_reorder_sol.cpp @@ -0,0 +1,287 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define TENSOR_REORDER_BLOCK_SIZE 256 +#define TENSOR_REORDER_PERSISTENT 0 + +#if TENSOR_REORDER_PERSISTENT +#define TENSOR_REORDER_OCCUPANCY 4 +#endif + +namespace miopen { +namespace tensor_reorder { + +static inline std::string GetNameTrait(std::size_t type_size) +{ + if(type_size == 1) + return "byte"; + if(type_size == 2) + return "half"; + if(type_size == 4) + return "dword"; + MIOPEN_THROW("data type not supported"); +} + +static inline const std::vector& GetKernelList(std::size_t data_size) +{ + if(data_size == 1) + { + static const std::vector byte_kernel_list{ + // clang-format off + {1, 256, 1, 1, 1, 1}, + {2, 256, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {8, 256, 1, 1, 1, 1}, + {16, 256, 1, 1, 1, 1}, + // clang-format on + }; + return byte_kernel_list; + } + if(data_size == 2) + { + static const std::vector half_kernel_list{ + // clang-format off + {1, 256, 1, 1, 1, 1}, + {2, 256, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {8, 256, 1, 1, 1, 1}, + {16, 256, 1, 1, 1, 1}, + // clang-format on + }; + return half_kernel_list; + } + if(data_size == 4) + { + static const std::vector dword_kernel_list{ + // clang-format off + {1, 256, 1, 1, 1, 1}, + {2, 256, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {8, 256, 1, 1, 1, 1}, + {16, 256, 1, 1, 1, 1}, + // clang-format on + }; + return dword_kernel_list; + } + MIOPEN_THROW("data type not supported"); +} + +static inline bool IsApplicable(uint32_t /* batch */, + uint32_t height, + uint32_t width, + const TensorReorderParam* kparam) +{ + return width % kparam->ediv_x == 0 && height % kparam->ediv_y == 0; +} + +static inline bool IsSameSide(uint32_t height, uint32_t width, const TensorReorderParam* kparam) +{ + float radio = 0; + if(width > height) + radio = static_cast(kparam->tile_x) / kparam->tile_y; + else + radio = static_cast(kparam->tile_y) / kparam->tile_x; + + // E.g. for cases like width=1000, height=10 + // Allow at least 32x64, 64x64... 16x64 not allowed + return radio >= 0.4; +} + +template +static inline float GetNormalizedRadio(T x, T y) +{ + if(y > x) + return static_cast(y) / x; + return static_cast(x) / y; +} + +template +static inline std::string GetKernelName(std::size_t data_size, const TensorReorderParam* kparam) +{ + std::ostringstream kernel_name; + std::string type_trait = GetNameTrait(data_size); + kernel_name << "general_4d_reorder_" << kparam->tile_x << "x" << kparam->tile_y << "_"; + if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) + { + kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" + << kparam->ediv_x << "x" << kparam->ediv_y << "_"; + } + kernel_name << type_trait<<"_r"<tile_y - 1) / kparam->tile_y) * kparam->tile_y; + uint32_t padded_w = ((width + kparam->tile_x - 1) / kparam->tile_x) * kparam->tile_x; + return static_cast(padded_h) * padded_w - static_cast(height) * width; +} + +static inline TensorReorderParam +HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) +{ + /* + * TODO: + * Design a algorithm to determine general tensor reorder tile size. + */ + + if(dim_3 >= 1 ) + { + if(dim_3 >= 16) + { + return TensorReorderParam{16, 256, 1, 1, 1, 1}; + } + else if(dim_3 >= 8) + { + return TensorReorderParam{8, 256, 1, 1, 1, 1}; + } + else if(dim_3 >= 4) + { + return TensorReorderParam{4, 256, 1, 1, 1, 1}; + } + else if(dim_3 >= 2) + { + return TensorReorderParam{2, 256, 1, 1, 1, 1}; + } + else + { + return TensorReorderParam{1, 256, 1, 1, 1, 1}; + } + } +} + +} // namespace tensor_reorder + +TensorReorderSolution::TensorReorderSolution(const ExecutionContext& ctx, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_) +{ + if(data_type == miopenInt8x4 || data_type == miopenDouble) + MIOPEN_THROW("These data type are not supported"); + num_cu = ctx.GetStream().GetMaxComputeUnits(); + std::size_t data_size = miopen::GetTypeSize(data_type); + kernel_param_heuristic = tensor_reorder::HeuristicGet(data_size, dim_0, dim_1, dim_2, dim_3); +} + +solver::KernelInfo TensorReorderSolution::GetKernel() const +{ + std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; +#if TENSOR_REORDER_PERSISTENT + std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; +#else + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); + std::size_t grid_size = dim_total; +#endif + std::string kernel_name = GetKernelName(); + solver::KernelInfo kernel; + kernel.kernel_file = "tensor_reorder.cpp"; + kernel.kernel_name = kernel_name; + kernel.g_wk.clear(); + kernel.g_wk.push_back(grid_size * block_size); + kernel.g_wk.push_back(1); + kernel.g_wk.push_back(1); + kernel.l_wk.clear(); + kernel.l_wk.push_back(block_size); + kernel.l_wk.push_back(1); + kernel.l_wk.push_back(1); + + MIOPEN_LOG_I2("TensorReorderSolution use kernel: " + kernel_name); + + return kernel; +} + +std::vector TensorReorderSolution::GetKernelArg() const +{ + std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); +#if TENSOR_REORDER_PERSISTENT + std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; +#else + std::size_t grid_size = dim_total; +#endif + + magic_div_u32_t magic_stride0 = magic_div_u32_gen(dim_1 * dim_2 * dim_3); + magic_div_u32_t magic_stride1 = magic_div_u32_gen(dim_2 * dim_3); + magic_div_u32_t magic_stride2 = magic_div_u32_gen(dim_3); + + std::vector opArgs; + opArgs.emplace_back(0); // placeholder + opArgs.emplace_back(0); // placeholder + opArgs.emplace_back(dim_0); + opArgs.emplace_back(dim_1); + opArgs.emplace_back(dim_2); + opArgs.emplace_back(dim_3); + opArgs.emplace_back(static_cast(grid_size)); + opArgs.emplace_back(dim_total); + opArgs.emplace_back(magic_stride0.magic); + opArgs.emplace_back(static_cast(magic_stride0.shift)); + opArgs.emplace_back(magic_stride1.magic); + opArgs.emplace_back(static_cast(magic_stride1.shift)); + opArgs.emplace_back(magic_stride2.magic); + opArgs.emplace_back(static_cast(magic_stride2.shift)); + + return opArgs; +} + +template +std::string TensorReorderSolution::GetKernelName() const +{ + std::size_t data_size = miopen::GetTypeSize(data_type); + return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); +} + +bool TensorReorderSolution::IsSkippable() const +{ + // Disable the IsSkippable funciton + return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0 ; +} + +size_t TensorReorderSolution::GetSize() const +{ + return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; +} + +} // namespace miopen diff --git a/src/include/miopen/tensor_reorder_sol.hpp b/src/include/miopen/tensor_reorder_sol.hpp new file mode 100644 index 0000000000..4e777d97fb --- /dev/null +++ b/src/include/miopen/tensor_reorder_sol.hpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_TENSOR_REORDER_SOL_HPP +#define GUARD_MIOPEN_TENSOR_REORDER_SOL_HPP + +#include +#include +#include +#include +#include + +namespace miopen { + +struct TensorReorderParam +{ + int tile_x{0}; + int tile_y{0}; + int pack_x{0}; + int pack_y{0}; + int ediv_x{0}; + int ediv_y{0}; +}; + +template +struct TensorReorderSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_); + solver::KernelInfo GetKernel() const; + std::vector GetKernelArg() const; + std::string GetKernelName() const; + bool IsSkippable() const; + size_t GetSize() const; + + miopenDataType_t data_type; + uint32_t dim_0; + uint32_t dim_1; + uint32_t dim_2; + uint32_t dim_3; + int num_cu; + + TensorReorderParam kernel_param_heuristic; +}; + +} // namespace miopen + +#endif diff --git a/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp b/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp new file mode 100644 index 0000000000..93c09f10c2 --- /dev/null +++ b/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp @@ -0,0 +1,745 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include "sequence.hpp" + +#ifndef TENSOR_REORDER_OCCUPANCY +#define TENSOR_REORDER_OCCUPANCY 4 +#endif + +inline __device__ uint32_t magic_div_u32(const uint32_t& numer, + const uint32_t& magic, + const uint32_t& shift) +{ + uint32_t tmp = __umulhi(numer, magic); + return (tmp + numer) >> shift; +} + +template +inline __device__ void general_4d_reorder_1x256(T* dst, + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) +{ + constexpr auto dorder = dst_order{}; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t src_index =0, dst_index=0; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + for (uint32_t k = 0; k < 1; k++) + { + //unroll k block thread + src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total){ + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; + + i_dst[0] = i_src[dorder.at(0)]; + i_dst[1] = i_src[dorder.at(1)]; + i_dst[2] = i_src[dorder.at(2)]; + i_dst[3] = i_src[dorder.at(3)]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + dst[dst_index] = src[src_index]; + } + } + } +} + +template +inline __device__ void general_4d_reorder_2x256(T* dst, + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) +{ + constexpr auto dorder = dst_order{}; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t src_index =0, dst_index=0; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + for (uint32_t k = 0; k < 2; k++) + { + //unroll k block thread + src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total){ + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; + + i_dst[0] = i_src[dorder.at(0)]; + i_dst[1] = i_src[dorder.at(1)]; + i_dst[2] = i_src[dorder.at(2)]; + i_dst[3] = i_src[dorder.at(3)]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + dst[dst_index] = src[src_index]; + } + } + } +} + +template +inline __device__ void general_4d_reorder_4x256(T* dst, + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) +{ + constexpr auto dorder = dst_order{}; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t src_index =0, dst_index=0; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + for (uint32_t k = 0; k < 4; k++) + { + //unroll k block thread + src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total){ + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; + + i_dst[0] = i_src[dorder.at(0)]; + i_dst[1] = i_src[dorder.at(1)]; + i_dst[2] = i_src[dorder.at(2)]; + i_dst[3] = i_src[dorder.at(3)]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + dst[dst_index] = src[src_index]; + } + } + } +} + +template +inline __device__ void general_4d_reorder_8x256(T* dst, + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) +{ + constexpr auto dorder = dst_order{}; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t src_index =0, dst_index=0; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + for (uint32_t k = 0; k < 8; k++) + { + //unroll k block thread + src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total){ + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; + + i_dst[0] = i_src[dorder.at(0)]; + i_dst[1] = i_src[dorder.at(1)]; + i_dst[2] = i_src[dorder.at(2)]; + i_dst[3] = i_src[dorder.at(3)]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + dst[dst_index] = src[src_index]; + } + } + } +} + +template +inline __device__ void general_4d_reorder_16x256(T* dst, + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) +{ + constexpr auto dorder = dst_order{}; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t src_index =0, dst_index=0; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; + + for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) + { + for (uint32_t k = 0; k < 16; k++) + { + //unroll k block thread + src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total){ + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; + + i_dst[0] = i_src[dorder.at(0)]; + i_dst[1] = i_src[dorder.at(1)]; + i_dst[2] = i_src[dorder.at(2)]; + i_dst[3] = i_src[dorder.at(3)]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + dst[dst_index] = src[src_index]; + } + } + } +} + +#define DEFINE_GENERAL_4D_REORDER_KERNEL( \ + tile_trait, dst_order, accept_data_type, cast_data_type, lb_threads_per_block, lb_blocks_per_cu) \ + extern "C" __global__ void __launch_bounds__(lb_threads_per_block, lb_blocks_per_cu) \ + general_4d_reorder_##tile_trait##_##accept_data_type##_##dst_order(void* dst, \ + void* src, \ + uint32_t dim_0, \ + uint32_t dim_1, \ + uint32_t dim_2, \ + uint32_t dim_3, \ + uint32_t dim_stride, \ + uint32_t dim_total, \ + uint32_t magic_stride0, \ + uint32_t shift_stride0, \ + uint32_t magic_stride1, \ + uint32_t shift_stride1, \ + uint32_t magic_stride2, \ + uint32_t shift_stride2) \ + { \ + general_4d_reorder_##tile_trait(reinterpret_cast(dst), \ + reinterpret_cast(src), \ + dim_0, \ + dim_1, \ + dim_2, \ + dim_3, \ + dim_stride, \ + dim_total, \ + magic_stride0, \ + shift_stride0, \ + magic_stride1, \ + shift_stride1, \ + magic_stride2, \ + shift_stride2); \ + } +//default order is 0 1 2 3 +using r0132 = sequence<0, 1, 3, 2>; +using r0213 = sequence<0, 2, 1, 3>;//nhwc2nchwc +using r0231 = sequence<0, 2, 3, 1>;//nchw2nchwc +using r0312 = sequence<0, 3, 1, 2>;//nhwc2nchw +using r0321 = sequence<0, 3, 2, 1>; +using r1023 = sequence<1, 0, 2, 3>; +using r1032 = sequence<1, 0, 3, 2>; +using r1203 = sequence<1, 2, 0, 3>; +using r1230 = sequence<1, 2, 3, 0>; +using r1302 = sequence<1, 3, 0, 2>;//nchw2chwnc +using r1320 = sequence<1, 3, 2, 0>; +using r2013 = sequence<2, 0, 1, 3>; +using r2031 = sequence<2, 0, 3, 1>; +using r2103 = sequence<2, 1, 0, 3>;//nhwc2chwnc +using r2130 = sequence<2, 1, 3, 0>; +using r2301 = sequence<2, 3, 0, 1>; +using r2310 = sequence<2, 3, 1, 0>; +using r3012 = sequence<3, 0, 1, 2>; +using r3021 = sequence<3, 0, 2, 1>; +using r3102 = sequence<3, 1, 0, 2>; +using r3120 = sequence<3, 1, 2, 0>; +using r3201 = sequence<3, 2, 0, 1>; +using r3210 = sequence<3, 2, 1, 0>; + +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0231, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0312, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0321, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1023, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1032, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1203, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1230, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1302, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1320, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2013, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2031, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2103, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2130, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2301, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2310, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3012, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3021, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3102, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3120, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3201, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3210, dword, float, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0132, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0213, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0231, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0312, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0321, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1023, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1032, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1203, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1230, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1302, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1320, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2013, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2031, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2103, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2130, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2301, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2310, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3012, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3021, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3102, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3120, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3201, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3210, dword, float, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0132, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0213, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0231, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0312, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0321, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1023, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1032, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1203, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1230, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1302, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1320, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2013, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2031, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2103, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2130, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2301, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2310, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3012, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3021, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3102, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3120, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3201, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3210, dword, float, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0132, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0213, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0231, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0312, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0321, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1023, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1032, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1203, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1230, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1302, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1320, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2013, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2031, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2103, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2130, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2301, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2310, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3012, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3021, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3102, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3120, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3201, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3210, dword, float, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0132, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0213, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0231, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0312, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0321, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1023, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1032, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1203, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1230, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1302, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1320, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2013, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2031, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2103, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2130, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2301, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2310, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3012, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3021, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3102, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3120, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3201, dword, float, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, dword, float, 256, TENSOR_REORDER_OCCUPANCY) + + +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0231, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0312, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0321, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1023, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1032, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1203, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1230, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1302, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1320, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2013, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2031, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2103, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2130, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2301, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2310, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3012, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3021, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3102, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3120, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3201, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3210, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0132, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0213, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0231, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0312, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0321, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1023, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1032, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1203, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1230, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1302, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1320, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2013, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2031, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2103, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2130, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2301, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2310, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3012, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3021, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3102, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3120, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3201, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3210, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0132, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0213, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0231, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0312, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0321, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1023, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1032, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1203, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1230, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1302, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1320, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2013, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2031, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2103, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2130, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2301, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2310, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3012, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3021, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3102, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3120, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3201, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3210, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0132, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0213, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0231, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0312, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0321, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1023, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1032, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1203, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1230, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1302, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1320, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2013, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2031, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2103, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2130, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2301, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2310, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3012, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3021, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3102, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3120, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3201, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3210, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0132, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0213, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0231, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0312, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0321, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1023, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1032, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1203, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1230, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1302, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1320, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2013, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2031, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2103, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2130, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2301, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2310, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3012, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3021, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3102, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3120, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3201, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) + + +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0231, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0312, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0321, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1023, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1032, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1203, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1230, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1302, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1320, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2013, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2031, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2103, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2130, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2301, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2310, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3012, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3021, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3102, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3120, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3201, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3210, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0132, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0213, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0231, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0312, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0321, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1023, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1032, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1203, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1230, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1302, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1320, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2013, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2031, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2103, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2130, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2301, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2310, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3012, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3021, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3102, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3120, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3201, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3210, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0132, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0213, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0231, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0312, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0321, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1023, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1032, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1203, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1230, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1302, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1320, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2013, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2031, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2103, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2130, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2301, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2310, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3012, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3021, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3102, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3120, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3201, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3210, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0132, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0213, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0231, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0312, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0321, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1023, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1032, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1203, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1230, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1302, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1320, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2013, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2031, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2103, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2130, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2301, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2310, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3012, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3021, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3102, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3120, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3201, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3210, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0132, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0213, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0231, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0312, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0321, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1023, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1032, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1203, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1230, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1302, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1320, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2013, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2031, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2103, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2130, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2301, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2310, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3012, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3021, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3102, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3120, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3201, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) \ No newline at end of file diff --git a/src/kernels/gpu_tensor_reorder/sequence.hpp b/src/kernels/gpu_tensor_reorder/sequence.hpp new file mode 100644 index 0000000000..cf83e64012 --- /dev/null +++ b/src/kernels/gpu_tensor_reorder/sequence.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef SEQUENCE_HPP +#define SEQUENCE_HPP + +template +struct sequence +{ + static constexpr int m_size = sizeof...(Is); + + __host__ __device__ static constexpr auto size() { return m_size; } + + __host__ __device__ static constexpr auto get_size() { return size(); } + + __host__ __device__ static constexpr int at(int I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const int m_data[m_size + 1] = {Is..., 0}; + return m_data[I]; + } + +}; +#endif \ No newline at end of file diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index f2e5aeaac8..afb57cba3d 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -23,197 +23,33 @@ * SOFTWARE. * *******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "gpu_tensor_reorder.h" +#include +#include +#include "test.hpp" +#include "driver.hpp" +#include "random.hpp" #include "sequence.hpp" -#ifndef HIP_CALL -#define HIP_CALL(call) \ - do { \ - hipError_t err = call; \ - if (err != hipSuccess) { \ - printf("[hiperror](%d) fail to call %s,(%s)\n", (int)err, #call, \ - hipGetErrorString(err)); \ - exit(1); \ - } \ - } while (0) -#endif - -static inline int env_get_int(const char *var_name, int default_int) { - char *v = getenv(var_name); - int r = default_int; - if (v) - r = atoi(v); - return r; -} - -static int gen_rand_integer() +template <> +struct miopen_type : std::integral_constant { - static int inited = 0; - if(inited == 0) - { - std::srand(std::time(nullptr)); - inited = 1; - } - return std::rand(); -} - - -static inline char *env_get_str(char *var_name, char* default_str) { - char *v = getenv(var_name); - if (v) - return v; - return default_str; -} - -template -struct distribution_t{ }; template <> -struct distribution_t{ - distribution_t(int min, int max) : distribution(min, max) {} - template - int8_t operator()(URNG & rng){ - int value = distribution(rng); - return *reinterpret_cast(&value); - //return 0xf; - } - std::uniform_int_distribution distribution; -}; -template <> -struct distribution_t{ - distribution_t(int min, int max) : distribution(min, max) {} - template - int operator()(URNG & rng){ return distribution(rng);} - std::uniform_int_distribution distribution; -}; -template <> -struct distribution_t{ - distribution_t(float min, float max) : distribution(min, max) {} - template - float operator()(URNG & rng){ return distribution(rng);} - std::uniform_real_distribution distribution; -}; - -template -void block_wise_rand_generator(Dst_T *p, int tid, int block_size, int total_size, Src_T min, Src_T max, Src_T scale) -{ - std::mt19937 rng(std::chrono::system_clock::now() - .time_since_epoch() - .count() + - std::hash()(std::this_thread::get_id())); - distribution_t distribution(min,max); - for (int i = tid; i < total_size; i += block_size) { - p[i] = static_cast(scale * distribution(rng)); - } -} - -template -void gen_rand_vector(Dst_T *vec, size_t vec_size, Src_T fmin, Src_T fmax, Src_T scale = 1) { - int num_threads = std::thread::hardware_concurrency(); - if (num_threads < 4) - num_threads = 4; - // printf("total threads:%d\n",num_threads); - std::vector threads; - for (int t = 0; t < num_threads; t++) { - threads.push_back(std::thread(block_wise_rand_generator, - vec, t, num_threads, vec_size, fmin, fmax, scale)); - } - for (auto &th : threads) - th.join(); -} - -static inline bool valid_float(float p) +struct miopen_type : std::integral_constant { - return !(std::isnan(p) || std::isinf(p)); -} -#ifndef ABS -#define ABS(b) ((b) > 0 ? (b) : -1 * (b)) -#endif -static inline bool valid_vector(const float *ref, const float *pred, int n, - double nrms = 1.5e-6) { - double s0 = 0.0; - double s1 = 0.0; - int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); - int igemm_per_pixel_check_print = env_get_int("PER_PIXEL_CHECK_PRINT", 1); - int pp_err = 0; - - for (int i = 0; i < n; ++i) { - if(!(valid_float(ref[i]) && valid_float(pred[i]))){ - printf(" invalid float at %d, ref:%f, pred:%f\n", i, ref[i], pred[i]); - return -1; - } - double ri = (double)ref[i]; - double pi = (double)pred[i]; - double d = ri - pi; - double dd = d * d; - double rr = 2.0 * ri * ri; - s0 += dd; - s1 += rr; - if(igemm_per_pixel_check){ - double delta = ABS(ABS(ri - pi) / ri); - printf("[%d] ref:%lf, pred:%lf(0x%08x) [%s]\n", i, ri, pi, ((uint32_t *)pred)[i], delta > 3e-5? "N":"Y"); - if (delta > 3e-5) { - if(igemm_per_pixel_check_print){ - if (pp_err < 100) - printf("diff at %d, ref:%lf, pred:%lf(0x%08x), d:%lf\n", i, ri, - pi, ((uint32_t *)pred)[i], delta); - } - pp_err++; - } - - } - } - // printf("\nnrms:%lf, s0:%lf, s1:%lf, expected_nrms is %1f\n",sqrt(s0/s1),s0,s1,nrms); - fflush(stdout); - return (sqrt(s0 / s1) < nrms) -#ifdef PER_PIXEL_CHECK - && (pp_err == 0) -#endif - ; -} - -static inline bool valid_vector_binary(int8_t *ref, int8_t *pred, size_t bytes) { - int igemm_per_pixel_check = env_get_int("PER_PIXEL_CHECK", 0); - size_t err = 0; - for(int i = 0; i < bytes ; i++){ - // { - // uint32_t r = 0; - // uint32_t p = 0; - // memcpy(reinterpret_cast(&r), reinterpret_cast(&ref[i]), 1); - // memcpy(reinterpret_cast(&p), reinterpret_cast(&pred[i]), 1); - // printf("%7d, ref:0x%x, pred:0x%x, %s\n", i, r, p, r==p?"y":"n"); - // } - if(ref[i] != pred[i]){ - err ++; - if(igemm_per_pixel_check){ - uint32_t r = 0; - uint32_t p = 0; - memcpy(reinterpret_cast(&r), reinterpret_cast(&ref[i]), 1); - memcpy(reinterpret_cast(&p), reinterpret_cast(&pred[i]), 1); - printf("fail at %d, ref:0x%x, pred:0x%x\n", i, r, p); - } - } - } - return err == 0; -} +}; template @@ -232,7 +68,7 @@ void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64 dst_dim[3], 1 }; - uint64_t itr_src_dim[4] = {0, 0, 0, 0}; + uint64_t itr_src_dim[4] = {0, 0, 0, 0}; uint64_t itr_dst_dim[4] = {0, 0, 0, 0}; for(itr_src_dim[0] = 0; itr_src_dim[0] < src_dim[0]; itr_src_dim[0]++){ @@ -260,6 +96,128 @@ void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64 } } +template +struct cpu_reorder +{ + static void run(T* dst, T* src, uint64_t N, uint64_t C, uint64_t H, uint64_t W) + { + cpu_tensor_reorder(dst, src, N, C, H, W); + } +}; + +template +struct reorder_str +{ + static std::string get() { + return ("r" + itoa(dst_order::at(0)) + + itoa(dst_order::at(1)) + + itoa(dst_order::at(2)) + + itoa(dst_order::at(3)) ); + } +}; + +enum tensor_layout_t +{ + miopen_tensor_layout_nchw, + miopen_tensor_layout_ncdhw, + miopen_tensor_layout_nhwc, + miopen_tensor_layout_ndhwc, +}; + +std::string tensor_layout_to_string(tensor_layout_t layout) +{ + std::string layout_string("N/A"); + if(layout == miopen_tensor_layout_nchw) + layout_string = "NCHW"; + else if(layout == miopen_tensor_layout_ncdhw) + layout_string = "NCDHW"; + else if(layout == miopen_tensor_layout_nhwc) + layout_string = "NHWC"; + else if(layout == miopen_tensor_layout_ndhwc) + layout_string = "NDHWC"; + else + MIOPEN_THROW("Unsupported tensor layout"); + return layout_string; +} + + +template +struct to_miopen_data_type +{ +}; + +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenFloat; } +}; + +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenHalf; } // we actually didn't calculate 16bit float +}; + +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenInt8; } +}; + +#define RAND_INTEGER_MAX 120 +#define RAND_INTEGER_MIN -88 + +static int gen_rand_integer() +{ + // NOLINTNEXTLINE (cppcoreguidelines-avoid-non-const-global-variables) + static int inited = 0; + if(inited == 0) + { + std::srand(std::time(nullptr)); + inited = 1; + } + return GET_RAND(); +} + +template +void rand_tensor_integer(tensor& t, int max = RAND_INTEGER_MAX, int min = RAND_INTEGER_MIN) +{ + // use integer to random. + for(int i = 0; i < t.data.size(); i++) + t[i] = static_cast(gen_rand_integer() % (max - min) + min); +} + +template +bool compare_equal(T r1, T r2) +{ + return r1 == r2; +} + +template <> +bool compare_equal(float r1, float r2) +{ + return miopen::float_equal(r1, r2); +} + +template +bool verify_tensor(tensor& t_gpu, tensor& t_cpu) +{ + if(t_gpu.data.size() != t_cpu.data.size()) + { + MIOPEN_LOG_E("size not equal, should not happen"); + return false; + } + auto idx = miopen::mismatch_idx(t_gpu.data, t_cpu.data, compare_equal); + bool valid_result = idx >= miopen::range_distance(t_cpu); + + if(!valid_result) + { + std::cout << "diff at:" << idx << ", gpu:" << t_gpu[idx] << ", cpu:" << t_cpu[idx] + << std::endl; + } + return valid_result; +} + //compile time for_loop namespace detail { @@ -275,268 +233,212 @@ constexpr void loop(F&& f) { detail::loop(std::make_integer_sequence{}, std::forward(f)); } -#define WARMUP 3 -#define REPEAT 7 -#define BATCHED_TRANSPOSE_HSACO "out/batched_transpose.hsaco" -#define GENERAL_TENSOR_REORDER_HSACO "out/general_tensor_reorder.hsaco" +struct reorder_base +{ + miopenHandle_t handle{}; +#if MIOPEN_BACKEND_OPENCL + cl_command_queue q{}; +#endif -int main(int argc, char ** argv){ - if(argc < 5){ - printf("%s Please input tensor size in order of: DIM0, DIM1, DIM2, DIM3\n", argv[0]); - return -1; - } - if(argc > 5){ - printf("Too many argument\n"); - return -1; - } - int warmup = env_get_int("IGEMM_WARMUP", WARMUP); - int repeat = env_get_int("IGEMM_REPEAT", REPEAT); - const uint64_t dim_0 = std::stoull(std::string(argv[1])); - const uint64_t dim_1 = std::stoull(std::string(argv[2])); - const uint64_t dim_2 = std::stoull(std::string(argv[3])); - const uint64_t dim_3 = std::stoull(std::string(argv[4])); - - size_t size_byte = 4; - const char* fp = env_get_str("FP", "32"); - std::string fp_str(fp); - if(fp_str == "32") - size_byte = 4; - else if(fp_str == "16") - size_byte = 2; - else if(fp_str == "8") - size_byte = 1; - else{ - printf("error FP:%s\n", fp); - return -1; + reorder_base() + { + miopenCreate(&handle); +#if MIOPEN_BACKEND_OPENCL + miopenGetStream(handle, &q); +#endif } + ~reorder_base() { miopenDestroy(handle); } - bool batched = false; - bool is_kernel_valid = false; - const char* hsaco; - void * src_cpu = malloc(dim_0*dim_1*dim_2*dim_3*size_byte); - void * dst_cpu = malloc(dim_0*dim_1*dim_2*dim_3*size_byte); - void * dst_gpu_valid = malloc(dim_0*dim_1*dim_2*dim_3*size_byte); + static std::vector get_dim_3_size() { return {1, 9, 14}; } + static std::vector get_dim_2_size() { return {1, 9, 14}; } + static std::vector get_dim_1_size() { return {3, 8, 14}; } + static std::vector get_dim_0_size() { return {1, 2}; } - void * src_gpu; - void * dst_gpu; - - HIP_CALL(hipMalloc(&src_gpu, dim_0*dim_1*dim_2*dim_3*size_byte)); - HIP_CALL(hipMalloc(&dst_gpu, dim_0*dim_1*dim_2*dim_3*size_byte)); - - gen_rand_vector(reinterpret_cast(src_cpu), dim_0*dim_1*dim_2*dim_3*size_byte, -116, 121); - HIP_CALL(hipMemcpy(src_gpu, src_cpu, dim_0*dim_1*dim_2*dim_3*size_byte, hipMemcpyHostToDevice)); - -loop([&](auto i) { - constexpr int all_possible_sequence[23][4] = { - {0, 1, 3, 2}, {2, 3, 0, 1}, {3, 0, 1, 2}, {0, 2, 3, 1}, {0, 3, 1, 2}, //BATCHED TRANSPOSE - {0, 2, 1, 3}, {0, 3, 2, 1}, - {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, - {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 1, 0}, - {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; - using dst_order = sequence; - std::cout <<" Tensor reorder to ("<< dst_order::at(0)<<","<< dst_order::at(1)<<","<< dst_order::at(2)<<","<< dst_order::at(3)<<")" << std::endl; - - //TODO: an API with more privacy - auto launch_gpu_init = [&](){ - if((dst_order::at(0)==0 && dst_order::at(1)==1 && dst_order::at(2)==3 && dst_order::at(3)==2) || - (dst_order::at(0)==0 && dst_order::at(1)==2 && dst_order::at(2)==3 && dst_order::at(3)==1) || - (dst_order::at(0)==0 && dst_order::at(1)==3 && dst_order::at(2)==1 && dst_order::at(3)==2) || - (dst_order::at(0)==3 && dst_order::at(1)==0 && dst_order::at(2)==1 && dst_order::at(3)==2) || - (dst_order::at(0)==2 && dst_order::at(1)==3 && dst_order::at(2)==0 && dst_order::at(3)==1) - ){ - printf("choose batched transpose kernel\n"); - batched = true; - //batched transpose. NCHW <----> NHWC, (NC)cHW <----> (NC)HWc - hsaco = env_get_str("BATCHED_TRANSPOSE", BATCHED_TRANSPOSE_HSACO); - gpu_nhwc_nchw_transpose_init(hsaco); - } - else { - printf("choose general tensor reorder kernel\n"); - hsaco = env_get_str("GENERAL_TENSOR_REORDER_HSACO", GENERAL_TENSOR_REORDER_HSACO); - gpu_tensor_reorder_init(hsaco); - } - }; - - auto launch_gpu_tensor_reorder = [&](const transpose_kernel_param_t * kparam){ - if(fp_str == "32") - gpu_tensor_reorder(reinterpret_cast (dst_gpu), reinterpret_cast (src_gpu), dim_0, dim_1, dim_2, dim_3, kparam); - else if(fp_str == "16") - gpu_tensor_reorder(reinterpret_cast(dst_gpu), reinterpret_cast(src_gpu), dim_0, dim_1, dim_2, dim_3, kparam); - else if(fp_str == "8") - gpu_tensor_reorder(reinterpret_cast(dst_gpu), reinterpret_cast(src_gpu), dim_0, dim_1, dim_2, dim_3, kparam); - }; - - auto launch_cpu_tensor_reorder = [&](){ - if(fp_str == "32") - cpu_tensor_reorder(reinterpret_cast (dst_cpu), reinterpret_cast (src_cpu), dim_0, dim_1, dim_2, dim_3); - else if(fp_str == "16") - cpu_tensor_reorder(reinterpret_cast(dst_cpu), reinterpret_cast(src_cpu), dim_0, dim_1, dim_2, dim_3); - else if(fp_str == "8") - cpu_tensor_reorder(reinterpret_cast(dst_cpu), reinterpret_cast(src_cpu), dim_0, dim_1, dim_2, dim_3); - }; - - auto test_batched_transpose = [&](const transpose_kernel_param_t *transpose_kparam){ - float kernel_time = 0; - bool valid = false; - bool is_kernel_valid = false; - - if(dst_order::at(0)==0 && dst_order::at(1)==2 && dst_order::at(2)==3 && dst_order::at(3)==1){ - is_kernel_valid = transpose_kernel_is_valid(dim_0, dim_1, dim_2 * dim_3, transpose_kparam); - } - else if(dst_order::at(0)==0 && dst_order::at(1)==1 && dst_order::at(2)==3 && dst_order::at(3)==2){ - is_kernel_valid = transpose_kernel_is_valid(dim_0 * dim_1, dim_2, dim_3, transpose_kparam); - } - else if(dst_order::at(0)==0 && dst_order::at(1)==3 && dst_order::at(2)==1 && dst_order::at(3)==2){ - is_kernel_valid = transpose_kernel_is_valid(dim_0, dim_1 * dim_2, dim_3, transpose_kparam); - } - else if(dst_order::at(0)==3 && dst_order::at(1)==0 && dst_order::at(2)==1 && dst_order::at(3)==2){ - is_kernel_valid = transpose_kernel_is_valid(1, dim_0 * dim_1 * dim_2, dim_3, transpose_kparam); - } - //dst_order::at(0)==2 && dst_order::at(1)==3 && dst_order::at(2)==0 && dst_order::at(3)==1 - else{ - is_kernel_valid = transpose_kernel_is_valid(1, dim_0 * dim_1, dim_2 * dim_3, transpose_kparam); - } - if(is_kernel_valid){ - hipEvent_t start, stop; - HIP_CALL(hipMemset(dst_gpu, 0, dim_0*dim_1*dim_2*dim_3*size_byte)); - - for(int i=0; i< warmup; i++){ - launch_gpu_tensor_reorder(transpose_kparam); - } - - HIP_CALL(hipEventCreate(&start)); - HIP_CALL(hipEventCreate(&stop)); - HIP_CALL(hipDeviceSynchronize()); - HIP_CALL(hipEventRecord(start, 0) ); - - for(int i=0; i< repeat; i++){ - launch_gpu_tensor_reorder(transpose_kparam); + template + void iterate_reorder(F f) + { + std::vector dim_3_list = get_dim_3_size(); + std::vector dim_2_list = get_dim_2_size(); + std::vector dim_1_list = get_dim_1_size(); + std::vector dim_0_list = get_dim_0_size(); + + dim_3_list.push_back(gen_rand_integer() % 13 + 29); + dim_2_list.push_back(gen_rand_integer() % 13 + 29); + dim_1_list.push_back(gen_rand_integer() % 13 + 15); + dim_0_list.push_back(gen_rand_integer() % 4 + 3); + + for(uint32_t dim_3 : dim_3_list) + { + for(uint32_t dim_2 : dim_2_list) + { + for(uint32_t dim_1 : dim_1_list) + { + for(uint32_t dim_0 : dim_0_list) + { + f(dim_0, dim_1, dim_2, dim_3); + } + } } - HIP_CALL(hipEventRecord(stop, 0) ); - HIP_CALL(hipEventSynchronize(stop) ); - HIP_CALL(hipEventElapsedTime(&kernel_time, start, stop) ); - HIP_CALL(hipEventDestroy(start) ); - HIP_CALL(hipEventDestroy(stop) ); - kernel_time = kernel_time / repeat; - - launch_cpu_tensor_reorder(); - - HIP_CALL(hipMemcpy(dst_gpu_valid, dst_gpu, dim_0*dim_1*dim_2*dim_3*size_byte, hipMemcpyDeviceToHost)); - - valid = valid_vector_binary(reinterpret_cast(dst_cpu), reinterpret_cast(dst_gpu_valid), dim_0*dim_1*dim_2*dim_3*size_byte); } + } +}; - double flop_cnt = 2 * dim_0*dim_1*dim_2*dim_3*size_byte; - double bw = is_kernel_valid ? flop_cnt / kernel_time / 1e6 : 0; - - printf("[tensor_reorder fp%s] tensor_size:(%lu, %lu, %lu, %lu), flop:%.0f, time:%fms, bw:%.4fGB/s, valid:%s (%dx%d, %dx%d, %dx%d)\n", - fp_str.c_str(), dim_0, dim_1, dim_2, dim_3, flop_cnt, kernel_time, bw, is_kernel_valid ? (valid ? "y" : "n") : "x", - transpose_kparam->tile_x, transpose_kparam->tile_y, transpose_kparam->pack_x, transpose_kparam->pack_y, transpose_kparam->ediv_x, transpose_kparam->ediv_y); - fflush(stdout); +struct reorder_invoke_param : public miopen::InvokeParams +{ + ConstData_t src = nullptr; + Data_t dst = nullptr; - return valid && is_kernel_valid ? kernel_time : FLT_MAX; - }; + reorder_invoke_param(ConstData_t src_, Data_t dst_) : src(src_), dst(dst_) {} + reorder_invoke_param(miopen::InvokeType type_, ConstData_t src_, Data_t dst_) + : InvokeParams{type_}, src(src_), dst(dst_) + { + } +}; - auto test_general_tensor_reorder = [&](const transpose_kernel_param_t *transpose_kparam){ - float kernel_time = 0; - bool valid = false; +template +struct reorder_test : reorder__base +{ + void run() + { + auto run_reorder = [this](uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) { + int tensor_sz = dim_0 * dim_1 * dim_2 * dim_3; + std::vector tensor_len({static_cast(dim_0), + static_cast(dim_1), + static_cast(dim_2), + static_cast(dim_3)}); + + std::vector tensor_strides; + + std::string layout_default = miopen::tensor_layout_get_default(4); + std::string layout_string = tensor_layout_to_string(miopen_tensor_layout_nchw); + + miopen::tensor_layout_to_strides( + tensor_len, layout_default, layout_string, tensor_strides); + + tensor t_src(tensor_len, tensor_strides); + tensor t_dst(tensor_len, tensor_strides); + tensor t_dst_gpu(tensor_len, tensor_strides); + rand_tensor_integer(t_src); +#if MIOPEN_BACKEND_OPENCL + cl_context cl_ctx; + clGetCommandQueueInfo(q, CL_QUEUE_CONTEXT, sizeof(cl_context), &cl_ctx, nullptr); + cl_int status = CL_SUCCESS; + cl_mem src_dev = + clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, &status); + cl_mem dst_dev = + clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, nullptr); + status |= clEnqueueWriteBuffer(q, + src_dev, + CL_TRUE, + 0, + sizeof(T) * tensor_sz, + t_src.data.data(), + 0, + nullptr, + nullptr); + EXPECT(status == CL_SUCCESS); +#elif MIOPEN_BACKEND_HIP + void* src_dev; + void* dst_dev; + EXPECT(hipMalloc(&src_dev, sizeof(T) * tensor_sz) == hipSuccess); + EXPECT(hipMalloc(&dst_dev, sizeof(T) * tensor_sz) == hipSuccess); + EXPECT(hipMemcpy( + src_dev, t_src.data.data(), sizeof(T) * tensor_sz, hipMemcpyHostToDevice) == + hipSuccess); +#endif - bool is_kernel_valid = true; - if(is_kernel_valid){ - hipEvent_t start, stop; - HIP_CALL(hipMemset(dst_gpu, 0, dim_0*dim_1*dim_2*dim_3*size_byte)); + const auto invoke_param = reorder_invoke_param{ + DataCast(static_cast(src_dev)), DataCast(dst_dev)}; + + miopen::ExecutionContext ctx; + ctx.SetStream(&miopen::deref(this->handle)); + ctx.DetectRocm(); + // ctx.SetupFloats(); + + REORDER_SOL reorder_sol(ctx, to_miopen_data_type::get(), dim_0, dim_1, dim_2, dim_3); + + std::vector opArgs = reorder_sol.GetKernelArg(); + + boost::optional invoker_factory( + [=](const std::vector& kernels) mutable { + return [=](const miopen::Handle& handle, + const miopen::AnyInvokeParams& primitive_param) mutable { + decltype(auto) invoke_params = + primitive_param.CastTo(); + + const auto k = handle.Run(kernels[0]); + + opArgs[0] = OpKernelArg(invoke_params.dst); + opArgs[1] = OpKernelArg(invoke_params.src); + + k(opArgs); + }; + }); + + std::vector construction_params{reorder_sol.GetKernel()}; + + const auto invoker = + miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); + + // run gpu + invoker(miopen::deref(this->handle), invoke_param); + + // run cpu + cpu_reorder::run(t_dst.data.data(), t_src.data.data(), dim_0, dim_1, dim_2, dim_3); + +#if MIOPEN_BACKEND_OPENCL + status = clEnqueueReadBuffer(q, + dst_dev, + CL_TRUE, + 0, + sizeof(T) * tensor_sz, + t_dst_gpu.data.data(), + 0, + nullptr, + nullptr); + EXPECT(status == CL_SUCCESS); +#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(t_dst_gpu.data.data(), + dst_dev, + sizeof(T) * tensor_sz, + hipMemcpyDeviceToHost) == hipSuccess); +#endif - for(int i=0; i< warmup; i++){ - launch_gpu_tensor_reorder(transpose_kparam); - } + // we expect excact match, since use integer + bool valid_result = verify_tensor(t_dst_gpu, t_dst); - HIP_CALL(hipEventCreate(&start)); - HIP_CALL(hipEventCreate(&stop)); - HIP_CALL(hipDeviceSynchronize()); - HIP_CALL(hipEventRecord(start, 0) ); + std::cout << "[" << reorder_str::get() << ", b" << (sizeof(T) * 8) + << " ] " + << "dim_0:" << dim_0 << ", dim_1:" << dim_1 << ", dim_2:" << dim_2 << ", dim_3:" << dim_3 + << ", valid:" << valid_result << std::endl; - for(int i=0; i< repeat; i++){ - launch_gpu_tensor_reorder(transpose_kparam); - } - HIP_CALL(hipEventRecord(stop, 0) ); - HIP_CALL(hipEventSynchronize(stop) ); - HIP_CALL(hipEventElapsedTime(&kernel_time, start, stop) ); - HIP_CALL(hipEventDestroy(start) ); - HIP_CALL(hipEventDestroy(stop) ); - kernel_time = kernel_time / repeat; + EXPECT(valid_result == true); - launch_cpu_tensor_reorder(); +#if MIOPEN_BACKEND_OPENCL + clReleaseMemObject(src_dev); + clReleaseMemObject(dst_dev); +#elif MIOPEN_BACKEND_HIP + hipFree(src_dev); + hipFree(dst_dev); +#endif + }; - HIP_CALL(hipMemcpy(dst_gpu_valid, dst_gpu, dim_0*dim_1*dim_2*dim_3*size_byte, hipMemcpyDeviceToHost)); + iterate_reorder(run_reorder); + } +}; - valid = valid_vector_binary(reinterpret_cast(dst_cpu), reinterpret_cast(dst_gpu_valid), dim_0*dim_1*dim_2*dim_3*size_byte); - } - double flop_cnt = 2 * dim_0*dim_1*dim_2*dim_3*size_byte; - double bw = is_kernel_valid ? flop_cnt / kernel_time / 1e6 : 0; - - printf("[tensor_reorder fp%s] tensor_size:(%lu, %lu, %lu, %lu), flop:%.0f, time:%fms, bw:%.4fGB/s, valid:%s (256x%d)\n", - fp_str.c_str(), dim_0, dim_1, dim_2, dim_3, flop_cnt, kernel_time, bw, is_kernel_valid ? (valid ? "y" : "n") : "x", - transpose_kparam->tile_x); - fflush(stdout); - - return valid && is_kernel_valid ? kernel_time : FLT_MAX; - }; - - auto get_transpose_all_kernel = [&](){ - if(fp_str == "32") - return transpose_kernel_get_all_param_t<4>::get(); - else if(fp_str == "16") - return transpose_kernel_get_all_param_t<2>::get(); - else if(fp_str == "8") - return transpose_kernel_get_all_param_t<1>::get(); - else - assert(false); - }; - - auto get_tensor_reorder_all_kernel = [&](){ - if(fp_str == "32") - return tensor_reorder_kernel_get_all_param_t<4>::get(); - else if(fp_str == "16") - return tensor_reorder_kernel_get_all_param_t<2>::get(); - else if(fp_str == "8") - return tensor_reorder_kernel_get_all_param_t<1>::get(); - else - assert(false); - }; - - batched = false; - launch_gpu_init(); - float min_tensor_reorder_time = FLT_MAX; - transpose_kernel_param_t min_tensor_reorder_kparam; - if(batched){ - for(auto kparam : get_transpose_all_kernel()){ - float current_time = test_batched_transpose(&kparam); - if(current_time < min_tensor_reorder_time){ - min_tensor_reorder_time = current_time; - min_tensor_reorder_kparam = kparam; - } - } - printf("-> min time:%fms, kparam: %dx%d, %dx%d, %dx%d\n", min_tensor_reorder_time, - min_tensor_reorder_kparam.tile_x, min_tensor_reorder_kparam.tile_y, min_tensor_reorder_kparam.pack_x, min_tensor_reorder_kparam.pack_y, min_tensor_reorder_kparam.ediv_x, min_tensor_reorder_kparam.ediv_y); - fflush(stdout); - printf("-------------------------\n"); - } - else{ - for(auto kparam : get_tensor_reorder_all_kernel()){ - float current_time = test_general_tensor_reorder(&kparam); - if(current_time < min_tensor_reorder_time){ - min_tensor_reorder_time = current_time; - min_tensor_reorder_kparam = kparam; - } - } - printf("-> min time:%fms, kparam: 256x%d\n", min_tensor_reorder_time, min_tensor_reorder_kparam.tile_x); - fflush(stdout); - printf("-------------------------\n"); - } +int main() +{ +loop([&](auto i) { + constexpr int all_possible_sequence[23][4] = { + {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, + {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, + {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, + {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; + using dst_order = sequence; + run_test >>(); + run_test >>(); + run_test >>(); }); - - free(src_cpu); - free(dst_cpu); - free(dst_gpu_valid); -} +} \ No newline at end of file From 569044f7e6bf1554cc28c07681e302f67cdeaf34 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 12:05:20 +0800 Subject: [PATCH 03/77] fix some bugs and try --- ...sol.cpp => general_tensor_reorder_sol.cpp} | 59 ++++---- ...sol.hpp => general_tensor_reorder_sol.hpp} | 8 +- src/include/miopen/tensor_reorder_util.hpp | 139 ++++++++++++++++++ test/sequence.hpp | 46 ++++++ test/tensor_reorder.cpp | 6 +- 5 files changed, 222 insertions(+), 36 deletions(-) rename src/hip/{tensor_reorder_sol.cpp => general_tensor_reorder_sol.cpp} (83%) rename src/include/miopen/{tensor_reorder_sol.hpp => general_tensor_reorder_sol.hpp} (93%) create mode 100644 src/include/miopen/tensor_reorder_util.hpp create mode 100644 test/sequence.hpp diff --git a/src/hip/tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp similarity index 83% rename from src/hip/tensor_reorder_sol.cpp rename to src/hip/general_tensor_reorder_sol.cpp index abf6854b98..5c4f7b12d1 100644 --- a/src/hip/tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -24,7 +24,7 @@ * *******************************************************************************/ -#include +#include #include #include #include @@ -55,11 +55,11 @@ static inline std::string GetNameTrait(std::size_t type_size) MIOPEN_THROW("data type not supported"); } -static inline const std::vector& GetKernelList(std::size_t data_size) +static inline const std::vector& GetKernelList(std::size_t data_size) { if(data_size == 1) { - static const std::vector byte_kernel_list{ + static const std::vector byte_kernel_list{ // clang-format off {1, 256, 1, 1, 1, 1}, {2, 256, 1, 1, 1, 1}, @@ -72,7 +72,7 @@ static inline const std::vector& GetKernelList(std::size_t d } if(data_size == 2) { - static const std::vector half_kernel_list{ + static const std::vector half_kernel_list{ // clang-format off {1, 256, 1, 1, 1, 1}, {2, 256, 1, 1, 1, 1}, @@ -85,7 +85,7 @@ static inline const std::vector& GetKernelList(std::size_t d } if(data_size == 4) { - static const std::vector dword_kernel_list{ + static const std::vector dword_kernel_list{ // clang-format off {1, 256, 1, 1, 1, 1}, {2, 256, 1, 1, 1, 1}, @@ -102,12 +102,12 @@ static inline const std::vector& GetKernelList(std::size_t d static inline bool IsApplicable(uint32_t /* batch */, uint32_t height, uint32_t width, - const TensorReorderParam* kparam) + const GeneralReorderParam* kparam) { return width % kparam->ediv_x == 0 && height % kparam->ediv_y == 0; } -static inline bool IsSameSide(uint32_t height, uint32_t width, const TensorReorderParam* kparam) +static inline bool IsSameSide(uint32_t height, uint32_t width, const GeneralReorderParam* kparam) { float radio = 0; if(width > height) @@ -127,9 +127,8 @@ static inline float GetNormalizedRadio(T x, T y) return static_cast(y) / x; return static_cast(x) / y; } - template -static inline std::string GetKernelName(std::size_t data_size, const TensorReorderParam* kparam) +static inline std::string GetKernelName(std::size_t data_size, const GeneralReorderParam* kparam) { std::ostringstream kernel_name; std::string type_trait = GetNameTrait(data_size); @@ -146,7 +145,7 @@ static inline std::string GetKernelName(std::size_t data_size, const TensorReord static inline std::size_t GetExtraPaddingSize(uint32_t /* batch */, uint32_t height, uint32_t width, - const TensorReorderParam* kparam) + const GeneralReorderParam* kparam) { // For simplicity and speed, we ignore batch, only compute h*w uint32_t padded_h = ((height + kparam->tile_y - 1) / kparam->tile_y) * kparam->tile_y; @@ -154,7 +153,7 @@ static inline std::size_t GetExtraPaddingSize(uint32_t /* batch */, return static_cast(padded_h) * padded_w - static_cast(height) * width; } -static inline TensorReorderParam +static inline GeneralReorderParam HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) { /* @@ -166,30 +165,30 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim { if(dim_3 >= 16) { - return TensorReorderParam{16, 256, 1, 1, 1, 1}; + return GeneralReorderParam{16, 256, 1, 1, 1, 1}; } else if(dim_3 >= 8) { - return TensorReorderParam{8, 256, 1, 1, 1, 1}; + return GeneralReorderParam{8, 256, 1, 1, 1, 1}; } else if(dim_3 >= 4) { - return TensorReorderParam{4, 256, 1, 1, 1, 1}; + return GeneralReorderParam{4, 256, 1, 1, 1, 1}; } else if(dim_3 >= 2) { - return TensorReorderParam{2, 256, 1, 1, 1, 1}; + return GeneralReorderParam{2, 256, 1, 1, 1, 1}; } else { - return TensorReorderParam{1, 256, 1, 1, 1, 1}; + return GeneralReorderParam{1, 256, 1, 1, 1, 1}; } } } } // namespace tensor_reorder - -TensorReorderSolution::TensorReorderSolution(const ExecutionContext& ctx, +template +GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, @@ -204,7 +203,8 @@ TensorReorderSolution::TensorReorderSolution(const ExecutionContext& ctx, kernel_param_heuristic = tensor_reorder::HeuristicGet(data_size, dim_0, dim_1, dim_2, dim_3); } -solver::KernelInfo TensorReorderSolution::GetKernel() const +template +solver::KernelInfo GeneralReorderSolution::GetKernel() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; #if TENSOR_REORDER_PERSISTENT @@ -214,9 +214,9 @@ solver::KernelInfo TensorReorderSolution::GetKernel() const uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); std::size_t grid_size = dim_total; #endif - std::string kernel_name = GetKernelName(); + std::string kernel_name = GetKernelName(); solver::KernelInfo kernel; - kernel.kernel_file = "tensor_reorder.cpp"; + kernel.kernel_file = "general_tensor_reorder.cpp"; kernel.kernel_name = kernel_name; kernel.g_wk.clear(); kernel.g_wk.push_back(grid_size * block_size); @@ -227,12 +227,12 @@ solver::KernelInfo TensorReorderSolution::GetKernel() const kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - MIOPEN_LOG_I2("TensorReorderSolution use kernel: " + kernel_name); + MIOPEN_LOG_I2("GeneralReorderSolution use kernel: " + kernel_name); return kernel; } - -std::vector TensorReorderSolution::GetKernelArg() const +template +std::vector GeneralReorderSolution::GetKernelArg() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; @@ -265,21 +265,22 @@ std::vector TensorReorderSolution::GetKernelArg() const return opArgs; } - template -std::string TensorReorderSolution::GetKernelName() const +std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); - return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); + return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); } -bool TensorReorderSolution::IsSkippable() const +template +bool GeneralReorderSolution::IsSkippable() const { // Disable the IsSkippable funciton return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0 ; } -size_t TensorReorderSolution::GetSize() const +template +size_t GeneralReorderSolution::GetSize() const { return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } diff --git a/src/include/miopen/tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp similarity index 93% rename from src/include/miopen/tensor_reorder_sol.hpp rename to src/include/miopen/general_tensor_reorder_sol.hpp index 4e777d97fb..5c6c6010a5 100644 --- a/src/include/miopen/tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -34,7 +34,7 @@ namespace miopen { -struct TensorReorderParam +struct GeneralReorderParam { int tile_x{0}; int tile_y{0}; @@ -45,9 +45,9 @@ struct TensorReorderParam }; template -struct TensorReorderSolution +struct GeneralReorderSolution { - TensorReorderSolution(const ExecutionContext& ctx_, + GeneralReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, @@ -66,7 +66,7 @@ struct TensorReorderSolution uint32_t dim_3; int num_cu; - TensorReorderParam kernel_param_heuristic; + GeneralReorderParam kernel_param_heuristic; }; } // namespace miopen diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp new file mode 100644 index 0000000000..7a28b2821d --- /dev/null +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -0,0 +1,139 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c_) 202 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_TENSOR_REORDER_UTIL_HPP_ +#define MIOPEN_TENSOR_REORDER_UTIL_HPP_ + +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +template +struct TensorReorderSolution : public GeneralReorderSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) + { + } +}; + +template<> +struct TensorReorderSolution> : public BatchedTransposeSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) + { + } +}; + +template<> +struct TensorReorderSolution> : public BatchedTransposeSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) + { + } +}; + +template<> +struct TensorReorderSolution> : public BatchedTransposeSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) + { + } +}; + +template<> +struct TensorReorderSolution> : public BatchedTransposeSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0 *dim_1_, dim_2_ * dim_3_) + { + } +}; + +template<> +struct TensorReorderSolution> : public BatchedTransposeSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0 *dim_1_, dim_2_ * dim_3_) + { + } +}; + +template<> +struct TensorReorderSolution> : public BatchedTransposeSolution +{ + TensorReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0 * dim_1_ * dim_2_,dim_3_) + { + } +}; + +} // namespace miopen + +#endif // MIOPEN_TENSOR_REORDER_UTIL_HPP_ diff --git a/test/sequence.hpp b/test/sequence.hpp new file mode 100644 index 0000000000..cf83e64012 --- /dev/null +++ b/test/sequence.hpp @@ -0,0 +1,46 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef SEQUENCE_HPP +#define SEQUENCE_HPP + +template +struct sequence +{ + static constexpr int m_size = sizeof...(Is); + + __host__ __device__ static constexpr auto size() { return m_size; } + + __host__ __device__ static constexpr auto get_size() { return size(); } + + __host__ __device__ static constexpr int at(int I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const int m_data[m_size + 1] = {Is..., 0}; + return m_data[I]; + } + +}; +#endif \ No newline at end of file diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index afb57cba3d..a72fc3bad0 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -25,10 +25,10 @@ *******************************************************************************/ #include #include -#include +#include #include #include -#include +//#include #include #include #include @@ -294,7 +294,7 @@ struct reorder_invoke_param : public miopen::InvokeParams { } }; - +//The template parameter dst_order is just for CPU verification template struct reorder_test : reorder__base { From 682a7250e7709434aa37d19e677177bea6909fc0 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 12:12:36 +0800 Subject: [PATCH 04/77] fix bug --- src/include/miopen/tensor_reorder_util.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 7a28b2821d..61d54f727a 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -32,12 +32,13 @@ #include #include #include +#include #include namespace miopen { template -struct TensorReorderSolution : public GeneralReorderSolution +struct TensorReorderSolution : public GeneralReorderSolution { TensorReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, @@ -45,7 +46,7 @@ struct TensorReorderSolution : public GeneralReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) + : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) { } }; From 7fc0de721c91872233b056e628ddef88f5c592f0 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:14:13 +0800 Subject: [PATCH 05/77] fix bug --- src/include/miopen/tensor_reorder_util.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 61d54f727a..61bac97ff3 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -32,7 +32,7 @@ #include #include #include -#include +#include <../kernels/gpu_tensor_reorder/sequence.hpp> #include namespace miopen { From b1f5c89bd2cbfa345e19bfa34af5d1fbd5dc0882 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:17:50 +0800 Subject: [PATCH 06/77] fix bugs --- src/include/miopen/tensor_reorder_util.hpp | 18 ++---------------- test/tensor_reorder.cpp | 2 +- 2 files changed, 3 insertions(+), 17 deletions(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 61bac97ff3..cf3cd991e1 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -102,21 +102,7 @@ struct TensorReorderSolution> : public BatchedTransposeSolu uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0 *dim_1_, dim_2_ * dim_3_) - { - } -}; - -template<> -struct TensorReorderSolution> : public BatchedTransposeSolution -{ - TensorReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0 *dim_1_, dim_2_ * dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ *dim_1_, dim_2_ * dim_3_) { } }; @@ -130,7 +116,7 @@ struct TensorReorderSolution> : public BatchedTransposeSolu uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0 * dim_1_ * dim_2_,dim_3_) + : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) { } }; diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index a72fc3bad0..1edea2119b 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -296,7 +296,7 @@ struct reorder_invoke_param : public miopen::InvokeParams }; //The template parameter dst_order is just for CPU verification template -struct reorder_test : reorder__base +struct reorder_test : reorder_base { void run() { From 95738617c16cdfe67b9031cb66dc5fbcfa66afa9 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:19:26 +0800 Subject: [PATCH 07/77] fix bugs --- test/sequence.hpp | 2 +- test/tensor_reorder.cpp | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/sequence.hpp b/test/sequence.hpp index cf83e64012..8ce9a874d1 100644 --- a/test/sequence.hpp +++ b/test/sequence.hpp @@ -43,4 +43,4 @@ struct sequence } }; -#endif \ No newline at end of file +#endif diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 1edea2119b..1dfa726a94 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -437,8 +437,8 @@ loop([&](auto i) { {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; using dst_order = sequence; - run_test >>(); - run_test >>(); - run_test >>(); + run_test >>(); + run_test >>(); + run_test >>(); }); -} \ No newline at end of file +} From ca1bb57a0258b273cf4759e576086f0418e6e88d Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:26:55 +0800 Subject: [PATCH 08/77] fix bug --- test/tensor_reorder.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 1dfa726a94..651a06d995 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -109,10 +109,10 @@ template struct reorder_str { static std::string get() { - return ("r" + itoa(dst_order::at(0)) - + itoa(dst_order::at(1)) - + itoa(dst_order::at(2)) - + itoa(dst_order::at(3)) ); + return ("r" + std::to_string(dst_order::at(0)) + + std::to_string(dst_order::at(1)) + + std::to_string(dst_order::at(2)) + + std::to_string(dst_order::at(3)) ); } }; From b0c188ca950d85a7b2b129524247ecd32718c958 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:36:50 +0800 Subject: [PATCH 09/77] fix bugs --- src/hip/general_tensor_reorder_sol.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 5c4f7b12d1..d48d790f79 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -188,12 +188,12 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim } // namespace tensor_reorder template -GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) +GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_) { if(data_type == miopenInt8x4 || data_type == miopenDouble) @@ -232,7 +232,7 @@ solver::KernelInfo GeneralReorderSolution::GetKernel() const return kernel; } template -std::vector GeneralReorderSolution::GetKernelArg() const +std::vector GeneralReorderSolution::GetKernelArg() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; @@ -273,14 +273,14 @@ std::string GeneralReorderSolution::GetKernelName() const } template -bool GeneralReorderSolution::IsSkippable() const +bool GeneralReorderSolution::IsSkippable() const { // Disable the IsSkippable funciton return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0 ; } template -size_t GeneralReorderSolution::GetSize() const +size_t GeneralReorderSolution::GetSize() const { return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } From 57dab09009373234633b425176c61e2e5e43849b Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:44:59 +0800 Subject: [PATCH 10/77] fix bug --- src/hip/general_tensor_reorder_sol.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index d48d790f79..904e660804 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -189,11 +189,11 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim } // namespace tensor_reorder template GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_) { if(data_type == miopenInt8x4 || data_type == miopenDouble) From 45894a7915c96c318441ca13a4d364ea785f70f9 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:45:13 +0800 Subject: [PATCH 11/77] fixbug --- src/include/miopen/tensor_reorder_util.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index cf3cd991e1..0912cf04b5 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -46,7 +46,7 @@ struct TensorReorderSolution : public GeneralReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) + : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) { } }; From 84863c4e95f5caef73128862ecf589fa87d03e9d Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:50:08 +0800 Subject: [PATCH 12/77] fixbug --- src/include/miopen/tensor_reorder_util.hpp | 2 +- src/kernels/gpu_tensor_reorder/sequence.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 0912cf04b5..cf3cd991e1 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -46,7 +46,7 @@ struct TensorReorderSolution : public GeneralReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) + : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) { } }; diff --git a/src/kernels/gpu_tensor_reorder/sequence.hpp b/src/kernels/gpu_tensor_reorder/sequence.hpp index cf83e64012..8ce9a874d1 100644 --- a/src/kernels/gpu_tensor_reorder/sequence.hpp +++ b/src/kernels/gpu_tensor_reorder/sequence.hpp @@ -43,4 +43,4 @@ struct sequence } }; -#endif \ No newline at end of file +#endif From 54d1f2e5a04a9538ed16b37d9e50766aa992e447 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:51:32 +0800 Subject: [PATCH 13/77] test 1 --- test/tensor_reorder.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 651a06d995..0f34b3f091 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -430,9 +430,9 @@ struct reorder_test : reorder_base int main() { -loop([&](auto i) { +loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { - {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, + {0, 2, 3, 1}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 3, 1, 2}, {0, 3, 2, 1}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; From e5f861701ac26f30f4c7cd74445a0d28f1ec69a1 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:53:17 +0800 Subject: [PATCH 14/77] General test, (Batched passed) --- test/tensor_reorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 0f34b3f091..22ccb261e7 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -432,7 +432,7 @@ int main() { loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { - {0, 2, 3, 1}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 3, 1, 2}, {0, 3, 2, 1}, + {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; From 4dba45ce13b09f0dbd983ea2836e2bedcf349a36 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 13:54:34 +0800 Subject: [PATCH 15/77] 0321 test --- test/tensor_reorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 22ccb261e7..895d7d2db5 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -432,7 +432,7 @@ int main() { loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { - {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, + {0, 3, 2, 1}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; From c3c530323da3d6a81eebbab1c398c711658134b2 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 14:40:21 +0800 Subject: [PATCH 16/77] explicit template instance --- src/hip/general_tensor_reorder_sol.cpp | 15 +++++++++++++++ src/include/miopen/general_tensor_reorder_sol.hpp | 3 ++- src/include/miopen/tensor_reorder_util.hpp | 4 ++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 904e660804..a81c3605f2 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -33,6 +33,7 @@ #include #include #include +#include <../kernels/gpu_tensor_reorder/sequence.hpp> #define TENSOR_REORDER_BLOCK_SIZE 256 #define TENSOR_REORDER_PERSISTENT 0 @@ -231,6 +232,7 @@ solver::KernelInfo GeneralReorderSolution::GetKernel() const return kernel; } + template std::vector GeneralReorderSolution::GetKernelArg() const { @@ -265,6 +267,7 @@ std::vector GeneralReorderSolution::GetKernelArg() const return opArgs; } + template std::string GeneralReorderSolution::GetKernelName() const { @@ -286,3 +289,15 @@ size_t GeneralReorderSolution::GetSize() const } } // namespace miopen +//explicit instance +template void GeneralReorderSolution>::GeneralReorderSolution(const ExecutionContext& ctx, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_); +template solver::KernelInfo GeneralReorderSolution>::GetKernel() const; +template std::vector GeneralReorderSolution>::GetKernelArg() const; +template std::string GeneralReorderSolution>::GetKernelName() const; +template bool GeneralReorderSolution>::IsSkippable() const; +template size_t GeneralReorderSolution>::GetSize() const; diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 5c6c6010a5..13f9bb181f 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -31,6 +31,7 @@ #include #include #include +#include <../kernels/gpu_tensor_reorder/sequence.hpp> namespace miopen { @@ -70,5 +71,5 @@ struct GeneralReorderSolution }; } // namespace miopen - +template struct GeneralReorderSolution>; #endif diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index cf3cd991e1..67484b97a0 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -47,8 +47,8 @@ struct TensorReorderSolution : public GeneralReorderSolution uint32_t dim_2_, uint32_t dim_3_) : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) - { - } + { + } }; template<> From b539c9c619ea2105f1aa78904bc8814733615141 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 14:41:50 +0800 Subject: [PATCH 17/77] fix bug --- src/hip/general_tensor_reorder_sol.cpp | 2 +- src/include/miopen/general_tensor_reorder_sol.hpp | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index a81c3605f2..cef0aea7dd 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -288,7 +288,6 @@ size_t GeneralReorderSolution::GetSize() const return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } -} // namespace miopen //explicit instance template void GeneralReorderSolution>::GeneralReorderSolution(const ExecutionContext& ctx, miopenDataType_t data_type_, @@ -301,3 +300,4 @@ template std::vector GeneralReorderSolution>:: template std::string GeneralReorderSolution>::GetKernelName() const; template bool GeneralReorderSolution>::IsSkippable() const; template size_t GeneralReorderSolution>::GetSize() const; +} // namespace miopen diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 13f9bb181f..b555beb3a9 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -69,7 +69,6 @@ struct GeneralReorderSolution GeneralReorderParam kernel_param_heuristic; }; - -} // namespace miopen template struct GeneralReorderSolution>; +} // namespace miopen #endif From b9e86848f211dfcbe0bbd8b3223418057732bd17 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 14:44:16 +0800 Subject: [PATCH 18/77] fix bug --- src/include/miopen/tensor_reorder_util.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 67484b97a0..1b3ddf62bc 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -120,7 +120,8 @@ struct TensorReorderSolution> : public BatchedTransposeSolu { } }; - +//explicit instance +template struct TensorReorderSolution>; } // namespace miopen #endif // MIOPEN_TENSOR_REORDER_UTIL_HPP_ From c766a69c59c6a978f2f4520ced9b2a83fa549fc1 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 15:23:20 +0800 Subject: [PATCH 19/77] move instantiation into sol.hpp --- src/hip/general_tensor_reorder_sol.cpp | 14 +- .../miopen/general_tensor_reorder_sol.hpp | 262 ++++++++++++++++++ 2 files changed, 263 insertions(+), 13 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index cef0aea7dd..e0eb306cbf 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -99,7 +99,6 @@ static inline const std::vector& GetKernelList(std::size_t } MIOPEN_THROW("data type not supported"); } - static inline bool IsApplicable(uint32_t /* batch */, uint32_t height, uint32_t width, @@ -288,16 +287,5 @@ size_t GeneralReorderSolution::GetSize() const return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } -//explicit instance -template void GeneralReorderSolution>::GeneralReorderSolution(const ExecutionContext& ctx, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_); -template solver::KernelInfo GeneralReorderSolution>::GetKernel() const; -template std::vector GeneralReorderSolution>::GetKernelArg() const; -template std::string GeneralReorderSolution>::GetKernelName() const; -template bool GeneralReorderSolution>::IsSkippable() const; -template size_t GeneralReorderSolution>::GetSize() const; } // namespace miopen +*/ diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index b555beb3a9..65e1e0d0b1 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -33,6 +33,21 @@ #include #include <../kernels/gpu_tensor_reorder/sequence.hpp> +#include +#include +#include +#include +#include +#include +#include + +#define TENSOR_REORDER_BLOCK_SIZE 256 +#define TENSOR_REORDER_PERSISTENT 0 + +#if TENSOR_REORDER_PERSISTENT +#define TENSOR_REORDER_OCCUPANCY 4 +#endif + namespace miopen { struct GeneralReorderParam @@ -71,4 +86,251 @@ struct GeneralReorderSolution }; template struct GeneralReorderSolution>; } // namespace miopen + +namespace miopen { +namespace tensor_reorder { + +static inline std::string GetNameTrait(std::size_t type_size) +{ + if(type_size == 1) + return "byte"; + if(type_size == 2) + return "half"; + if(type_size == 4) + return "dword"; + MIOPEN_THROW("data type not supported"); +} + +static inline const std::vector& GetKernelList(std::size_t data_size) +{ + if(data_size == 1) + { + static const std::vector byte_kernel_list{ + // clang-format off + {1, 256, 1, 1, 1, 1}, + {2, 256, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {8, 256, 1, 1, 1, 1}, + {16, 256, 1, 1, 1, 1}, + // clang-format on + }; + return byte_kernel_list; + } + if(data_size == 2) + { + static const std::vector half_kernel_list{ + // clang-format off + {1, 256, 1, 1, 1, 1}, + {2, 256, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {8, 256, 1, 1, 1, 1}, + {16, 256, 1, 1, 1, 1}, + // clang-format on + }; + return half_kernel_list; + } + if(data_size == 4) + { + static const std::vector dword_kernel_list{ + // clang-format off + {1, 256, 1, 1, 1, 1}, + {2, 256, 1, 1, 1, 1}, + {4, 256, 1, 1, 1, 1}, + {8, 256, 1, 1, 1, 1}, + {16, 256, 1, 1, 1, 1}, + // clang-format on + }; + return dword_kernel_list; + } + MIOPEN_THROW("data type not supported"); +} + +static inline bool IsApplicable(uint32_t /* batch */, + uint32_t height, + uint32_t width, + const GeneralReorderParam* kparam) +{ + return width % kparam->ediv_x == 0 && height % kparam->ediv_y == 0; +} + +static inline bool IsSameSide(uint32_t height, uint32_t width, const GeneralReorderParam* kparam) +{ + float radio = 0; + if(width > height) + radio = static_cast(kparam->tile_x) / kparam->tile_y; + else + radio = static_cast(kparam->tile_y) / kparam->tile_x; + + // E.g. for cases like width=1000, height=10 + // Allow at least 32x64, 64x64... 16x64 not allowed + return radio >= 0.4; +} + +template +static inline float GetNormalizedRadio(T x, T y) +{ + if(y > x) + return static_cast(y) / x; + return static_cast(x) / y; +} +template +static inline std::string GetKernelName(std::size_t data_size, const GeneralReorderParam* kparam) +{ + std::ostringstream kernel_name; + std::string type_trait = GetNameTrait(data_size); + kernel_name << "general_4d_reorder_" << kparam->tile_x << "x" << kparam->tile_y << "_"; + if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) + { + kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" + << kparam->ediv_x << "x" << kparam->ediv_y << "_"; + } + kernel_name << type_trait<<"_r"<tile_y - 1) / kparam->tile_y) * kparam->tile_y; + uint32_t padded_w = ((width + kparam->tile_x - 1) / kparam->tile_x) * kparam->tile_x; + return static_cast(padded_h) * padded_w - static_cast(height) * width; +} + +static inline GeneralReorderParam +HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) +{ + /* + * TODO: + * Design a algorithm to determine general tensor reorder tile size. + */ + + if(dim_3 >= 1 ) + { + if(dim_3 >= 16) + { + return GeneralReorderParam{16, 256, 1, 1, 1, 1}; + } + else if(dim_3 >= 8) + { + return GeneralReorderParam{8, 256, 1, 1, 1, 1}; + } + else if(dim_3 >= 4) + { + return GeneralReorderParam{4, 256, 1, 1, 1, 1}; + } + else if(dim_3 >= 2) + { + return GeneralReorderParam{2, 256, 1, 1, 1, 1}; + } + else + { + return GeneralReorderParam{1, 256, 1, 1, 1, 1}; + } + } +} + +} // namespace tensor_reorder +template +GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_) +{ + if(data_type == miopenInt8x4 || data_type == miopenDouble) + MIOPEN_THROW("These data type are not supported"); + num_cu = ctx.GetStream().GetMaxComputeUnits(); + std::size_t data_size = miopen::GetTypeSize(data_type); + kernel_param_heuristic = tensor_reorder::HeuristicGet(data_size, dim_0, dim_1, dim_2, dim_3); +} + +template +solver::KernelInfo GeneralReorderSolution::GetKernel() const +{ + std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; +#if TENSOR_REORDER_PERSISTENT + std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; +#else + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); + std::size_t grid_size = dim_total; +#endif + std::string kernel_name = GetKernelName(); + solver::KernelInfo kernel; + kernel.kernel_file = "general_tensor_reorder.cpp"; + kernel.kernel_name = kernel_name; + kernel.g_wk.clear(); + kernel.g_wk.push_back(grid_size * block_size); + kernel.g_wk.push_back(1); + kernel.g_wk.push_back(1); + kernel.l_wk.clear(); + kernel.l_wk.push_back(block_size); + kernel.l_wk.push_back(1); + kernel.l_wk.push_back(1); + + MIOPEN_LOG_I2("GeneralReorderSolution use kernel: " + kernel_name); + + return kernel; +} + +template +std::vector GeneralReorderSolution::GetKernelArg() const +{ + std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); +#if TENSOR_REORDER_PERSISTENT + std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; +#else + std::size_t grid_size = dim_total; +#endif + + magic_div_u32_t magic_stride0 = magic_div_u32_gen(dim_1 * dim_2 * dim_3); + magic_div_u32_t magic_stride1 = magic_div_u32_gen(dim_2 * dim_3); + magic_div_u32_t magic_stride2 = magic_div_u32_gen(dim_3); + + std::vector opArgs; + opArgs.emplace_back(0); // placeholder + opArgs.emplace_back(0); // placeholder + opArgs.emplace_back(dim_0); + opArgs.emplace_back(dim_1); + opArgs.emplace_back(dim_2); + opArgs.emplace_back(dim_3); + opArgs.emplace_back(static_cast(grid_size)); + opArgs.emplace_back(dim_total); + opArgs.emplace_back(magic_stride0.magic); + opArgs.emplace_back(static_cast(magic_stride0.shift)); + opArgs.emplace_back(magic_stride1.magic); + opArgs.emplace_back(static_cast(magic_stride1.shift)); + opArgs.emplace_back(magic_stride2.magic); + opArgs.emplace_back(static_cast(magic_stride2.shift)); + + return opArgs; +} + +template +std::string GeneralReorderSolution::GetKernelName() const +{ + std::size_t data_size = miopen::GetTypeSize(data_type); + return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); +} + +template +bool GeneralReorderSolution::IsSkippable() const +{ + // Disable the IsSkippable funciton + return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0 ; +} + +template +size_t GeneralReorderSolution::GetSize() const +{ + return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; +} +} // namespace miopen #endif From 37b19265430c58c77d83fef5e001d9feda7d141b Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 15:24:09 +0800 Subject: [PATCH 20/77] fix bug --- src/include/miopen/general_tensor_reorder_sol.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 65e1e0d0b1..fde6f30d59 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -84,7 +84,6 @@ struct GeneralReorderSolution GeneralReorderParam kernel_param_heuristic; }; -template struct GeneralReorderSolution>; } // namespace miopen namespace miopen { From a36ce9867d7d9babcaca69547885edc6a330b387 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 15:28:44 +0800 Subject: [PATCH 21/77] fixbug --- src/include/miopen/general_tensor_reorder_sol.hpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index fde6f30d59..5439f012ef 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -84,9 +84,6 @@ struct GeneralReorderSolution GeneralReorderParam kernel_param_heuristic; }; -} // namespace miopen - -namespace miopen { namespace tensor_reorder { static inline std::string GetNameTrait(std::size_t type_size) @@ -206,7 +203,7 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim * Design a algorithm to determine general tensor reorder tile size. */ - if(dim_3 >= 1 ) + if(dim_0 >= 1 && dim_1 >= 1 && dim_2 >= 1 && dim_3 >= 1 && data_size<=4) { if(dim_3 >= 16) { From 923e4b3544d89a2f9f5a286c152675a482087cef Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 15:36:21 +0800 Subject: [PATCH 22/77] fix bug --- src/include/miopen/general_tensor_reorder_sol.hpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 5439f012ef..5078b534f5 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -313,7 +313,17 @@ template std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); - return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); + std::ostringstream kernel_name; + std::string type_trait = GetNameTrait(data_size); + kernel_name << "general_4d_reorder_" << kparam->tile_x << "x" << kparam->tile_y << "_"; + if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) + { + kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" + << kparam->ediv_x << "x" << kparam->ediv_y << "_"; + } + kernel_name << type_trait<<"_r"< From 7802205a1a4bcfc63c7b167ee103c4b82487cd45 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 15:38:48 +0800 Subject: [PATCH 23/77] fix bug --- src/include/miopen/general_tensor_reorder_sol.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 5078b534f5..039b52d42f 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -314,12 +314,12 @@ std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); std::ostringstream kernel_name; - std::string type_trait = GetNameTrait(data_size); - kernel_name << "general_4d_reorder_" << kparam->tile_x << "x" << kparam->tile_y << "_"; - if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) + std::string type_trait = tensor_reorder::GetNameTrait(data_size); + kernel_name << "general_4d_reorder_" << kernel_param_heuristic->tile_x << "x" << kernel_param_heuristic->tile_y << "_"; + if(!(kernel_param_heuristic->pack_x == 1 && kernel_param_heuristic->pack_y == 1 && kernel_param_heuristic->ediv_x == 1 && kernel_param_heuristic->ediv_y == 1)) { - kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" - << kparam->ediv_x << "x" << kparam->ediv_y << "_"; + kernel_name << "pack_" << kernel_param_heuristic->pack_x << "x" << kernel_param_heuristic->pack_y << "_ediv_" + << kernel_param_heuristic->ediv_x << "x" << kernel_param_heuristic->ediv_y << "_"; } kernel_name << type_trait<<"_r"< Date: Thu, 27 Jan 2022 15:40:06 +0800 Subject: [PATCH 24/77] fixbug --- src/include/miopen/general_tensor_reorder_sol.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 039b52d42f..faf090d7ca 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -315,11 +315,11 @@ std::string GeneralReorderSolution::GetKernelName() const std::size_t data_size = miopen::GetTypeSize(data_type); std::ostringstream kernel_name; std::string type_trait = tensor_reorder::GetNameTrait(data_size); - kernel_name << "general_4d_reorder_" << kernel_param_heuristic->tile_x << "x" << kernel_param_heuristic->tile_y << "_"; - if(!(kernel_param_heuristic->pack_x == 1 && kernel_param_heuristic->pack_y == 1 && kernel_param_heuristic->ediv_x == 1 && kernel_param_heuristic->ediv_y == 1)) + kernel_name << "general_4d_reorder_" << kernel_param_heuristic.tile_x << "x" << kernel_param_heuristic.tile_y << "_"; + if(!(kernel_param_heuristic.pack_x == 1 && kernel_param_heuristic.pack_y == 1 && kernel_param_heuristic.ediv_x == 1 && kernel_param_heuristic.ediv_y == 1)) { - kernel_name << "pack_" << kernel_param_heuristic->pack_x << "x" << kernel_param_heuristic->pack_y << "_ediv_" - << kernel_param_heuristic->ediv_x << "x" << kernel_param_heuristic->ediv_y << "_"; + kernel_name << "pack_" << kernel_param_heuristic.pack_x << "x" << kernel_param_heuristic.pack_y << "_ediv_" + << kernel_param_heuristic.ediv_x << "x" << kernel_param_heuristic.ediv_y << "_"; } kernel_name << type_trait<<"_r"< Date: Thu, 27 Jan 2022 15:47:50 +0800 Subject: [PATCH 25/77] fix bug --- src/include/miopen/tensor_reorder_util.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 1b3ddf62bc..3e9b11adfc 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -120,8 +120,6 @@ struct TensorReorderSolution> : public BatchedTransposeSolu { } }; -//explicit instance -template struct TensorReorderSolution>; } // namespace miopen #endif // MIOPEN_TENSOR_REORDER_UTIL_HPP_ From 3cc7c61081ea63ce05450d9a63f90c585ab563c1 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 15:49:59 +0800 Subject: [PATCH 26/77] fixbug --- src/include/miopen/tensor_reorder_util.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 3e9b11adfc..f0264e2957 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -28,11 +28,11 @@ #include #include +#include <../kernels/gpu_tensor_reorder/sequence.hpp> #include #include #include #include -#include <../kernels/gpu_tensor_reorder/sequence.hpp> #include namespace miopen { From 08a9c82c753f9efb84b21d6aa93e945b69a042d2 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 16:00:37 +0800 Subject: [PATCH 27/77] fixbug --- src/include/miopen/tensor_reorder_util.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index f0264e2957..fe19f36d0c 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -28,12 +28,12 @@ #include #include -#include <../kernels/gpu_tensor_reorder/sequence.hpp> #include #include #include #include #include +#include <../kernels/gpu_tensor_reorder/sequence.hpp> namespace miopen { @@ -52,7 +52,7 @@ struct TensorReorderSolution : public GeneralReorderSolution }; template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct TensorReorderSolution<(sequence<0, 2, 3, 1>)> : public BatchedTransposeSolution { TensorReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, From 3374fa6cd7e937cc8c7880daac0ab5101d96f263 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 16:05:24 +0800 Subject: [PATCH 28/77] fixbug --- src/hip/general_tensor_reorder_sol.cpp | 2 +- .../miopen/general_tensor_reorder_sol.hpp | 2 +- src/include/miopen/tensor_reorder_util.hpp | 12 ++--- .../general_tensor_reorder.cpp | 48 +++++++++---------- .../{sequence.hpp => order.hpp} | 6 +-- test/{sequence.hpp => order.hpp} | 6 +-- test/tensor_reorder.cpp | 4 +- 7 files changed, 40 insertions(+), 40 deletions(-) rename src/kernels/gpu_tensor_reorder/{sequence.hpp => order.hpp} (96%) rename test/{sequence.hpp => order.hpp} (96%) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index e0eb306cbf..91104455c3 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -33,7 +33,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder/sequence.hpp> +#include <../kernels/gpu_tensor_reorder/order.hpp> #define TENSOR_REORDER_BLOCK_SIZE 256 #define TENSOR_REORDER_PERSISTENT 0 diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index faf090d7ca..4a42c2bf0c 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -31,7 +31,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder/sequence.hpp> +#include <../kernels/gpu_tensor_reorder/order.hpp> #include #include diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index fe19f36d0c..0a31408484 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -33,7 +33,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder/sequence.hpp> +#include <../kernels/gpu_tensor_reorder/order.hpp> namespace miopen { @@ -52,7 +52,7 @@ struct TensorReorderSolution : public GeneralReorderSolution }; template<> -struct TensorReorderSolution<(sequence<0, 2, 3, 1>)> : public BatchedTransposeSolution +struct TensorReorderSolution> : public BatchedTransposeSolution { TensorReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, @@ -66,7 +66,7 @@ struct TensorReorderSolution<(sequence<0, 2, 3, 1>)> : public BatchedTransposeSo }; template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct TensorReorderSolution> : public BatchedTransposeSolution { TensorReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, @@ -80,7 +80,7 @@ struct TensorReorderSolution> : public BatchedTransposeSolu }; template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct TensorReorderSolution> : public BatchedTransposeSolution { TensorReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, @@ -94,7 +94,7 @@ struct TensorReorderSolution> : public BatchedTransposeSolu }; template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct TensorReorderSolution> : public BatchedTransposeSolution { TensorReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, @@ -108,7 +108,7 @@ struct TensorReorderSolution> : public BatchedTransposeSolu }; template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct TensorReorderSolution> : public BatchedTransposeSolution { TensorReorderSolution(const ExecutionContext& ctx_, miopenDataType_t data_type_, diff --git a/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp b/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp index 93c09f10c2..31e73e3a49 100644 --- a/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp +++ b/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp @@ -25,7 +25,7 @@ *******************************************************************************/ #include #include -#include "sequence.hpp" +#include "order.hpp" #ifndef TENSOR_REORDER_OCCUPANCY #define TENSOR_REORDER_OCCUPANCY 4 @@ -358,29 +358,29 @@ inline __device__ void general_4d_reorder_16x256(T* dst, shift_stride2); \ } //default order is 0 1 2 3 -using r0132 = sequence<0, 1, 3, 2>; -using r0213 = sequence<0, 2, 1, 3>;//nhwc2nchwc -using r0231 = sequence<0, 2, 3, 1>;//nchw2nchwc -using r0312 = sequence<0, 3, 1, 2>;//nhwc2nchw -using r0321 = sequence<0, 3, 2, 1>; -using r1023 = sequence<1, 0, 2, 3>; -using r1032 = sequence<1, 0, 3, 2>; -using r1203 = sequence<1, 2, 0, 3>; -using r1230 = sequence<1, 2, 3, 0>; -using r1302 = sequence<1, 3, 0, 2>;//nchw2chwnc -using r1320 = sequence<1, 3, 2, 0>; -using r2013 = sequence<2, 0, 1, 3>; -using r2031 = sequence<2, 0, 3, 1>; -using r2103 = sequence<2, 1, 0, 3>;//nhwc2chwnc -using r2130 = sequence<2, 1, 3, 0>; -using r2301 = sequence<2, 3, 0, 1>; -using r2310 = sequence<2, 3, 1, 0>; -using r3012 = sequence<3, 0, 1, 2>; -using r3021 = sequence<3, 0, 2, 1>; -using r3102 = sequence<3, 1, 0, 2>; -using r3120 = sequence<3, 1, 2, 0>; -using r3201 = sequence<3, 2, 0, 1>; -using r3210 = sequence<3, 2, 1, 0>; +using r0132 = order<0, 1, 3, 2>; +using r0213 = order<0, 2, 1, 3>;//nhwc2nchwc +using r0231 = order<0, 2, 3, 1>;//nchw2nchwc +using r0312 = order<0, 3, 1, 2>;//nhwc2nchw +using r0321 = order<0, 3, 2, 1>; +using r1023 = order<1, 0, 2, 3>; +using r1032 = order<1, 0, 3, 2>; +using r1203 = order<1, 2, 0, 3>; +using r1230 = order<1, 2, 3, 0>; +using r1302 = order<1, 3, 0, 2>;//nchw2chwnc +using r1320 = order<1, 3, 2, 0>; +using r2013 = order<2, 0, 1, 3>; +using r2031 = order<2, 0, 3, 1>; +using r2103 = order<2, 1, 0, 3>;//nhwc2chwnc +using r2130 = order<2, 1, 3, 0>; +using r2301 = order<2, 3, 0, 1>; +using r2310 = order<2, 3, 1, 0>; +using r3012 = order<3, 0, 1, 2>; +using r3021 = order<3, 0, 2, 1>; +using r3102 = order<3, 1, 0, 2>; +using r3120 = order<3, 1, 2, 0>; +using r3201 = order<3, 2, 0, 1>; +using r3210 = order<3, 2, 1, 0>; DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, dword, float, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, dword, float, 256, TENSOR_REORDER_OCCUPANCY) diff --git a/src/kernels/gpu_tensor_reorder/sequence.hpp b/src/kernels/gpu_tensor_reorder/order.hpp similarity index 96% rename from src/kernels/gpu_tensor_reorder/sequence.hpp rename to src/kernels/gpu_tensor_reorder/order.hpp index 8ce9a874d1..6ec61f0912 100644 --- a/src/kernels/gpu_tensor_reorder/sequence.hpp +++ b/src/kernels/gpu_tensor_reorder/order.hpp @@ -23,11 +23,11 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef SEQUENCE_HPP -#define SEQUENCE_HPP +#ifndef ORDER_HPP +#define ORDER_HPP template -struct sequence +struct order { static constexpr int m_size = sizeof...(Is); diff --git a/test/sequence.hpp b/test/order.hpp similarity index 96% rename from test/sequence.hpp rename to test/order.hpp index 8ce9a874d1..6ec61f0912 100644 --- a/test/sequence.hpp +++ b/test/order.hpp @@ -23,11 +23,11 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef SEQUENCE_HPP -#define SEQUENCE_HPP +#ifndef ORDER_HPP +#define ORDER_HPP template -struct sequence +struct order { static constexpr int m_size = sizeof...(Is); diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 895d7d2db5..d5c69aa602 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -38,7 +38,7 @@ #include "test.hpp" #include "driver.hpp" #include "random.hpp" -#include "sequence.hpp" +#include "order.hpp" template <> @@ -436,7 +436,7 @@ loop([&](auto i) { {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; - using dst_order = sequence; + using dst_order = order; run_test >>(); run_test >>(); run_test >>(); From 0dfac32302c735c862c953007961cf3794821655 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 16:13:17 +0800 Subject: [PATCH 29/77] fixbug --- src/include/miopen/general_tensor_reorder_sol.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 4a42c2bf0c..4e84a13605 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -226,6 +226,7 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim return GeneralReorderParam{1, 256, 1, 1, 1, 1}; } } + MIOPEN_THROW("data type not supported"); } } // namespace tensor_reorder From e9ac702495725eb4b060e80ba2cf1e7fcf0c7d9f Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 16:24:54 +0800 Subject: [PATCH 30/77] batched test --- test/tensor_reorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index d5c69aa602..6ab54dac0d 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -432,7 +432,7 @@ int main() { loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { - {0, 3, 2, 1}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, + {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; From 879694f3d308600281596d4ff234c4a63a376e68 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 16:29:18 +0800 Subject: [PATCH 31/77] test batch --- test/tensor_reorder.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 6ab54dac0d..40776d92b7 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -430,9 +430,9 @@ struct reorder_test : reorder_base int main() { -loop([&](auto i) { +loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { - {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, + {0, 1, 3, 2}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, {0, 2, 1, 3}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; From 183e728ac9591c55c0a5a3fe879fc4c6458ba193 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 16:36:09 +0800 Subject: [PATCH 32/77] test --- test/tensor_reorder.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 40776d92b7..d5c69aa602 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -430,9 +430,9 @@ struct reorder_test : reorder_base int main() { -loop([&](auto i) { +loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { - {0, 1, 3, 2}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, {0, 2, 1, 3}, + {0, 3, 2, 1}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; From 978f8e963443083f19ea2ee62f4835b1b73128bf Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 17:13:03 +0800 Subject: [PATCH 33/77] add kernel --- src/CMakeLists.txt | 2 ++ .../general_tensor_reorder.cpp | 0 .../order.hpp | 0 3 files changed, 2 insertions(+) rename src/kernels/{gpu_tensor_reorder => gpu_general_tensor_reorder_kernel}/general_tensor_reorder.cpp (100%) rename src/kernels/{gpu_tensor_reorder => gpu_general_tensor_reorder_kernel}/order.hpp (100%) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2c1cecc396..efeacd6191 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -230,6 +230,8 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN file(GLOB_RECURSE GPU_REFERENCE_KERNEL_HIP "kernels/gpu_reference_kernel/*.cpp") file(GLOB_RECURSE GPU_REFERENCE_KERNEL_ASM "kernels/gpu_reference_kernel/*.s") file(GLOB_RECURSE GPU_BATCHED_TRANSPOSE_KERNEL_HIP "kernels/gpu_batched_transpose_kernel/*.cpp") + file(GLOB_RECURSE GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP "kernels/gpu_general_tensor_reorder_kernel/*.cpp") + set(MIOPEN_KERNEL_INCLUDES ${STATIC_COMPOSABLE_KERNEL_INCLUDE} diff --git a/src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp similarity index 100% rename from src/kernels/gpu_tensor_reorder/general_tensor_reorder.cpp rename to src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp diff --git a/src/kernels/gpu_tensor_reorder/order.hpp b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp similarity index 100% rename from src/kernels/gpu_tensor_reorder/order.hpp rename to src/kernels/gpu_general_tensor_reorder_kernel/order.hpp From ff5e47e7e5f0b51abe5bc7c4f0b62e5369785d3c Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 17:15:26 +0800 Subject: [PATCH 34/77] fixbugs --- src/hip/general_tensor_reorder_sol.cpp | 2 +- src/include/miopen/general_tensor_reorder_sol.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 91104455c3..82fd9b5f20 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -33,7 +33,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder/order.hpp> +#include <../kernels/gpu_tensor_reorder_kernel/order.hpp> #define TENSOR_REORDER_BLOCK_SIZE 256 #define TENSOR_REORDER_PERSISTENT 0 diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 4e84a13605..c579cc7e7b 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -31,7 +31,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder/order.hpp> +#include <../kernels/gpu_tensor_reorder_kernel/order.hpp> #include #include From 7845771c47472c98350f35f46200a00d30d8337a Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 17:17:02 +0800 Subject: [PATCH 35/77] fixtypo --- src/hip/general_tensor_reorder_sol.cpp | 2 +- src/include/miopen/general_tensor_reorder_sol.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 82fd9b5f20..4ce0e5d7b5 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -33,7 +33,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder_kernel/order.hpp> +#include <../kernels/gpu_general_tensor_reorder_kernel/order.hpp> #define TENSOR_REORDER_BLOCK_SIZE 256 #define TENSOR_REORDER_PERSISTENT 0 diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index c579cc7e7b..e93619fffb 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -31,7 +31,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder_kernel/order.hpp> +#include <../kernels/gpu_general_tensor_reorder_kernel/order.hpp> #include #include From 16dfe07767f80db6373e502d30c24f0c1402baf6 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 17:19:04 +0800 Subject: [PATCH 36/77] fixtypo --- src/include/miopen/tensor_reorder_util.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 0a31408484..9465e2a158 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -33,7 +33,7 @@ #include #include #include -#include <../kernels/gpu_tensor_reorder/order.hpp> +#include <../kernels/gpu_general_tensor_reorder_kernel/order.hpp> namespace miopen { From 9ef53c063c2787a53ff09c6329a7b6ae1b347575 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 17:39:29 +0800 Subject: [PATCH 37/77] addkerneltest --- src/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index efeacd6191..e85247b69d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -458,6 +458,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN ocl/rnn_util_ocl.cpp hip/hip_build_utils.cpp hip/batched_transpose_sol.cpp + hip/general_tensor_reorder_sol.cpp pooling.cpp ocl/fusionopconvocl.cpp ocl/fusionopbiasbnactivocl.cpp From aa6a09d819b18e4a862fed7377f2afc5f5dbb6b5 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 23:08:32 +0800 Subject: [PATCH 38/77] try separated solution --- .../miopen/general_tensor_reorder_sol.hpp | 254 ------------------ 1 file changed, 254 deletions(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index e93619fffb..3a6cdaa7ef 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -84,260 +84,6 @@ struct GeneralReorderSolution GeneralReorderParam kernel_param_heuristic; }; -namespace tensor_reorder { -static inline std::string GetNameTrait(std::size_t type_size) -{ - if(type_size == 1) - return "byte"; - if(type_size == 2) - return "half"; - if(type_size == 4) - return "dword"; - MIOPEN_THROW("data type not supported"); -} - -static inline const std::vector& GetKernelList(std::size_t data_size) -{ - if(data_size == 1) - { - static const std::vector byte_kernel_list{ - // clang-format off - {1, 256, 1, 1, 1, 1}, - {2, 256, 1, 1, 1, 1}, - {4, 256, 1, 1, 1, 1}, - {8, 256, 1, 1, 1, 1}, - {16, 256, 1, 1, 1, 1}, - // clang-format on - }; - return byte_kernel_list; - } - if(data_size == 2) - { - static const std::vector half_kernel_list{ - // clang-format off - {1, 256, 1, 1, 1, 1}, - {2, 256, 1, 1, 1, 1}, - {4, 256, 1, 1, 1, 1}, - {8, 256, 1, 1, 1, 1}, - {16, 256, 1, 1, 1, 1}, - // clang-format on - }; - return half_kernel_list; - } - if(data_size == 4) - { - static const std::vector dword_kernel_list{ - // clang-format off - {1, 256, 1, 1, 1, 1}, - {2, 256, 1, 1, 1, 1}, - {4, 256, 1, 1, 1, 1}, - {8, 256, 1, 1, 1, 1}, - {16, 256, 1, 1, 1, 1}, - // clang-format on - }; - return dword_kernel_list; - } - MIOPEN_THROW("data type not supported"); -} - -static inline bool IsApplicable(uint32_t /* batch */, - uint32_t height, - uint32_t width, - const GeneralReorderParam* kparam) -{ - return width % kparam->ediv_x == 0 && height % kparam->ediv_y == 0; -} - -static inline bool IsSameSide(uint32_t height, uint32_t width, const GeneralReorderParam* kparam) -{ - float radio = 0; - if(width > height) - radio = static_cast(kparam->tile_x) / kparam->tile_y; - else - radio = static_cast(kparam->tile_y) / kparam->tile_x; - - // E.g. for cases like width=1000, height=10 - // Allow at least 32x64, 64x64... 16x64 not allowed - return radio >= 0.4; -} - -template -static inline float GetNormalizedRadio(T x, T y) -{ - if(y > x) - return static_cast(y) / x; - return static_cast(x) / y; -} -template -static inline std::string GetKernelName(std::size_t data_size, const GeneralReorderParam* kparam) -{ - std::ostringstream kernel_name; - std::string type_trait = GetNameTrait(data_size); - kernel_name << "general_4d_reorder_" << kparam->tile_x << "x" << kparam->tile_y << "_"; - if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) - { - kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" - << kparam->ediv_x << "x" << kparam->ediv_y << "_"; - } - kernel_name << type_trait<<"_r"<tile_y - 1) / kparam->tile_y) * kparam->tile_y; - uint32_t padded_w = ((width + kparam->tile_x - 1) / kparam->tile_x) * kparam->tile_x; - return static_cast(padded_h) * padded_w - static_cast(height) * width; -} - -static inline GeneralReorderParam -HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) -{ - /* - * TODO: - * Design a algorithm to determine general tensor reorder tile size. - */ - - if(dim_0 >= 1 && dim_1 >= 1 && dim_2 >= 1 && dim_3 >= 1 && data_size<=4) - { - if(dim_3 >= 16) - { - return GeneralReorderParam{16, 256, 1, 1, 1, 1}; - } - else if(dim_3 >= 8) - { - return GeneralReorderParam{8, 256, 1, 1, 1, 1}; - } - else if(dim_3 >= 4) - { - return GeneralReorderParam{4, 256, 1, 1, 1, 1}; - } - else if(dim_3 >= 2) - { - return GeneralReorderParam{2, 256, 1, 1, 1, 1}; - } - else - { - return GeneralReorderParam{1, 256, 1, 1, 1, 1}; - } - } - MIOPEN_THROW("data type not supported"); -} - -} // namespace tensor_reorder -template -GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_) -{ - if(data_type == miopenInt8x4 || data_type == miopenDouble) - MIOPEN_THROW("These data type are not supported"); - num_cu = ctx.GetStream().GetMaxComputeUnits(); - std::size_t data_size = miopen::GetTypeSize(data_type); - kernel_param_heuristic = tensor_reorder::HeuristicGet(data_size, dim_0, dim_1, dim_2, dim_3); -} - -template -solver::KernelInfo GeneralReorderSolution::GetKernel() const -{ - std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; -#if TENSOR_REORDER_PERSISTENT - std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; -#else - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); - std::size_t grid_size = dim_total; -#endif - std::string kernel_name = GetKernelName(); - solver::KernelInfo kernel; - kernel.kernel_file = "general_tensor_reorder.cpp"; - kernel.kernel_name = kernel_name; - kernel.g_wk.clear(); - kernel.g_wk.push_back(grid_size * block_size); - kernel.g_wk.push_back(1); - kernel.g_wk.push_back(1); - kernel.l_wk.clear(); - kernel.l_wk.push_back(block_size); - kernel.l_wk.push_back(1); - kernel.l_wk.push_back(1); - - MIOPEN_LOG_I2("GeneralReorderSolution use kernel: " + kernel_name); - - return kernel; -} - -template -std::vector GeneralReorderSolution::GetKernelArg() const -{ - std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); -#if TENSOR_REORDER_PERSISTENT - std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; -#else - std::size_t grid_size = dim_total; -#endif - - magic_div_u32_t magic_stride0 = magic_div_u32_gen(dim_1 * dim_2 * dim_3); - magic_div_u32_t magic_stride1 = magic_div_u32_gen(dim_2 * dim_3); - magic_div_u32_t magic_stride2 = magic_div_u32_gen(dim_3); - - std::vector opArgs; - opArgs.emplace_back(0); // placeholder - opArgs.emplace_back(0); // placeholder - opArgs.emplace_back(dim_0); - opArgs.emplace_back(dim_1); - opArgs.emplace_back(dim_2); - opArgs.emplace_back(dim_3); - opArgs.emplace_back(static_cast(grid_size)); - opArgs.emplace_back(dim_total); - opArgs.emplace_back(magic_stride0.magic); - opArgs.emplace_back(static_cast(magic_stride0.shift)); - opArgs.emplace_back(magic_stride1.magic); - opArgs.emplace_back(static_cast(magic_stride1.shift)); - opArgs.emplace_back(magic_stride2.magic); - opArgs.emplace_back(static_cast(magic_stride2.shift)); - - return opArgs; -} - -template -std::string GeneralReorderSolution::GetKernelName() const -{ - std::size_t data_size = miopen::GetTypeSize(data_type); - std::ostringstream kernel_name; - std::string type_trait = tensor_reorder::GetNameTrait(data_size); - kernel_name << "general_4d_reorder_" << kernel_param_heuristic.tile_x << "x" << kernel_param_heuristic.tile_y << "_"; - if(!(kernel_param_heuristic.pack_x == 1 && kernel_param_heuristic.pack_y == 1 && kernel_param_heuristic.ediv_x == 1 && kernel_param_heuristic.ediv_y == 1)) - { - kernel_name << "pack_" << kernel_param_heuristic.pack_x << "x" << kernel_param_heuristic.pack_y << "_ediv_" - << kernel_param_heuristic.ediv_x << "x" << kernel_param_heuristic.ediv_y << "_"; - } - kernel_name << type_trait<<"_r"< -bool GeneralReorderSolution::IsSkippable() const -{ - // Disable the IsSkippable funciton - return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0 ; -} - -template -size_t GeneralReorderSolution::GetSize() const -{ - return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; -} } // namespace miopen #endif From 108b80cc662f02978d6d419effe0b43892cef295 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 27 Jan 2022 23:10:41 +0800 Subject: [PATCH 39/77] fixbug --- src/hip/general_tensor_reorder_sol.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 4ce0e5d7b5..58ef8b0f1a 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -271,7 +271,7 @@ template std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); - return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); + return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); } template @@ -287,5 +287,4 @@ size_t GeneralReorderSolution::GetSize() const return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } -} // namespace miopen -*/ +} // namespace miopen \ No newline at end of file From 47d4b3d79ba06a31158ab531706ff556a2db7044 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 00:18:57 +0800 Subject: [PATCH 40/77] fix bug --- src/hip/general_tensor_reorder_sol.cpp | 2 ++ src/include/miopen/general_tensor_reorder_sol.hpp | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 58ef8b0f1a..35a050776c 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -287,4 +287,6 @@ size_t GeneralReorderSolution::GetSize() const return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } +//Explicit instance +template struct GeneralReorderSolution>; } // namespace miopen \ No newline at end of file diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 3a6cdaa7ef..6a04578008 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -84,6 +84,5 @@ struct GeneralReorderSolution GeneralReorderParam kernel_param_heuristic; }; - } // namespace miopen #endif From 74a7545449fd9dfcd9078ca1033eb7798b8e89e1 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 00:53:20 +0800 Subject: [PATCH 41/77] elimate some warnings --- src/hip/general_tensor_reorder_sol.cpp | 33 ++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 35a050776c..3ce71635ec 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -160,8 +160,8 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim * TODO: * Design a algorithm to determine general tensor reorder tile size. */ - - if(dim_3 >= 1 ) + GeneralReorderParam default_kernel; + if(data_size <= 4 && dim_0 >= 1 && dim_1 >= 1 && dim_2 >= 1 && dim_3 >= 1) { if(dim_3 >= 16) { @@ -184,6 +184,10 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim return GeneralReorderParam{1, 256, 1, 1, 1, 1}; } } + else{ + return default_kernel; + } + } } // namespace tensor_reorder @@ -288,5 +292,30 @@ size_t GeneralReorderSolution::GetSize() const } //Explicit instance +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; template struct GeneralReorderSolution>; + +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; + +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; + +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; +template struct GeneralReorderSolution>; } // namespace miopen \ No newline at end of file From 86c21afa35bdba4a08d39ff43392a70aed685d46 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 00:55:25 +0800 Subject: [PATCH 42/77] fix some warnings --- src/hip/general_tensor_reorder_sol.cpp | 40 +------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 3ce71635ec..05d1bdc508 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -99,34 +99,7 @@ static inline const std::vector& GetKernelList(std::size_t } MIOPEN_THROW("data type not supported"); } -static inline bool IsApplicable(uint32_t /* batch */, - uint32_t height, - uint32_t width, - const GeneralReorderParam* kparam) -{ - return width % kparam->ediv_x == 0 && height % kparam->ediv_y == 0; -} - -static inline bool IsSameSide(uint32_t height, uint32_t width, const GeneralReorderParam* kparam) -{ - float radio = 0; - if(width > height) - radio = static_cast(kparam->tile_x) / kparam->tile_y; - else - radio = static_cast(kparam->tile_y) / kparam->tile_x; - - // E.g. for cases like width=1000, height=10 - // Allow at least 32x64, 64x64... 16x64 not allowed - return radio >= 0.4; -} -template -static inline float GetNormalizedRadio(T x, T y) -{ - if(y > x) - return static_cast(y) / x; - return static_cast(x) / y; -} template static inline std::string GetKernelName(std::size_t data_size, const GeneralReorderParam* kparam) { @@ -142,17 +115,6 @@ static inline std::string GetKernelName(std::size_t data_size, const GeneralReor return kernel_name.str(); } -static inline std::size_t GetExtraPaddingSize(uint32_t /* batch */, - uint32_t height, - uint32_t width, - const GeneralReorderParam* kparam) -{ - // For simplicity and speed, we ignore batch, only compute h*w - uint32_t padded_h = ((height + kparam->tile_y - 1) / kparam->tile_y) * kparam->tile_y; - uint32_t padded_w = ((width + kparam->tile_x - 1) / kparam->tile_x) * kparam->tile_x; - return static_cast(padded_h) * padded_w - static_cast(height) * width; -} - static inline GeneralReorderParam HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) { @@ -318,4 +280,4 @@ template struct GeneralReorderSolution>; template struct GeneralReorderSolution>; template struct GeneralReorderSolution>; template struct GeneralReorderSolution>; -} // namespace miopen \ No newline at end of file +} // namespace miopen From 096c66149196fac4764c03cf65eae6aa31228dda Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 00:56:36 +0800 Subject: [PATCH 43/77] fix some warnings --- src/hip/general_tensor_reorder_sol.cpp | 44 -------------------------- 1 file changed, 44 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 05d1bdc508..c7225a4729 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -56,50 +56,6 @@ static inline std::string GetNameTrait(std::size_t type_size) MIOPEN_THROW("data type not supported"); } -static inline const std::vector& GetKernelList(std::size_t data_size) -{ - if(data_size == 1) - { - static const std::vector byte_kernel_list{ - // clang-format off - {1, 256, 1, 1, 1, 1}, - {2, 256, 1, 1, 1, 1}, - {4, 256, 1, 1, 1, 1}, - {8, 256, 1, 1, 1, 1}, - {16, 256, 1, 1, 1, 1}, - // clang-format on - }; - return byte_kernel_list; - } - if(data_size == 2) - { - static const std::vector half_kernel_list{ - // clang-format off - {1, 256, 1, 1, 1, 1}, - {2, 256, 1, 1, 1, 1}, - {4, 256, 1, 1, 1, 1}, - {8, 256, 1, 1, 1, 1}, - {16, 256, 1, 1, 1, 1}, - // clang-format on - }; - return half_kernel_list; - } - if(data_size == 4) - { - static const std::vector dword_kernel_list{ - // clang-format off - {1, 256, 1, 1, 1, 1}, - {2, 256, 1, 1, 1, 1}, - {4, 256, 1, 1, 1, 1}, - {8, 256, 1, 1, 1, 1}, - {16, 256, 1, 1, 1, 1}, - // clang-format on - }; - return dword_kernel_list; - } - MIOPEN_THROW("data type not supported"); -} - template static inline std::string GetKernelName(std::size_t data_size, const GeneralReorderParam* kparam) { From 32f21b0e9ee9151a9f2d38c8ea8351a7c86c7df4 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 15:00:41 +0800 Subject: [PATCH 44/77] fork should not call CI --- src/include/miopen/general_tensor_reorder_sol.hpp | 4 ++-- test/tensor_reorder.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 6a04578008..74be6414ee 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -23,8 +23,8 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef GUARD_MIOPEN_TENSOR_REORDER_SOL_HPP -#define GUARD_MIOPEN_TENSOR_REORDER_SOL_HPP +#ifndef GUARD_GENERAL_MIOPEN_TENSOR_REORDER_SOL_HPP +#define GUARD_GENERAL_MIOPEN_TENSOR_REORDER_SOL_HPP #include #include diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index d5c69aa602..aaf87e7c65 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -28,7 +28,7 @@ #include #include #include -//#include +#include #include #include #include From eff9686b690764e2e0d41e7a5bdb712722c27388 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 15:23:31 +0800 Subject: [PATCH 45/77] push & pull test on forked repo --- test/tensor_reorder.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index aaf87e7c65..ec66ced51c 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -430,6 +430,7 @@ struct reorder_test : reorder_base int main() { + //push & pull test loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { {0, 3, 2, 1}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, From c81cb86e8c4d3d1a59c517855b86817e93d73871 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 17:23:37 +0800 Subject: [PATCH 46/77] try --- src/hip/general_tensor_reorder_sol.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index c7225a4729..2bd22ed893 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -193,7 +193,7 @@ template std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); - return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); + return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); } template From 6b9f145410da31526f8ab844f764fdd81c0996e8 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 17:24:57 +0800 Subject: [PATCH 47/77] try --- src/hip/general_tensor_reorder_sol.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 2bd22ed893..ad104c7424 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -68,6 +68,7 @@ static inline std::string GetKernelName(std::size_t data_size, const GeneralReor << kparam->ediv_x << "x" << kparam->ediv_y << "_"; } kernel_name << type_trait<<"_r"< std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); - return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); + return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); } template From ff7a441f9956e515c3e41fb11e4e1eb1d670595f Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 17:25:33 +0800 Subject: [PATCH 48/77] fix typo --- src/hip/general_tensor_reorder_sol.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index ad104c7424..b22ea10495 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -68,7 +68,7 @@ static inline std::string GetKernelName(std::size_t data_size, const GeneralReor << kparam->ediv_x << "x" << kparam->ediv_y << "_"; } kernel_name << type_trait<<"_r"< Date: Fri, 28 Jan 2022 22:38:56 +0800 Subject: [PATCH 49/77] debug --- test/tensor_reorder.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index ec66ced51c..346c5bd7da 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -354,6 +354,7 @@ struct reorder_test : reorder_base ctx.SetStream(&miopen::deref(this->handle)); ctx.DetectRocm(); // ctx.SetupFloats(); + std::cout<<"check point 1"<::get(), dim_0, dim_1, dim_2, dim_3); From acd2877725cb836b17b9f5719f49e1b2577f5658 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 22:40:17 +0800 Subject: [PATCH 50/77] add debug points --- test/tensor_reorder.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 346c5bd7da..63d6980c50 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -357,9 +357,9 @@ struct reorder_test : reorder_base std::cout<<"check point 1"<::get(), dim_0, dim_1, dim_2, dim_3); - + std::cout<<"check point 2"< opArgs = reorder_sol.GetKernelArg(); - + std::cout<<"check point 3"< invoker_factory( [=](const std::vector& kernels) mutable { return [=](const miopen::Handle& handle, From 3453bba5c1229f8156009b974dc05398854e50df Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 22:41:21 +0800 Subject: [PATCH 51/77] add checkpoints --- test/tensor_reorder.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 63d6980c50..0ec0eaed69 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -375,15 +375,15 @@ struct reorder_test : reorder_base k(opArgs); }; }); - + std::cout<<"check point 4"< construction_params{reorder_sol.GetKernel()}; - + std::cout<<"check point 5"<handle).PrepareInvoker(*invoker_factory, construction_params); - + std::cout<<"check point 6"<handle), invoke_param); - + std::cout<<"check point 7"<::run(t_dst.data.data(), t_src.data.data(), dim_0, dim_1, dim_2, dim_3); From 0651536eb64aac01649b1c95b77d3c3e2befd116 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 28 Jan 2022 22:47:36 +0800 Subject: [PATCH 52/77] add check point --- src/hip/general_tensor_reorder_sol.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index b22ea10495..7d08cf41ec 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -137,7 +137,9 @@ solver::KernelInfo GeneralReorderSolution::GetKernel() const uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); std::size_t grid_size = dim_total; #endif + std::cout<<"check point before GetKernelName"< Date: Sat, 29 Jan 2022 00:52:59 +0800 Subject: [PATCH 53/77] fixbugs --- src/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index e85247b69d..5ce0eb87d8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -333,6 +333,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN ${GPU_REFERENCE_KERNEL_HIP} ${GPU_REFERENCE_KERNEL_ASM} ${GPU_BATCHED_TRANSPOSE_KERNEL_HIP} + ${GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP} kernels/detect_llvm_amdgcn_buffer_atomic_fadd_f32_float.cpp kernels/MIOpenCheckNumerics.cl kernels/MIOpenBatchNormActivBwdPerAct.cl From 3196935f4a07dfb2f6a155d67ba8cc41793b7add Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sat, 29 Jan 2022 01:34:05 +0800 Subject: [PATCH 54/77] fixbug try --- .../general_tensor_reorder.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp index 31e73e3a49..7d1c23f8cd 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp @@ -23,8 +23,10 @@ * SOFTWARE. * *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #include +#endif #include "order.hpp" #ifndef TENSOR_REORDER_OCCUPANCY From e1244fd5743068d1649a252a2407bbc5997c9e57 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sat, 29 Jan 2022 09:50:54 +0800 Subject: [PATCH 55/77] debug --- .../general_tensor_reorder.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp index 7d1c23f8cd..31e73e3a49 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp @@ -23,10 +23,8 @@ * SOFTWARE. * *******************************************************************************/ -#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #include -#endif #include "order.hpp" #ifndef TENSOR_REORDER_OCCUPANCY From 8652207577ba53bdbae8689efbe7eb572a3a9ba1 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sat, 29 Jan 2022 12:15:20 +0800 Subject: [PATCH 56/77] cmake debug --- src/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5ce0eb87d8..fec308eabc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -230,13 +230,15 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN file(GLOB_RECURSE GPU_REFERENCE_KERNEL_HIP "kernels/gpu_reference_kernel/*.cpp") file(GLOB_RECURSE GPU_REFERENCE_KERNEL_ASM "kernels/gpu_reference_kernel/*.s") file(GLOB_RECURSE GPU_BATCHED_TRANSPOSE_KERNEL_HIP "kernels/gpu_batched_transpose_kernel/*.cpp") - file(GLOB_RECURSE GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP "kernels/gpu_general_tensor_reorder_kernel/*.cpp") + file(GLOB_RECURSE GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP_INCLUDE "kernels/gpu_general_tensor_reorder_kernel/*.hpp") + file(GLOB_RECURSE GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP_SOURCE "kernels/gpu_general_tensor_reorder_kernel/*.cpp") set(MIOPEN_KERNEL_INCLUDES ${STATIC_COMPOSABLE_KERNEL_INCLUDE} ${COMPOSABLE_KERNEL_INCLUDE} ${COMPOSABLE_KERNEL_DYNAMIC_ASM_INCLUDE} + ${GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP_INCLUDE} include/miopen/implicitgemm_params.hpp kernels/Conv_Winograd_v13_3_12_fp16dot_stride1.inc kernels/Conv_Winograd_v13_3_12_fp16dot_stride2_dec.inc @@ -333,7 +335,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN ${GPU_REFERENCE_KERNEL_HIP} ${GPU_REFERENCE_KERNEL_ASM} ${GPU_BATCHED_TRANSPOSE_KERNEL_HIP} - ${GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP} + ${GPU_GENERAL_TENSOR_REORDER_KERNEL_HIP_SOURCE} kernels/detect_llvm_amdgcn_buffer_atomic_fadd_f32_float.cpp kernels/MIOpenCheckNumerics.cl kernels/MIOpenBatchNormActivBwdPerAct.cl From 6e98bfb6c02622b395cc533539e8f5e89704498f Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 8 Feb 2022 06:46:41 +0000 Subject: [PATCH 57/77] Before warning fixed --- src/hip/general_tensor_reorder_sol.cpp | 3 --- src/include/miopen/general_tensor_reorder_sol.hpp | 15 --------------- src/include/miopen/tensor_reorder_util.hpp | 2 +- .../gpu_general_tensor_reorder_kernel/order.hpp | 4 ++-- test/tensor_reorder.cpp | 14 ++------------ 5 files changed, 5 insertions(+), 33 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 7d08cf41ec..c7225a4729 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -68,7 +68,6 @@ static inline std::string GetKernelName(std::size_t data_size, const GeneralReor << kparam->ediv_x << "x" << kparam->ediv_y << "_"; } kernel_name << type_trait<<"_r"<::GetKernel() const uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); std::size_t grid_size = dim_total; #endif - std::cout<<"check point before GetKernelName"< #include <../kernels/gpu_general_tensor_reorder_kernel/order.hpp> -#include -#include -#include -#include -#include -#include -#include - -#define TENSOR_REORDER_BLOCK_SIZE 256 -#define TENSOR_REORDER_PERSISTENT 0 - -#if TENSOR_REORDER_PERSISTENT -#define TENSOR_REORDER_OCCUPANCY 4 -#endif - namespace miopen { struct GeneralReorderParam diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 9465e2a158..f7a1bf9f2d 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -46,7 +46,7 @@ struct TensorReorderSolution : public GeneralReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) + : GeneralReorderSolution::GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) { } }; diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp index 6ec61f0912..00916595c2 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp @@ -31,9 +31,9 @@ struct order { static constexpr int m_size = sizeof...(Is); - __host__ __device__ static constexpr auto size() { return m_size; } + __host__ __device__ static constexpr int size() { return m_size; } - __host__ __device__ static constexpr auto get_size() { return size(); } + __host__ __device__ static constexpr int get_size() { return size(); } __host__ __device__ static constexpr int at(int I) { diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 0ec0eaed69..9d95587a63 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -354,12 +354,9 @@ struct reorder_test : reorder_base ctx.SetStream(&miopen::deref(this->handle)); ctx.DetectRocm(); // ctx.SetupFloats(); - std::cout<<"check point 1"<::get(), dim_0, dim_1, dim_2, dim_3); - std::cout<<"check point 2"< opArgs = reorder_sol.GetKernelArg(); - std::cout<<"check point 3"< invoker_factory( [=](const std::vector& kernels) mutable { return [=](const miopen::Handle& handle, @@ -375,15 +372,11 @@ struct reorder_test : reorder_base k(opArgs); }; }); - std::cout<<"check point 4"< construction_params{reorder_sol.GetKernel()}; - std::cout<<"check point 5"<handle).PrepareInvoker(*invoker_factory, construction_params); - std::cout<<"check point 6"<handle), invoke_param); - std::cout<<"check point 7"<::run(t_dst.data.data(), t_src.data.data(), dim_0, dim_1, dim_2, dim_3); @@ -407,12 +400,10 @@ struct reorder_test : reorder_base // we expect excact match, since use integer bool valid_result = verify_tensor(t_dst_gpu, t_dst); - std::cout << "[" << reorder_str::get() << ", b" << (sizeof(T) * 8) << " ] " << "dim_0:" << dim_0 << ", dim_1:" << dim_1 << ", dim_2:" << dim_2 << ", dim_3:" << dim_3 << ", valid:" << valid_result << std::endl; - EXPECT(valid_result == true); #if MIOPEN_BACKEND_OPENCL @@ -431,10 +422,9 @@ struct reorder_test : reorder_base int main() { - //push & pull test -loop([&](auto i) { +loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { - {0, 3, 2, 1}, {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, + {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; From d97fc28b6e876c0a04e68e8d232a40db2036af38 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 8 Feb 2022 07:52:38 +0000 Subject: [PATCH 58/77] test all cases --- test/tensor_reorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 9d95587a63..49e8b28dc5 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -422,7 +422,7 @@ struct reorder_test : reorder_base int main() { -loop([&](auto i) { +loop([&](auto i) { constexpr int all_possible_sequence[23][4] = { {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, From a3aab19879ab002b726d8182784c65d94d924da4 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 10 Feb 2022 07:00:48 +0000 Subject: [PATCH 59/77] local analyze passed --- src/hip/general_tensor_reorder_sol.cpp | 70 ++--- .../miopen/general_tensor_reorder_sol.hpp | 15 +- src/include/miopen/tensor_reorder_util.hpp | 211 ++++++++++----- .../general_tensor_reorder.cpp | 254 +++++++++--------- .../order.hpp | 6 +- test/order.hpp | 46 ---- test/tensor_reorder.cpp | 137 ++++++---- 7 files changed, 385 insertions(+), 354 deletions(-) delete mode 100644 test/order.hpp diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index c7225a4729..96cea2f196 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2021 Advanced Micro Devices, Inc. + * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -25,6 +25,7 @@ *******************************************************************************/ #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include #include -#include <../kernels/gpu_general_tensor_reorder_kernel/order.hpp> #define TENSOR_REORDER_BLOCK_SIZE 256 #define TENSOR_REORDER_PERSISTENT 0 @@ -56,8 +56,12 @@ static inline std::string GetNameTrait(std::size_t type_size) MIOPEN_THROW("data type not supported"); } -template -static inline std::string GetKernelName(std::size_t data_size, const GeneralReorderParam* kparam) +static inline std::string GetKernelName(std::size_t data_size, + uint32_t order_0, + uint32_t order_1, + uint32_t order_2, + uint32_t order_3, + const GeneralReorderParam* kparam) { std::ostringstream kernel_name; std::string type_trait = GetNameTrait(data_size); @@ -67,7 +71,7 @@ static inline std::string GetKernelName(std::size_t data_size, const GeneralReor kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" << kparam->ediv_x << "x" << kparam->ediv_y << "_"; } - kernel_name << type_trait<<"_r"< -GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, +GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, uint32_t dim_2_, - uint32_t dim_3_) - : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_) + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_ ) + : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_), + order_0(order_0_), order_1(order_1_), order_2(order_2_), order_3(order_3_) { if(data_type == miopenInt8x4 || data_type == miopenDouble) MIOPEN_THROW("These data type are not supported"); @@ -125,8 +133,7 @@ GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext kernel_param_heuristic = tensor_reorder::HeuristicGet(data_size, dim_0, dim_1, dim_2, dim_3); } -template -solver::KernelInfo GeneralReorderSolution::GetKernel() const +solver::KernelInfo GeneralReorderSolution::GetKernel() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; #if TENSOR_REORDER_PERSISTENT @@ -154,8 +161,7 @@ solver::KernelInfo GeneralReorderSolution::GetKernel() const return kernel; } -template -std::vector GeneralReorderSolution::GetKernelArg() const +std::vector GeneralReorderSolution::GetKernelArg() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; @@ -189,51 +195,21 @@ std::vector GeneralReorderSolution::GetKernelArg() const return opArgs; } -template -std::string GeneralReorderSolution::GetKernelName() const +std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); - return tensor_reorder::GetKernelName(data_size, &kernel_param_heuristic); + return tensor_reorder::GetKernelName(data_size, order_0, order_1, order_2, order_3, &kernel_param_heuristic); } -template -bool GeneralReorderSolution::IsSkippable() const +bool GeneralReorderSolution::IsSkippable() const { // Disable the IsSkippable funciton return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0 ; } -template -size_t GeneralReorderSolution::GetSize() const +size_t GeneralReorderSolution::GetSize() const { return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } -//Explicit instance -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; - -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; - -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; - -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; -template struct GeneralReorderSolution>; } // namespace miopen diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 711b557f67..0514665102 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c) 2021 Advanced Micro Devices, Inc. + * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -31,7 +31,6 @@ #include #include #include -#include <../kernels/gpu_general_tensor_reorder_kernel/order.hpp> namespace miopen { @@ -45,7 +44,6 @@ struct GeneralReorderParam int ediv_y{0}; }; -template struct GeneralReorderSolution { GeneralReorderSolution(const ExecutionContext& ctx_, @@ -53,7 +51,12 @@ struct GeneralReorderSolution uint32_t dim_0_, uint32_t dim_1_, uint32_t dim_2_, - uint32_t dim_3_); + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_ ); + //TODO batched transpose API solver::KernelInfo GetKernel() const; std::vector GetKernelArg() const; std::string GetKernelName() const; @@ -65,6 +68,10 @@ struct GeneralReorderSolution uint32_t dim_1; uint32_t dim_2; uint32_t dim_3; + uint32_t order_0; + uint32_t order_1; + uint32_t order_2; + uint32_t order_3; int num_cu; GeneralReorderParam kernel_param_heuristic; diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index f7a1bf9f2d..39d3ecd39b 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (c_) 202 Advanced Micro Devices, Inc. + * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -33,93 +33,168 @@ #include #include #include -#include <../kernels/gpu_general_tensor_reorder_kernel/order.hpp> namespace miopen { +struct TensorReorderSolution{ -template -struct TensorReorderSolution : public GeneralReorderSolution + virtual ~TensorReorderSolution() = default; + virtual solver::KernelInfo GetKernel() const = 0; + virtual std::vector GetKernelArg() const = 0; + virtual std::string GetKernelName() const = 0; + virtual bool IsSkippable() const = 0; + virtual size_t GetSize() const = 0; +}; + +struct WrapperBatchedTransposeSolution_0132 : TensorReorderSolution { - TensorReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : GeneralReorderSolution::GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_) - { - } + BatchedTransposeSolution m_BatchedTransposeSolution; + WrapperBatchedTransposeSolution_0132(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) + { + } + solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} + std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} + std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} + bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} + size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} }; -template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct WrapperBatchedTransposeSolution_0231 : TensorReorderSolution { - TensorReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) - { - } + BatchedTransposeSolution m_BatchedTransposeSolution; + WrapperBatchedTransposeSolution_0231(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) + { + } + solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} + std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} + std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} + bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} + size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} }; -template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct WrapperBatchedTransposeSolution_0312 : TensorReorderSolution { - TensorReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) - { - } + BatchedTransposeSolution m_BatchedTransposeSolution; + WrapperBatchedTransposeSolution_0312(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) + { + } + solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} + std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} + std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} + bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} + size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} }; -template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct WrapperBatchedTransposeSolution_2301 : TensorReorderSolution { - TensorReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) - { - } + BatchedTransposeSolution m_BatchedTransposeSolution; + WrapperBatchedTransposeSolution_2301(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ *dim_1_, dim_2_ * dim_3_) + { + } + solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} + std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} + std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} + bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} + size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} }; -template<> -struct TensorReorderSolution> : public BatchedTransposeSolution + +struct WrapperBatchedTransposeSolution_3012 : TensorReorderSolution { - TensorReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ *dim_1_, dim_2_ * dim_3_) - { - } + BatchedTransposeSolution m_BatchedTransposeSolution; + WrapperBatchedTransposeSolution_3012(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) + { + } + solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} + std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} + std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} + bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} + size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} }; -template<> -struct TensorReorderSolution> : public BatchedTransposeSolution +struct WrapperGeneralReorderSolution : TensorReorderSolution { - TensorReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) - : BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) - { - } + GeneralReorderSolution m_GeneralReorderSolution; + WrapperGeneralReorderSolution(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) + : m_GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_, order_0_, order_1_, order_2_, order_3_) + { + } + solver::KernelInfo GetKernel() const override{ return m_GeneralReorderSolution.GetKernel();} + std::vector GetKernelArg() const override{ return m_GeneralReorderSolution.GetKernelArg();} + std::string GetKernelName() const override{ return m_GeneralReorderSolution.GetKernelName();} + bool IsSkippable() const override{ return m_GeneralReorderSolution.IsSkippable();} + size_t GetSize() const override{ return m_GeneralReorderSolution.GetSize();} }; + +__inline__ std::unique_ptr TensorReorderSolutionConstructor(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) { + //Default using general reorder + int which = 0; + if( (order_0_ == 0) && (order_1_ == 1) && (order_2_ == 3) && (order_3_ == 2) ) which = 1; + if( (order_0_ == 0) && (order_1_ == 2) && (order_2_ == 3) && (order_3_ == 1) ) which = 2; + if( (order_0_ == 0) && (order_1_ == 3) && (order_2_ == 1) && (order_3_ == 2) ) which = 3; + if( (order_0_ == 2) && (order_1_ == 3) && (order_2_ == 0) && (order_3_ == 1) ) which = 4; + if( (order_0_ == 3) && (order_1_ == 0) && (order_2_ == 1) && (order_3_ == 2) ) which = 5; + + switch (which) { + case 0: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_, + order_0_, order_1_, order_2_, order_3_); + case 1: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 2: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 3: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 4: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 5: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + default : return nullptr; + } + return nullptr; +} + } // namespace miopen #endif // MIOPEN_TENSOR_REORDER_UTIL_HPP_ diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp index 31e73e3a49..a507969d65 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp @@ -41,36 +41,36 @@ inline __device__ uint32_t magic_div_u32(const uint32_t& numer, template inline __device__ void general_4d_reorder_1x256(T* dst, - T* src, - uint32_t dim_0, - uint32_t dim_1, - uint32_t dim_2, - uint32_t dim_3, - uint32_t dim_stride, - uint32_t dim_total, - uint32_t magic_stride0, - uint32_t shift_stride0, - uint32_t magic_stride1, - uint32_t shift_stride1, - uint32_t magic_stride2, - uint32_t shift_stride2) + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t src_index =0, dst_index=0; - const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + uint32_t src_index, dst_index; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; - - uint32_t i_src[4] = {0, 0, 0, 0}; - uint32_t i_dst[4] = {0, 0, 0, 0}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { @@ -79,8 +79,8 @@ inline __device__ void general_4d_reorder_1x256(T* dst, //unroll k block thread src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; @@ -98,36 +98,36 @@ inline __device__ void general_4d_reorder_1x256(T* dst, template inline __device__ void general_4d_reorder_2x256(T* dst, - T* src, - uint32_t dim_0, - uint32_t dim_1, - uint32_t dim_2, - uint32_t dim_3, - uint32_t dim_stride, - uint32_t dim_total, - uint32_t magic_stride0, - uint32_t shift_stride0, - uint32_t magic_stride1, - uint32_t shift_stride1, - uint32_t magic_stride2, - uint32_t shift_stride2) + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t src_index =0, dst_index=0; - const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + uint32_t src_index, dst_index; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; - - uint32_t i_src[4] = {0, 0, 0, 0}; - uint32_t i_dst[4] = {0, 0, 0, 0}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { @@ -136,8 +136,8 @@ inline __device__ void general_4d_reorder_2x256(T* dst, //unroll k block thread src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; @@ -155,36 +155,36 @@ inline __device__ void general_4d_reorder_2x256(T* dst, template inline __device__ void general_4d_reorder_4x256(T* dst, - T* src, - uint32_t dim_0, - uint32_t dim_1, - uint32_t dim_2, - uint32_t dim_3, - uint32_t dim_stride, - uint32_t dim_total, - uint32_t magic_stride0, - uint32_t shift_stride0, - uint32_t magic_stride1, - uint32_t shift_stride1, - uint32_t magic_stride2, - uint32_t shift_stride2) + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t src_index =0, dst_index=0; - const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + uint32_t src_index, dst_index; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; - - uint32_t i_src[4] = {0, 0, 0, 0}; - uint32_t i_dst[4] = {0, 0, 0, 0}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { @@ -193,8 +193,8 @@ inline __device__ void general_4d_reorder_4x256(T* dst, //unroll k block thread src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; @@ -212,36 +212,36 @@ inline __device__ void general_4d_reorder_4x256(T* dst, template inline __device__ void general_4d_reorder_8x256(T* dst, - T* src, - uint32_t dim_0, - uint32_t dim_1, - uint32_t dim_2, - uint32_t dim_3, - uint32_t dim_stride, - uint32_t dim_total, - uint32_t magic_stride0, - uint32_t shift_stride0, - uint32_t magic_stride1, - uint32_t shift_stride1, - uint32_t magic_stride2, - uint32_t shift_stride2) + T* src, + uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t dim_stride, + uint32_t dim_total, + uint32_t magic_stride0, + uint32_t shift_stride0, + uint32_t magic_stride1, + uint32_t shift_stride1, + uint32_t magic_stride2, + uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t src_index =0, dst_index=0; - const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + uint32_t src_index, dst_index; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; - - uint32_t i_src[4] = {0, 0, 0, 0}; - uint32_t i_dst[4] = {0, 0, 0, 0}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { @@ -250,8 +250,8 @@ inline __device__ void general_4d_reorder_8x256(T* dst, //unroll k block thread src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; @@ -285,20 +285,20 @@ inline __device__ void general_4d_reorder_16x256(T* dst, { constexpr auto dorder = dst_order{}; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t src_index =0, dst_index=0; - const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; + uint32_t src_index, dst_index; + const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; - - uint32_t i_src[4] = {0, 0, 0, 0}; - uint32_t i_dst[4] = {0, 0, 0, 0}; + const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], + src_dim[2] * src_dim[3], + src_dim[3], + 1 }; + const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], + dst_dim[2] * dst_dim[3], + dst_dim[3], + 1 }; + + uint32_t i_src[4] = {0, 0, 0, 0}; + uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { @@ -307,8 +307,8 @@ inline __device__ void general_4d_reorder_16x256(T* dst, //unroll k block thread src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp index 00916595c2..927b4ba46f 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp @@ -29,11 +29,11 @@ template struct order { - static constexpr int m_size = sizeof...(Is); + __host__ __device__ static constexpr uint64_t m_size = sizeof...(Is); - __host__ __device__ static constexpr int size() { return m_size; } + __host__ __device__ static constexpr uint64_t size() { return m_size; } - __host__ __device__ static constexpr int get_size() { return size(); } + __host__ __device__ static constexpr uint64_t get_size() { return size(); } __host__ __device__ static constexpr int at(int I) { diff --git a/test/order.hpp b/test/order.hpp deleted file mode 100644 index 6ec61f0912..0000000000 --- a/test/order.hpp +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2020-2022 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef ORDER_HPP -#define ORDER_HPP - -template -struct order -{ - static constexpr int m_size = sizeof...(Is); - - __host__ __device__ static constexpr auto size() { return m_size; } - - __host__ __device__ static constexpr auto get_size() { return size(); } - - __host__ __device__ static constexpr int at(int I) - { - // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 - const int m_data[m_size + 1] = {Is..., 0}; - return m_data[I]; - } - -}; -#endif diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 49e8b28dc5..7a1436cd4e 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -38,7 +38,6 @@ #include "test.hpp" #include "driver.hpp" #include "random.hpp" -#include "order.hpp" template <> @@ -51,13 +50,12 @@ struct miopen_type : std::integral_constant -void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64_t dim_2, uint64_t dim_3) +template +void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64_t dim_2, uint64_t dim_3, + uint64_t order_0, uint64_t order_1, uint64_t order_2, uint64_t order_3) { - constexpr auto dorder = dst_order{}; const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; - const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t dst_dim[4] = {src_dim[order_0], src_dim[order_1], src_dim[order_2], src_dim[order_3]}; const uint64_t src_stride[4] ={src_dim[1] * src_dim[2] * src_dim[3], src_dim[2] * src_dim[3], @@ -75,10 +73,10 @@ void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64 for(itr_src_dim[1] = 0; itr_src_dim[1] < src_dim[1]; itr_src_dim[1]++){ for(itr_src_dim[2] = 0; itr_src_dim[2] < src_dim[2]; itr_src_dim[2]++){ for(itr_src_dim[3] = 0; itr_src_dim[3] < src_dim[3]; itr_src_dim[3]++){ - itr_dst_dim[0] = itr_src_dim[dorder.at(0)]; - itr_dst_dim[1] = itr_src_dim[dorder.at(1)]; - itr_dst_dim[2] = itr_src_dim[dorder.at(2)]; - itr_dst_dim[3] = itr_src_dim[dorder.at(3)]; + itr_dst_dim[0] = itr_src_dim[order_0]; + itr_dst_dim[1] = itr_src_dim[order_1]; + itr_dst_dim[2] = itr_src_dim[order_2]; + itr_dst_dim[3] = itr_src_dim[order_3]; uint64_t idx_src = itr_src_dim[0] * src_stride[0] + itr_src_dim[1] * src_stride[1] + @@ -96,23 +94,23 @@ void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64 } } -template +template struct cpu_reorder { - static void run(T* dst, T* src, uint64_t N, uint64_t C, uint64_t H, uint64_t W) + static void run(T* dst, T* src, uint64_t dim_0, uint64_t dim_1, uint64_t dim_2, uint64_t dim_3, + uint64_t order_0, uint64_t order_1, uint64_t order_2, uint64_t order_3) { - cpu_tensor_reorder(dst, src, N, C, H, W); + cpu_tensor_reorder(dst, src, dim_0, dim_1, dim_2, dim_3, order_0, order_1, order_2, order_3); } }; -template struct reorder_str { - static std::string get() { - return ("r" + std::to_string(dst_order::at(0)) - + std::to_string(dst_order::at(1)) - + std::to_string(dst_order::at(2)) - + std::to_string(dst_order::at(3)) ); + static std::string get(uint32_t order_0, uint32_t order_1, uint32_t order_2, uint32_t order_3) { + return ("r" + std::to_string(order_0) + + std::to_string(order_1) + + std::to_string(order_2) + + std::to_string(order_3) ); } }; @@ -140,6 +138,40 @@ std::string tensor_layout_to_string(tensor_layout_t layout) return layout_string; } +std::string supported_reorder_to_string(uint32_t order_0, + uint32_t order_1, + uint32_t order_2, + uint32_t order_3) +{ + std::string layout_string("N/A"); + if((order_0==0) && (order_1==1) && (order_2==3) && (order_3==2)) layout_string = "r0132"; + else if((order_0==0) && (order_1==2) && (order_2==1) && (order_3==3)) layout_string = "r0213"; + else if((order_0==0) && (order_1==2) && (order_2==3) && (order_3==1)) layout_string = "r0231"; + else if((order_0==0) && (order_1==3) && (order_2==1) && (order_3==2)) layout_string = "r0312"; + else if((order_0==0) && (order_1==3) && (order_2==2) && (order_3==1)) layout_string = "r0321"; + else if((order_0==1) && (order_1==0) && (order_2==2) && (order_3==3)) layout_string = "r1023"; + else if((order_0==1) && (order_1==0) && (order_2==3) && (order_3==2)) layout_string = "r1032"; + else if((order_0==1) && (order_1==2) && (order_2==0) && (order_3==3)) layout_string = "r1203"; + else if((order_0==1) && (order_1==2) && (order_2==3) && (order_3==0)) layout_string = "r1230"; + else if((order_0==1) && (order_1==3) && (order_2==0) && (order_3==2)) layout_string = "r1302"; + else if((order_0==1) && (order_1==3) && (order_2==2) && (order_3==0)) layout_string = "r1320"; + else if((order_0==2) && (order_1==0) && (order_2==1) && (order_3==3)) layout_string = "r2013"; + else if((order_0==2) && (order_1==0) && (order_2==3) && (order_3==1)) layout_string = "r2031"; + else if((order_0==2) && (order_1==1) && (order_2==0) && (order_3==3)) layout_string = "r2103"; + else if((order_0==2) && (order_1==1) && (order_2==3) && (order_3==0)) layout_string = "r2130"; + else if((order_0==2) && (order_1==3) && (order_2==0) && (order_3==1)) layout_string = "r2301"; + else if((order_0==2) && (order_1==3) && (order_2==1) && (order_3==0)) layout_string = "r2310"; + else if((order_0==3) && (order_1==0) && (order_2==1) && (order_3==2)) layout_string = "r3012"; + else if((order_0==3) && (order_1==0) && (order_2==2) && (order_3==1)) layout_string = "r3021"; + else if((order_0==3) && (order_1==1) && (order_2==0) && (order_3==2)) layout_string = "r3102"; + else if((order_0==3) && (order_1==1) && (order_2==2) && (order_3==0)) layout_string = "r3120"; + else if((order_0==3) && (order_1==2) && (order_2==0) && (order_3==1)) layout_string = "r3201"; + else if((order_0==3) && (order_1==2) && (order_2==1) && (order_3==0)) layout_string = "r3210"; + else + MIOPEN_THROW("Unsupported reorder layout"); + return layout_string; +} + template struct to_miopen_data_type @@ -218,21 +250,6 @@ bool verify_tensor(tensor& t_gpu, tensor& t_cpu) return valid_result; } -//compile time for_loop -namespace detail { - - template - constexpr void loop(std::integer_sequence, F&& f) { - (f(std::integral_constant{}), ...);// C++17 fold expression - } - -} - -template -constexpr void loop(F&& f) { - detail::loop(std::make_integer_sequence{}, std::forward(f)); -} - struct reorder_base { miopenHandle_t handle{}; @@ -267,15 +284,24 @@ struct reorder_base dim_1_list.push_back(gen_rand_integer() % 13 + 15); dim_0_list.push_back(gen_rand_integer() % 4 + 3); - for(uint32_t dim_3 : dim_3_list) + constexpr int all_possible_order[23][4] = { + {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, + {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, + {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, + {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; + + for(auto order : all_possible_order) { - for(uint32_t dim_2 : dim_2_list) + for(uint32_t dim_3 : dim_3_list) { - for(uint32_t dim_1 : dim_1_list) + for(uint32_t dim_2 : dim_2_list) { - for(uint32_t dim_0 : dim_0_list) + for(uint32_t dim_1 : dim_1_list) { - f(dim_0, dim_1, dim_2, dim_3); + for(uint32_t dim_0 : dim_0_list) + { + f(dim_0, dim_1, dim_2, dim_3, order[0], order[1], order[2], order[3]); + } } } } @@ -294,13 +320,13 @@ struct reorder_invoke_param : public miopen::InvokeParams { } }; -//The template parameter dst_order is just for CPU verification -template +template struct reorder_test : reorder_base { void run() { - auto run_reorder = [this](uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) { + auto run_reorder = [this](uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3, + uint32_t order_0, uint32_t order_1, uint32_t order_2, uint32_t order_3) { int tensor_sz = dim_0 * dim_1 * dim_2 * dim_3; std::vector tensor_len({static_cast(dim_0), static_cast(dim_1), @@ -311,6 +337,7 @@ struct reorder_test : reorder_base std::string layout_default = miopen::tensor_layout_get_default(4); std::string layout_string = tensor_layout_to_string(miopen_tensor_layout_nchw); + std::string reorder_string = supported_reorder_to_string(order_0, order_1, order_2, order_3); miopen::tensor_layout_to_strides( tensor_len, layout_default, layout_string, tensor_strides); @@ -354,9 +381,9 @@ struct reorder_test : reorder_base ctx.SetStream(&miopen::deref(this->handle)); ctx.DetectRocm(); // ctx.SetupFloats(); - - REORDER_SOL reorder_sol(ctx, to_miopen_data_type::get(), dim_0, dim_1, dim_2, dim_3); - std::vector opArgs = reorder_sol.GetKernelArg(); + auto reorder_sol = TensorReorderSolutionConstructor(ctx, to_miopen_data_type::get(), dim_0, dim_1, dim_2, dim_3, + order_0, order_1, order_2, order_3); + std::vector opArgs = reorder_sol->GetKernelArg(); boost::optional invoker_factory( [=](const std::vector& kernels) mutable { return [=](const miopen::Handle& handle, @@ -372,13 +399,13 @@ struct reorder_test : reorder_base k(opArgs); }; }); - std::vector construction_params{reorder_sol.GetKernel()}; + std::vector construction_params{reorder_sol->GetKernel()}; const auto invoker = miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); // run gpu invoker(miopen::deref(this->handle), invoke_param); // run cpu - cpu_reorder::run(t_dst.data.data(), t_src.data.data(), dim_0, dim_1, dim_2, dim_3); + cpu_reorder::run(t_dst.data.data(), t_src.data.data(), dim_0, dim_1, dim_2, dim_3, order_0, order_1, order_2, order_3); #if MIOPEN_BACKEND_OPENCL status = clEnqueueReadBuffer(q, @@ -400,7 +427,7 @@ struct reorder_test : reorder_base // we expect excact match, since use integer bool valid_result = verify_tensor(t_dst_gpu, t_dst); - std::cout << "[" << reorder_str::get() << ", b" << (sizeof(T) * 8) + std::cout << "[" << reorder_str::get(order_0, order_1, order_2, order_3) << ", b" << (sizeof(T) * 8) << " ] " << "dim_0:" << dim_0 << ", dim_1:" << dim_1 << ", dim_2:" << dim_2 << ", dim_3:" << dim_3 << ", valid:" << valid_result << std::endl; @@ -422,15 +449,7 @@ struct reorder_test : reorder_base int main() { -loop([&](auto i) { - constexpr int all_possible_sequence[23][4] = { - {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, - {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, - {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, - {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; - using dst_order = order; - run_test >>(); - run_test >>(); - run_test >>(); -}); + run_test>(); + run_test>(); + run_test>(); } From 35aa269fa1965b1d7beca79a58c140549237a6ed Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 10 Feb 2022 07:33:09 +0000 Subject: [PATCH 60/77] fix typo --- test/tensor_reorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 7a1436cd4e..09ba8c1474 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -144,7 +144,7 @@ std::string supported_reorder_to_string(uint32_t order_0, uint32_t order_3) { std::string layout_string("N/A"); - if((order_0==0) && (order_1==1) && (order_2==3) && (order_3==2)) layout_string = "r0132"; + if ((order_0==0) && (order_1==1) && (order_2==3) && (order_3==2)) layout_string = "r0132"; else if((order_0==0) && (order_1==2) && (order_2==1) && (order_3==3)) layout_string = "r0213"; else if((order_0==0) && (order_1==2) && (order_2==3) && (order_3==1)) layout_string = "r0231"; else if((order_0==0) && (order_1==3) && (order_2==1) && (order_3==2)) layout_string = "r0312"; From 2870b321c72412d70b1cd9784c87f56f2508a034 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 10 Feb 2022 07:43:04 +0000 Subject: [PATCH 61/77] fix typo --- test/tensor_reorder.cpp | 69 +++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 09ba8c1474..2bb412a7e8 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -144,29 +144,52 @@ std::string supported_reorder_to_string(uint32_t order_0, uint32_t order_3) { std::string layout_string("N/A"); - if ((order_0==0) && (order_1==1) && (order_2==3) && (order_3==2)) layout_string = "r0132"; - else if((order_0==0) && (order_1==2) && (order_2==1) && (order_3==3)) layout_string = "r0213"; - else if((order_0==0) && (order_1==2) && (order_2==3) && (order_3==1)) layout_string = "r0231"; - else if((order_0==0) && (order_1==3) && (order_2==1) && (order_3==2)) layout_string = "r0312"; - else if((order_0==0) && (order_1==3) && (order_2==2) && (order_3==1)) layout_string = "r0321"; - else if((order_0==1) && (order_1==0) && (order_2==2) && (order_3==3)) layout_string = "r1023"; - else if((order_0==1) && (order_1==0) && (order_2==3) && (order_3==2)) layout_string = "r1032"; - else if((order_0==1) && (order_1==2) && (order_2==0) && (order_3==3)) layout_string = "r1203"; - else if((order_0==1) && (order_1==2) && (order_2==3) && (order_3==0)) layout_string = "r1230"; - else if((order_0==1) && (order_1==3) && (order_2==0) && (order_3==2)) layout_string = "r1302"; - else if((order_0==1) && (order_1==3) && (order_2==2) && (order_3==0)) layout_string = "r1320"; - else if((order_0==2) && (order_1==0) && (order_2==1) && (order_3==3)) layout_string = "r2013"; - else if((order_0==2) && (order_1==0) && (order_2==3) && (order_3==1)) layout_string = "r2031"; - else if((order_0==2) && (order_1==1) && (order_2==0) && (order_3==3)) layout_string = "r2103"; - else if((order_0==2) && (order_1==1) && (order_2==3) && (order_3==0)) layout_string = "r2130"; - else if((order_0==2) && (order_1==3) && (order_2==0) && (order_3==1)) layout_string = "r2301"; - else if((order_0==2) && (order_1==3) && (order_2==1) && (order_3==0)) layout_string = "r2310"; - else if((order_0==3) && (order_1==0) && (order_2==1) && (order_3==2)) layout_string = "r3012"; - else if((order_0==3) && (order_1==0) && (order_2==2) && (order_3==1)) layout_string = "r3021"; - else if((order_0==3) && (order_1==1) && (order_2==0) && (order_3==2)) layout_string = "r3102"; - else if((order_0==3) && (order_1==1) && (order_2==2) && (order_3==0)) layout_string = "r3120"; - else if((order_0==3) && (order_1==2) && (order_2==0) && (order_3==1)) layout_string = "r3201"; - else if((order_0==3) && (order_1==2) && (order_2==1) && (order_3==0)) layout_string = "r3210"; + if((order_0==0) && (order_1==1) && (order_2==3) && (order_3==2)) + layout_string = "r0132"; + else if((order_0==0) && (order_1==2) && (order_2==1) && (order_3==3)) + layout_string = "r0213"; + else if((order_0==0) && (order_1==2) && (order_2==3) && (order_3==1)) + layout_string = "r0231"; + else if((order_0==0) && (order_1==3) && (order_2==1) && (order_3==2)) + layout_string = "r0312"; + else if((order_0==0) && (order_1==3) && (order_2==2) && (order_3==1)) + layout_string = "r0321"; + else if((order_0==1) && (order_1==0) && (order_2==2) && (order_3==3)) + layout_string = "r1023"; + else if((order_0==1) && (order_1==0) && (order_2==3) && (order_3==2)) + layout_string = "r1032"; + else if((order_0==1) && (order_1==2) && (order_2==0) && (order_3==3)) + layout_string = "r1203"; + else if((order_0==1) && (order_1==2) && (order_2==3) && (order_3==0)) + layout_string = "r1230"; + else if((order_0==1) && (order_1==3) && (order_2==0) && (order_3==2)) + layout_string = "r1302"; + else if((order_0==1) && (order_1==3) && (order_2==2) && (order_3==0)) + layout_string = "r1320"; + else if((order_0==2) && (order_1==0) && (order_2==1) && (order_3==3)) + layout_string = "r2013"; + else if((order_0==2) && (order_1==0) && (order_2==3) && (order_3==1)) + layout_string = "r2031"; + else if((order_0==2) && (order_1==1) && (order_2==0) && (order_3==3)) + layout_string = "r2103"; + else if((order_0==2) && (order_1==1) && (order_2==3) && (order_3==0)) + layout_string = "r2130"; + else if((order_0==2) && (order_1==3) && (order_2==0) && (order_3==1)) + layout_string = "r2301"; + else if((order_0==2) && (order_1==3) && (order_2==1) && (order_3==0)) + layout_string = "r2310"; + else if((order_0==3) && (order_1==0) && (order_2==1) && (order_3==2)) + layout_string = "r3012"; + else if((order_0==3) && (order_1==0) && (order_2==2) && (order_3==1)) + layout_string = "r3021"; + else if((order_0==3) && (order_1==1) && (order_2==0) && (order_3==2)) + layout_string = "r3102"; + else if((order_0==3) && (order_1==1) && (order_2==2) && (order_3==0)) + layout_string = "r3120"; + else if((order_0==3) && (order_1==2) && (order_2==0) && (order_3==1)) + layout_string = "r3201"; + else if((order_0==3) && (order_1==2) && (order_2==1) && (order_3==0)) + layout_string = "r3210"; else MIOPEN_THROW("Unsupported reorder layout"); return layout_string; From 24d8916ca5cf54f7935eea6f71417e4a3abd8cb6 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 10 Feb 2022 13:43:23 +0000 Subject: [PATCH 62/77] fix bug in order.hpp --- src/kernels/gpu_general_tensor_reorder_kernel/order.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp index 927b4ba46f..0ff1fbe2aa 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp @@ -29,14 +29,13 @@ template struct order { - __host__ __device__ static constexpr uint64_t m_size = sizeof...(Is); + static constexpr uint64_t m_size = sizeof...(Is); __host__ __device__ static constexpr uint64_t size() { return m_size; } __host__ __device__ static constexpr uint64_t get_size() { return size(); } - __host__ __device__ static constexpr int at(int I) - { + __host__ __device__ static constexpr int at(int I) { // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 const int m_data[m_size + 1] = {Is..., 0}; return m_data[I]; From 8d6f9954e7d688ef1273628ceed52c7b414ebce3 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 10 Feb 2022 16:32:29 +0000 Subject: [PATCH 63/77] fix bug in order.hpp to satisfy cxx11 --- src/kernels/gpu_general_tensor_reorder_kernel/order.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp index 0ff1fbe2aa..00312e9b38 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp @@ -30,14 +30,14 @@ template struct order { static constexpr uint64_t m_size = sizeof...(Is); + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + static constexpr int m_data[m_size + 1] = {Is..., 0}; __host__ __device__ static constexpr uint64_t size() { return m_size; } __host__ __device__ static constexpr uint64_t get_size() { return size(); } __host__ __device__ static constexpr int at(int I) { - // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 - const int m_data[m_size + 1] = {Is..., 0}; return m_data[I]; } From 04f48d6d2e650e057ab2a8bdf488c6644ced7b77 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 11 Feb 2022 02:10:58 +0000 Subject: [PATCH 64/77] fix format: add a new line --- .../general_tensor_reorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp index a507969d65..4cb69e5874 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp @@ -742,4 +742,4 @@ DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3021, byte, uchar, 256, TENSOR_REORDER DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3102, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3120, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3201, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) -DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) \ No newline at end of file +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) From e42f13fa1fd6fc6ef4c7d445ac1c5c465e6814aa Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 14 Feb 2022 07:56:36 +0000 Subject: [PATCH 65/77] [skip ci] Update: add double data type suppport. --- src/hip/general_tensor_reorder_sol.cpp | 81 ++- .../miopen/general_tensor_reorder_sol.hpp | 20 +- src/include/miopen/tensor_reorder_util.hpp | 232 ++++++--- .../general_tensor_reorder.cpp | 469 ++++++++++++------ .../order.hpp | 11 +- test/tensor_reorder.cpp | 215 ++++---- 6 files changed, 645 insertions(+), 383 deletions(-) diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 96cea2f196..2b004ca851 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -25,7 +25,6 @@ *******************************************************************************/ #include -#include #include #include #include @@ -36,11 +35,6 @@ #include #define TENSOR_REORDER_BLOCK_SIZE 256 -#define TENSOR_REORDER_PERSISTENT 0 - -#if TENSOR_REORDER_PERSISTENT -#define TENSOR_REORDER_OCCUPANCY 4 -#endif namespace miopen { namespace tensor_reorder { @@ -53,14 +47,16 @@ static inline std::string GetNameTrait(std::size_t type_size) return "half"; if(type_size == 4) return "dword"; + if(type_size == 8) + return "dwordx2"; MIOPEN_THROW("data type not supported"); } -static inline std::string GetKernelName(std::size_t data_size, - uint32_t order_0, - uint32_t order_1, - uint32_t order_2, - uint32_t order_3, +static inline std::string GetKernelName(std::size_t data_size, + uint32_t order_0, + uint32_t order_1, + uint32_t order_2, + uint32_t order_3, const GeneralReorderParam* kparam) { std::ostringstream kernel_name; @@ -71,7 +67,7 @@ static inline std::string GetKernelName(std::size_t data_size, kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" << kparam->ediv_x << "x" << kparam->ediv_y << "_"; } - kernel_name << type_trait<<"_r"<= 1 && dim_1 >= 1 && dim_2 >= 1 && dim_3 >= 1) + if(data_size <= 8 && dim_0 >= 1 && dim_1 >= 1 && dim_2 >= 1 && dim_3 >= 1) { if(dim_3 >= 16) { @@ -106,27 +102,34 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim return GeneralReorderParam{1, 256, 1, 1, 1, 1}; } } - else{ + else + { return default_kernel; } - } } // namespace tensor_reorder GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_ ) - : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), dim_2(dim_2_), dim_3(dim_3_), - order_0(order_0_), order_1(order_1_), order_2(order_2_), order_3(order_3_) + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) + : data_type(data_type_), + dim_0(dim_0_), + dim_1(dim_1_), + dim_2(dim_2_), + dim_3(dim_3_), + order_0(order_0_), + order_1(order_1_), + order_2(order_2_), + order_3(order_3_) { - if(data_type == miopenInt8x4 || data_type == miopenDouble) + if(data_type == miopenInt8x4) MIOPEN_THROW("These data type are not supported"); num_cu = ctx.GetStream().GetMaxComputeUnits(); std::size_t data_size = miopen::GetTypeSize(data_type); @@ -136,13 +139,11 @@ GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, solver::KernelInfo GeneralReorderSolution::GetKernel() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; -#if TENSOR_REORDER_PERSISTENT - std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; -#else - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / + (block_size * kernel_param_heuristic.tile_x); std::size_t grid_size = dim_total; -#endif + std::string kernel_name = GetKernelName(); solver::KernelInfo kernel; kernel.kernel_file = "general_tensor_reorder.cpp"; @@ -164,13 +165,10 @@ solver::KernelInfo GeneralReorderSolution::GetKernel() const std::vector GeneralReorderSolution::GetKernelArg() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; - uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / (block_size * kernel_param_heuristic.tile_x); -#if TENSOR_REORDER_PERSISTENT - std::size_t grid_size = num_cu * TENSOR_REORDER_OCCUPANCY; -#else + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t dim_total = (pixel_total + block_size * kernel_param_heuristic.tile_x - 1) / + (block_size * kernel_param_heuristic.tile_x); std::size_t grid_size = dim_total; -#endif magic_div_u32_t magic_stride0 = magic_div_u32_gen(dim_1 * dim_2 * dim_3); magic_div_u32_t magic_stride1 = magic_div_u32_gen(dim_2 * dim_3); @@ -198,13 +196,14 @@ std::vector GeneralReorderSolution::GetKernelArg() const std::string GeneralReorderSolution::GetKernelName() const { std::size_t data_size = miopen::GetTypeSize(data_type); - return tensor_reorder::GetKernelName(data_size, order_0, order_1, order_2, order_3, &kernel_param_heuristic); + return tensor_reorder::GetKernelName( + data_size, order_0, order_1, order_2, order_3, &kernel_param_heuristic); } bool GeneralReorderSolution::IsSkippable() const { // Disable the IsSkippable funciton - return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0 ; + return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0; } size_t GeneralReorderSolution::GetSize() const diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 0514665102..28ce62668d 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -47,16 +47,16 @@ struct GeneralReorderParam struct GeneralReorderSolution { GeneralReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_ ); - //TODO batched transpose API + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_); + // TODO batched transpose API solver::KernelInfo GetKernel() const; std::vector GetKernelArg() const; std::string GetKernelName() const; diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 39d3ecd39b..299801d6f2 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -35,14 +35,15 @@ #include namespace miopen { -struct TensorReorderSolution{ +struct TensorReorderSolution +{ - virtual ~TensorReorderSolution() = default; - virtual solver::KernelInfo GetKernel() const = 0; + virtual ~TensorReorderSolution() = default; + virtual solver::KernelInfo GetKernel() const = 0; virtual std::vector GetKernelArg() const = 0; - virtual std::string GetKernelName() const = 0; - virtual bool IsSkippable() const = 0; - virtual size_t GetSize() const = 0; + virtual std::string GetKernelName() const = 0; + virtual bool IsSkippable() const = 0; + virtual size_t GetSize() const = 0; }; struct WrapperBatchedTransposeSolution_0132 : TensorReorderSolution @@ -54,14 +55,20 @@ struct WrapperBatchedTransposeSolution_0132 : TensorReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) + { + } + solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + std::vector GetKernelArg() const override { + return m_BatchedTransposeSolution.GetKernelArg(); } - solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} - std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} - std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} - bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} - size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} + std::string GetKernelName() const override + { + return m_BatchedTransposeSolution.GetKernelName(); + } + bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } + size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } }; struct WrapperBatchedTransposeSolution_0231 : TensorReorderSolution @@ -73,14 +80,20 @@ struct WrapperBatchedTransposeSolution_0231 : TensorReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) + { + } + solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + std::vector GetKernelArg() const override { + return m_BatchedTransposeSolution.GetKernelArg(); } - solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} - std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} - std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} - bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} - size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} + std::string GetKernelName() const override + { + return m_BatchedTransposeSolution.GetKernelName(); + } + bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } + size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } }; struct WrapperBatchedTransposeSolution_0312 : TensorReorderSolution @@ -92,14 +105,20 @@ struct WrapperBatchedTransposeSolution_0312 : TensorReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) + { + } + solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + std::vector GetKernelArg() const override { + return m_BatchedTransposeSolution.GetKernelArg(); } - solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} - std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} - std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} - bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} - size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} + std::string GetKernelName() const override + { + return m_BatchedTransposeSolution.GetKernelName(); + } + bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } + size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } }; struct WrapperBatchedTransposeSolution_2301 : TensorReorderSolution @@ -111,17 +130,22 @@ struct WrapperBatchedTransposeSolution_2301 : TensorReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ *dim_1_, dim_2_ * dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_, dim_2_ * dim_3_) + { + } + solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + std::vector GetKernelArg() const override + { + return m_BatchedTransposeSolution.GetKernelArg(); + } + std::string GetKernelName() const override { + return m_BatchedTransposeSolution.GetKernelName(); } - solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} - std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} - std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} - bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} - size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} + bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } + size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } }; - struct WrapperBatchedTransposeSolution_3012 : TensorReorderSolution { BatchedTransposeSolution m_BatchedTransposeSolution; @@ -131,66 +155,116 @@ struct WrapperBatchedTransposeSolution_3012 : TensorReorderSolution uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) + : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) + { + } + solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + std::vector GetKernelArg() const override + { + return m_BatchedTransposeSolution.GetKernelArg(); + } + std::string GetKernelName() const override { + return m_BatchedTransposeSolution.GetKernelName(); } - solver::KernelInfo GetKernel() const override{ return m_BatchedTransposeSolution.GetKernel();} - std::vector GetKernelArg() const override{ return m_BatchedTransposeSolution.GetKernelArg();} - std::string GetKernelName() const override{ return m_BatchedTransposeSolution.GetKernelName();} - bool IsSkippable() const override{ return m_BatchedTransposeSolution.IsSkippable();} - size_t GetSize() const override{ return m_BatchedTransposeSolution.GetSize();} + bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } + size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } }; struct WrapperGeneralReorderSolution : TensorReorderSolution { GeneralReorderSolution m_GeneralReorderSolution; WrapperGeneralReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_) - : m_GeneralReorderSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_, order_0_, order_1_, order_2_, order_3_) - { - } - solver::KernelInfo GetKernel() const override{ return m_GeneralReorderSolution.GetKernel();} - std::vector GetKernelArg() const override{ return m_GeneralReorderSolution.GetKernelArg();} - std::string GetKernelName() const override{ return m_GeneralReorderSolution.GetKernelName();} - bool IsSkippable() const override{ return m_GeneralReorderSolution.IsSkippable();} - size_t GetSize() const override{ return m_GeneralReorderSolution.GetSize();} + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) + : m_GeneralReorderSolution(ctx_, + data_type_, + dim_0_, + dim_1_, + dim_2_, + dim_3_, + order_0_, + order_1_, + order_2_, + order_3_) + { + } + solver::KernelInfo GetKernel() const override { return m_GeneralReorderSolution.GetKernel(); } + std::vector GetKernelArg() const override + { + return m_GeneralReorderSolution.GetKernelArg(); + } + std::string GetKernelName() const override { return m_GeneralReorderSolution.GetKernelName(); } + bool IsSkippable() const override { return m_GeneralReorderSolution.IsSkippable(); } + size_t GetSize() const override { return m_GeneralReorderSolution.GetSize(); } }; -__inline__ std::unique_ptr TensorReorderSolutionConstructor(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_) { - //Default using general reorder +__inline__ std::unique_ptr +TensorReorderSolutionConstructor(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) +{ + // Default using general reorder int which = 0; - if( (order_0_ == 0) && (order_1_ == 1) && (order_2_ == 3) && (order_3_ == 2) ) which = 1; - if( (order_0_ == 0) && (order_1_ == 2) && (order_2_ == 3) && (order_3_ == 1) ) which = 2; - if( (order_0_ == 0) && (order_1_ == 3) && (order_2_ == 1) && (order_3_ == 2) ) which = 3; - if( (order_0_ == 2) && (order_1_ == 3) && (order_2_ == 0) && (order_3_ == 1) ) which = 4; - if( (order_0_ == 3) && (order_1_ == 0) && (order_2_ == 1) && (order_3_ == 2) ) which = 5; + if((data_type_ != miopenDouble) && (order_0_ == 0) && (order_1_ == 1) && (order_2_ == 3) && + (order_3_ == 2)) + which = 1; + if((data_type_ != miopenDouble) && (order_0_ == 0) && (order_1_ == 2) && (order_2_ == 3) && + (order_3_ == 1)) + which = 2; + if((data_type_ != miopenDouble) && (order_0_ == 0) && (order_1_ == 3) && (order_2_ == 1) && + (order_3_ == 2)) + which = 3; + if((data_type_ != miopenDouble) && (order_0_ == 2) && (order_1_ == 3) && (order_2_ == 0) && + (order_3_ == 1)) + which = 4; + if((data_type_ != miopenDouble) && (order_0_ == 3) && (order_1_ == 0) && (order_2_ == 1) && + (order_3_ == 2)) + which = 5; - switch (which) { - case 0: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_, - order_0_, order_1_, order_2_, order_3_); - case 1: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); - case 2: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); - case 3: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); - case 4: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); - case 5: return std::make_unique(ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); - default : return nullptr; + switch(which) + { + case 0: + return std::make_unique(ctx_, + data_type_, + dim_0_, + dim_1_, + dim_2_, + dim_3_, + order_0_, + order_1_, + order_2_, + order_3_); + case 1: + return std::make_unique( + ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 2: + return std::make_unique( + ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 3: + return std::make_unique( + ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 4: + return std::make_unique( + ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + case 5: + return std::make_unique( + ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); + default: return nullptr; } return nullptr; } diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp index 4cb69e5874..7bdf24e14b 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp @@ -56,40 +56,44 @@ inline __device__ void general_4d_reorder_1x256(T* dst, uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; uint32_t src_index, dst_index; const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; - const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; + const uint64_t dst_dim[4] = { + src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = { + src_dim[1] * src_dim[2] * src_dim[3], src_dim[2] * src_dim[3], src_dim[3], 1}; + const uint64_t dst_stride[4] = { + dst_dim[1] * dst_dim[2] * dst_dim[3], dst_dim[2] * dst_dim[3], dst_dim[3], 1}; uint32_t i_src[4] = {0, 0, 0, 0}; uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { - for (uint32_t k = 0; k < 1; k++) + for(uint32_t k = 0; k < 1; k++) { - //unroll k block thread - src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; - if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); - i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); - i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; - + // unroll k block thread + src_index = k * dim_total * 256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total) + { + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32( + src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = + magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], + magic_stride2, + shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - + i_src[2] * src_stride[2]; + i_dst[0] = i_src[dorder.at(0)]; i_dst[1] = i_src[dorder.at(1)]; i_dst[2] = i_src[dorder.at(2)]; i_dst[3] = i_src[dorder.at(3)]; - - dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; dst[dst_index] = src[src_index]; } } @@ -113,40 +117,44 @@ inline __device__ void general_4d_reorder_2x256(T* dst, uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; uint32_t src_index, dst_index; const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; - const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; + const uint64_t dst_dim[4] = { + src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = { + src_dim[1] * src_dim[2] * src_dim[3], src_dim[2] * src_dim[3], src_dim[3], 1}; + const uint64_t dst_stride[4] = { + dst_dim[1] * dst_dim[2] * dst_dim[3], dst_dim[2] * dst_dim[3], dst_dim[3], 1}; uint32_t i_src[4] = {0, 0, 0, 0}; uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { - for (uint32_t k = 0; k < 2; k++) + for(uint32_t k = 0; k < 2; k++) { - //unroll k block thread - src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; - if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); - i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); - i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; - + // unroll k block thread + src_index = k * dim_total * 256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total) + { + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32( + src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = + magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], + magic_stride2, + shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - + i_src[2] * src_stride[2]; + i_dst[0] = i_src[dorder.at(0)]; i_dst[1] = i_src[dorder.at(1)]; i_dst[2] = i_src[dorder.at(2)]; i_dst[3] = i_src[dorder.at(3)]; - - dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; dst[dst_index] = src[src_index]; } } @@ -170,40 +178,44 @@ inline __device__ void general_4d_reorder_4x256(T* dst, uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; uint32_t src_index, dst_index; const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; - const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; + const uint64_t dst_dim[4] = { + src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = { + src_dim[1] * src_dim[2] * src_dim[3], src_dim[2] * src_dim[3], src_dim[3], 1}; + const uint64_t dst_stride[4] = { + dst_dim[1] * dst_dim[2] * dst_dim[3], dst_dim[2] * dst_dim[3], dst_dim[3], 1}; uint32_t i_src[4] = {0, 0, 0, 0}; uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { - for (uint32_t k = 0; k < 4; k++) + for(uint32_t k = 0; k < 4; k++) { - //unroll k block thread - src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; - if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); - i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); - i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; - + // unroll k block thread + src_index = k * dim_total * 256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total) + { + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32( + src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = + magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], + magic_stride2, + shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - + i_src[2] * src_stride[2]; + i_dst[0] = i_src[dorder.at(0)]; i_dst[1] = i_src[dorder.at(1)]; i_dst[2] = i_src[dorder.at(2)]; i_dst[3] = i_src[dorder.at(3)]; - - dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; dst[dst_index] = src[src_index]; } } @@ -227,40 +239,44 @@ inline __device__ void general_4d_reorder_8x256(T* dst, uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; uint32_t src_index, dst_index; const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; - const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; + const uint64_t dst_dim[4] = { + src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = { + src_dim[1] * src_dim[2] * src_dim[3], src_dim[2] * src_dim[3], src_dim[3], 1}; + const uint64_t dst_stride[4] = { + dst_dim[1] * dst_dim[2] * dst_dim[3], dst_dim[2] * dst_dim[3], dst_dim[3], 1}; uint32_t i_src[4] = {0, 0, 0, 0}; uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { - for (uint32_t k = 0; k < 8; k++) + for(uint32_t k = 0; k < 8; k++) { - //unroll k block thread - src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; - if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); - i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); - i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; - + // unroll k block thread + src_index = k * dim_total * 256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total) + { + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32( + src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = + magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], + magic_stride2, + shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - + i_src[2] * src_stride[2]; + i_dst[0] = i_src[dorder.at(0)]; i_dst[1] = i_src[dorder.at(1)]; i_dst[2] = i_src[dorder.at(2)]; i_dst[3] = i_src[dorder.at(3)]; - - dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; dst[dst_index] = src[src_index]; } } @@ -284,103 +300,232 @@ inline __device__ void general_4d_reorder_16x256(T* dst, uint32_t shift_stride2) { constexpr auto dorder = dst_order{}; - uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; + uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; uint32_t src_index, dst_index; const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; - const uint64_t dst_dim[4] = {src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; - const uint64_t src_stride[4] = {src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; + const uint64_t dst_dim[4] = { + src_dim[dorder.at(0)], src_dim[dorder.at(1)], src_dim[dorder.at(2)], src_dim[dorder.at(3)]}; + const uint64_t src_stride[4] = { + src_dim[1] * src_dim[2] * src_dim[3], src_dim[2] * src_dim[3], src_dim[3], 1}; + const uint64_t dst_stride[4] = { + dst_dim[1] * dst_dim[2] * dst_dim[3], dst_dim[2] * dst_dim[3], dst_dim[3], 1}; uint32_t i_src[4] = {0, 0, 0, 0}; uint32_t i_dst[4] = {0, 0, 0, 0}; for(uint32_t dim_id = blockIdx.x; dim_id < dim_total; dim_id += dim_stride) { - for (uint32_t k = 0; k < 16; k++) + for(uint32_t k = 0; k < 16; k++) { - //unroll k block thread - src_index = k*dim_total*256 + dim_id * 256 + threadIdx.x; - if(src_index < pixel_total){ - i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); - i_src[1] = magic_div_u32(src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); - i_src[2] = magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], magic_stride2, shift_stride2); - i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - i_src[2] * src_stride[2]; - + // unroll k block thread + src_index = k * dim_total * 256 + dim_id * 256 + threadIdx.x; + if(src_index < pixel_total) + { + i_src[0] = magic_div_u32(src_index, magic_stride0, shift_stride0); + i_src[1] = magic_div_u32( + src_index - i_src[0] * src_stride[0], magic_stride1, shift_stride1); + i_src[2] = + magic_div_u32(src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1], + magic_stride2, + shift_stride2); + i_src[3] = src_index - i_src[0] * src_stride[0] - i_src[1] * src_stride[1] - + i_src[2] * src_stride[2]; + i_dst[0] = i_src[dorder.at(0)]; i_dst[1] = i_src[dorder.at(1)]; i_dst[2] = i_src[dorder.at(2)]; i_dst[3] = i_src[dorder.at(3)]; - - dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; + + dst_index = i_dst[0] * dst_stride[0] + i_dst[1] * dst_stride[1] + + i_dst[2] * dst_stride[2] + i_dst[3] * dst_stride[3]; dst[dst_index] = src[src_index]; } } } } -#define DEFINE_GENERAL_4D_REORDER_KERNEL( \ - tile_trait, dst_order, accept_data_type, cast_data_type, lb_threads_per_block, lb_blocks_per_cu) \ - extern "C" __global__ void __launch_bounds__(lb_threads_per_block, lb_blocks_per_cu) \ - general_4d_reorder_##tile_trait##_##accept_data_type##_##dst_order(void* dst, \ - void* src, \ - uint32_t dim_0, \ - uint32_t dim_1, \ - uint32_t dim_2, \ - uint32_t dim_3, \ - uint32_t dim_stride, \ - uint32_t dim_total, \ - uint32_t magic_stride0, \ - uint32_t shift_stride0, \ - uint32_t magic_stride1, \ - uint32_t shift_stride1, \ - uint32_t magic_stride2, \ - uint32_t shift_stride2) \ - { \ - general_4d_reorder_##tile_trait(reinterpret_cast(dst), \ - reinterpret_cast(src), \ - dim_0, \ - dim_1, \ - dim_2, \ - dim_3, \ - dim_stride, \ - dim_total, \ - magic_stride0, \ - shift_stride0, \ - magic_stride1, \ - shift_stride1, \ - magic_stride2, \ - shift_stride2); \ +#define DEFINE_GENERAL_4D_REORDER_KERNEL(tile_trait, \ + dst_order, \ + accept_data_type, \ + cast_data_type, \ + lb_threads_per_block, \ + lb_blocks_per_cu) \ + extern "C" __global__ void __launch_bounds__(lb_threads_per_block, lb_blocks_per_cu) \ + general_4d_reorder_##tile_trait##_##accept_data_type##_##dst_order(void* dst, \ + void* src, \ + uint32_t dim_0, \ + uint32_t dim_1, \ + uint32_t dim_2, \ + uint32_t dim_3, \ + uint32_t dim_stride, \ + uint32_t dim_total, \ + uint32_t magic_stride0, \ + uint32_t shift_stride0, \ + uint32_t magic_stride1, \ + uint32_t shift_stride1, \ + uint32_t magic_stride2, \ + uint32_t shift_stride2) \ + { \ + general_4d_reorder_##tile_trait( \ + reinterpret_cast(dst), \ + reinterpret_cast(src), \ + dim_0, \ + dim_1, \ + dim_2, \ + dim_3, \ + dim_stride, \ + dim_total, \ + magic_stride0, \ + shift_stride0, \ + magic_stride1, \ + shift_stride1, \ + magic_stride2, \ + shift_stride2); \ } -//default order is 0 1 2 3 -using r0132 = order<0, 1, 3, 2>; -using r0213 = order<0, 2, 1, 3>;//nhwc2nchwc -using r0231 = order<0, 2, 3, 1>;//nchw2nchwc -using r0312 = order<0, 3, 1, 2>;//nhwc2nchw -using r0321 = order<0, 3, 2, 1>; -using r1023 = order<1, 0, 2, 3>; -using r1032 = order<1, 0, 3, 2>; -using r1203 = order<1, 2, 0, 3>; -using r1230 = order<1, 2, 3, 0>; -using r1302 = order<1, 3, 0, 2>;//nchw2chwnc -using r1320 = order<1, 3, 2, 0>; -using r2013 = order<2, 0, 1, 3>; -using r2031 = order<2, 0, 3, 1>; -using r2103 = order<2, 1, 0, 3>;//nhwc2chwnc -using r2130 = order<2, 1, 3, 0>; -using r2301 = order<2, 3, 0, 1>; -using r2310 = order<2, 3, 1, 0>; -using r3012 = order<3, 0, 1, 2>; -using r3021 = order<3, 0, 2, 1>; -using r3102 = order<3, 1, 0, 2>; -using r3120 = order<3, 1, 2, 0>; -using r3201 = order<3, 2, 0, 1>; -using r3210 = order<3, 2, 1, 0>; +// default order is 0 1 2 3 +using r0132 = order<0, 1, 3, 2>; +using r0213 = order<0, 2, 1, 3>; // nhwc2nchwc +using r0231 = order<0, 2, 3, 1>; // nchw2nchwc +using r0312 = order<0, 3, 1, 2>; // nhwc2nchw +using r0321 = order<0, 3, 2, 1>; +using r1023 = order<1, 0, 2, 3>; +using r1032 = order<1, 0, 3, 2>; +using r1203 = order<1, 2, 0, 3>; +using r1230 = order<1, 2, 3, 0>; +using r1302 = order<1, 3, 0, 2>; // nchw2chwnc +using r1320 = order<1, 3, 2, 0>; +using r2013 = order<2, 0, 1, 3>; +using r2031 = order<2, 0, 3, 1>; +using r2103 = order<2, 1, 0, 3>; // nhwc2chwnc +using r2130 = order<2, 1, 3, 0>; +using r2301 = order<2, 3, 0, 1>; +using r2310 = order<2, 3, 1, 0>; +using r3012 = order<3, 0, 1, 2>; +using r3021 = order<3, 0, 2, 1>; +using r3102 = order<3, 1, 0, 2>; +using r3120 = order<3, 1, 2, 0>; +using r3201 = order<3, 2, 0, 1>; +using r3210 = order<3, 2, 1, 0>; + +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0231, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0312, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0321, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1023, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1032, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1203, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1230, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1302, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r1320, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2013, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2031, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2103, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2130, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2301, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r2310, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3012, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3021, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3102, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3120, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3201, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r3210, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0132, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0213, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0231, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0312, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r0321, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1023, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1032, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1203, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1230, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1302, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r1320, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2013, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2031, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2103, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2130, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2301, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r2310, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3012, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3021, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3102, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3120, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3201, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(2x256, r3210, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0132, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0213, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0231, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0312, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r0321, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1023, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1032, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1203, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1230, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1302, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r1320, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2013, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2031, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2103, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2130, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2301, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r2310, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3012, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3021, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3102, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3120, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3201, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(4x256, r3210, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0132, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0213, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0231, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0312, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r0321, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1023, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1032, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1203, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1230, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1302, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r1320, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2013, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2031, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2103, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2130, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2301, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r2310, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3012, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3021, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3102, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3120, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3201, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(8x256, r3210, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) + +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0132, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0213, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0231, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0312, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r0321, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1023, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1032, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1203, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1230, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1302, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r1320, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2013, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2031, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2103, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2130, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2301, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r2310, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3012, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3021, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3102, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3120, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3201, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) +DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, dwordx2, double, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, dword, float, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, dword, float, 256, TENSOR_REORDER_OCCUPANCY) @@ -502,7 +647,6 @@ DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3120, dword, float, 256, TENSOR_REORDE DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3201, dword, float, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, dword, float, 256, TENSOR_REORDER_OCCUPANCY) - DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0231, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) @@ -623,7 +767,6 @@ DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3120, half, ushort, 256, TENSOR_REORDE DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3201, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(16x256, r3210, half, ushort, 256, TENSOR_REORDER_OCCUPANCY) - DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0132, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0213, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) DEFINE_GENERAL_4D_REORDER_KERNEL(1x256, r0231, byte, uchar, 256, TENSOR_REORDER_OCCUPANCY) diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp index 00312e9b38..cc989da928 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp @@ -29,17 +29,14 @@ template struct order { - static constexpr uint64_t m_size = sizeof...(Is); - // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 - static constexpr int m_data[m_size + 1] = {Is..., 0}; + static constexpr uint64_t m_size = sizeof...(Is); + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + static constexpr int m_data[m_size + 1] = {Is..., 0}; __host__ __device__ static constexpr uint64_t size() { return m_size; } __host__ __device__ static constexpr uint64_t get_size() { return size(); } - __host__ __device__ static constexpr int at(int I) { - return m_data[I]; - } - + __host__ __device__ static constexpr int at(int I) { return m_data[I]; } }; #endif diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 2bb412a7e8..20707d3d00 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -39,7 +39,6 @@ #include "driver.hpp" #include "random.hpp" - template <> struct miopen_type : std::integral_constant { @@ -50,44 +49,51 @@ struct miopen_type : std::integral_constant -void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64_t dim_2, uint64_t dim_3, - uint64_t order_0, uint64_t order_1, uint64_t order_2, uint64_t order_3) +template +void cpu_tensor_reorder(T* dst, + T* src, + uint64_t dim_0, + uint64_t dim_1, + uint64_t dim_2, + uint64_t dim_3, + uint64_t order_0, + uint64_t order_1, + uint64_t order_2, + uint64_t order_3) { const uint64_t src_dim[4] = {dim_0, dim_1, dim_2, dim_3}; - const uint64_t dst_dim[4] = {src_dim[order_0], src_dim[order_1], src_dim[order_2], src_dim[order_3]}; - - const uint64_t src_stride[4] ={src_dim[1] * src_dim[2] * src_dim[3], - src_dim[2] * src_dim[3], - src_dim[3], - 1 }; - const uint64_t dst_stride[4] = {dst_dim[1] * dst_dim[2] * dst_dim[3], - dst_dim[2] * dst_dim[3], - dst_dim[3], - 1 }; + const uint64_t dst_dim[4] = { + src_dim[order_0], src_dim[order_1], src_dim[order_2], src_dim[order_3]}; + + const uint64_t src_stride[4] = { + src_dim[1] * src_dim[2] * src_dim[3], src_dim[2] * src_dim[3], src_dim[3], 1}; + const uint64_t dst_stride[4] = { + dst_dim[1] * dst_dim[2] * dst_dim[3], dst_dim[2] * dst_dim[3], dst_dim[3], 1}; uint64_t itr_src_dim[4] = {0, 0, 0, 0}; uint64_t itr_dst_dim[4] = {0, 0, 0, 0}; - for(itr_src_dim[0] = 0; itr_src_dim[0] < src_dim[0]; itr_src_dim[0]++){ - for(itr_src_dim[1] = 0; itr_src_dim[1] < src_dim[1]; itr_src_dim[1]++){ - for(itr_src_dim[2] = 0; itr_src_dim[2] < src_dim[2]; itr_src_dim[2]++){ - for(itr_src_dim[3] = 0; itr_src_dim[3] < src_dim[3]; itr_src_dim[3]++){ + for(itr_src_dim[0] = 0; itr_src_dim[0] < src_dim[0]; itr_src_dim[0]++) + { + for(itr_src_dim[1] = 0; itr_src_dim[1] < src_dim[1]; itr_src_dim[1]++) + { + for(itr_src_dim[2] = 0; itr_src_dim[2] < src_dim[2]; itr_src_dim[2]++) + { + for(itr_src_dim[3] = 0; itr_src_dim[3] < src_dim[3]; itr_src_dim[3]++) + { itr_dst_dim[0] = itr_src_dim[order_0]; itr_dst_dim[1] = itr_src_dim[order_1]; itr_dst_dim[2] = itr_src_dim[order_2]; itr_dst_dim[3] = itr_src_dim[order_3]; - uint64_t idx_src = itr_src_dim[0] * src_stride[0] + - itr_src_dim[1] * src_stride[1] + - itr_src_dim[2] * src_stride[2] + - itr_src_dim[3] * src_stride[3] ; - uint64_t idx_dst = itr_dst_dim[0] * dst_stride[0] + - itr_dst_dim[1] * dst_stride[1] + - itr_dst_dim[2] * dst_stride[2] + - itr_dst_dim[3] * dst_stride[3] ; - - dst[idx_dst] = src[idx_src]; + uint64_t idx_src = + itr_src_dim[0] * src_stride[0] + itr_src_dim[1] * src_stride[1] + + itr_src_dim[2] * src_stride[2] + itr_src_dim[3] * src_stride[3]; + uint64_t idx_dst = + itr_dst_dim[0] * dst_stride[0] + itr_dst_dim[1] * dst_stride[1] + + itr_dst_dim[2] * dst_stride[2] + itr_dst_dim[3] * dst_stride[3]; + + dst[idx_dst] = src[idx_src]; } } } @@ -97,21 +103,29 @@ void cpu_tensor_reorder(T * dst, T * src, uint64_t dim_0, uint64_t dim_1, uint64 template struct cpu_reorder { - static void run(T* dst, T* src, uint64_t dim_0, uint64_t dim_1, uint64_t dim_2, uint64_t dim_3, - uint64_t order_0, uint64_t order_1, uint64_t order_2, uint64_t order_3) + static void run(T* dst, + T* src, + uint64_t dim_0, + uint64_t dim_1, + uint64_t dim_2, + uint64_t dim_3, + uint64_t order_0, + uint64_t order_1, + uint64_t order_2, + uint64_t order_3) { - cpu_tensor_reorder(dst, src, dim_0, dim_1, dim_2, dim_3, order_0, order_1, order_2, order_3); + cpu_tensor_reorder( + dst, src, dim_0, dim_1, dim_2, dim_3, order_0, order_1, order_2, order_3); } }; struct reorder_str { - static std::string get(uint32_t order_0, uint32_t order_1, uint32_t order_2, uint32_t order_3) { - return ("r" + std::to_string(order_0) - + std::to_string(order_1) - + std::to_string(order_2) - + std::to_string(order_3) ); - } + static std::string get(uint32_t order_0, uint32_t order_1, uint32_t order_2, uint32_t order_3) + { + return ("r" + std::to_string(order_0) + std::to_string(order_1) + std::to_string(order_2) + + std::to_string(order_3)); + } }; enum tensor_layout_t @@ -138,69 +152,72 @@ std::string tensor_layout_to_string(tensor_layout_t layout) return layout_string; } -std::string supported_reorder_to_string(uint32_t order_0, - uint32_t order_1, - uint32_t order_2, - uint32_t order_3) +std::string +supported_reorder_to_string(uint32_t order_0, uint32_t order_1, uint32_t order_2, uint32_t order_3) { std::string layout_string("N/A"); - if((order_0==0) && (order_1==1) && (order_2==3) && (order_3==2)) + if((order_0 == 0) && (order_1 == 1) && (order_2 == 3) && (order_3 == 2)) layout_string = "r0132"; - else if((order_0==0) && (order_1==2) && (order_2==1) && (order_3==3)) + else if((order_0 == 0) && (order_1 == 2) && (order_2 == 1) && (order_3 == 3)) layout_string = "r0213"; - else if((order_0==0) && (order_1==2) && (order_2==3) && (order_3==1)) + else if((order_0 == 0) && (order_1 == 2) && (order_2 == 3) && (order_3 == 1)) layout_string = "r0231"; - else if((order_0==0) && (order_1==3) && (order_2==1) && (order_3==2)) + else if((order_0 == 0) && (order_1 == 3) && (order_2 == 1) && (order_3 == 2)) layout_string = "r0312"; - else if((order_0==0) && (order_1==3) && (order_2==2) && (order_3==1)) + else if((order_0 == 0) && (order_1 == 3) && (order_2 == 2) && (order_3 == 1)) layout_string = "r0321"; - else if((order_0==1) && (order_1==0) && (order_2==2) && (order_3==3)) + else if((order_0 == 1) && (order_1 == 0) && (order_2 == 2) && (order_3 == 3)) layout_string = "r1023"; - else if((order_0==1) && (order_1==0) && (order_2==3) && (order_3==2)) + else if((order_0 == 1) && (order_1 == 0) && (order_2 == 3) && (order_3 == 2)) layout_string = "r1032"; - else if((order_0==1) && (order_1==2) && (order_2==0) && (order_3==3)) + else if((order_0 == 1) && (order_1 == 2) && (order_2 == 0) && (order_3 == 3)) layout_string = "r1203"; - else if((order_0==1) && (order_1==2) && (order_2==3) && (order_3==0)) + else if((order_0 == 1) && (order_1 == 2) && (order_2 == 3) && (order_3 == 0)) layout_string = "r1230"; - else if((order_0==1) && (order_1==3) && (order_2==0) && (order_3==2)) + else if((order_0 == 1) && (order_1 == 3) && (order_2 == 0) && (order_3 == 2)) layout_string = "r1302"; - else if((order_0==1) && (order_1==3) && (order_2==2) && (order_3==0)) + else if((order_0 == 1) && (order_1 == 3) && (order_2 == 2) && (order_3 == 0)) layout_string = "r1320"; - else if((order_0==2) && (order_1==0) && (order_2==1) && (order_3==3)) + else if((order_0 == 2) && (order_1 == 0) && (order_2 == 1) && (order_3 == 3)) layout_string = "r2013"; - else if((order_0==2) && (order_1==0) && (order_2==3) && (order_3==1)) + else if((order_0 == 2) && (order_1 == 0) && (order_2 == 3) && (order_3 == 1)) layout_string = "r2031"; - else if((order_0==2) && (order_1==1) && (order_2==0) && (order_3==3)) + else if((order_0 == 2) && (order_1 == 1) && (order_2 == 0) && (order_3 == 3)) layout_string = "r2103"; - else if((order_0==2) && (order_1==1) && (order_2==3) && (order_3==0)) + else if((order_0 == 2) && (order_1 == 1) && (order_2 == 3) && (order_3 == 0)) layout_string = "r2130"; - else if((order_0==2) && (order_1==3) && (order_2==0) && (order_3==1)) + else if((order_0 == 2) && (order_1 == 3) && (order_2 == 0) && (order_3 == 1)) layout_string = "r2301"; - else if((order_0==2) && (order_1==3) && (order_2==1) && (order_3==0)) + else if((order_0 == 2) && (order_1 == 3) && (order_2 == 1) && (order_3 == 0)) layout_string = "r2310"; - else if((order_0==3) && (order_1==0) && (order_2==1) && (order_3==2)) + else if((order_0 == 3) && (order_1 == 0) && (order_2 == 1) && (order_3 == 2)) layout_string = "r3012"; - else if((order_0==3) && (order_1==0) && (order_2==2) && (order_3==1)) + else if((order_0 == 3) && (order_1 == 0) && (order_2 == 2) && (order_3 == 1)) layout_string = "r3021"; - else if((order_0==3) && (order_1==1) && (order_2==0) && (order_3==2)) + else if((order_0 == 3) && (order_1 == 1) && (order_2 == 0) && (order_3 == 2)) layout_string = "r3102"; - else if((order_0==3) && (order_1==1) && (order_2==2) && (order_3==0)) + else if((order_0 == 3) && (order_1 == 1) && (order_2 == 2) && (order_3 == 0)) layout_string = "r3120"; - else if((order_0==3) && (order_1==2) && (order_2==0) && (order_3==1)) + else if((order_0 == 3) && (order_1 == 2) && (order_2 == 0) && (order_3 == 1)) layout_string = "r3201"; - else if((order_0==3) && (order_1==2) && (order_2==1) && (order_3==0)) + else if((order_0 == 3) && (order_1 == 2) && (order_2 == 1) && (order_3 == 0)) layout_string = "r3210"; else MIOPEN_THROW("Unsupported reorder layout"); return layout_string; } - template struct to_miopen_data_type { }; +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenDouble; } +}; + template <> struct to_miopen_data_type { @@ -248,6 +265,12 @@ bool compare_equal(T r1, T r2) return r1 == r2; } +template <> +bool compare_equal(double r1, double r2) +{ + return miopen::float_equal(r1, r2); +} + template <> bool compare_equal(float r1, float r2) { @@ -301,17 +324,17 @@ struct reorder_base std::vector dim_2_list = get_dim_2_size(); std::vector dim_1_list = get_dim_1_size(); std::vector dim_0_list = get_dim_0_size(); - + dim_3_list.push_back(gen_rand_integer() % 13 + 29); dim_2_list.push_back(gen_rand_integer() % 13 + 29); dim_1_list.push_back(gen_rand_integer() % 13 + 15); dim_0_list.push_back(gen_rand_integer() % 4 + 3); constexpr int all_possible_order[23][4] = { - {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, - {1, 0, 2, 3}, {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, - {2, 0, 1, 3}, {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, - {3, 0, 1, 2}, {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0} }; + {0, 1, 3, 2}, {0, 2, 1, 3}, {0, 2, 3, 1}, {0, 3, 1, 2}, {0, 3, 2, 1}, {1, 0, 2, 3}, + {1, 0, 3, 2}, {1, 2, 0, 3}, {1, 2, 3, 0}, {1, 3, 0, 2}, {1, 3, 2, 0}, {2, 0, 1, 3}, + {2, 0, 3, 1}, {2, 1, 0, 3}, {2, 1, 3, 0}, {2, 3, 0, 1}, {2, 3, 1, 0}, {3, 0, 1, 2}, + {3, 0, 2, 1}, {3, 1, 0, 2}, {3, 1, 2, 0}, {3, 2, 0, 1}, {3, 2, 1, 0}}; for(auto order : all_possible_order) { @@ -348,8 +371,14 @@ struct reorder_test : reorder_base { void run() { - auto run_reorder = [this](uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3, - uint32_t order_0, uint32_t order_1, uint32_t order_2, uint32_t order_3) { + auto run_reorder = [this](uint32_t dim_0, + uint32_t dim_1, + uint32_t dim_2, + uint32_t dim_3, + uint32_t order_0, + uint32_t order_1, + uint32_t order_2, + uint32_t order_3) { int tensor_sz = dim_0 * dim_1 * dim_2 * dim_3; std::vector tensor_len({static_cast(dim_0), static_cast(dim_1), @@ -360,7 +389,8 @@ struct reorder_test : reorder_base std::string layout_default = miopen::tensor_layout_get_default(4); std::string layout_string = tensor_layout_to_string(miopen_tensor_layout_nchw); - std::string reorder_string = supported_reorder_to_string(order_0, order_1, order_2, order_3); + std::string reorder_string = + supported_reorder_to_string(order_0, order_1, order_2, order_3); miopen::tensor_layout_to_strides( tensor_len, layout_default, layout_string, tensor_strides); @@ -404,8 +434,16 @@ struct reorder_test : reorder_base ctx.SetStream(&miopen::deref(this->handle)); ctx.DetectRocm(); // ctx.SetupFloats(); - auto reorder_sol = TensorReorderSolutionConstructor(ctx, to_miopen_data_type::get(), dim_0, dim_1, dim_2, dim_3, - order_0, order_1, order_2, order_3); + auto reorder_sol = TensorReorderSolutionConstructor(ctx, + to_miopen_data_type::get(), + dim_0, + dim_1, + dim_2, + dim_3, + order_0, + order_1, + order_2, + order_3); std::vector opArgs = reorder_sol->GetKernelArg(); boost::optional invoker_factory( [=](const std::vector& kernels) mutable { @@ -428,7 +466,16 @@ struct reorder_test : reorder_base // run gpu invoker(miopen::deref(this->handle), invoke_param); // run cpu - cpu_reorder::run(t_dst.data.data(), t_src.data.data(), dim_0, dim_1, dim_2, dim_3, order_0, order_1, order_2, order_3); + cpu_reorder::run(t_dst.data.data(), + t_src.data.data(), + dim_0, + dim_1, + dim_2, + dim_3, + order_0, + order_1, + order_2, + order_3); #if MIOPEN_BACKEND_OPENCL status = clEnqueueReadBuffer(q, @@ -450,10 +497,10 @@ struct reorder_test : reorder_base // we expect excact match, since use integer bool valid_result = verify_tensor(t_dst_gpu, t_dst); - std::cout << "[" << reorder_str::get(order_0, order_1, order_2, order_3) << ", b" << (sizeof(T) * 8) - << " ] " - << "dim_0:" << dim_0 << ", dim_1:" << dim_1 << ", dim_2:" << dim_2 << ", dim_3:" << dim_3 - << ", valid:" << valid_result << std::endl; + std::cout << "[" << reorder_str::get(order_0, order_1, order_2, order_3) << ", b" + << (sizeof(T) * 8) << " ] " + << "dim_0:" << dim_0 << ", dim_1:" << dim_1 << ", dim_2:" << dim_2 + << ", dim_3:" << dim_3 << ", valid:" << valid_result << std::endl; EXPECT(valid_result == true); #if MIOPEN_BACKEND_OPENCL @@ -469,10 +516,12 @@ struct reorder_test : reorder_base } }; - int main() { - run_test>(); + run_test>(); // DOUBLE only support general + // reorder solution, do not + // support batched transpose. + run_test>(); run_test>(); - run_test>(); + run_test>(); } From a5099b099255ffe09e5d4160774bc57a74e8d182 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 14 Feb 2022 08:51:03 +0000 Subject: [PATCH 66/77] Update: add explanation comments on specific order. --- src/include/miopen/tensor_reorder_util.hpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 299801d6f2..2b106625f8 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -235,6 +235,18 @@ TensorReorderSolutionConstructor(const ExecutionContext& ctx_, if((data_type_ != miopenDouble) && (order_0_ == 3) && (order_1_ == 0) && (order_2_ == 1) && (order_3_ == 2)) which = 5; + // Order [0, 1, 3, 2], [0, 2, 3, 1], [0, 3, 1, 2], [2, 3, 0, 1], [3, 0, 1, 2] are using batched + // transpose kernel to achieve higher performance. Details as following: + // reorder to [0, 1, 3, 2] from [0, 1, 2, 3], we can fix layout index [0] and [1], transpose [2, + // 3] to [3, 2]. reorder to [0, 2, 3, 1] from [0, 1, 2, 3], we can fix layout index [0], see [2, + // 3] as an entity, then transpose [1, (2, 3)] to [(2, 3), 1]. reorder to [0, 3, 1, 2] from [0, + // 1, 2, 3], we can fix layout index [0], see [1, 2] as an entity, then transpose [(1, 2), 3)] + // to [3, (1, 2)]. reorder to [2, 3, 0, 1] from [0, 1, 2, 3], we can add a fixed layout index , + // see [0, 1] and [2, 3] as entities, then transpose [(0, 1), (2, 3)] to [(2, 3), (0, 1)]. + // reorder to [3, 0, 1, 2] from [0, 1, 2, 3], we can add a fixed layout index , see [0, 1, 2] as + // an entity, then transpose [(0, 1, 2), 3] to [3, (0, 1, 2)]. The reason we have different API + // like WrapperBatchedTransposeSolution_0132 is that we choose different fixed index and two + // dimensions which will be transposed. switch(which) { From 2c44205e727c0978057dcda80092d97b55a5e942 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 25 Mar 2022 19:07:03 +0000 Subject: [PATCH 67/77] Respond to reivew suggestions --- .gitignore | 1 + .vscode/settings.json | 72 +++++ src/conv/invokers/impl_gemm_dynamic.cpp | 12 +- src/hip/batched_transpose_sol.cpp | 7 +- src/hip/general_tensor_reorder_sol.cpp | 57 ++-- src/include/miopen/batched_transpose_sol.hpp | 4 +- .../miopen/general_tensor_reorder_sol.hpp | 28 +- src/include/miopen/tensor_reorder_util.hpp | 247 +++++++++--------- .../general_tensor_reorder.cpp | 2 + .../order.hpp | 3 +- .../conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp | 12 +- .../conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp | 12 +- .../conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp | 18 +- test/gpu_nchw_nhwc_transpose.cpp | 2 +- test/tensor_reorder.cpp | 159 ++++++----- 15 files changed, 366 insertions(+), 270 deletions(-) create mode 100644 .gitignore create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..d16386367f --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +build/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..b824eccc08 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,72 @@ +{ + "files.associations": { + "limits": "cpp", + "array": "cpp", + "atomic": "cpp", + "strstream": "cpp", + "*.tcc": "cpp", + "bitset": "cpp", + "cctype": "cpp", + "chrono": "cpp", + "clocale": "cpp", + "cmath": "cpp", + "codecvt": "cpp", + "complex": "cpp", + "condition_variable": "cpp", + "csignal": "cpp", + "cstdarg": "cpp", + "cstddef": "cpp", + "cstdint": "cpp", + "cstdio": "cpp", + "cstdlib": "cpp", + "cstring": "cpp", + "ctime": "cpp", + "cwchar": "cpp", + "cwctype": "cpp", + "deque": "cpp", + "list": "cpp", + "unordered_map": "cpp", + "unordered_set": "cpp", + "vector": "cpp", + "exception": "cpp", + "algorithm": "cpp", + "filesystem": "cpp", + "functional": "cpp", + "iterator": "cpp", + "map": "cpp", + "memory": "cpp", + "memory_resource": "cpp", + "numeric": "cpp", + "optional": "cpp", + "random": "cpp", + "ratio": "cpp", + "set": "cpp", + "string": "cpp", + "string_view": "cpp", + "system_error": "cpp", + "tuple": "cpp", + "type_traits": "cpp", + "utility": "cpp", + "fstream": "cpp", + "future": "cpp", + "initializer_list": "cpp", + "iomanip": "cpp", + "iosfwd": "cpp", + "iostream": "cpp", + "istream": "cpp", + "mutex": "cpp", + "new": "cpp", + "ostream": "cpp", + "shared_mutex": "cpp", + "sstream": "cpp", + "stdexcept": "cpp", + "streambuf": "cpp", + "thread": "cpp", + "cfenv": "cpp", + "cinttypes": "cpp", + "typeindex": "cpp", + "typeinfo": "cpp", + "valarray": "cpp", + "variant": "cpp" + } +} \ No newline at end of file diff --git a/src/conv/invokers/impl_gemm_dynamic.cpp b/src/conv/invokers/impl_gemm_dynamic.cpp index 4e81e96bb7..f094063b96 100644 --- a/src/conv/invokers/impl_gemm_dynamic.cpp +++ b/src/conv/invokers/impl_gemm_dynamic.cpp @@ -568,9 +568,9 @@ InvokerFactory MakeImplGemmDynamicForwardXdlopsNHWCInvokerFactory( if(!trans_output_skippable) opArgsTrans.emplace_back(trans_output.GetKernelArg()); - trans_input_size = trans_input_skippable ? 0 : trans_input.GetSize(); - trans_weight_size = trans_weight_skippable ? 0 : trans_weight.GetSize(); - trans_output_size = trans_output_skippable ? 0 : trans_output.GetSize(); + trans_input_size = trans_input_skippable ? 0 : trans_input.GetOutputTensorSize(); + trans_weight_size = trans_weight_skippable ? 0 : trans_weight.GetOutputTensorSize(); + trans_output_size = trans_output_skippable ? 0 : trans_output.GetOutputTensorSize(); int idx = 0; if(!trans_input_skippable) @@ -885,9 +885,9 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory( if(!trans_output_skippable) opArgsTrans.emplace_back(trans_output.GetKernelArg()); - trans_input_size = trans_input_skippable ? 0 : trans_input.GetSize(); - trans_weight_size = trans_weight_skippable ? 0 : trans_weight.GetSize(); - trans_output_size = trans_output_skippable ? 0 : trans_output.GetSize(); + trans_input_size = trans_input_skippable ? 0 : trans_input.GetOutputTensorSize(); + trans_weight_size = trans_weight_skippable ? 0 : trans_weight.GetOutputTensorSize(); + trans_output_size = trans_output_skippable ? 0 : trans_output.GetOutputTensorSize(); int idx = 0; if(!trans_input_skippable) diff --git a/src/hip/batched_transpose_sol.cpp b/src/hip/batched_transpose_sol.cpp index 51a6a99359..ca96688f15 100644 --- a/src/hip/batched_transpose_sol.cpp +++ b/src/hip/batched_transpose_sol.cpp @@ -304,7 +304,7 @@ BatchedTransposeSolution::BatchedTransposeSolution(const ExecutionContext& ctx, kernel_param_heuristic = batched_transpose::HeuristicGet(data_size, batch, height, width); } -solver::KernelInfo BatchedTransposeSolution::GetKernel() const +solver::KernelInfo BatchedTransposeSolution::GetKernelInfo() const { std::size_t block_size = BATCHED_TRANSPOSE_BLOCK_SIZE; #if BATCHED_TRANSPOSE_PERSISTENT @@ -327,7 +327,7 @@ solver::KernelInfo BatchedTransposeSolution::GetKernel() const kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - MIOPEN_LOG_I2("BatchedTransposeSolution use kernel: " + kernel_name); + MIOPEN_LOG_T(kernel_name); return kernel; } @@ -351,6 +351,7 @@ std::vector BatchedTransposeSolution::GetKernelArg() const opArgs.emplace_back(0); // placeholder opArgs.emplace_back(height); opArgs.emplace_back(width); + if(grid_size != static_cast(grid_size)) MIOPEN_THROW("Variable grid size can't be casted to uint32_t safely"); opArgs.emplace_back(static_cast(grid_size)); opArgs.emplace_back(dim_total); opArgs.emplace_back(magic_h.magic); @@ -374,7 +375,7 @@ bool BatchedTransposeSolution::IsSkippable() const return height == 1 || width == 1; } -size_t BatchedTransposeSolution::GetSize() const +size_t BatchedTransposeSolution::GetOutputTensorSize() const { return miopen::GetTypeSize(data_type) * batch * height * width; } diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 2b004ca851..75b5e17ed3 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -39,7 +39,7 @@ namespace miopen { namespace tensor_reorder { -static inline std::string GetNameTrait(std::size_t type_size) +static inline std::string GetKernelNameType(std::size_t type_size) { if(type_size == 1) return "byte"; @@ -59,47 +59,44 @@ static inline std::string GetKernelName(std::size_t data_size, uint32_t order_3, const GeneralReorderParam* kparam) { + if(kparam == nullptr) MIOPEN_THROW("Memory access fault, kparam is a nullptr"); std::ostringstream kernel_name; - std::string type_trait = GetNameTrait(data_size); kernel_name << "general_4d_reorder_" << kparam->tile_x << "x" << kparam->tile_y << "_"; if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) { kernel_name << "pack_" << kparam->pack_x << "x" << kparam->pack_y << "_ediv_" << kparam->ediv_x << "x" << kparam->ediv_y << "_"; } - kernel_name << type_trait << "_r" << order_0 << order_1 << order_2 << order_3; + kernel_name << GetKernelNameType(data_size) << "_r" << order_0 << order_1 << order_2 << order_3; return kernel_name.str(); } static inline GeneralReorderParam HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim_2, uint32_t dim_3) { - /* - * TODO: - * Design a algorithm to determine general tensor reorder tile size. - */ + ///\todo Design a algorithm to determine general tensor reorder tile size. GeneralReorderParam default_kernel; if(data_size <= 8 && dim_0 >= 1 && dim_1 >= 1 && dim_2 >= 1 && dim_3 >= 1) { if(dim_3 >= 16) { - return GeneralReorderParam{16, 256, 1, 1, 1, 1}; + return GeneralReorderParam{16, TENSOR_REORDER_BLOCK_SIZE, 1, 1, 1, 1}; } else if(dim_3 >= 8) { - return GeneralReorderParam{8, 256, 1, 1, 1, 1}; + return GeneralReorderParam{8, TENSOR_REORDER_BLOCK_SIZE, 1, 1, 1, 1}; } else if(dim_3 >= 4) { - return GeneralReorderParam{4, 256, 1, 1, 1, 1}; + return GeneralReorderParam{4, TENSOR_REORDER_BLOCK_SIZE, 1, 1, 1, 1}; } else if(dim_3 >= 2) { - return GeneralReorderParam{2, 256, 1, 1, 1, 1}; + return GeneralReorderParam{2, TENSOR_REORDER_BLOCK_SIZE, 1, 1, 1, 1}; } else { - return GeneralReorderParam{1, 256, 1, 1, 1, 1}; + return GeneralReorderParam{1, TENSOR_REORDER_BLOCK_SIZE, 1, 1, 1, 1}; } } else @@ -109,16 +106,15 @@ HeuristicGet(std::size_t data_size, uint32_t dim_0, uint32_t dim_1, uint32_t dim } } // namespace tensor_reorder -GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_) +GenericReorderSolutionImpl::GenericReorderSolutionImpl(miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) : data_type(data_type_), dim_0(dim_0_), dim_1(dim_1_), @@ -131,12 +127,11 @@ GeneralReorderSolution::GeneralReorderSolution(const ExecutionContext& ctx, { if(data_type == miopenInt8x4) MIOPEN_THROW("These data type are not supported"); - num_cu = ctx.GetStream().GetMaxComputeUnits(); std::size_t data_size = miopen::GetTypeSize(data_type); kernel_param_heuristic = tensor_reorder::HeuristicGet(data_size, dim_0, dim_1, dim_2, dim_3); } -solver::KernelInfo GeneralReorderSolution::GetKernel() const +solver::KernelInfo GenericReorderSolutionImpl::GetKernelInfo() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; @@ -157,12 +152,12 @@ solver::KernelInfo GeneralReorderSolution::GetKernel() const kernel.l_wk.push_back(1); kernel.l_wk.push_back(1); - MIOPEN_LOG_I2("GeneralReorderSolution use kernel: " + kernel_name); + MIOPEN_LOG_T(kernel_name); return kernel; } -std::vector GeneralReorderSolution::GetKernelArg() const +std::vector GenericReorderSolutionImpl::GetKernelArg() const { std::size_t block_size = TENSOR_REORDER_BLOCK_SIZE; uint32_t pixel_total = dim_0 * dim_1 * dim_2 * dim_3; @@ -181,6 +176,7 @@ std::vector GeneralReorderSolution::GetKernelArg() const opArgs.emplace_back(dim_1); opArgs.emplace_back(dim_2); opArgs.emplace_back(dim_3); + if(grid_size != static_cast(grid_size)) MIOPEN_THROW("Variable grid size can't be casted to uint32_t safely"); opArgs.emplace_back(static_cast(grid_size)); opArgs.emplace_back(dim_total); opArgs.emplace_back(magic_stride0.magic); @@ -193,20 +189,19 @@ std::vector GeneralReorderSolution::GetKernelArg() const return opArgs; } -std::string GeneralReorderSolution::GetKernelName() const +std::string GenericReorderSolutionImpl::GetKernelName() const { - std::size_t data_size = miopen::GetTypeSize(data_type); return tensor_reorder::GetKernelName( - data_size, order_0, order_1, order_2, order_3, &kernel_param_heuristic); + miopen::GetTypeSize(data_type), order_0, order_1, order_2, order_3, &kernel_param_heuristic); } -bool GeneralReorderSolution::IsSkippable() const +bool GenericReorderSolutionImpl::IsSkippable() const { // Disable the IsSkippable funciton return dim_0 == 0 || dim_1 == 0 || dim_2 == 0 || dim_3 == 0; } -size_t GeneralReorderSolution::GetSize() const +size_t GenericReorderSolutionImpl::GetOutputTensorSize() const { return miopen::GetTypeSize(data_type) * dim_0 * dim_1 * dim_2 * dim_3; } diff --git a/src/include/miopen/batched_transpose_sol.hpp b/src/include/miopen/batched_transpose_sol.hpp index c912669c63..dedbf4f73e 100644 --- a/src/include/miopen/batched_transpose_sol.hpp +++ b/src/include/miopen/batched_transpose_sol.hpp @@ -51,11 +51,11 @@ struct BatchedTransposeSolution uint32_t batch_, uint32_t height_, uint32_t width_); - solver::KernelInfo GetKernel() const; + solver::KernelInfo GetKernelInfo() const; std::vector GetKernelArg() const; std::string GetKernelName() const; bool IsSkippable() const; - size_t GetSize() const; + size_t GetOutputTensorSize() const; miopenDataType_t data_type; uint32_t batch; diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index 28ce62668d..cac41370e9 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -26,10 +26,10 @@ #ifndef GUARD_GENERAL_MIOPEN_TENSOR_REORDER_SOL_HPP #define GUARD_GENERAL_MIOPEN_TENSOR_REORDER_SOL_HPP -#include #include #include #include +#include #include namespace miopen { @@ -44,24 +44,23 @@ struct GeneralReorderParam int ediv_y{0}; }; -struct GeneralReorderSolution +struct GenericReorderSolutionImpl { - GeneralReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_); + GenericReorderSolutionImpl(miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_); // TODO batched transpose API - solver::KernelInfo GetKernel() const; + solver::KernelInfo GetKernelInfo() const; std::vector GetKernelArg() const; std::string GetKernelName() const; bool IsSkippable() const; - size_t GetSize() const; + size_t GetOutputTensorSize() const; miopenDataType_t data_type; uint32_t dim_0; @@ -72,7 +71,6 @@ struct GeneralReorderSolution uint32_t order_1; uint32_t order_2; uint32_t order_3; - int num_cu; GeneralReorderParam kernel_param_heuristic; }; diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 2b106625f8..efb23157a2 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -28,213 +28,218 @@ #include #include -#include #include #include #include #include namespace miopen { -struct TensorReorderSolution +struct TensorReorderAttributesBase { - virtual ~TensorReorderSolution() = default; - virtual solver::KernelInfo GetKernel() const = 0; + virtual ~TensorReorderAttributesBase() = default; + virtual solver::KernelInfo GetKernelInfo() const = 0; virtual std::vector GetKernelArg() const = 0; virtual std::string GetKernelName() const = 0; + // used in HOST side to check the special cases that either tensor height or width equal = 1. + // In such cases, we don't need to conduct batched transpose operation, + // since the transposed tensor layout has exactly same memory layout as before. virtual bool IsSkippable() const = 0; - virtual size_t GetSize() const = 0; + // workspace (buffer) to hold output tensor of transform Pre/Post convolution + virtual size_t GetOutputTensorSize() const = 0; }; -struct WrapperBatchedTransposeSolution_0132 : TensorReorderSolution +struct BatchedTransposeSolution_0132 : TensorReorderAttributesBase { - BatchedTransposeSolution m_BatchedTransposeSolution; - WrapperBatchedTransposeSolution_0132(const ExecutionContext& ctx_, + BatchedTransposeSolution impl; + BatchedTransposeSolution_0132(const ExecutionContext& ctx_, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) + : impl(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) { } - solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } std::vector GetKernelArg() const override { - return m_BatchedTransposeSolution.GetKernelArg(); + return impl.GetKernelArg(); } std::string GetKernelName() const override { - return m_BatchedTransposeSolution.GetKernelName(); + return impl.GetKernelName(); } - bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } - size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } + bool IsSkippable() const override { return impl.IsSkippable(); } + size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; -struct WrapperBatchedTransposeSolution_0231 : TensorReorderSolution +struct BatchedTransposeSolution_0231 : TensorReorderAttributesBase { - BatchedTransposeSolution m_BatchedTransposeSolution; - WrapperBatchedTransposeSolution_0231(const ExecutionContext& ctx_, + BatchedTransposeSolution impl; + BatchedTransposeSolution_0231(const ExecutionContext& ctx_, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) + : impl(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) { } - solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } std::vector GetKernelArg() const override { - return m_BatchedTransposeSolution.GetKernelArg(); + return impl.GetKernelArg(); } std::string GetKernelName() const override { - return m_BatchedTransposeSolution.GetKernelName(); + return impl.GetKernelName(); } - bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } - size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } + bool IsSkippable() const override { return impl.IsSkippable(); } + size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; -struct WrapperBatchedTransposeSolution_0312 : TensorReorderSolution +struct BatchedTransposeSolution_0312 : TensorReorderAttributesBase { - BatchedTransposeSolution m_BatchedTransposeSolution; - WrapperBatchedTransposeSolution_0312(const ExecutionContext& ctx_, + BatchedTransposeSolution impl; + BatchedTransposeSolution_0312(const ExecutionContext& ctx_, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) + : impl(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) { } - solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } std::vector GetKernelArg() const override { - return m_BatchedTransposeSolution.GetKernelArg(); + return impl.GetKernelArg(); } std::string GetKernelName() const override { - return m_BatchedTransposeSolution.GetKernelName(); + return impl.GetKernelName(); } - bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } - size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } + bool IsSkippable() const override { return impl.IsSkippable(); } + size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; -struct WrapperBatchedTransposeSolution_2301 : TensorReorderSolution +struct BatchedTransposeSolution_2301 : TensorReorderAttributesBase { - BatchedTransposeSolution m_BatchedTransposeSolution; - WrapperBatchedTransposeSolution_2301(const ExecutionContext& ctx_, + BatchedTransposeSolution impl; + BatchedTransposeSolution_2301(const ExecutionContext& ctx_, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_, dim_2_ * dim_3_) + : impl(ctx_, data_type_, 1, dim_0_ * dim_1_, dim_2_ * dim_3_) { } - solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } std::vector GetKernelArg() const override { - return m_BatchedTransposeSolution.GetKernelArg(); + return impl.GetKernelArg(); } std::string GetKernelName() const override { - return m_BatchedTransposeSolution.GetKernelName(); + return impl.GetKernelName(); } - bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } - size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } + bool IsSkippable() const override { return impl.IsSkippable(); } + size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; -struct WrapperBatchedTransposeSolution_3012 : TensorReorderSolution +struct BatchedTransposeSolution_3012 : TensorReorderAttributesBase { - BatchedTransposeSolution m_BatchedTransposeSolution; - WrapperBatchedTransposeSolution_3012(const ExecutionContext& ctx_, + BatchedTransposeSolution impl; + BatchedTransposeSolution_3012(const ExecutionContext& ctx_, miopenDataType_t data_type_, uint32_t dim_0_, uint32_t dim_1_, uint32_t dim_2_, uint32_t dim_3_) - : m_BatchedTransposeSolution(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) + : impl(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) { } - solver::KernelInfo GetKernel() const override { return m_BatchedTransposeSolution.GetKernel(); } + solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } std::vector GetKernelArg() const override { - return m_BatchedTransposeSolution.GetKernelArg(); + return impl.GetKernelArg(); } std::string GetKernelName() const override { - return m_BatchedTransposeSolution.GetKernelName(); + return impl.GetKernelName(); } - bool IsSkippable() const override { return m_BatchedTransposeSolution.IsSkippable(); } - size_t GetSize() const override { return m_BatchedTransposeSolution.GetSize(); } + bool IsSkippable() const override { return impl.IsSkippable(); } + size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; -struct WrapperGeneralReorderSolution : TensorReorderSolution +struct GenericReorderSolution : TensorReorderAttributesBase { - GeneralReorderSolution m_GeneralReorderSolution; - WrapperGeneralReorderSolution(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_) - : m_GeneralReorderSolution(ctx_, - data_type_, - dim_0_, - dim_1_, - dim_2_, - dim_3_, - order_0_, - order_1_, - order_2_, - order_3_) + GenericReorderSolutionImpl impl; + GenericReorderSolution(miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) + : impl( data_type_, + dim_0_, + dim_1_, + dim_2_, + dim_3_, + order_0_, + order_1_, + order_2_, + order_3_) { } - solver::KernelInfo GetKernel() const override { return m_GeneralReorderSolution.GetKernel(); } + solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } std::vector GetKernelArg() const override { - return m_GeneralReorderSolution.GetKernelArg(); + return impl.GetKernelArg(); } - std::string GetKernelName() const override { return m_GeneralReorderSolution.GetKernelName(); } - bool IsSkippable() const override { return m_GeneralReorderSolution.IsSkippable(); } - size_t GetSize() const override { return m_GeneralReorderSolution.GetSize(); } + std::string GetKernelName() const override { return impl.GetKernelName(); } + bool IsSkippable() const override { return impl.IsSkippable(); } + size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; -__inline__ std::unique_ptr -TensorReorderSolutionConstructor(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_, - uint32_t order_0_, - uint32_t order_1_, - uint32_t order_2_, - uint32_t order_3_) +inline std::unique_ptr +MakeTensorReorderAttributes(const ExecutionContext& ctx_, + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_, + uint32_t order_0_, + uint32_t order_1_, + uint32_t order_2_, + uint32_t order_3_) { + std::unique_ptr default_ptr; + if(ctx_.use_hip_kernels == false){ + return default_ptr; + } // Default using general reorder + if(data_type_ == miopenBFloat16){ + MIOPEN_THROW("Unsupported reorder data type"); + } int which = 0; - if((data_type_ != miopenDouble) && (order_0_ == 0) && (order_1_ == 1) && (order_2_ == 3) && - (order_3_ == 2)) - which = 1; - if((data_type_ != miopenDouble) && (order_0_ == 0) && (order_1_ == 2) && (order_2_ == 3) && - (order_3_ == 1)) - which = 2; - if((data_type_ != miopenDouble) && (order_0_ == 0) && (order_1_ == 3) && (order_2_ == 1) && - (order_3_ == 2)) - which = 3; - if((data_type_ != miopenDouble) && (order_0_ == 2) && (order_1_ == 3) && (order_2_ == 0) && - (order_3_ == 1)) - which = 4; - if((data_type_ != miopenDouble) && (order_0_ == 3) && (order_1_ == 0) && (order_2_ == 1) && - (order_3_ == 2)) - which = 5; + if(data_type_ != miopenDouble){ + if((order_0_ == 0) && (order_1_ == 1) && (order_2_ == 3) && (order_3_ == 2)) + which = 1; + else if((order_0_ == 0) && (order_1_ == 2) && (order_2_ == 3) && (order_3_ == 1)) + which = 2; + else if((order_0_ == 0) && (order_1_ == 3) && (order_2_ == 1) && (order_3_ == 2)) + which = 3; + else if((order_0_ == 2) && (order_1_ == 3) && (order_2_ == 0) && (order_3_ == 1)) + which = 4; + else if((order_0_ == 3) && (order_1_ == 0) && (order_2_ == 1) && (order_3_ == 2)) + which = 5; + } // Order [0, 1, 3, 2], [0, 2, 3, 1], [0, 3, 1, 2], [2, 3, 0, 1], [3, 0, 1, 2] are using batched // transpose kernel to achieve higher performance. Details as following: // reorder to [0, 1, 3, 2] from [0, 1, 2, 3], we can fix layout index [0] and [1], transpose [2, @@ -245,40 +250,40 @@ TensorReorderSolutionConstructor(const ExecutionContext& ctx_, // see [0, 1] and [2, 3] as entities, then transpose [(0, 1), (2, 3)] to [(2, 3), (0, 1)]. // reorder to [3, 0, 1, 2] from [0, 1, 2, 3], we can add a fixed layout index , see [0, 1, 2] as // an entity, then transpose [(0, 1, 2), 3] to [3, (0, 1, 2)]. The reason we have different API - // like WrapperBatchedTransposeSolution_0132 is that we choose different fixed index and two + // like BatchedTransposeSolution_0132 is that we choose different fixed index and two // dimensions which will be transposed. switch(which) { case 0: - return std::make_unique(ctx_, - data_type_, - dim_0_, - dim_1_, - dim_2_, - dim_3_, - order_0_, - order_1_, - order_2_, - order_3_); + return std::make_unique(data_type_, + dim_0_, + dim_1_, + dim_2_, + dim_3_, + order_0_, + order_1_, + order_2_, + order_3_); case 1: - return std::make_unique( + return std::make_unique( ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); case 2: - return std::make_unique( + return std::make_unique( ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); case 3: - return std::make_unique( + return std::make_unique( ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); case 4: - return std::make_unique( + return std::make_unique( ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); case 5: - return std::make_unique( + return std::make_unique( ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); - default: return nullptr; + default: + MIOPEN_THROW("Unsupported reorder sequence"); + break; } - return nullptr; } } // namespace miopen diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp index 7bdf24e14b..d16739d185 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp @@ -23,7 +23,9 @@ * SOFTWARE. * *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include +#endif #include #include "order.hpp" diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp index cc989da928..c8e80f7e7f 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/order.hpp @@ -23,13 +23,14 @@ * SOFTWARE. * *******************************************************************************/ +#include #ifndef ORDER_HPP #define ORDER_HPP template struct order { - static constexpr uint64_t m_size = sizeof...(Is); + static constexpr std::size_t m_size = sizeof...(Is); // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 static constexpr int m_data[m_size + 1] = {Is..., 0}; diff --git a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp index 0ec6eb507c..ba5630f1cd 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp @@ -954,11 +954,11 @@ ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetWorkspaceSize(const ConvolutionCo x); // group * k_per_group as batch for weight TransposeSolutionDefault2Nhwc trans_output(ctx, ctx.in_data_type, n, k, ho, wo); if(!trans_input.IsSkippable()) - size_trans_input = trans_input.GetSize(); + size_trans_input = trans_input.GetOutputTensorSize(); if(!trans_weight.IsSkippable()) - size_trans_weight = trans_weight.GetSize(); + size_trans_weight = trans_weight.GetOutputTensorSize(); if(!trans_output.IsSkippable()) - size_trans_output = trans_output.GetSize(); + size_trans_output = trans_output.GetOutputTensorSize(); } if(!ctx.IsFp32()) @@ -1055,19 +1055,19 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::GetSolution( if(!trans_input.IsSkippable()) { - result.construction_params.push_back(trans_input.GetKernel()); + result.construction_params.push_back(trans_input.GetKernelInfo()); if(miopen::IsLogging(LoggingLevel::Info2)) msg << ", inp trans:" << trans_input.GetKernelName(); } if(!trans_weight.IsSkippable()) { - result.construction_params.push_back(trans_weight.GetKernel()); + result.construction_params.push_back(trans_weight.GetKernelInfo()); if(miopen::IsLogging(LoggingLevel::Info2)) msg << ", wei trans:" << trans_weight.GetKernelName(); } if(!trans_output.IsSkippable()) { - result.construction_params.push_back(trans_output.GetKernel()); + result.construction_params.push_back(trans_output.GetKernelInfo()); if(miopen::IsLogging(LoggingLevel::Info2)) msg << ", out trans:" << trans_output.GetKernelName(); } diff --git a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp index 9eb37bc0dd..fd88dc4b45 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp @@ -796,11 +796,11 @@ ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetWorkspaceSize(const ConvolutionCo TransposeSolutionNhwc2Default trans_output(ctx, ctx.out_data_type, n, k, ho, wo); if(!trans_input.IsSkippable()) - size_trans_input = trans_input.GetSize(); + size_trans_input = trans_input.GetOutputTensorSize(); if(!trans_weight.IsSkippable()) - size_trans_weight = trans_weight.GetSize(); + size_trans_weight = trans_weight.GetOutputTensorSize(); if(!trans_output.IsSkippable()) - size_trans_output = trans_output.GetSize(); + size_trans_output = trans_output.GetOutputTensorSize(); } if(!ctx.IsFp32()) @@ -938,19 +938,19 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::GetSolution( if(!trans_input.IsSkippable()) { - result.construction_params.push_back(trans_input.GetKernel()); + result.construction_params.push_back(trans_input.GetKernelInfo()); if(miopen::IsLogging(LoggingLevel::Info2)) msg << ", inp trans:" << trans_input.GetKernelName(); } if(!trans_weight.IsSkippable()) { - result.construction_params.push_back(trans_weight.GetKernel()); + result.construction_params.push_back(trans_weight.GetKernelInfo()); if(miopen::IsLogging(LoggingLevel::Info2)) msg << ", wei trans:" << trans_weight.GetKernelName(); } if(!trans_output.IsSkippable()) { - result.construction_params.push_back(trans_output.GetKernel()); + result.construction_params.push_back(trans_output.GetKernelInfo()); if(miopen::IsLogging(LoggingLevel::Info2)) msg << ", out trans:" << trans_output.GetKernelName(); } diff --git a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp index c79372b9f7..a0b3298aee 100644 --- a/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp +++ b/src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp @@ -917,11 +917,11 @@ ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetWorkspaceSize(const ConvolutionCo x); // group * k_per_group as batch for weight TransposeSolutionDefault2Nhwc trans_output(ctx, ctx.in_data_type, n, k, ho, wo); if(!trans_input.IsSkippable()) - size_trans_input = trans_input.GetSize(); + size_trans_input = trans_input.GetOutputTensorSize(); if(!trans_weight.IsSkippable()) - size_trans_weight = trans_weight.GetSize(); + size_trans_weight = trans_weight.GetOutputTensorSize(); if(!trans_output.IsSkippable()) - size_trans_output = trans_output.GetSize(); + size_trans_output = trans_output.GetOutputTensorSize(); } @@ -1065,27 +1065,27 @@ ConvSolution ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::GetSolution( trans_output_skippable = trans_output.IsSkippable(); if(!trans_input_skippable){ - result.construction_params.push_back(trans_input.GetKernel()); + result.construction_params.push_back(trans_input.GetKernelInfo()); opArgsTrans.emplace_back(trans_input.GetKernelArg()); if(miopen::IsLogging(LoggingLevel::Info2)) msg << ", inp trans:"< construction_params{transpose_sol.GetKernel()}; + std::vector construction_params{transpose_sol.GetKernelInfo()}; const auto invoker = miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 20707d3d00..ce335760b1 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -49,6 +49,8 @@ struct miopen_type : std::integral_constant void cpu_tensor_reorder(T* dst, T* src, @@ -225,17 +227,23 @@ struct to_miopen_data_type }; template <> -struct to_miopen_data_type +struct to_miopen_data_type { static miopenDataType_t get() { return miopenHalf; } // we actually didn't calculate 16bit float }; template <> -struct to_miopen_data_type +struct to_miopen_data_type { static miopenDataType_t get() { return miopenInt8; } }; +template <> +struct to_miopen_data_type +{ + static miopenDataType_t get() { return miopenBFloat16; } +}; + #define RAND_INTEGER_MAX 120 #define RAND_INTEGER_MIN -88 @@ -280,11 +288,7 @@ bool compare_equal(float r1, float r2) template bool verify_tensor(tensor& t_gpu, tensor& t_cpu) { - if(t_gpu.data.size() != t_cpu.data.size()) - { - MIOPEN_LOG_E("size not equal, should not happen"); - return false; - } + EXPECT(t_gpu.data.size() == t_cpu.data.size()); auto idx = miopen::mismatch_idx(t_gpu.data, t_cpu.data, compare_equal); bool valid_result = idx >= miopen::range_distance(t_cpu); @@ -296,21 +300,21 @@ bool verify_tensor(tensor& t_gpu, tensor& t_cpu) return valid_result; } -struct reorder_base +struct tensor_reorder_base_driver : test_driver { miopenHandle_t handle{}; #if MIOPEN_BACKEND_OPENCL cl_command_queue q{}; #endif - reorder_base() + tensor_reorder_base_driver() { miopenCreate(&handle); #if MIOPEN_BACKEND_OPENCL miopenGetStream(handle, &q); #endif } - ~reorder_base() { miopenDestroy(handle); } + ~tensor_reorder_base_driver() { miopenDestroy(handle); } static std::vector get_dim_3_size() { return {1, 9, 14}; } static std::vector get_dim_2_size() { return {1, 9, 14}; } @@ -366,8 +370,8 @@ struct reorder_invoke_param : public miopen::InvokeParams { } }; -template -struct reorder_test : reorder_base +template +struct tensor_reorder_driver : tensor_reorder_base_driver { void run() { @@ -405,8 +409,6 @@ struct reorder_test : reorder_base cl_int status = CL_SUCCESS; cl_mem src_dev = clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, &status); - cl_mem dst_dev = - clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, nullptr); status |= clEnqueueWriteBuffer(q, src_dev, CL_TRUE, @@ -419,52 +421,79 @@ struct reorder_test : reorder_base EXPECT(status == CL_SUCCESS); #elif MIOPEN_BACKEND_HIP void* src_dev; - void* dst_dev; EXPECT(hipMalloc(&src_dev, sizeof(T) * tensor_sz) == hipSuccess); - EXPECT(hipMalloc(&dst_dev, sizeof(T) * tensor_sz) == hipSuccess); EXPECT(hipMemcpy( src_dev, t_src.data.data(), sizeof(T) * tensor_sz, hipMemcpyHostToDevice) == hipSuccess); #endif - - const auto invoke_param = reorder_invoke_param{ - DataCast(static_cast(src_dev)), DataCast(dst_dev)}; - miopen::ExecutionContext ctx; ctx.SetStream(&miopen::deref(this->handle)); ctx.DetectRocm(); // ctx.SetupFloats(); - auto reorder_sol = TensorReorderSolutionConstructor(ctx, - to_miopen_data_type::get(), - dim_0, - dim_1, - dim_2, - dim_3, - order_0, - order_1, - order_2, - order_3); - std::vector opArgs = reorder_sol->GetKernelArg(); - boost::optional invoker_factory( - [=](const std::vector& kernels) mutable { - return [=](const miopen::Handle& handle, - const miopen::AnyInvokeParams& primitive_param) mutable { - decltype(auto) invoke_params = - primitive_param.CastTo(); - - const auto k = handle.Run(kernels[0]); - - opArgs[0] = OpKernelArg(invoke_params.dst); - opArgs[1] = OpKernelArg(invoke_params.src); - - k(opArgs); - }; - }); - std::vector construction_params{reorder_sol->GetKernel()}; - const auto invoker = - miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); - // run gpu - invoker(miopen::deref(this->handle), invoke_param); + auto reorder_sol = MakeTensorReorderAttributes(ctx, + to_miopen_data_type::get(), + dim_0, + dim_1, + dim_2, + dim_3, + order_0, + order_1, + order_2, + order_3); + EXPECT(reorder_sol != nullptr); + size_t workspace = reorder_sol->GetOutputTensorSize(); + if(!reorder_sol->IsSkippable()){ +#if MIOPEN_BACKEND_OPENCL + cl_mem dst_dev = + clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, workspace, nullptr, nullptr); +#elif MIOPEN_BACKEND_HIP + void* dst_dev; + EXPECT(hipMalloc(&dst_dev, workspace) == hipSuccess); +#endif + const auto invoke_param = reorder_invoke_param{ + DataCast(static_cast(src_dev)), DataCast(dst_dev)}; + + std::vector opArgs = reorder_sol->GetKernelArg(); + boost::optional invoker_factory( + [=](const std::vector& kernels) mutable { + return [=](const miopen::Handle& handle, + const miopen::AnyInvokeParams& primitive_param) mutable { + decltype(auto) invoke_params = + primitive_param.CastTo(); + + const auto k = handle.Run(kernels[0]); + + opArgs[0] = OpKernelArg(invoke_params.dst); + opArgs[1] = OpKernelArg(invoke_params.src); + + k(opArgs); + }; + }); + std::vector construction_params{reorder_sol->GetKernelInfo()}; + const auto invoker = + miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); + // run gpu + invoker(miopen::deref(this->handle), invoke_param); +#if MIOPEN_BACKEND_OPENCL + status = clEnqueueReadBuffer(q, + dst_dev, + CL_TRUE, + 0, + workspace, + t_dst_gpu.data.data(), + 0, + nullptr, + nullptr); + EXPECT(status == CL_SUCCESS); + clReleaseMemObject(dst_dev); +#elif MIOPEN_BACKEND_HIP + EXPECT(hipMemcpy(t_dst_gpu.data.data(), + dst_dev, + workspace, + hipMemcpyDeviceToHost) == hipSuccess); + hipFree(dst_dev); +#endif + } // run cpu cpu_reorder::run(t_dst.data.data(), t_src.data.data(), @@ -478,6 +507,7 @@ struct reorder_test : reorder_base order_3); #if MIOPEN_BACKEND_OPENCL + if(reorder_sol->IsSkippable()){ status = clEnqueueReadBuffer(q, dst_dev, CL_TRUE, @@ -488,11 +518,16 @@ struct reorder_test : reorder_base nullptr, nullptr); EXPECT(status == CL_SUCCESS); + clReleaseMemObject(src_dev); + } #elif MIOPEN_BACKEND_HIP + if(reorder_sol->IsSkippable()){ EXPECT(hipMemcpy(t_dst_gpu.data.data(), - dst_dev, + src_dev, sizeof(T) * tensor_sz, hipMemcpyDeviceToHost) == hipSuccess); + hipFree(src_dev); + } #endif // we expect excact match, since use integer @@ -502,26 +537,12 @@ struct reorder_test : reorder_base << "dim_0:" << dim_0 << ", dim_1:" << dim_1 << ", dim_2:" << dim_2 << ", dim_3:" << dim_3 << ", valid:" << valid_result << std::endl; EXPECT(valid_result == true); - -#if MIOPEN_BACKEND_OPENCL - clReleaseMemObject(src_dev); - clReleaseMemObject(dst_dev); -#elif MIOPEN_BACKEND_HIP - hipFree(src_dev); - hipFree(dst_dev); -#endif }; iterate_reorder(run_reorder); } }; -int main() -{ - run_test>(); // DOUBLE only support general - // reorder solution, do not - // support batched transpose. - run_test>(); - run_test>(); - run_test>(); -} + + +int main(int argc, const char* argv[]){test_drive(argc, argv);} From 209cd1cc6102439e86feda1569b07efd7bcb2627 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Sat, 26 Mar 2022 14:44:53 +0800 Subject: [PATCH 68/77] Delete .gitignore --- .gitignore | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index e62317aacc..0000000000 --- a/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -build/ -.vscode/ -.gitignore \ No newline at end of file From 4dd545e38101cfe80ba145db23bcc071d1e36038 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Sat, 26 Mar 2022 14:45:51 +0800 Subject: [PATCH 69/77] Delete settings.json --- .vscode/settings.json | 72 ------------------------------------------- 1 file changed, 72 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index b824eccc08..0000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,72 +0,0 @@ -{ - "files.associations": { - "limits": "cpp", - "array": "cpp", - "atomic": "cpp", - "strstream": "cpp", - "*.tcc": "cpp", - "bitset": "cpp", - "cctype": "cpp", - "chrono": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "codecvt": "cpp", - "complex": "cpp", - "condition_variable": "cpp", - "csignal": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdint": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "deque": "cpp", - "list": "cpp", - "unordered_map": "cpp", - "unordered_set": "cpp", - "vector": "cpp", - "exception": "cpp", - "algorithm": "cpp", - "filesystem": "cpp", - "functional": "cpp", - "iterator": "cpp", - "map": "cpp", - "memory": "cpp", - "memory_resource": "cpp", - "numeric": "cpp", - "optional": "cpp", - "random": "cpp", - "ratio": "cpp", - "set": "cpp", - "string": "cpp", - "string_view": "cpp", - "system_error": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "fstream": "cpp", - "future": "cpp", - "initializer_list": "cpp", - "iomanip": "cpp", - "iosfwd": "cpp", - "iostream": "cpp", - "istream": "cpp", - "mutex": "cpp", - "new": "cpp", - "ostream": "cpp", - "shared_mutex": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "streambuf": "cpp", - "thread": "cpp", - "cfenv": "cpp", - "cinttypes": "cpp", - "typeindex": "cpp", - "typeinfo": "cpp", - "valarray": "cpp", - "variant": "cpp" - } -} \ No newline at end of file From ffb5a107180761c7e8d79cbc19581a89bd467664 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sat, 26 Mar 2022 07:04:05 +0000 Subject: [PATCH 70/77] clang-format check --- .gitignore | 4 +- src/hip/batched_transpose_sol.cpp | 3 +- src/hip/general_tensor_reorder_sol.cpp | 14 +- src/include/miopen/tensor_reorder_util.hpp | 149 ++++++------------ .../general_tensor_reorder.cpp | 2 +- test/tensor_reorder.cpp | 84 +++++----- 6 files changed, 105 insertions(+), 151 deletions(-) diff --git a/.gitignore b/.gitignore index d16386367f..e62317aacc 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ -build/ \ No newline at end of file +build/ +.vscode/ +.gitignore \ No newline at end of file diff --git a/src/hip/batched_transpose_sol.cpp b/src/hip/batched_transpose_sol.cpp index ca96688f15..87e0bd9c8c 100644 --- a/src/hip/batched_transpose_sol.cpp +++ b/src/hip/batched_transpose_sol.cpp @@ -351,7 +351,8 @@ std::vector BatchedTransposeSolution::GetKernelArg() const opArgs.emplace_back(0); // placeholder opArgs.emplace_back(height); opArgs.emplace_back(width); - if(grid_size != static_cast(grid_size)) MIOPEN_THROW("Variable grid size can't be casted to uint32_t safely"); + if(grid_size != static_cast(grid_size)) + MIOPEN_THROW("Variable grid size can't be casted to uint32_t safely"); opArgs.emplace_back(static_cast(grid_size)); opArgs.emplace_back(dim_total); opArgs.emplace_back(magic_h.magic); diff --git a/src/hip/general_tensor_reorder_sol.cpp b/src/hip/general_tensor_reorder_sol.cpp index 75b5e17ed3..2012a574a0 100644 --- a/src/hip/general_tensor_reorder_sol.cpp +++ b/src/hip/general_tensor_reorder_sol.cpp @@ -59,7 +59,8 @@ static inline std::string GetKernelName(std::size_t data_size, uint32_t order_3, const GeneralReorderParam* kparam) { - if(kparam == nullptr) MIOPEN_THROW("Memory access fault, kparam is a nullptr"); + if(kparam == nullptr) + MIOPEN_THROW("Memory access fault, kparam is a nullptr"); std::ostringstream kernel_name; kernel_name << "general_4d_reorder_" << kparam->tile_x << "x" << kparam->tile_y << "_"; if(!(kparam->pack_x == 1 && kparam->pack_y == 1 && kparam->ediv_x == 1 && kparam->ediv_y == 1)) @@ -176,7 +177,8 @@ std::vector GenericReorderSolutionImpl::GetKernelArg() const opArgs.emplace_back(dim_1); opArgs.emplace_back(dim_2); opArgs.emplace_back(dim_3); - if(grid_size != static_cast(grid_size)) MIOPEN_THROW("Variable grid size can't be casted to uint32_t safely"); + if(grid_size != static_cast(grid_size)) + MIOPEN_THROW("Variable grid size can't be casted to uint32_t safely"); opArgs.emplace_back(static_cast(grid_size)); opArgs.emplace_back(dim_total); opArgs.emplace_back(magic_stride0.magic); @@ -191,8 +193,12 @@ std::vector GenericReorderSolutionImpl::GetKernelArg() const std::string GenericReorderSolutionImpl::GetKernelName() const { - return tensor_reorder::GetKernelName( - miopen::GetTypeSize(data_type), order_0, order_1, order_2, order_3, &kernel_param_heuristic); + return tensor_reorder::GetKernelName(miopen::GetTypeSize(data_type), + order_0, + order_1, + order_2, + order_3, + &kernel_param_heuristic); } bool GenericReorderSolutionImpl::IsSkippable() const diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index efb23157a2..05ca4ba9cb 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -41,35 +41,29 @@ struct TensorReorderAttributesBase virtual solver::KernelInfo GetKernelInfo() const = 0; virtual std::vector GetKernelArg() const = 0; virtual std::string GetKernelName() const = 0; - // used in HOST side to check the special cases that either tensor height or width equal = 1. - // In such cases, we don't need to conduct batched transpose operation, + // used in HOST side to check the special cases that either tensor height or width equal = 1. + // In such cases, we don't need to conduct batched transpose operation, // since the transposed tensor layout has exactly same memory layout as before. - virtual bool IsSkippable() const = 0; + virtual bool IsSkippable() const = 0; // workspace (buffer) to hold output tensor of transform Pre/Post convolution - virtual size_t GetOutputTensorSize() const = 0; + virtual size_t GetOutputTensorSize() const = 0; }; struct BatchedTransposeSolution_0132 : TensorReorderAttributesBase { BatchedTransposeSolution impl; BatchedTransposeSolution_0132(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) : impl(ctx_, data_type_, dim_0_ * dim_1_, dim_2_, dim_3_) { } solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } - std::vector GetKernelArg() const override - { - return impl.GetKernelArg(); - } - std::string GetKernelName() const override - { - return impl.GetKernelName(); - } + std::vector GetKernelArg() const override { return impl.GetKernelArg(); } + std::string GetKernelName() const override { return impl.GetKernelName(); } bool IsSkippable() const override { return impl.IsSkippable(); } size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; @@ -78,23 +72,17 @@ struct BatchedTransposeSolution_0231 : TensorReorderAttributesBase { BatchedTransposeSolution impl; BatchedTransposeSolution_0231(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) : impl(ctx_, data_type_, dim_0_, dim_1_, dim_2_ * dim_3_) { } solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } - std::vector GetKernelArg() const override - { - return impl.GetKernelArg(); - } - std::string GetKernelName() const override - { - return impl.GetKernelName(); - } + std::vector GetKernelArg() const override { return impl.GetKernelArg(); } + std::string GetKernelName() const override { return impl.GetKernelName(); } bool IsSkippable() const override { return impl.IsSkippable(); } size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; @@ -103,23 +91,17 @@ struct BatchedTransposeSolution_0312 : TensorReorderAttributesBase { BatchedTransposeSolution impl; BatchedTransposeSolution_0312(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) : impl(ctx_, data_type_, dim_0_, dim_1_ * dim_2_, dim_3_) { } solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } - std::vector GetKernelArg() const override - { - return impl.GetKernelArg(); - } - std::string GetKernelName() const override - { - return impl.GetKernelName(); - } + std::vector GetKernelArg() const override { return impl.GetKernelArg(); } + std::string GetKernelName() const override { return impl.GetKernelName(); } bool IsSkippable() const override { return impl.IsSkippable(); } size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; @@ -128,23 +110,17 @@ struct BatchedTransposeSolution_2301 : TensorReorderAttributesBase { BatchedTransposeSolution impl; BatchedTransposeSolution_2301(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) : impl(ctx_, data_type_, 1, dim_0_ * dim_1_, dim_2_ * dim_3_) { } solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } - std::vector GetKernelArg() const override - { - return impl.GetKernelArg(); - } - std::string GetKernelName() const override - { - return impl.GetKernelName(); - } + std::vector GetKernelArg() const override { return impl.GetKernelArg(); } + std::string GetKernelName() const override { return impl.GetKernelName(); } bool IsSkippable() const override { return impl.IsSkippable(); } size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; @@ -153,23 +129,17 @@ struct BatchedTransposeSolution_3012 : TensorReorderAttributesBase { BatchedTransposeSolution impl; BatchedTransposeSolution_3012(const ExecutionContext& ctx_, - miopenDataType_t data_type_, - uint32_t dim_0_, - uint32_t dim_1_, - uint32_t dim_2_, - uint32_t dim_3_) + miopenDataType_t data_type_, + uint32_t dim_0_, + uint32_t dim_1_, + uint32_t dim_2_, + uint32_t dim_3_) : impl(ctx_, data_type_, 1, dim_0_ * dim_1_ * dim_2_, dim_3_) { } solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } - std::vector GetKernelArg() const override - { - return impl.GetKernelArg(); - } - std::string GetKernelName() const override - { - return impl.GetKernelName(); - } + std::vector GetKernelArg() const override { return impl.GetKernelArg(); } + std::string GetKernelName() const override { return impl.GetKernelName(); } bool IsSkippable() const override { return impl.IsSkippable(); } size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } }; @@ -186,22 +156,11 @@ struct GenericReorderSolution : TensorReorderAttributesBase uint32_t order_1_, uint32_t order_2_, uint32_t order_3_) - : impl( data_type_, - dim_0_, - dim_1_, - dim_2_, - dim_3_, - order_0_, - order_1_, - order_2_, - order_3_) + : impl(data_type_, dim_0_, dim_1_, dim_2_, dim_3_, order_0_, order_1_, order_2_, order_3_) { } solver::KernelInfo GetKernelInfo() const override { return impl.GetKernelInfo(); } - std::vector GetKernelArg() const override - { - return impl.GetKernelArg(); - } + std::vector GetKernelArg() const override { return impl.GetKernelArg(); } std::string GetKernelName() const override { return impl.GetKernelName(); } bool IsSkippable() const override { return impl.IsSkippable(); } size_t GetOutputTensorSize() const override { return impl.GetOutputTensorSize(); } @@ -220,20 +179,23 @@ MakeTensorReorderAttributes(const ExecutionContext& ctx_, uint32_t order_3_) { std::unique_ptr default_ptr; - if(ctx_.use_hip_kernels == false){ + if(ctx_.use_hip_kernels == false) + { return default_ptr; } // Default using general reorder - if(data_type_ == miopenBFloat16){ + if(data_type_ == miopenBFloat16) + { MIOPEN_THROW("Unsupported reorder data type"); } int which = 0; - if(data_type_ != miopenDouble){ + if(data_type_ != miopenDouble) + { if((order_0_ == 0) && (order_1_ == 1) && (order_2_ == 3) && (order_3_ == 2)) which = 1; else if((order_0_ == 0) && (order_1_ == 2) && (order_2_ == 3) && (order_3_ == 1)) which = 2; - else if((order_0_ == 0) && (order_1_ == 3) && (order_2_ == 1) && (order_3_ == 2)) + else if((order_0_ == 0) && (order_1_ == 3) && (order_2_ == 1) && (order_3_ == 2)) which = 3; else if((order_0_ == 2) && (order_1_ == 3) && (order_2_ == 0) && (order_3_ == 1)) which = 4; @@ -256,15 +218,8 @@ MakeTensorReorderAttributes(const ExecutionContext& ctx_, switch(which) { case 0: - return std::make_unique(data_type_, - dim_0_, - dim_1_, - dim_2_, - dim_3_, - order_0_, - order_1_, - order_2_, - order_3_); + return std::make_unique( + data_type_, dim_0_, dim_1_, dim_2_, dim_3_, order_0_, order_1_, order_2_, order_3_); case 1: return std::make_unique( ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); @@ -280,9 +235,7 @@ MakeTensorReorderAttributes(const ExecutionContext& ctx_, case 5: return std::make_unique( ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); - default: - MIOPEN_THROW("Unsupported reorder sequence"); - break; + default: MIOPEN_THROW("Unsupported reorder sequence"); break; } } diff --git a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp index d16739d185..e8f236c36e 100644 --- a/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp +++ b/src/kernels/gpu_general_tensor_reorder_kernel/general_tensor_reorder.cpp @@ -25,7 +25,7 @@ *******************************************************************************/ #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include -#endif +#endif #include #include "order.hpp" diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index ce335760b1..5e1acb052a 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -49,8 +49,6 @@ struct miopen_type : std::integral_constant void cpu_tensor_reorder(T* dst, T* src, @@ -442,17 +440,18 @@ struct tensor_reorder_driver : tensor_reorder_base_driver order_3); EXPECT(reorder_sol != nullptr); size_t workspace = reorder_sol->GetOutputTensorSize(); - if(!reorder_sol->IsSkippable()){ + if(!reorder_sol->IsSkippable()) + { #if MIOPEN_BACKEND_OPENCL cl_mem dst_dev = - clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, workspace, nullptr, nullptr); + clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, workspace, nullptr, nullptr); #elif MIOPEN_BACKEND_HIP void* dst_dev; EXPECT(hipMalloc(&dst_dev, workspace) == hipSuccess); #endif const auto invoke_param = reorder_invoke_param{ DataCast(static_cast(src_dev)), DataCast(dst_dev)}; - + std::vector opArgs = reorder_sol->GetKernelArg(); boost::optional invoker_factory( [=](const std::vector& kernels) mutable { @@ -460,37 +459,30 @@ struct tensor_reorder_driver : tensor_reorder_base_driver const miopen::AnyInvokeParams& primitive_param) mutable { decltype(auto) invoke_params = primitive_param.CastTo(); - + const auto k = handle.Run(kernels[0]); - + opArgs[0] = OpKernelArg(invoke_params.dst); opArgs[1] = OpKernelArg(invoke_params.src); - + k(opArgs); }; }); - std::vector construction_params{reorder_sol->GetKernelInfo()}; - const auto invoker = - miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); + std::vector construction_params{ + reorder_sol->GetKernelInfo()}; + const auto invoker = miopen::deref(this->handle) + .PrepareInvoker(*invoker_factory, construction_params); // run gpu invoker(miopen::deref(this->handle), invoke_param); #if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - dst_dev, - CL_TRUE, - 0, - workspace, - t_dst_gpu.data.data(), - 0, - nullptr, - nullptr); + status = clEnqueueReadBuffer( + q, dst_dev, CL_TRUE, 0, workspace, t_dst_gpu.data.data(), 0, nullptr, nullptr); EXPECT(status == CL_SUCCESS); clReleaseMemObject(dst_dev); -#elif MIOPEN_BACKEND_HIP - EXPECT(hipMemcpy(t_dst_gpu.data.data(), - dst_dev, - workspace, - hipMemcpyDeviceToHost) == hipSuccess); +#elif MIOPEN_BACKEND_HIP + EXPECT( + hipMemcpy(t_dst_gpu.data.data(), dst_dev, workspace, hipMemcpyDeviceToHost) == + hipSuccess); hipFree(dst_dev); #endif } @@ -507,26 +499,28 @@ struct tensor_reorder_driver : tensor_reorder_base_driver order_3); #if MIOPEN_BACKEND_OPENCL - if(reorder_sol->IsSkippable()){ - status = clEnqueueReadBuffer(q, - dst_dev, - CL_TRUE, - 0, - sizeof(T) * tensor_sz, - t_dst_gpu.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); - clReleaseMemObject(src_dev); + if(reorder_sol->IsSkippable()) + { + status = clEnqueueReadBuffer(q, + src_dev, + CL_TRUE, + 0, + sizeof(T) * tensor_sz, + t_dst_gpu.data.data(), + 0, + nullptr, + nullptr); + EXPECT(status == CL_SUCCESS); + clReleaseMemObject(src_dev); } #elif MIOPEN_BACKEND_HIP - if(reorder_sol->IsSkippable()){ - EXPECT(hipMemcpy(t_dst_gpu.data.data(), - src_dev, - sizeof(T) * tensor_sz, - hipMemcpyDeviceToHost) == hipSuccess); - hipFree(src_dev); + if(reorder_sol->IsSkippable()) + { + EXPECT(hipMemcpy(t_dst_gpu.data.data(), + src_dev, + sizeof(T) * tensor_sz, + hipMemcpyDeviceToHost) == hipSuccess); + hipFree(src_dev); } #endif @@ -543,6 +537,4 @@ struct tensor_reorder_driver : tensor_reorder_base_driver } }; - - -int main(int argc, const char* argv[]){test_drive(argc, argv);} +int main(int argc, const char* argv[]) { test_drive(argc, argv); } From d124707a37176a883531604fa69f592ac15dfcb8 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sat, 26 Mar 2022 18:02:59 +0000 Subject: [PATCH 71/77] post review --- .gitignore | 4 +- .vscode/settings.json | 2 +- test/tensor_reorder.cpp | 136 +++++++++++++++++----------------------- 3 files changed, 59 insertions(+), 83 deletions(-) diff --git a/.gitignore b/.gitignore index e62317aacc..e94418e610 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ build/ -.vscode/ -.gitignore \ No newline at end of file +.vscode/settings.json +.gitignore diff --git a/.vscode/settings.json b/.vscode/settings.json index b824eccc08..212ae34517 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -67,6 +67,6 @@ "typeindex": "cpp", "typeinfo": "cpp", "valarray": "cpp", - "variant": "cpp" + "variant": "cpp", } } \ No newline at end of file diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 5e1acb052a..399721be6d 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -401,12 +401,27 @@ struct tensor_reorder_driver : tensor_reorder_base_driver tensor t_dst(tensor_len, tensor_strides); tensor t_dst_gpu(tensor_len, tensor_strides); rand_tensor_integer(t_src); + + auto reorder_sol = MakeTensorReorderAttributes(ctx, + to_miopen_data_type::get(), + dim_0, + dim_1, + dim_2, + dim_3, + order_0, + order_1, + order_2, + order_3); + EXPECT(reorder_sol != nullptr); + size_t workspace = reorder_sol->IsSkippable()? sizeof(T) * tensor_sz : reorder_sol->GetOutputTensorSize(); #if MIOPEN_BACKEND_OPENCL cl_context cl_ctx; clGetCommandQueueInfo(q, CL_QUEUE_CONTEXT, sizeof(cl_context), &cl_ctx, nullptr); cl_int status = CL_SUCCESS; cl_mem src_dev = clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, &status); + cl_mem dst_dev = + clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, workspace, nullptr, nullptr); status |= clEnqueueWriteBuffer(q, src_dev, CL_TRUE, @@ -419,7 +434,9 @@ struct tensor_reorder_driver : tensor_reorder_base_driver EXPECT(status == CL_SUCCESS); #elif MIOPEN_BACKEND_HIP void* src_dev; + void* dst_dev; EXPECT(hipMalloc(&src_dev, sizeof(T) * tensor_sz) == hipSuccess); + EXPECT(hipMalloc(&dst_dev, workspace) == hipSuccess); EXPECT(hipMemcpy( src_dev, t_src.data.data(), sizeof(T) * tensor_sz, hipMemcpyHostToDevice) == hipSuccess); @@ -428,64 +445,27 @@ struct tensor_reorder_driver : tensor_reorder_base_driver ctx.SetStream(&miopen::deref(this->handle)); ctx.DetectRocm(); // ctx.SetupFloats(); - auto reorder_sol = MakeTensorReorderAttributes(ctx, - to_miopen_data_type::get(), - dim_0, - dim_1, - dim_2, - dim_3, - order_0, - order_1, - order_2, - order_3); - EXPECT(reorder_sol != nullptr); - size_t workspace = reorder_sol->GetOutputTensorSize(); - if(!reorder_sol->IsSkippable()) - { -#if MIOPEN_BACKEND_OPENCL - cl_mem dst_dev = - clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, workspace, nullptr, nullptr); -#elif MIOPEN_BACKEND_HIP - void* dst_dev; - EXPECT(hipMalloc(&dst_dev, workspace) == hipSuccess); -#endif - const auto invoke_param = reorder_invoke_param{ - DataCast(static_cast(src_dev)), DataCast(dst_dev)}; - - std::vector opArgs = reorder_sol->GetKernelArg(); - boost::optional invoker_factory( - [=](const std::vector& kernels) mutable { - return [=](const miopen::Handle& handle, - const miopen::AnyInvokeParams& primitive_param) mutable { - decltype(auto) invoke_params = - primitive_param.CastTo(); - - const auto k = handle.Run(kernels[0]); - - opArgs[0] = OpKernelArg(invoke_params.dst); - opArgs[1] = OpKernelArg(invoke_params.src); - - k(opArgs); - }; - }); - std::vector construction_params{ - reorder_sol->GetKernelInfo()}; - const auto invoker = miopen::deref(this->handle) - .PrepareInvoker(*invoker_factory, construction_params); - // run gpu - invoker(miopen::deref(this->handle), invoke_param); -#if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer( - q, dst_dev, CL_TRUE, 0, workspace, t_dst_gpu.data.data(), 0, nullptr, nullptr); - EXPECT(status == CL_SUCCESS); - clReleaseMemObject(dst_dev); -#elif MIOPEN_BACKEND_HIP - EXPECT( - hipMemcpy(t_dst_gpu.data.data(), dst_dev, workspace, hipMemcpyDeviceToHost) == - hipSuccess); - hipFree(dst_dev); -#endif - } + const auto invoke_param = reorder_invoke_param{ + DataCast(static_cast(src_dev)), DataCast(dst_dev)}; + std::vector opArgs = reorder_sol->GetKernelArg(); + boost::optional invoker_factory( + [=](const std::vector& kernels) mutable { + return [=](const miopen::Handle& handle, + const miopen::AnyInvokeParams& primitive_param) mutable { + decltype(auto) invoke_params = + primitive_param.CastTo(); + const auto k = handle.Run(kernels[0]); + opArgs[0] = OpKernelArg(invoke_params.dst); + opArgs[1] = OpKernelArg(invoke_params.src); + k(opArgs); + }; + }); + std::vector construction_params{ + reorder_sol->GetKernelInfo()}; + const auto invoker = miopen::deref(this->handle) + .PrepareInvoker(*invoker_factory, construction_params); + // run gpu + invoker(miopen::deref(this->handle), invoke_param); // run cpu cpu_reorder::run(t_dst.data.data(), t_src.data.data(), @@ -499,29 +479,25 @@ struct tensor_reorder_driver : tensor_reorder_base_driver order_3); #if MIOPEN_BACKEND_OPENCL - if(reorder_sol->IsSkippable()) - { - status = clEnqueueReadBuffer(q, - src_dev, - CL_TRUE, - 0, - sizeof(T) * tensor_sz, - t_dst_gpu.data.data(), - 0, - nullptr, - nullptr); - EXPECT(status == CL_SUCCESS); - clReleaseMemObject(src_dev); - } + status = clEnqueueReadBuffer(q, + dst_dev, + CL_TRUE, + 0, + workspace, + t_dst_gpu.data.data(), + 0, + nullptr, + nullptr); + EXPECT(status == CL_SUCCESS); + clReleaseMemObject(dst_dev); + clReleaseMemObject(src_dev); #elif MIOPEN_BACKEND_HIP - if(reorder_sol->IsSkippable()) - { - EXPECT(hipMemcpy(t_dst_gpu.data.data(), - src_dev, - sizeof(T) * tensor_sz, - hipMemcpyDeviceToHost) == hipSuccess); - hipFree(src_dev); - } + EXPECT(hipMemcpy(t_dst_gpu.data.data(), + src_dev, + workspace, + hipMemcpyDeviceToHost) == hipSuccess); + hipFree(dst_dev); + hipFree(src_dev); #endif // we expect excact match, since use integer From a2c3202edff022b8f0dcb064302ee03307b8a62e Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sun, 27 Mar 2022 10:56:24 +0000 Subject: [PATCH 72/77] update on ctest --- .gitignore | 3 --- test/gpu_nchw_nhwc_transpose.cpp | 3 ++- test/tensor_reorder.cpp | 43 +++++++++++++------------------- 3 files changed, 19 insertions(+), 30 deletions(-) delete mode 100644 .gitignore diff --git a/.gitignore b/.gitignore deleted file mode 100644 index e94418e610..0000000000 --- a/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -build/ -.vscode/settings.json -.gitignore diff --git a/test/gpu_nchw_nhwc_transpose.cpp b/test/gpu_nchw_nhwc_transpose.cpp index d40fff3f9e..650d401fdb 100644 --- a/test/gpu_nchw_nhwc_transpose.cpp +++ b/test/gpu_nchw_nhwc_transpose.cpp @@ -370,7 +370,8 @@ struct transpose_test : transpose_base }; }); - std::vector construction_params{transpose_sol.GetKernelInfo()}; + std::vector construction_params{ + transpose_sol.GetKernelInfo()}; const auto invoker = miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 399721be6d..7045bd1d82 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -401,7 +401,11 @@ struct tensor_reorder_driver : tensor_reorder_base_driver tensor t_dst(tensor_len, tensor_strides); tensor t_dst_gpu(tensor_len, tensor_strides); rand_tensor_integer(t_src); - + + miopen::ExecutionContext ctx; + ctx.SetStream(&miopen::deref(this->handle)); + ctx.DetectRocm(); + // ctx.SetupFloats(); auto reorder_sol = MakeTensorReorderAttributes(ctx, to_miopen_data_type::get(), dim_0, @@ -413,15 +417,15 @@ struct tensor_reorder_driver : tensor_reorder_base_driver order_2, order_3); EXPECT(reorder_sol != nullptr); - size_t workspace = reorder_sol->IsSkippable()? sizeof(T) * tensor_sz : reorder_sol->GetOutputTensorSize(); + size_t workspace = reorder_sol->IsSkippable() ? sizeof(T) * tensor_sz + : reorder_sol->GetOutputTensorSize(); #if MIOPEN_BACKEND_OPENCL cl_context cl_ctx; clGetCommandQueueInfo(q, CL_QUEUE_CONTEXT, sizeof(cl_context), &cl_ctx, nullptr); cl_int status = CL_SUCCESS; cl_mem src_dev = clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, sizeof(T) * tensor_sz, nullptr, &status); - cl_mem dst_dev = - clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, workspace, nullptr, nullptr); + cl_mem dst_dev = clCreateBuffer(cl_ctx, CL_MEM_READ_WRITE, workspace, nullptr, nullptr); status |= clEnqueueWriteBuffer(q, src_dev, CL_TRUE, @@ -441,10 +445,6 @@ struct tensor_reorder_driver : tensor_reorder_base_driver src_dev, t_src.data.data(), sizeof(T) * tensor_sz, hipMemcpyHostToDevice) == hipSuccess); #endif - miopen::ExecutionContext ctx; - ctx.SetStream(&miopen::deref(this->handle)); - ctx.DetectRocm(); - // ctx.SetupFloats(); const auto invoke_param = reorder_invoke_param{ DataCast(static_cast(src_dev)), DataCast(dst_dev)}; std::vector opArgs = reorder_sol->GetKernelArg(); @@ -455,15 +455,15 @@ struct tensor_reorder_driver : tensor_reorder_base_driver decltype(auto) invoke_params = primitive_param.CastTo(); const auto k = handle.Run(kernels[0]); - opArgs[0] = OpKernelArg(invoke_params.dst); - opArgs[1] = OpKernelArg(invoke_params.src); + opArgs[0] = OpKernelArg(invoke_params.dst); + opArgs[1] = OpKernelArg(invoke_params.src); k(opArgs); }; }); std::vector construction_params{ reorder_sol->GetKernelInfo()}; - const auto invoker = miopen::deref(this->handle) - .PrepareInvoker(*invoker_factory, construction_params); + const auto invoker = + miopen::deref(this->handle).PrepareInvoker(*invoker_factory, construction_params); // run gpu invoker(miopen::deref(this->handle), invoke_param); // run cpu @@ -479,23 +479,14 @@ struct tensor_reorder_driver : tensor_reorder_base_driver order_3); #if MIOPEN_BACKEND_OPENCL - status = clEnqueueReadBuffer(q, - dst_dev, - CL_TRUE, - 0, - workspace, - t_dst_gpu.data.data(), - 0, - nullptr, - nullptr); + status = clEnqueueReadBuffer( + q, dst_dev, CL_TRUE, 0, workspace, t_dst_gpu.data.data(), 0, nullptr, nullptr); EXPECT(status == CL_SUCCESS); clReleaseMemObject(dst_dev); - clReleaseMemObject(src_dev); + clReleaseMemObject(src_dev); #elif MIOPEN_BACKEND_HIP - EXPECT(hipMemcpy(t_dst_gpu.data.data(), - src_dev, - workspace, - hipMemcpyDeviceToHost) == hipSuccess); + EXPECT(hipMemcpy(t_dst_gpu.data.data(), dst_dev, workspace, hipMemcpyDeviceToHost) == + hipSuccess); hipFree(dst_dev); hipFree(src_dev); #endif From edfbbba3420344ab999bbdc18509a0965ed4bf39 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sun, 27 Mar 2022 10:58:45 +0000 Subject: [PATCH 73/77] resolve M/D conflict --- .vscode/settings.json | 72 ------------------------------------------- 1 file changed, 72 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 212ae34517..0000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,72 +0,0 @@ -{ - "files.associations": { - "limits": "cpp", - "array": "cpp", - "atomic": "cpp", - "strstream": "cpp", - "*.tcc": "cpp", - "bitset": "cpp", - "cctype": "cpp", - "chrono": "cpp", - "clocale": "cpp", - "cmath": "cpp", - "codecvt": "cpp", - "complex": "cpp", - "condition_variable": "cpp", - "csignal": "cpp", - "cstdarg": "cpp", - "cstddef": "cpp", - "cstdint": "cpp", - "cstdio": "cpp", - "cstdlib": "cpp", - "cstring": "cpp", - "ctime": "cpp", - "cwchar": "cpp", - "cwctype": "cpp", - "deque": "cpp", - "list": "cpp", - "unordered_map": "cpp", - "unordered_set": "cpp", - "vector": "cpp", - "exception": "cpp", - "algorithm": "cpp", - "filesystem": "cpp", - "functional": "cpp", - "iterator": "cpp", - "map": "cpp", - "memory": "cpp", - "memory_resource": "cpp", - "numeric": "cpp", - "optional": "cpp", - "random": "cpp", - "ratio": "cpp", - "set": "cpp", - "string": "cpp", - "string_view": "cpp", - "system_error": "cpp", - "tuple": "cpp", - "type_traits": "cpp", - "utility": "cpp", - "fstream": "cpp", - "future": "cpp", - "initializer_list": "cpp", - "iomanip": "cpp", - "iosfwd": "cpp", - "iostream": "cpp", - "istream": "cpp", - "mutex": "cpp", - "new": "cpp", - "ostream": "cpp", - "shared_mutex": "cpp", - "sstream": "cpp", - "stdexcept": "cpp", - "streambuf": "cpp", - "thread": "cpp", - "cfenv": "cpp", - "cinttypes": "cpp", - "typeindex": "cpp", - "typeinfo": "cpp", - "valarray": "cpp", - "variant": "cpp", - } -} \ No newline at end of file From 5fb50b7a04be7d9ae7c036a1d270df54dce79d25 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Sun, 27 Mar 2022 16:42:36 +0000 Subject: [PATCH 74/77] re-clang format check --- test/tensor_reorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tensor_reorder.cpp b/test/tensor_reorder.cpp index 7045bd1d82..d1e2e1d1ce 100644 --- a/test/tensor_reorder.cpp +++ b/test/tensor_reorder.cpp @@ -401,7 +401,7 @@ struct tensor_reorder_driver : tensor_reorder_base_driver tensor t_dst(tensor_len, tensor_strides); tensor t_dst_gpu(tensor_len, tensor_strides); rand_tensor_integer(t_src); - + miopen::ExecutionContext ctx; ctx.SetStream(&miopen::deref(this->handle)); ctx.DetectRocm(); From 5cc064cb1af0aaa723fa7c5f4c594ce85621ea67 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Mon, 28 Mar 2022 03:21:06 +0000 Subject: [PATCH 75/77] fix opencl tidy --- src/include/miopen/tensor_reorder_util.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 05ca4ba9cb..4b20c4aa7b 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -179,7 +179,7 @@ MakeTensorReorderAttributes(const ExecutionContext& ctx_, uint32_t order_3_) { std::unique_ptr default_ptr; - if(ctx_.use_hip_kernels == false) + if(!ctx_.use_hip_kernels) { return default_ptr; } From f5411ca08045f8ceb5df5cfa2b926cee0eb52d37 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 29 Mar 2022 09:37:37 +0000 Subject: [PATCH 76/77] bug fix --- src/include/miopen/tensor_reorder_util.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/include/miopen/tensor_reorder_util.hpp b/src/include/miopen/tensor_reorder_util.hpp index 4b20c4aa7b..1010b98e95 100644 --- a/src/include/miopen/tensor_reorder_util.hpp +++ b/src/include/miopen/tensor_reorder_util.hpp @@ -237,6 +237,7 @@ MakeTensorReorderAttributes(const ExecutionContext& ctx_, ctx_, data_type_, dim_0_, dim_1_, dim_2_, dim_3_); default: MIOPEN_THROW("Unsupported reorder sequence"); break; } + return default_ptr; } } // namespace miopen From 321086793a45ca2b4776941ef655acaaadb3ddaf Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Wed, 30 Mar 2022 02:36:01 +0000 Subject: [PATCH 77/77] header file fix --- src/include/miopen/general_tensor_reorder_sol.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/miopen/general_tensor_reorder_sol.hpp b/src/include/miopen/general_tensor_reorder_sol.hpp index cac41370e9..41f387f995 100644 --- a/src/include/miopen/general_tensor_reorder_sol.hpp +++ b/src/include/miopen/general_tensor_reorder_sol.hpp @@ -29,7 +29,7 @@ #include #include #include -#include +#include #include namespace miopen {