From c7e0d377ec8d0b7c348bd81d81a64a8d73fe3647 Mon Sep 17 00:00:00 2001 From: mentat <108366729+bghimireamd@users.noreply.github.com> Date: Thu, 5 Oct 2023 09:38:20 -0500 Subject: [PATCH 1/2] bg/lwpmiopen 193 : Integrate CK's batch norm backward training into non-tunable MIOpen solver (#2385) --- src/CMakeLists.txt | 2 + src/batch_norm_api.cpp | 7 - src/include/miopen/batchnorm/solvers.hpp | 20 + .../miopen/solver/implicitgemm_ck_util.hpp | 65 +- src/ocl/batchnormocl.cpp | 9 +- src/solver.cpp | 2 + src/solver/batchnorm/backward_ck.cpp | 251 ++++++ src/solver/batchnorm/forward_training_ck.cpp | 239 ++++++ test/bn_spatial_nhwc_test.cpp | 749 ------------------ test/fusionHost.hpp | 31 +- test/gtest/bn.hpp | 171 ++++ test/gtest/bn_bwd.cpp | 73 ++ test/gtest/bn_fwd_train.cpp | 73 ++ test/gtest/bn_infer.cpp | 8 +- test/gtest/bn_test_data.hpp | 223 +++++- test/gtest/test_operations.hpp | 35 + 16 files changed, 1164 insertions(+), 794 deletions(-) create mode 100644 src/solver/batchnorm/backward_ck.cpp create mode 100644 src/solver/batchnorm/forward_training_ck.cpp delete mode 100644 test/bn_spatial_nhwc_test.cpp create mode 100644 test/gtest/bn_bwd.cpp create mode 100644 test/gtest/bn_fwd_train.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 71289c8b42..abc0679a8a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -152,6 +152,7 @@ set( MIOpen_Source solver/activ/bwd_1.cpp solver/activ/fwd_0.cpp solver/activ/fwd_1.cpp + solver/batchnorm/backward_ck.cpp solver/batchnorm/backward_per_activation.cpp solver/batchnorm/backward_per_activation_fused.cpp solver/batchnorm/backward_spatial_multiple.cpp @@ -163,6 +164,7 @@ set( MIOpen_Source solver/batchnorm/forward_per_activation_fused.cpp solver/batchnorm/forward_spatial_multiple.cpp solver/batchnorm/forward_spatial_single.cpp + solver/batchnorm/forward_training_ck.cpp solver/conv_asm_1x1u.cpp solver/conv_asm_1x1u_bias_activ_fused.cpp solver/conv_asm_1x1u_stride2.cpp diff --git a/src/batch_norm_api.cpp b/src/batch_norm_api.cpp index 03db138945..69454b185a 100644 --- a/src/batch_norm_api.cpp +++ b/src/batch_norm_api.cpp @@ -243,13 +243,6 @@ miopenBatchNormalizationBackward(miopenHandle_t handle, const void* savedMean, const void* savedInvVariance) { - // bfloat16 not supported for batchnorm operation - if(miopen::deref(xDesc).GetType() == miopenBFloat16 || - miopen::deref(dyDesc).GetType() == miopenBFloat16 || - miopen::deref(dxDesc).GetType() == miopenBFloat16) - { - return miopenStatusNotImplemented; - } MIOPEN_LOG_FUNCTION(handle, bn_mode, diff --git a/src/include/miopen/batchnorm/solvers.hpp b/src/include/miopen/batchnorm/solvers.hpp index c7d050abeb..70d64bb204 100644 --- a/src/include/miopen/batchnorm/solvers.hpp +++ b/src/include/miopen/batchnorm/solvers.hpp @@ -142,6 +142,26 @@ struct BnCKFwdInference final : BatchnormSolver const miopen::batchnorm::ProblemDescription& problem) const override; }; +struct BnCKBwdBackward final : BatchnormSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; +}; + +struct BnCKFwdTraining final : BatchnormSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& problem) const override; +}; + } // namespace batchnorm } // namespace solver diff --git a/src/include/miopen/solver/implicitgemm_ck_util.hpp b/src/include/miopen/solver/implicitgemm_ck_util.hpp index 8656bdbabc..318d970170 100644 --- a/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -41,8 +41,10 @@ typename ConvPtrsType::iterator FindConvPtrByID(ConvPtrsType& conv_ptrs, }); } -template -std::vector FillValidKernelsIDs(const ProblemDescription& problem) +template +std::vector FillValidKernelsIDs(const ProblemDescriptionType& problem) { const auto args = CKArgsType{problem}; const auto conv_ptrs = DeviceOpType::GetInstances(); @@ -59,8 +61,10 @@ std::vector FillValidKernelsIDs(const ProblemDescription& problem) return valid_kernels; } -template -bool IsCKArgsSupported(const ProblemDescription& problem, const std::string& kernel_id) +template +bool IsCKArgsSupported(const ProblemDescriptionType& problem, const std::string& kernel_id) { auto conv_ptrs = DeviceOpType::GetInstances(); auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); @@ -68,20 +72,25 @@ bool IsCKArgsSupported(const ProblemDescription& problem, const std::string& ker return (ptr_iter != conv_ptrs.end()) && CKArgsType{problem}.IsSupportedBy(*ptr_iter); } -template -bool IsCKApplicable(const ProblemDescription& problem) +template +bool IsCKApplicable(const ProblemDescriptionType& problem) { const auto args = CKArgsType{problem}; - if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; })) - return false; + // if(!std::all_of(args.strides.begin(), args.strides.end(), [](auto x) { return x == 1; })) + // return false; const auto ptrs = DeviceOpType::GetInstances(); return std::any_of( ptrs.begin(), ptrs.end(), [&args](auto& ptr) { return args.IsSupportedBy(ptr); }); } -template -ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::string& kernel_id) +template +ConvSolution InitInvokerFactory(const ProblemDescriptionType& problem, const std::string& kernel_id) { auto conv_ptrs = DeviceOpType::GetInstances(); auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); @@ -112,5 +121,41 @@ ConvSolution InitInvokerFactory(const ProblemDescription& problem, const std::st return result; } +template +ConvSolution InitAnyInvokerFactory(const ProblemDescriptionType& problem, + const std::string& kernel_id) +{ + auto conv_ptrs = DeviceOpType::GetInstances(); + auto ptr_iter = FindConvPtrByID(conv_ptrs, kernel_id); + + if(ptr_iter == conv_ptrs.end()) + return {miopenStatusInvalidValue}; + + ConvSolution result; + result.invoker_factory = + [ck_args = CKArgsType{problem}, + sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}](const std::vector&) mutable { + return [ck_args = std::move(ck_args), sh_conv_ptr = std::move(sh_conv_ptr)]( + const Handle& handle, const AnyInvokeParams& primitive_parameters) { + const auto& data_ctx = primitive_parameters.CastTo(); + auto argument_ptr = ck_args.MakeArgPtr(sh_conv_ptr, data_ctx); + auto invoker_ptr = sh_conv_ptr->MakeInvokerPointer(); + + const auto enable_profiling = handle.IsProfilingEnabled(); + float elapsed_time = + invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling}); + if(enable_profiling) + { + handle.ResetKernelTime(); + handle.AccumKernelTime(elapsed_time); + } + }; + }; + return result; +} + } // namespace solver } // namespace miopen diff --git a/src/ocl/batchnormocl.cpp b/src/ocl/batchnormocl.cpp index 6c8a079a2a..6147a827b8 100644 --- a/src/ocl/batchnormocl.cpp +++ b/src/ocl/batchnormocl.cpp @@ -131,7 +131,8 @@ void BatchNormForwardTraining(Handle& handle, return tmp; }(); - const auto solvers = solver::SolverContainer{}; @@ -300,7 +301,7 @@ void BatchNormBackward(Handle& handle, { MIOPEN_THROW(miopenStatusBadParm); } - if(dxDesc.GetType() != dyDesc.GetType() || dyDesc.GetType() != xDesc.GetType()) + if(dxDesc.GetType() != dyDesc.GetType()) { MIOPEN_THROW(miopenStatusBadParm); } @@ -338,7 +339,6 @@ void BatchNormBackward(Handle& handle, tmp.dx = dx; tmp.bnScale = bnScale; tmp.resultBnScaleDiff = resultBnScaleDiff; - tmp.resultBnScaleDiff = resultBnScaleDiff; tmp.resultBnBiasDiff = resultBnBiasDiff; tmp.epsilon = epsilon; tmp.savedMean = savedMean; @@ -346,7 +346,8 @@ void BatchNormBackward(Handle& handle, return tmp; }(); - const auto solvers = solver::SolverContainer{}; diff --git a/src/solver.cpp b/src/solver.cpp index d83935e646..4cd680dd9c 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -569,6 +569,8 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) RegisterWithSolver( registry, ++id, ConvHipImplicitGemm3DGroupBwdXdlops{}, miopenConvolutionAlgoImplicitGEMM); Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdInference{}.SolverDbId()); + Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKBwdBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::Batchnorm, batchnorm::BnCKFwdTraining{}.SolverDbId()); // IMPORTANT: New solvers should be added to the end of the function! } diff --git a/src/solver/batchnorm/backward_ck.cpp b/src/solver/batchnorm/backward_ck.cpp new file mode 100644 index 0000000000..fba8724990 --- /dev/null +++ b/src/solver/batchnorm/backward_ck.cpp @@ -0,0 +1,251 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +#include +#include +#include +#endif +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_CK_BN_BACK) + +namespace miopen { +namespace solver { +namespace batchnorm { +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using index_t = int32_t; + +constexpr index_t Rank = 4; +constexpr index_t NumBatchNormReduceDim = 3; + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; +using BF16 = ushort; + +template +using DeviceOpBNBwdPtrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchNormBwd>; + +struct CKArgsBNormBwd +{ + CKArgsBNormBwd(const miopen::batchnorm::ProblemDescription& problem) + { + std::copy(problem.GetXDesc().GetLengths().begin(), + problem.GetXDesc().GetLengths().end(), + lens.begin()); + + std::copy(problem.GetXDesc().GetStrides().begin(), + problem.GetXDesc().GetStrides().end(), + strides.begin()); + arrScaleBiasMeanVarLengths[0] = lens[1]; // get channel + arrScaleBiasMeanVarStrides[0] = 1; + + // prep for CK + std::sort(strides.begin(), strides.end(), std::greater<>()); + std::rotate(lens.begin() + 1, lens.begin() + 2, lens.end()); + } + + CKArgsBNormBwd(const CKArgsBNormBwd&) = default; + CKArgsBNormBwd(CKArgsBNormBwd&&) = default; + CKArgsBNormBwd& operator=(const CKArgsBNormBwd&) = default; + + template + auto MakeArgPtr(const InvokerPtr& invoker_ptr, const InvokerParams& data_ctx) const + { + return invoker_ptr->MakeArgumentPointer(lens, + strides, + strides, + strides, + reduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + data_ctx.x, + data_ctx.dy, + data_ctx.bnScale, + data_ctx.savedMean, + data_ctx.savedInvVariance, + epsilon, + PassThrough{}, + data_ctx.dx, + data_ctx.resultBnScaleDiff, + data_ctx.resultBnBiasDiff); + } + + template + bool IsSupportedBy(const ConvPtr& invoker_ptr) const + { + auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::BwdInvokeParams{}); + return invoker_ptr->IsSupportedArgument(arg_ptr.get()); + } + + std::array lens; // inOutLengths + std::array strides; // inOutStrides + std::vector invariantDims; + + std::array arrScaleBiasMeanVarLengths; + std::array arrScaleBiasMeanVarStrides; + + double epsilon = 1e-5; + std::array reduceDims{0, 1, 2}; +}; + +template +static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem) +{ + return IsCKApplicable, + CKArgsBNormBwd>(problem); +} + +#endif + +bool BnCKBwdBackward::IsApplicable(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = ctx; + std::ignore = fdesc_problem; + return false; +#else + if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_BACK{})) + return false; + if(!bn_problem.IsLayoutNHWC()) + return false; + if(!ck_utility::is_ck_supported_hardware(ctx.GetStream())) + return false; + if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType()) + return false; + + switch(bn_problem.GetXDesc().GetType()) + { + case miopenFloat: return CheckCKApplicability(bn_problem); + case miopenDouble: return CheckCKApplicability(bn_problem); + case miopenHalf: return CheckCKApplicability(bn_problem); + case miopenBFloat16: + return CheckCKApplicability(bn_problem); + case miopenInt32: + case miopenInt8: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: MIOPEN_THROW("Unsupported datatype"); + } + return false; +#endif +} + +template +ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +{ + const auto& valid_kernel_ids = FillValidKernelsIDs, + CKArgsBNormBwd>(bn_problem); + assert(!valid_kernel_ids.empty()); + const auto& kernel_id = valid_kernel_ids[0]; + return InitAnyInvokerFactory, + CKArgsBNormBwd, + miopen::batchnorm::BwdInvokeParams>(bn_problem, kernel_id); +} + +ConvSolution BnCKBwdBackward::GetSolution( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(bn_problem.GetXDesc().GetType()) + { + + case miopenFloat: return MakeAnyInvokerFactory(bn_problem); + case miopenDouble: return MakeAnyInvokerFactory(bn_problem); + case miopenHalf: return MakeAnyInvokerFactory(bn_problem); + case miopenBFloat16: + return MakeAnyInvokerFactory(bn_problem); + case miopenInt8: + case miopenInt32: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, "BnCKBwdBackward operation not for this data type"); + } +#endif + return {}; +} + +} // namespace batchnorm +} // namespace solver +} // namespace miopen diff --git a/src/solver/batchnorm/forward_training_ck.cpp b/src/solver/batchnorm/forward_training_ck.cpp new file mode 100644 index 0000000000..a65cec14a9 --- /dev/null +++ b/src/solver/batchnorm/forward_training_ck.cpp @@ -0,0 +1,239 @@ + +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL +#include +#include +#include +#endif +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING) + +namespace miopen { +namespace solver { +namespace batchnorm { +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + +using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; +using index_t = int32_t; + +constexpr index_t Rank = 4; +constexpr index_t NumBatchNormReduceDim = 3; + +using F16 = ck::half_t; +using F32 = float; +using F64 = double; +using BF16 = ushort; + +template +using DeviceOpBNFwdTrainingPtrs = + ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchNormFwd>; + +struct CKArgsBNormFwdTraining +{ + CKArgsBNormFwdTraining(const miopen::batchnorm::ProblemDescription& problem) + { + std::copy(problem.GetXDesc().GetLengths().begin(), + problem.GetXDesc().GetLengths().end(), + xyLengths.begin()); + + std::copy(problem.GetXDesc().GetStrides().begin(), + problem.GetXDesc().GetStrides().end(), + xyStrides.begin()); + arrScaleBiasMeanVarLengths[0] = xyLengths[1]; // get channel + arrScaleBiasMeanVarStrides[0] = 1; + + // prep for CK + std::sort(xyStrides.begin(), xyStrides.end(), std::greater<>()); + std::rotate(xyLengths.begin() + 1, xyLengths.begin() + 2, xyLengths.end()); + } + + CKArgsBNormFwdTraining(const CKArgsBNormFwdTraining&) = default; + CKArgsBNormFwdTraining(CKArgsBNormFwdTraining&&) = default; + CKArgsBNormFwdTraining& operator=(const CKArgsBNormFwdTraining&) = default; + + template + auto MakeArgPtr(const InvokerPtr& invoker_ptr, const InvokerParams& data_ctx) const + { + return invoker_ptr->MakeArgumentPointer(xyLengths, + xyStrides, + xyStrides, + reduceDims, + arrScaleBiasMeanVarLengths, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + arrScaleBiasMeanVarStrides, + data_ctx.x, + data_ctx.bnScale, + data_ctx.bnBias, + data_ctx.epsilon, + PassThroughOp{}, + data_ctx.y, + data_ctx.resultSaveMean, + data_ctx.resultSaveInvVariance, + data_ctx.expAvgFactor, + data_ctx.resultRunningMean, + data_ctx.resultRunningVariance); + } + + template + bool IsSupportedBy(const ConvPtr& invoker_ptr) const + { + auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::InvokeParams{}); + return invoker_ptr->IsSupportedArgument(arg_ptr.get()); + } + + std::array xyLengths; + std::array xyStrides; + std::vector invariantDims; + + std::array arrScaleBiasMeanVarLengths; + std::array arrScaleBiasMeanVarStrides; + + std::array reduceDims{0, 1, 2}; +}; + +template +static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem) +{ + return IsCKApplicable, + CKArgsBNormFwdTraining>(problem); +} +#endif + +bool BnCKFwdTraining::IsApplicable(const ExecutionContext& context, + const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = context; + std::ignore = fdesc_problem; + return false; +#else + if(miopen::IsDisabled(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING{})) + return false; + if(!bn_problem.IsLayoutNHWC()) + return false; + if(!ck_utility::is_ck_supported_hardware(context.GetStream())) + return false; + + switch(bn_problem.GetXDesc().GetType()) + { + case miopenHalf: return CheckCKApplicability(bn_problem); + case miopenFloat: return CheckCKApplicability(bn_problem); + case miopenDouble: return CheckCKApplicability(bn_problem); + case miopenBFloat16: return CheckCKApplicability(bn_problem); + case miopenInt32: + case miopenInt8: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: MIOPEN_THROW("BnCKFwdTraining operation does not supprot this data type"); + } + return false; +#endif +} + +template +ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +{ + const auto& valid_kernel_ids = FillValidKernelsIDs, + CKArgsBNormFwdTraining>(bn_problem); + assert(!valid_kernel_ids.empty()); + const auto& kernel_id = valid_kernel_ids[0]; + return InitAnyInvokerFactory, + CKArgsBNormFwdTraining, + miopen::batchnorm::InvokeParams>(bn_problem, kernel_id); +} + +ConvSolution BnCKFwdTraining::GetSolution( + [[maybe_unused]] const ExecutionContext& context, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(bn_problem.GetXDesc().GetType()) + { + + case miopenFloat: return MakeAnyInvokerFactory(bn_problem); + case miopenDouble: return MakeAnyInvokerFactory(bn_problem); + case miopenHalf: return MakeAnyInvokerFactory(bn_problem); + case miopenBFloat16: return MakeAnyInvokerFactory(bn_problem); + case miopenInt8: + case miopenInt32: + case miopenInt8x4: + case miopenBFloat8: + case miopenFloat8: + default: + MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdTraining operation not for this data type"); + } +#endif + return {}; +} + +} // namespace batchnorm +} // namespace solver +} // namespace miopen diff --git a/test/bn_spatial_nhwc_test.cpp b/test/bn_spatial_nhwc_test.cpp deleted file mode 100644 index abca57e7ce..0000000000 --- a/test/bn_spatial_nhwc_test.cpp +++ /dev/null @@ -1,749 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2021 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#include "driver.hpp" -#include "get_handle.hpp" -#include "tensor_holder.hpp" -#include "test.hpp" -#include "verify.hpp" -#include "random.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define MIO_BN_TEST_EXPAVGFACTOR 0.1 -#define MIO_BN_TEST_EPSILON 1e-5 -#define MIO_BN_USE_MIX_PREC 1 -#if MIO_BN_USE_MIX_PREC == 1 -#define PREC_TYPE float -#else -#define PREC_TYPE T -#endif - -template -struct verify_forward_train_bn_spatial -{ - const tensor input; - const tensor scale; - const tensor shift; - - std::tuple, tensor, tensor, tensor, tensor> cpu() const - { - double epsilon = MIO_BN_TEST_EPSILON; - double expAvgFactor = MIO_BN_TEST_EXPAVGFACTOR; - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(input.desc.GetLengths()); - - std::size_t rs_n_batch, rs_channels, rs_height, rs_width; - auto derivedBnDesc = - miopen::TensorDescriptor(input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(rs_n_batch, rs_height, rs_width, rs_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - tensor runMean; - tensor runVar; - if(input.desc.GetType() == miopenFloat) - { - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - } - else - { - prng::reset_seed(); - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - - const U Data_scale = static_cast(0.001); - for(std::size_t i = 0; i < runMean.desc.GetElementSize(); i++) - { - runMean[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - runVar[i] = prng::gen_descreet_unsigned(Data_scale, 100); - } - } - auto saveMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - auto saveInvVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - auto out = input; - std::fill(out.begin(), out.end(), 0); - - const auto nhw = double(height * width * n_batch); - par_for(channels, 1, [&](int cidx) { - double elemStd = 0.; - double variance_accum = 0.; - double mean_accum = 0.; - double invVar = 0.; - double newRunMean = 0.; - double adjust = 0.; - - std::vector variance_accum_arr(height, 0.0); - std::vector mean_accum_arr(height, 0.0); - std::vector dshift_accum_arr(height, 0.0); - std::vector dscale_accum_arr(height, 0.0); - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - mean_accum_arr[row] += input(bidx, cidx, row, column); - } - } - } - for(std::size_t i = 0; i < height; i++) - mean_accum += mean_accum_arr[i]; - - mean_accum /= nhw; - - elemStd = 0.; - variance_accum = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - out(bidx, cidx, row, column) = elemStd = - input(bidx, cidx, row, column) - mean_accum; - variance_accum_arr[row] += elemStd * elemStd; - } - } - } - for(std::size_t i = 0; i < height; i++) - variance_accum += variance_accum_arr[i]; - - variance_accum /= nhw; - invVar = 1.0 / sqrt(variance_accum + epsilon); - - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - out(bidx, cidx, row, column) = - scale(0, 0, 0, cidx) * (invVar * out(bidx, cidx, row, column)) + - shift(0, 0, 0, cidx); - } - } - } - - saveMean(0, 0, 0, cidx) = mean_accum; - saveInvVar(0, 0, 0, cidx) = invVar; - - newRunMean = runMean(0, 0, 0, cidx) * (1 - expAvgFactor); - runMean(0, 0, 0, cidx) = mean_accum * expAvgFactor + newRunMean; - adjust = (n_batch * height * width == 1) ? variance_accum - : (nhw / (nhw - 1)) * variance_accum; - runVar(0, 0, 0, cidx) = - (1 - expAvgFactor) * runVar(0, 0, 0, cidx) + expAvgFactor * adjust; - }); - - return std::make_tuple(out, runMean, runVar, saveMean, saveInvVar); - } - - std::tuple, tensor, tensor, tensor, tensor> gpu() const - { - auto&& handle = get_handle(); - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(input.desc.GetLengths()); - - auto out = input; - std::fill(out.begin(), out.end(), 0); - - std::size_t rs_n_batch, rs_channels, rs_height, rs_width; - auto derivedBnDesc = - miopen::TensorDescriptor(input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(rs_n_batch, rs_height, rs_width, rs_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - tensor runMean; - tensor runVar; - if(input.desc.GetType() == miopenFloat) - { - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}.generate( - tensor_elem_gen_integer{17}); - } - else - { - prng::reset_seed(); - runMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - runVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - - const U Data_scale = static_cast(0.001); - for(std::size_t i = 0; i < runMean.desc.GetElementSize(); i++) - { - runMean[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - runVar[i] = prng::gen_descreet_unsigned(Data_scale, 100); - } - } - - auto saveMean = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - auto saveInvVar = tensor{rs_n_batch, rs_height, rs_width, rs_channels}; - - auto in_dev = handle.Write(input.data); - auto scale_dev = handle.Write(scale.data); - auto shift_dev = handle.Write(shift.data); - - auto runMean_dev = handle.Write(runMean.data); - auto runVar_dev = handle.Write(runVar.data); - auto saveMean_dev = handle.Create(channels); - auto saveInvVar_dev = handle.Create(channels); - auto out_dev = handle.Create(n_batch * channels * height * width); - - double epsilon = MIO_BN_TEST_EPSILON; - double expAvgFactor = MIO_BN_TEST_EXPAVGFACTOR; - - float alpha = 1.0; - float beta = 0.0; - - miopen::BatchNormForwardTraining(handle, - miopenBNSpatial, - &alpha, - &beta, - input.desc, - in_dev.get(), - out.desc, - out_dev.get(), - scale.desc, - scale_dev.get(), - shift_dev.get(), - expAvgFactor, - runMean_dev.get(), - runVar_dev.get(), - epsilon, - saveMean_dev.get(), - saveInvVar_dev.get()); - - saveMean.data = handle.Read(saveMean_dev, saveMean.data.size()); - saveInvVar.data = handle.Read(saveInvVar_dev, saveInvVar.data.size()); - runMean.data = handle.Read(runMean_dev, runMean.data.size()); - runVar.data = handle.Read(runVar_dev, runVar.data.size()); - out.data = handle.Read(out_dev, out.data.size()); - - return std::make_tuple(out, runMean, runVar, saveMean, saveInvVar); - } - - void fail(int badtensor) const - { - std::cout << "Forward Train Spatial Batch Normalization: " << std::endl; - std::cout << "Input tensor: " << input.desc.ToString() << std::endl; - - switch(badtensor) - { - case(0): std::cout << "Output tensor output failed verification." << std::endl; break; - case(1): std::cout << "Running Mean output tensor failed verification." << std::endl; break; - case(2): - std::cout << "Running Variance output tensor failed verification." << std::endl; - break; - case(3): std::cout << "Saved Mean tensor failed verification." << std::endl; break; - case(4): std::cout << "Saved Variance tensor failed verification." << std::endl; break; - default: break; - } - } -}; - -template -struct verify_backward_bn_spatial_recalc -{ - const tensor x_input; - const tensor dy_input; - const tensor scale; - - std::tuple, tensor, tensor> cpu() const - { - double epsilon = MIO_BN_TEST_EPSILON; - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - const auto nhw = double(height * width * n_batch); - - par_for(channels, 1, [&](int cidx) { - double elemStd = 0.; - unsigned int xhat_index; - double mean = 0.; - double invVar = 0.; - double dyelem = 0.; - double variance = 0.; - - std::vector xhat(height * width * n_batch, 0.0); - std::vector variance_accum_arr(height, 0.0); - std::vector mean_accum_arr(height, 0.0); - std::vector dshift_accum_arr(height, 0.0); - std::vector dscale_accum_arr(height, 0.0); - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - mean_accum_arr[row] += x_input(bidx, cidx, row, column); - } - } - } - for(std::size_t i = 0; i < height; i++) - mean += mean_accum_arr[i]; - - mean /= nhw; - - elemStd = 0.; - variance = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - elemStd = x_input(bidx, cidx, row, column) - mean; - variance_accum_arr[row] += elemStd * elemStd; - } - } - } - for(std::size_t i = 0; i < height; i++) - variance += variance_accum_arr[i]; - - variance /= nhw; - invVar = 1. / double(sqrt(variance + epsilon)); - - dscale(0, cidx, 0, 0) = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - elemStd = x_input(bidx, cidx, row, column) - mean; - xhat[xhat_index] = elemStd * invVar; - dyelem = dy_input(bidx, cidx, row, column); - dshift_accum_arr[row] += dyelem; - dscale_accum_arr[row] += xhat[xhat_index] * dyelem; - } - } - } - for(std::size_t i = 0; i < height; i++) - { - dshift(0, cidx, 0, 0) += dshift_accum_arr[i]; - dscale(0, cidx, 0, 0) += dscale_accum_arr[i]; - } - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - - double tmp1 = - nhw * dy_input(bidx, cidx, row, column) - dshift(0, cidx, 0, 0); - double tmp2 = -xhat[xhat_index] * dscale(0, cidx, 0, 0); - double tmp3 = (scale(0, 0, 0, cidx) * invVar) / nhw; - dx_out(bidx, cidx, row, column) = tmp3 * (tmp2 + tmp1); - } - } - } - }); - - return std::make_tuple(dx_out, dscale, dshift); - } - - std::tuple, tensor, tensor> gpu() const - { - auto&& handle = get_handle(); - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - float alpha = 1.0; - float beta = 0.0; - - auto xin_dev = handle.Write(x_input.data); - auto dyin_dev = handle.Write(dy_input.data); - auto scale_dev = handle.Write(scale.data); - auto dscale_dev = handle.Write(dscale.data); - auto dshift_dev = handle.Write(dshift.data); - auto dx_out_dev = handle.Write(dx_out.data); - - double epsilon = MIO_BN_TEST_EPSILON; - - miopen::BatchNormBackward(handle, - miopenBNSpatial, - &alpha, - &beta, - &alpha, - &beta, - x_input.desc, - xin_dev.get(), - dy_input.desc, - dyin_dev.get(), - dx_out.desc, - dx_out_dev.get(), - scale.desc, - scale_dev.get(), - dscale_dev.get(), - dshift_dev.get(), - epsilon, - nullptr, - nullptr); - - dx_out.data = handle.Read(dx_out_dev, dx_out.data.size()); - dscale.data = handle.Read(dscale_dev, dscale.data.size()); - dshift.data = handle.Read(dshift_dev, dshift.data.size()); - - return std::make_tuple(dx_out, dscale, dshift); - } - - void fail(int badtensor) const - { - std::cout << "Backward Batch Spatial Normalization Recalc Mean and Variance: " << std::endl; - std::cout << "X Input tensor: " << x_input.desc.ToString() << std::endl; - std::cout << "Delta Y Input tensor: " << dy_input.desc.ToString() << std::endl; - switch(badtensor) - { - case(0): - std::cout << "Delta X output tensor output failed verification." << std::endl; - break; - case(1): std::cout << "Delta scale output tensor failed verification." << std::endl; break; - case(2): std::cout << "Delta shift output tensor failed verification." << std::endl; break; - default: break; - } - } -}; - -template -struct verify_backward_bn_spatial_use_saved -{ - const tensor x_input; - const tensor dy_input; - const tensor scale; - const tensor savedMean; - const tensor savedInvVar; - std::tuple, tensor, tensor> cpu() const - { - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - const auto nhw = double(height * width * n_batch); - - par_for(channels, 1, [&](int cidx) { - double elemStd = 0.; - unsigned int xhat_index; - double mean = savedMean(0, 0, 0, cidx); - double invVar = savedInvVar(0, 0, 0, cidx); - double dyelem = 0.; - - std::vector xhat(n_batch * height * width, 0.0); - std::vector dshift_accum_arr(height, 0.0); - std::vector dscale_accum_arr(height, 0.0); - dscale(0, cidx, 0, 0) = 0.; - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - elemStd = x_input(bidx, cidx, row, column) - mean; - xhat[xhat_index] = elemStd * invVar; - dyelem = dy_input(bidx, cidx, row, column); - dshift_accum_arr[row] += dyelem; - dscale_accum_arr[row] += xhat[xhat_index] * dyelem; - } - } - } - for(std::size_t i = 0; i < height; i++) - { - dshift(0, cidx, 0, 0) += dshift_accum_arr[i]; - dscale(0, cidx, 0, 0) += dscale_accum_arr[i]; - } - - for(std::size_t row = 0; row < height; row++) - { - for(std::size_t column = 0; column < width; column++) - { - for(std::size_t bidx = 0; bidx < n_batch; bidx++) - { - xhat_index = height * width * bidx + (width * row + column); - - double tmp1 = - nhw * dy_input(bidx, cidx, row, column) - dshift(0, cidx, 0, 0); - double tmp2 = -xhat[xhat_index] * dscale(0, cidx, 0, 0); - double tmp3 = (scale(0, 0, 0, cidx) * invVar) / nhw; - dx_out(bidx, cidx, row, column) = tmp3 * (tmp2 + tmp1); - } - } - } - }); - - return std::make_tuple(dx_out, dscale, dshift); - } - - std::tuple, tensor, tensor> gpu() const - { - auto&& handle = get_handle(); - - std::size_t n_batch, channels, height, width; - std::tie(n_batch, channels, height, width) = miopen::tien<4>(x_input.desc.GetLengths()); - - auto dx_out = dy_input; - std::fill(dx_out.begin(), dx_out.end(), 0); - - std::size_t ss_n_batch, ss_channels, ss_height, ss_width; - auto derivedBnDesc = - miopen::TensorDescriptor(x_input.desc.GetType(), - std::vector{1, 1, 1, channels}, - std::vector{channels, channels, channels, 1}); - std::tie(ss_n_batch, ss_height, ss_width, ss_channels) = - miopen::tien<4>(derivedBnDesc.GetLengths()); - - auto dscale = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dscale.begin(), dscale.end(), 0); - - auto dshift = tensor{ss_n_batch, ss_channels, ss_height, ss_width}; - std::fill(dshift.begin(), dshift.end(), 0); - - float alpha = 1.0; - float beta = 0.0; - - auto xin_dev = handle.Write(x_input.data); - auto dyin_dev = handle.Write(dy_input.data); - auto scale_dev = handle.Write(scale.data); - auto dscale_dev = handle.Write(dscale.data); - auto dshift_dev = handle.Write(dshift.data); - auto dx_out_dev = handle.Write(dx_out.data); - auto savedMean_dev = handle.Write(savedMean.data); - auto savedInvVar_dev = handle.Write(savedInvVar.data); - - double epsilon = MIO_BN_TEST_EPSILON; - - miopen::BatchNormBackward(handle, - miopenBNSpatial, - &alpha, - &beta, - &alpha, - &beta, - x_input.desc, - xin_dev.get(), - dy_input.desc, - dyin_dev.get(), - dx_out.desc, - dx_out_dev.get(), - scale.desc, - scale_dev.get(), - dscale_dev.get(), - dshift_dev.get(), - epsilon, - savedMean_dev.get(), - savedInvVar_dev.get()); - - dx_out.data = handle.Read(dx_out_dev, dx_out.data.size()); - dscale.data = handle.Read(dscale_dev, dscale.data.size()); - dshift.data = handle.Read(dshift_dev, dshift.data.size()); - - return std::make_tuple(dx_out, dscale, dshift); - } - - void fail(int badtensor) const - { - std::cout << "Backward Batch Spatial Normalization Use Saved Mean and Variance: " - << std::endl; - std::cout << "X Input tensor: " << x_input.desc.ToString() << std::endl; - std::cout << "Delta Y Input tensor: " << dy_input.desc.ToString() << std::endl; - switch(badtensor) - { - case(0): - std::cout << "Delta X output tensor output failed verification." << std::endl; - break; - case(1): std::cout << "Delta scale output tensor failed verification." << std::endl; break; - case(2): std::cout << "Delta shift output tensor failed verification." << std::endl; break; - default: break; - } - } -}; - -template -struct batch_norm_spatial_nhwc_driver : test_driver -{ - tensor input; - tensor scale; - tensor shift; - batch_norm_spatial_nhwc_driver() - { - this->batch_factor = 4; - add(input, - "input", - get_bn_spatial_input_tensor( - tensor_elem_gen_integer{miopen_type{} == miopenHalf ? 5 : 17})); - } - - void run() - { - std::size_t n, c, h, w; - std::tie(n, c, h, w) = miopen::tien<4>(input.desc.GetLengths()); - - std::size_t ssn, ssc, ssh, ssw; - auto derivedBnDesc = miopen::TensorDescriptor(input.desc.GetType(), - std::vector{1, 1, 1, c}, - std::vector{c, c, c, 1}); - std::tie(ssn, ssh, ssw, ssc) = miopen::tien<4>(derivedBnDesc.GetLengths()); - - std::vector new_len = input.desc.GetLengths(); - std::vector new_str; - miopen::tensor_layout_to_strides(new_len, "NCHW", "NHWC", new_str); - input.desc = miopen::TensorDescriptor(miopen_type{}, new_len, new_str); - - if(input.desc.GetType() == miopenFloat) - { - scale = tensor{ssn, ssh, ssw, ssc}.generate(tensor_elem_gen_integer{17}); - shift = tensor{ssn, ssh, ssw, ssc}.generate(tensor_elem_gen_integer{17}); - } - else - { - scale = tensor{ssn, ssh, ssw, ssc}; - shift = tensor{ssn, ssh, ssw, ssc}; - - const PREC_TYPE Data_scale = static_cast(1e-4); - for(std::size_t i = 0; i < scale.desc.GetElementSize(); i++) - { - scale[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - shift[i] = prng::gen_descreet_uniform_sign(Data_scale, 100); - } - for(std::size_t i = 0; i < input.desc.GetElementSize(); i++) - { - input[i] = prng::gen_descreet_uniform_sign(static_cast(1e-5), 100); - } - } - - auto outpair = verify(verify_forward_train_bn_spatial{input, scale, shift}); - - auto dy_input = std::get<0>(outpair.second); - for(std::size_t bidx = 0; bidx < n; bidx++) - { - for(std::size_t cidx = 0; cidx < c; cidx++) - { - for(std::size_t row = 0; row < h; row++) - { - for(std::size_t column = 0; column < w; column++) - { - dy_input(bidx, cidx, row, column) *= 0.1; - } - } - } - } - this->tolerance = 80 * input.desc.GetElementSize(); - verify(verify_backward_bn_spatial_recalc{input, dy_input, scale}); - - auto savedMean = std::get<3>(outpair.second); - auto savedInvVar = std::get<4>(outpair.second); - verify(verify_backward_bn_spatial_use_saved{ - input, dy_input, scale, savedMean, savedInvVar}); - } -}; - -int main(int argc, const char* argv[]) -{ - test_drive(argc, argv); - return 0; -} diff --git a/test/fusionHost.hpp b/test/fusionHost.hpp index cffefea0e2..5374abd1fa 100644 --- a/test/fusionHost.hpp +++ b/test/fusionHost.hpp @@ -36,7 +36,6 @@ #include #include #include -// #include "driver.hpp" #include "get_handle.hpp" #include "tensor_holder.hpp" #include "verify.hpp" @@ -203,17 +202,17 @@ void batchNormPerActivHostInference(const tensor& input, }); } -template +template void batchNormSpatialHostFwdTrain(const tensor& input, tensor& out, const tensor& scale, const tensor& bias, double epsilon, double expAvgFactor, - tensor& saveMean, - tensor& saveInvVar, - tensor& runMean, - tensor& runVar) + tensor& saveMean, + tensor& saveInvVar, + tensor& runMean, + tensor& runVar) { int height, width, n_batch, channels; @@ -279,15 +278,15 @@ void batchNormSpatialHostFwdTrain(const tensor& input, }); } -template -void batchNormSpatialHostBwdTrain(const tensor& x_input, - const tensor& dy_input, - tensor& dx_out, - const tensor& scale, - tensor& dscale, - tensor& dbias, - const tensor& savedMean, - const tensor& savedInvVar) +template +void batchNormSpatialHostBwdTrain(const tensor& x_input, + const tensor& dy_input, + tensor& dx_out, + const tensor& scale, + tensor& dscale, + tensor& dbias, + const tensor& savedMean, + const tensor& savedInvVar) { int height, width, n_batch, channels; @@ -335,7 +334,7 @@ void batchNormSpatialHostBwdTrain(const tensor& x_input, double tmp1 = nhw * dy_input(bidx, cidx, row, column) - dbias(0, cidx, 0, 0); double tmp2 = -xhat[xhat_index] * dscale(0, cidx, 0, 0); double tmp3 = (scale(0, cidx, 0, 0) * invVar) / nhw; - dx_out(bidx, cidx, row, column) = static_cast(tmp3 * (tmp2 + tmp1)); + dx_out(bidx, cidx, row, column) = static_cast(tmp3 * (tmp2 + tmp1)); } // end for(n_batchs) } // for (column) } // for (row) diff --git a/test/gtest/bn.hpp b/test/gtest/bn.hpp index 0b763da411..22f8391fe6 100644 --- a/test/gtest/bn.hpp +++ b/test/gtest/bn.hpp @@ -84,3 +84,174 @@ struct BNInferTest : public ::testing::TestWithParam +struct BNBwdTest : public ::testing::TestWithParam> +{ +protected: + void SetUp() override + { + test_skipped = false; + std::tie(bn_config, tensor_layout) = GetParam(); + bn_bwd_test_data.SetUpImpl(bn_config, tensor_layout); + + auto&& handle = get_handle(); + miopenBatchNormalizationBackward(&handle, + bn_config.mode, + &bn_bwd_test_data.alphaDataDiff, + &bn_bwd_test_data.betaDataDiff, + &bn_bwd_test_data.alphaParamDiff, + &bn_bwd_test_data.betaParamDiff, + &bn_bwd_test_data.input.desc, + bn_bwd_test_data.in_dev.get(), + &bn_bwd_test_data.dy.desc, + bn_bwd_test_data.dy_dev.get(), + &bn_bwd_test_data.output.desc, + bn_bwd_test_data.out_dev.get(), + &bn_bwd_test_data.bnScale.desc, + bn_bwd_test_data.bnScale_dev.get(), + bn_bwd_test_data.dScale_dev.get(), + bn_bwd_test_data.dBias_dev.get(), + bn_bwd_test_data.epsilon, + bn_bwd_test_data.savedMean_dev.get(), + bn_bwd_test_data.savedInvVar_dev.get()); + + std::fill(bn_bwd_test_data.output.begin(), + bn_bwd_test_data.output.end(), + std::numeric_limits::quiet_NaN()); + } + + void TearDown() override + { + if(test_skipped) + return; + auto&& handle = get_handle(); + bn_bwd_test_data.output.data = + handle.Read(bn_bwd_test_data.out_dev, bn_bwd_test_data.output.data.size()); + bn_bwd_test_data.dScale.data = handle.Read(bn_bwd_test_data.dScale_dev, + bn_bwd_test_data.dScale.data.size()); + bn_bwd_test_data.dBias.data = + handle.Read(bn_bwd_test_data.dBias_dev, bn_bwd_test_data.dBias.data.size()); + + test::ComputeCPUBNBwd(bn_bwd_test_data); + + // using tolerance = 1e-4 since this the tolerance CK uses + test::CompareTensor(bn_bwd_test_data.output, bn_bwd_test_data.ref_out, 1e-4); + test::CompareTensor(bn_bwd_test_data.dScale, bn_bwd_test_data.dScale_ref, 1e-4); + test::CompareTensor(bn_bwd_test_data.dBias, bn_bwd_test_data.dBias_ref, 1e-4); + } + + BNTestCase bn_config; + bool test_skipped = false; + BNBwdTestData + bn_bwd_test_data; + miopenTensorLayout_t tensor_layout; +}; + +template +struct BNFwdTrainTest + : public ::testing::TestWithParam> +{ +protected: + void SetUp() override + { + test_skipped = false; + std::tie(bn_config, tensor_layout) = GetParam(); + bn_fwd_train_test_data.SetUpImpl(bn_config, tensor_layout); + + auto&& handle = get_handle(); + miopenBatchNormalizationForwardTraining(&handle, + bn_config.mode, + &bn_fwd_train_test_data.alpha, + &bn_fwd_train_test_data.beta, + &bn_fwd_train_test_data.input.desc, + bn_fwd_train_test_data.in_dev.get(), + &bn_fwd_train_test_data.output.desc, + bn_fwd_train_test_data.out_dev.get(), + &bn_fwd_train_test_data.scale.desc, + bn_fwd_train_test_data.scale_dev.get(), + bn_fwd_train_test_data.shift_dev.get(), + bn_fwd_train_test_data.averageFactor, + bn_fwd_train_test_data.runMean_dev.get(), + bn_fwd_train_test_data.runVariance_dev.get(), + bn_fwd_train_test_data.epsilon, + bn_fwd_train_test_data.saveMean_dev.get(), + bn_fwd_train_test_data.saveVariance_dev.get()); + + std::fill(bn_fwd_train_test_data.output.begin(), + bn_fwd_train_test_data.output.end(), + std::numeric_limits::quiet_NaN()); + std::fill(bn_fwd_train_test_data.saveMean_ref.begin(), + bn_fwd_train_test_data.saveMean_ref.end(), + std::numeric_limits::quiet_NaN()); + std::fill(bn_fwd_train_test_data.saveVariance_ref.begin(), + bn_fwd_train_test_data.saveVariance_ref.end(), + std::numeric_limits::quiet_NaN()); + } + + void TearDown() override + { + if(test_skipped) + return; + auto&& handle = get_handle(); + bn_fwd_train_test_data.output.data = handle.Read( + bn_fwd_train_test_data.out_dev, bn_fwd_train_test_data.output.data.size()); + + bn_fwd_train_test_data.saveMean.data = handle.Read( + bn_fwd_train_test_data.saveMean_dev, bn_fwd_train_test_data.saveMean.data.size()); + bn_fwd_train_test_data.saveVariance.data = + handle.Read(bn_fwd_train_test_data.saveVariance_dev, + bn_fwd_train_test_data.saveVariance_ref.data.size()); + bn_fwd_train_test_data.runMean.data = handle.Read( + bn_fwd_train_test_data.runMean_dev, bn_fwd_train_test_data.runMean_ref.data.size()); + bn_fwd_train_test_data.runVariance.data = + handle.Read(bn_fwd_train_test_data.runVariance_dev, + bn_fwd_train_test_data.runVariance_ref.data.size()); + test::ComputeCPUBNFwdTrain(bn_fwd_train_test_data); + + // 4e-3 is tolerance used by CK kernel. + test::CompareTensor( + bn_fwd_train_test_data.output, bn_fwd_train_test_data.ref_out, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.saveMean, bn_fwd_train_test_data.saveMean_ref, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.saveVariance, bn_fwd_train_test_data.saveVariance_ref, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.runMean, bn_fwd_train_test_data.runMean_ref, 4e-3); + test::CompareTensor( + bn_fwd_train_test_data.runVariance, bn_fwd_train_test_data.runVariance_ref, 4e-3); + } + + BNTestCase bn_config; + bool test_skipped = false; + BNFwdTrainTestData + bn_fwd_train_test_data; + miopenTensorLayout_t tensor_layout; +}; diff --git a/test/gtest/bn_bwd.cpp b/test/gtest/bn_bwd.cpp new file mode 100644 index 0000000000..722b42e872 --- /dev/null +++ b/test/gtest/bn_bwd.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "bn.hpp" + +struct BNBwdTestTestHalf + : BNBwdTest +{ +}; + +struct BNBwdTestFloat : BNBwdTest +{ +}; + +struct BNBwdTestBFloat16 : BNBwdTest +{ +}; + +struct BNBwdTestDouble : BNBwdTest +{ +}; + +TEST_P(BNBwdTestTestHalf, BnBwdCKHalf) {} + +TEST_P(BNBwdTestFloat, BnBwdCKFloat) {} + +// Currently disabled since miopen::batchnorm::MakeForwardTrainingNetworkConfig +// only supports half and float +TEST_P(BNBwdTestBFloat16, DISABLED_BnBwdCKBFloat16) {} +TEST_P(BNBwdTestDouble, DISABLED_BnBwdCKDouble) {} + +INSTANTIATE_TEST_SUITE_P(BNBwdTestTestHalfNHWCSuite, + BNBwdTestTestHalf, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNBwdTestFloatNHWCSuite, + BNBwdTestFloat, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNBwdTestBFloat16NHWCSuite, + BNBwdTestBFloat16, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNBwdTestDoubleNHWCSuite, + BNBwdTestDouble, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); diff --git a/test/gtest/bn_fwd_train.cpp b/test/gtest/bn_fwd_train.cpp new file mode 100644 index 0000000000..4a4dd4c728 --- /dev/null +++ b/test/gtest/bn_fwd_train.cpp @@ -0,0 +1,73 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "bn.hpp" + +struct BNFwdTrainTestHalf + : BNFwdTrainTest +{ +}; + +struct BNFwdTrainTestFloat : BNFwdTrainTest +{ +}; + +struct BNFwdTrainTestDouble : BNFwdTrainTest +{ +}; + +struct BNFwdTrainTestBFloat16 : BNFwdTrainTest +{ +}; + +TEST_P(BNFwdTrainTestHalf, BnFwdTrainCKHalf) {} + +TEST_P(BNFwdTrainTestFloat, BnFwdTrainCKFloat) {} + +// Currently disabled since miopen::batchnorm::MakeForwardTrainingNetworkConfig +// only supports half and float +TEST_P(BNFwdTrainTestDouble, DISABLED_BnFwdTrainCKDouble) {} +TEST_P(BNFwdTrainTestBFloat16, DISABLED_BnFwdTrainCKBFloat16) {} + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestHalfNHWCSuite, + BNFwdTrainTestHalf, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestFloatNHWCSuite, + BNFwdTrainTestFloat, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestFloatNHWCSuite, + BNFwdTrainTestDouble, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); + +INSTANTIATE_TEST_SUITE_P(BNFwdTrainTestFloatNHWCSuite, + BNFwdTrainTestBFloat16, + testing::Combine(testing::ValuesIn(Network1()), + testing::Values(miopenTensorNHWC))); diff --git a/test/gtest/bn_infer.cpp b/test/gtest/bn_infer.cpp index 6598ef7169..0dceaa1ba5 100644 --- a/test/gtest/bn_infer.cpp +++ b/test/gtest/bn_infer.cpp @@ -43,14 +43,14 @@ struct BNInferTestBFloat16 : BNInferTest +#include "random.hpp" #include #include @@ -60,7 +59,8 @@ std::vector Network1() { // pyt_mlperf_resnet50v1.5 return { - {16, 8, 128, 256, miopenBNSpatial, miopen::batchnorm::Direction::ForwardInference, 1, 0}, + {192, 1, 8, 8, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 1, 0}, + {16, 8, 128, 256, miopenBNSpatial, miopen::batchnorm::Direction::ForwardTraining, 1, 0}, {16, 8, 128, 256, miopenBNSpatial, miopen::batchnorm::Direction::ForwardInference, 1, 0}, {64, 2048, 7, 7, miopenBNSpatial, miopen::batchnorm::Direction::Backward, 0, 1}, {64, 2048, 7, 7, miopenBNSpatial, miopen::batchnorm::Direction::ForwardTraining, 1, 1}, @@ -125,7 +125,7 @@ struct BNTestData { input = tensor{miopen_type{}, tensor_layout, bn_config.GetInput()}; output = tensor{miopen_type{}, tensor_layout, bn_config.GetInput()}; - ref_out = output; + ref_out = tensor{miopen_type{}, tensor_layout, bn_config.GetInput()}; } void InitTensorsWithRandValue() @@ -226,3 +226,218 @@ struct BNInferTestData : public BNTestData estVariance_dev = handle.Write(estVariance.data); } }; + +template +struct BNBwdTestData : public BNTestData +{ + void SetUpImpl(const TConfig& config, miopenTensorLayout_t t_layout) + { + BNTestData::SetUpImpl(config, t_layout); + CreateTensors(); + InitTensorsWithRandValue(); + WriteToGPU(); + } + + tensor bnScale; + + tensor savedMean; + tensor savedInvVar; + + tensor dy; + tensor dScale; + tensor dBias; + tensor dScale_ref; + tensor dBias_ref; + + miopen::Allocator::ManageDataPtr bnScale_dev; + miopen::Allocator::ManageDataPtr savedMean_dev; + miopen::Allocator::ManageDataPtr savedInvVar_dev; + + miopen::Allocator::ManageDataPtr dy_dev; + miopen::Allocator::ManageDataPtr dScale_dev; + miopen::Allocator::ManageDataPtr dBias_dev; + miopen::Allocator::ManageDataPtr dScale_ref_dev; + miopen::Allocator::ManageDataPtr dBias_ref_dev; + double epsilon = std::numeric_limits::epsilon(); + + float alphaDataDiff = static_cast(1), betaDataDiff = static_cast(0); + float alphaParamDiff = static_cast(1), betaParamDiff = static_cast(0); + +private: + void CreateTensors() + { + dy = tensor{miopen_type{}, + BNTestData::tensor_layout, + BNTestData::bn_config.GetInput()}; + + auto derivedBnDesc = miopen::TensorDescriptor{}; + miopen::DeriveBNTensorDescriptor(derivedBnDesc, + BNTestData::input.desc, + BNTestData::bn_mode); + bnScale = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + savedMean = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + savedInvVar = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + dScale = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + dBias = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + dScale_ref = dScale; + dBias_ref = dBias; + } + + void InitTensorsWithRandValue() + { + auto gen_value = [](auto...) { + return prng::gen_descreet_uniform_sign(static_cast(1e-2), 100); + }; + dy.generate(gen_value); + bnScale.generate(gen_value); + savedMean.generate(gen_value); + + auto gen_var = [](auto...) { + return static_cast(1e-2) * + static_cast(prng::gen_0_to_B(100) + 1); + }; + savedInvVar.generate(gen_var); + + std::fill(dScale.begin(), dScale.end(), 0.); + std::fill(dBias.begin(), dBias.end(), 0.); + + std::fill(dScale_ref.begin(), dScale_ref.end(), 0.); + std::fill(dBias_ref.begin(), dBias_ref.end(), 0.); + } + void WriteToGPU() + { + auto&& handle = get_handle(); + + bnScale_dev = handle.Write(bnScale.data); + savedMean_dev = handle.Write(savedMean.data); + savedInvVar_dev = handle.Write(savedInvVar.data); + dy_dev = handle.Write(dy.data); + + dScale_dev = handle.Write(dScale.data); + dBias_dev = handle.Write(dBias.data); + } +}; + +template +struct BNFwdTrainTestData : public BNTestData +{ + void SetUpImpl(const TConfig& config, miopenTensorLayout_t t_layout) + { + BNTestData::SetUpImpl(config, t_layout); + CreateTensors(); + InitTensorsWithRandValue(); + WriteToGPU(); + } + + tensor scale; + tensor shift; + tensor saveMean; + tensor saveVariance; + tensor runMean; + tensor runVariance; + + tensor saveMean_ref; + tensor saveVariance_ref; + tensor runMean_ref; + tensor runVariance_ref; + + miopen::Allocator::ManageDataPtr scale_dev; + miopen::Allocator::ManageDataPtr shift_dev; // bias + miopen::Allocator::ManageDataPtr saveMean_dev; + miopen::Allocator::ManageDataPtr saveVariance_dev; + miopen::Allocator::ManageDataPtr runMean_dev; + miopen::Allocator::ManageDataPtr runVariance_dev; + double epsilon = 1.0e-5; + double averageFactor = 0.1; + float alpha = static_cast(1.0f); + float beta = static_cast(0); + const float activ_alpha = static_cast(0.5f); + const float activ_beta = static_cast(0.5f); + const float activ_gamma = static_cast(0.5f); + +private: + void CreateTensors() + { + auto derivedBnDesc = miopen::TensorDescriptor{}; + miopen::DeriveBNTensorDescriptor(derivedBnDesc, + BNTestData::input.desc, + BNTestData::bn_mode); + scale = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + shift = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + saveMean = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + saveVariance = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + runMean = tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + runVariance = + tensor{miopen_type{}, + BNTestData::tensor_layout, + derivedBnDesc.GetLengths()}; + } + + void InitTensorsWithRandValue() + { + auto gen_value = [](auto...) { + return prng::gen_descreet_uniform_sign(static_cast(1e-2), 100); + }; + scale.generate(gen_value); + shift.generate(gen_value); + + auto gen_var = [](auto...) { + return static_cast(1e-2) * + static_cast(prng::gen_0_to_B(100) + 1); + }; + runMean.generate(gen_var); + runVariance.generate(gen_var); + + saveMean_ref = saveMean; + saveVariance_ref = saveVariance; + runMean_ref = runMean; + runVariance_ref = runVariance; + } + void WriteToGPU() + { + auto&& handle = get_handle(); + scale_dev = handle.Write(scale.data); + shift_dev = handle.Write(shift.data); + saveMean_dev = handle.Write(saveMean.data); + saveVariance_dev = handle.Write(saveVariance.data); + runMean_dev = handle.Write(runMean.data); + runVariance_dev = handle.Write(runVariance.data); + } +}; diff --git a/test/gtest/test_operations.hpp b/test/gtest/test_operations.hpp index d1528fe2bb..da41212302 100644 --- a/test/gtest/test_operations.hpp +++ b/test/gtest/test_operations.hpp @@ -38,6 +38,41 @@ void ComputeCPUBNInference(DLModule& dl_module) dl_module.estVariance); } +template +void ComputeCPUBNBwd(DLModule& dl_module) +{ + batchNormSpatialHostBwdTrain(dl_module.input, + dl_module.dy, + dl_module.ref_out, + dl_module.bnScale, + dl_module.dScale_ref, + dl_module.dBias_ref, + dl_module.savedMean, + dl_module.savedInvVar); +} + +template +void ComputeCPUBNFwdTrain(DLModule& dl_module) +{ + batchNormSpatialHostFwdTrain(dl_module.input, + dl_module.ref_out, + dl_module.scale, + dl_module.shift, + dl_module.epsilon, + dl_module.averageFactor, + dl_module.saveMean_ref, + dl_module.saveVariance_ref, + dl_module.runMean_ref, + dl_module.runVariance_ref); +} + template void CompareTensor(const tensor& output, const tensor& ref_out, From 14118a413eec00071800d4efa48ef0199bbbabd5 Mon Sep 17 00:00:00 2001 From: amberhassaan Date: Thu, 5 Oct 2023 18:27:19 -0400 Subject: [PATCH 2/2] Reference kernel for 3D convolution for non-packed tensors (#2334) --- src/CMakeLists.txt | 1 + src/hip/hip_build_utils.cpp | 2 +- src/include/miopen/hipoc_kernel.hpp | 24 +- .../miopen/solver/conv_direct_naive_conv.hpp | 95 +- .../gpu_reference_kernel/fp8_kern_types.h | 6 +- .../gpu_reference_kernel/naive_conv.cpp | 1719 +++++++++++------ src/kernels/stride_array.hpp | 86 + src/solver/conv_direct_naive_conv.cpp | 57 +- src/solver/conv_direct_naive_conv_bwd.cpp | 39 + src/solver/conv_direct_naive_conv_fwd.cpp | 31 +- src/solver/conv_direct_naive_conv_wrw.cpp | 35 + test/gpu_reference_kernel.cpp | 3 +- test/gtest/conv3d_test_case.hpp | 112 ++ test/gtest/group_conv3d_bwd.cpp | 2 +- test/gtest/group_conv3d_bwd.hpp | 88 +- test/gtest/group_conv3d_fwd.cpp | 2 +- test/gtest/group_conv3d_fwd.hpp | 88 +- test/gtest/group_conv3d_wrw.cpp | 2 +- test/gtest/group_conv3d_wrw.hpp | 88 +- test/gtest/group_solver.hpp | 6 +- 20 files changed, 1633 insertions(+), 853 deletions(-) create mode 100644 src/kernels/stride_array.hpp create mode 100644 test/gtest/conv3d_test_case.hpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index abc0679a8a..7866ad1a5a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -390,6 +390,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/workaround_issue_1431.hpp kernels/hip_f8_impl.hpp kernels/hip_float8.hpp + kernels/stride_array.hpp ) set(MIOPEN_KERNELS diff --git a/src/hip/hip_build_utils.cpp b/src/hip/hip_build_utils.cpp index 8f6f9f0c50..86cf3a7272 100644 --- a/src/hip/hip_build_utils.cpp +++ b/src/hip/hip_build_utils.cpp @@ -73,7 +73,7 @@ static boost::filesystem::path HipBuildImpl(boost::optional& tmp_dir, auto env = std::string(""); if(params.find("-std=") == std::string::npos) - params += " --std=c++11"; + params += " --std=c++17"; #if HIP_PACKAGE_VERSION_FLAT < 4001000000ULL params += " --cuda-gpu-arch=" + lots.device; diff --git a/src/include/miopen/hipoc_kernel.hpp b/src/include/miopen/hipoc_kernel.hpp index ba9992bab3..73ac77f160 100644 --- a/src/include/miopen/hipoc_kernel.hpp +++ b/src/include/miopen/hipoc_kernel.hpp @@ -26,14 +26,15 @@ #ifndef GUARD_MIOPEN_HIPOC_KERNEL_HPP #define GUARD_MIOPEN_HIPOC_KERNEL_HPP -#include -#include #include #include #include #include + +#include +#include +#include #include -#include namespace miopen { @@ -47,29 +48,20 @@ inline HipEventPtr make_hip_event() #if 1 // Keep around other storage techinques -- @pfultz2 27.03.2017 -#if 1 // Keep around other storage techinques -- @pfultz2 27.03.2017 template struct KernelArgsPair { - static const int alignment = sizeof(U); - static const int padding = (alignment - sizeof(T) % alignment) % alignment; - static const int second_index = sizeof(T) + padding; + constexpr static auto alignU = alignof(U); + constexpr static auto padding = (alignU - (sizeof(T) % alignU)) % alignU; + constexpr static auto second_index = sizeof(T) + padding; KernelArgsPair(T x, U y) { new(buffer) T(x); // NOLINT (clang-analyzer-cplusplus.PlacementNew) new(buffer + second_index) U(y); } + alignas(U) char buffer[second_index + sizeof(U)] = {}; }; -#else -template -struct KernelArgsPair -{ - KernelArgsPair(T x, U y) : first(x), second(y) {} - T first; - U second; -}; -#endif template struct KernelArgsPack; diff --git a/src/include/miopen/solver/conv_direct_naive_conv.hpp b/src/include/miopen/solver/conv_direct_naive_conv.hpp index 7bad52ff9e..6d935b249d 100644 --- a/src/include/miopen/solver/conv_direct_naive_conv.hpp +++ b/src/include/miopen/solver/conv_direct_naive_conv.hpp @@ -25,9 +25,15 @@ *******************************************************************************/ #pragma once -#include #include #include +#include "miopen/../../kernels/stride_array.hpp" + +#include +#include +#include +#include +#include namespace miopen { @@ -54,5 +60,92 @@ bool IsOutputBfp16(const ProblemDescription&); bool IsOutputInt8(const ProblemDescription&); bool IsOutputInt32(const ProblemDescription&); +namespace conv_internal { + +void DebugPrintTensorStrides(const TensorDescriptor& inDesc, + const TensorDescriptor& wDesc, + const TensorDescriptor& outDesc); + +/** + * Get the index where group (G) stride should go. For NCHW, we want to convert + * its strides to NGCHW, and for NHWC, we want to convert its strides to NHWGC. + * Same applies for the 3D case. + */ +int GetGroupStrideIndex(const ProblemDescription& problem); + +/** + * split the strides for C dimension in a tensor descriptor into (G, C_per_group). + * Normally, (in packed case) num channels is a multiplying factor in the stride of + * whatever lies to the left of C, e.g., in NCHW, N's stride contains C as a + * factor. We output NGCHW for NCHW (and NHWGC for NHWC) + * where the stride[G] = stride[N] / num_groups + */ +template +V SplitStrideCtoGC(int num_groups, const V& orig_strides, int G_stride_idx) +{ + assert(G_stride_idx > 0 && G_stride_idx <= orig_strides.size()); + // (G_stride_idx - 1) is the stride index of whatever lies to the left and + // contains C or K as a multiplying factor. We divide this value by num_groups + // to get G_stride_val + assert(orig_strides[G_stride_idx - 1] % num_groups == 0); + + V ret{orig_strides}; + auto G_stride_val = orig_strides[G_stride_idx - 1] / num_groups; + + ret.insert(ret.begin() + G_stride_idx, G_stride_val); + + return ret; +} + +/** + * Weight tensor has original dims: [K, C_per_group, Y, X] (2D case) + * We return a new stride vector with strides for [G, K_per_group, C_per_group, Y, X] + * Stride for G is computed as stride[C_per_group] * K_per_group and inserted at + * left most position + */ +template +V SplitWeiStrideKtoGK(int k_per_group, const V& wei_strides) +{ + V ret{wei_strides}; + ret.insert(ret.begin(), wei_strides[0] * k_per_group); + return ret; +} + +template +struct ChooseStride +{ +}; + +template <> +struct ChooseStride<5u> +{ + using type = Strides5D; +}; + +template <> +struct ChooseStride<6u> +{ + using type = Strides6D; +}; + +template +auto MakeStrideArray(V vec) +{ + typename ChooseStride::type ret; + assert(vec.size() == N); + + // MIOpen stores strides for NHWC in NCHW order, i.e. C stride in 2nd from left. + // We sort the input stride vector so that smallest stride is at index 0. This + // (little-endian) order is what naive convolution kernel expects for strides + std::sort(vec.begin(), vec.end()); + + for(unsigned i = 0; i < N; ++i) + { + ret[i] = static_cast(vec[i]); + } + return ret; +} +} // end namespace conv_internal + } // namespace solver } // namespace miopen diff --git a/src/kernels/gpu_reference_kernel/fp8_kern_types.h b/src/kernels/gpu_reference_kernel/fp8_kern_types.h index 3bac0a31f7..b14302e0c2 100644 --- a/src/kernels/gpu_reference_kernel/fp8_kern_types.h +++ b/src/kernels/gpu_reference_kernel/fp8_kern_types.h @@ -58,6 +58,6 @@ #define KERNEL_NAME_SUFFIX CAT(CAT(INPUT_TYPE, _), CAT(CAT(WEIGHTS_TYPE, _), OUTPUT_TYPE)) -#define FWD_KERNEL_NAME CAT(naive_conv_fwd_nchw_, KERNEL_NAME_SUFFIX) -#define BWD_KERNEL_NAME CAT(naive_conv_bwd_nchw_, KERNEL_NAME_SUFFIX) -#define WRW_KERNEL_NAME CAT(naive_conv_wrw_nchw_, KERNEL_NAME_SUFFIX) +#define FWD_KERNEL_NAME CAT(naive_conv_packed_fwd_nchw_, KERNEL_NAME_SUFFIX) +#define BWD_KERNEL_NAME CAT(naive_conv_packed_bwd_nchw_, KERNEL_NAME_SUFFIX) +#define WRW_KERNEL_NAME CAT(naive_conv_packed_wrw_nchw_, KERNEL_NAME_SUFFIX) diff --git a/src/kernels/gpu_reference_kernel/naive_conv.cpp b/src/kernels/gpu_reference_kernel/naive_conv.cpp index 24d7cd489e..b243b1234a 100644 --- a/src/kernels/gpu_reference_kernel/naive_conv.cpp +++ b/src/kernels/gpu_reference_kernel/naive_conv.cpp @@ -46,6 +46,8 @@ typedef float float_t; #endif #endif // __HIPCC_RTC__ +#include "stride_array.hpp" + // hcc seems need __device__ __host__ together to compile, and no extern "C" typedef union value_bf16_fp32_t { @@ -114,10 +116,27 @@ inline __device__ __host__ int8_t cast_to(const int32_t& val) return static_cast(val & 0xff); } -template +/// \todo remove template parameter 'bool ASSUME_PACKED' in a follow up PR +/// --amberhassaan +/// Notes (Amber): +/// - The following code used to assume that group (G) is an implicit +/// dimension, i.e. c= c_per_group * group and k = k_per_group * group. This is not +/// true for non-packed case because group (G) dimension needs to have its stride +/// explicitly specified for address math to make sense. This is also how +/// composable_kernel (CK) treats G dimension. Which is why nchw should be ngchw, +/// and nhwc should be nhwgc. Same follows for the 3D case. +/// +/// - strides here are in the little-endian order, i.e., for NHWC, stride for N is +/// at index 3 while stride for C is at index 0. This is reverse of how strides are +/// stored in tensor descriptors, which are big-endian. + +template inline __device__ void naive_conv_fwd_nchw(const src_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, dst_data_t* __restrict__ p_out, + Strides5D in_strides, + Strides5D wei_strides, + Strides5D out_strides, int hi, int wi, int n, @@ -148,18 +167,36 @@ inline __device__ void naive_conv_fwd_nchw(const src_data_t* __restrict__ p_in, int in = (bid / k_per_group) % n; int ig = bid / (n * k_per_group); - p_in += static_cast(in) * c * hi * wi + static_cast(ig) * c_per_group * hi * wi; - p_wei += static_cast(ig) * k_per_group * c_per_group * fy * fx + - static_cast(ik) * c_per_group * fy * fx; - p_out += static_cast(in) * k * ho * wo + - static_cast(ig) * k_per_group * ho * wo + static_cast(ik) * ho * wo; + if constexpr(ASSUME_PACKED) + { + p_in += + static_cast(in) * c * hi * wi + static_cast(ig) * c_per_group * hi * wi; + + p_wei += static_cast(ig) * k_per_group * c_per_group * fy * fx + + static_cast(ik) * c_per_group * fy * fx; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += static_cast(in) * k * ho * wo + + static_cast(ig) * k_per_group * ho * wo + + static_cast(ik) * ho * wo; + } + else + { + p_in += static_cast(in) * in_strides[4] + static_cast(ig) * in_strides[3]; + + p_wei += + static_cast(ig) * wei_strides[4] + static_cast(ik) * wei_strides[3]; + + p_out += static_cast(in) * out_strides[4] + + static_cast(ig) * out_strides[3] + + static_cast(ik) * out_strides[2]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int iho = tid / wo; int iwo = tid % wo; - double value = .0f; + acc_data_t value = 0; for(int ic = 0; ic < c_per_group; ic++) { @@ -178,25 +215,58 @@ inline __device__ void naive_conv_fwd_nchw(const src_data_t* __restrict__ p_in, if(valid_w & valid_h) { - size_t i_idx = static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + static_cast(cur_w); - size_t f_idx = static_cast(ic) * fy * fx + - static_cast(iy) * fx + static_cast(ix); - value += cast_to(p_in[i_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + + size_t f_idx = static_cast(ic) * fy * fx + + static_cast(iy) * fx + static_cast(ix); + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t i_idx = static_cast(ic) * in_strides[2] + + static_cast(cur_h) * in_strides[1] + + static_cast(cur_w) * in_strides[0]; + + size_t f_idx = static_cast(ic) * wei_strides[2] + + static_cast(iy) * wei_strides[1] + + static_cast(ix) * wei_strides[0]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } } } } } - size_t o_idx = static_cast(iho) * wo + static_cast(iwo); - p_out[o_idx] = cast_to(value); + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(iho) * wo + static_cast(iwo); + + p_out[o_idx] = cast_to(value); + } + else + { + size_t o_idx = static_cast(iho) * out_strides[1] + + static_cast(iwo) * out_strides[0]; + + p_out[o_idx] = cast_to(value); + } } } -template +template inline __device__ void naive_conv_bwd_nchw(dst_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides5D in_strides, + Strides5D wei_strides, + Strides5D out_strides, int hi, int wi, int n, @@ -227,19 +297,35 @@ inline __device__ void naive_conv_bwd_nchw(dst_data_t* __restrict__ p_in, int in = (bid / c_per_group) % n; int ig = bid / (n * c_per_group); - p_in += static_cast(in) * c * hi * wi + - static_cast(ig) * c_per_group * hi * wi + static_cast(ic) * hi * wi; - p_wei += static_cast(ig) * k_per_group * c_per_group * fy * fx + - static_cast(ic) * fy * fx; - p_out += - static_cast(in) * k * ho * wo + static_cast(ig) * k_per_group * ho * wo; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(in) * c * hi * wi + + static_cast(ig) * c_per_group * hi * wi + static_cast(ic) * hi * wi; + + p_wei += static_cast(ig) * k_per_group * c_per_group * fy * fx + + static_cast(ic) * fy * fx; + + p_out += + static_cast(in) * k * ho * wo + static_cast(ig) * k_per_group * ho * wo; + } + else + { + p_in += static_cast(in) * in_strides[4] + static_cast(ig) * in_strides[3] + + static_cast(ic) * in_strides[2]; + + p_wei += + static_cast(ig) * wei_strides[4] + static_cast(ic) * wei_strides[2]; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += + static_cast(in) * out_strides[4] + static_cast(ig) * out_strides[3]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int ihi = tid / wi; int iwi = tid % wi; - double value = .0f; + acc_data_t value = 0; for(int ik = 0; ik < k_per_group; ik++) { @@ -264,26 +350,59 @@ inline __device__ void naive_conv_bwd_nchw(dst_data_t* __restrict__ p_in, if(valid_h & valid_w) { - size_t o_idx = static_cast(ik) * ho * wo + - static_cast(cur_ho) * wo + - static_cast(cur_wo); - size_t f_idx = static_cast(ik) * c_per_group * fy * fx + - static_cast(iy) * fx + static_cast(ix); - value += cast_to(p_out[o_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(ik) * ho * wo + + static_cast(cur_ho) * wo + + static_cast(cur_wo); + + size_t f_idx = static_cast(ik) * c_per_group * fy * fx + + static_cast(iy) * fx + static_cast(ix); + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t o_idx = static_cast(ik) * out_strides[2] + + static_cast(cur_ho) * out_strides[1] + + static_cast(cur_wo) * out_strides[0]; + + size_t f_idx = static_cast(ik) * wei_strides[3] + + static_cast(iy) * wei_strides[1] + + static_cast(ix) * wei_strides[0]; + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } } } } } - size_t i_idx = static_cast(ihi) * wi + static_cast(iwi); - p_in[i_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(ihi) * wi + static_cast(iwi); + + p_in[i_idx] = cast_to(value); + } + else + { + size_t i_idx = + static_cast(ihi) * in_strides[1] + static_cast(iwi) * in_strides[0]; + + p_in[i_idx] = cast_to(value); + } } } -template +template inline __device__ void naive_conv_wrw_nchw(const src_data_t* __restrict__ p_in, dst_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides5D in_strides, + Strides5D wei_strides, + Strides5D out_strides, int hi, int wi, int n, @@ -315,18 +434,34 @@ inline __device__ void naive_conv_wrw_nchw(const src_data_t* __restrict__ p_in, int ik = bid % k_per_group; int ig = bid / k_per_group; - p_in += static_cast(ig) * c_per_group * hi * wi; - p_wei += static_cast(ig) * k_per_group * c_per_group * fy * fx + - static_cast(ik) * c_per_group * fy * fx; - p_out += static_cast(ig) * k_per_group * ho * wo + static_cast(ik) * ho * wo; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(ig) * c_per_group * hi * wi; + + p_wei += static_cast(ig) * k_per_group * c_per_group * fy * fx + + static_cast(ik) * c_per_group * fy * fx; + + p_out += + static_cast(ig) * k_per_group * ho * wo + static_cast(ik) * ho * wo; + } + else + { + p_in += static_cast(ig) * in_strides[3]; + + p_wei += + static_cast(ig) * wei_strides[4] + static_cast(ik) * wei_strides[3]; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += + static_cast(ig) * out_strides[3] + static_cast(ik) * out_strides[2]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int ix = tid % fx; int iy = (tid / fx) % fy; int ic = tid / (fx * fy); - double value = .0f; + acc_data_t value = 0; for(int in = 0; in < n; in++) { @@ -345,28 +480,64 @@ inline __device__ void naive_conv_wrw_nchw(const src_data_t* __restrict__ p_in, if(valid_h & valid_w) { - size_t i_idx = static_cast(in) * c * hi * wi + - static_cast(ic) * hi * wi + - static_cast(cur_h) * wi + static_cast(cur_w); - size_t o_idx = static_cast(in) * k * ho * wo + - static_cast(iho) * wo + static_cast(iwo); - value += cast_to(p_in[i_idx]) * - cast_to(p_out[o_idx]); + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(in) * c * hi * wi + + static_cast(ic) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + + size_t o_idx = static_cast(in) * k * ho * wo + + static_cast(iho) * wo + static_cast(iwo); + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } + else + { + size_t i_idx = static_cast(in) * in_strides[4] + + static_cast(ic) * in_strides[2] + + static_cast(cur_h) * in_strides[1] + + static_cast(cur_w) * in_strides[0]; + + size_t o_idx = static_cast(in) * out_strides[4] + + static_cast(iho) * out_strides[1] + + static_cast(iwo) * out_strides[0]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } } } } } - size_t f_idx = static_cast(ic) * fy * fx + static_cast(iy) * fx + - static_cast(ix); - p_wei[f_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t f_idx = static_cast(ic) * fy * fx + static_cast(iy) * fx + + static_cast(ix); + + p_wei[f_idx] = cast_to(value); + } + else + { + size_t f_idx = static_cast(ic) * wei_strides[2] + + static_cast(iy) * wei_strides[1] + + static_cast(ix) * wei_strides[0]; + + p_wei[f_idx] = cast_to(value); + } } } // design block_size 256 -template +template inline __device__ void naive_conv_fwd_ncdhw(const src_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, dst_data_t* __restrict__ p_out, + Strides6D in_strides, + Strides6D wei_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -405,21 +576,37 @@ inline __device__ void naive_conv_fwd_ncdhw(const src_data_t* __restrict__ p_in, int in = (bid / k_per_group) % n; int ig = bid / (n * k_per_group); - p_in += static_cast(in) * c * di * hi * wi + - static_cast(ig) * c_per_group * di * hi * wi; - p_wei += static_cast(ig) * k_per_group * c_per_group * fz * fy * fx + - static_cast(ik) * c_per_group * fz * fy * fx; - p_out += static_cast(in) * k * do_ * ho * wo + - static_cast(ig) * k_per_group * do_ * ho * wo + - static_cast(ik) * do_ * ho * wo; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(in) * c * di * hi * wi + + static_cast(ig) * c_per_group * di * hi * wi; + + p_wei += static_cast(ig) * k_per_group * c_per_group * fz * fy * fx + + static_cast(ik) * c_per_group * fz * fy * fx; + + p_out += static_cast(in) * k * do_ * ho * wo + + static_cast(ig) * k_per_group * do_ * ho * wo + + static_cast(ik) * do_ * ho * wo; + } + else + { + p_in += static_cast(in) * in_strides[5] + static_cast(ig) * in_strides[4]; + + p_wei += + static_cast(ig) * wei_strides[5] + static_cast(ik) * wei_strides[4]; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += static_cast(in) * out_strides[5] + + static_cast(ig) * out_strides[4] + + static_cast(ik) * out_strides[3]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int iwo = tid % wo; int iho = (tid / wo) % ho; int ido = tid / (ho * wo); - double value = .0f; + acc_data_t value = 0; for(int ic = 0; ic < c_per_group; ic++) { @@ -444,30 +631,67 @@ inline __device__ void naive_conv_fwd_ncdhw(const src_data_t* __restrict__ p_in, if(valid_d & valid_w & valid_h) { - size_t i_idx = static_cast(ic) * di * hi * wi + - static_cast(cur_d) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - size_t f_idx = static_cast(ic) * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + static_cast(ix); - value += cast_to(p_in[i_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(ic) * di * hi * wi + + static_cast(cur_d) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + + size_t f_idx = static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + + static_cast(ix); + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t i_idx = static_cast(ic) * in_strides[3] + + static_cast(cur_d) * in_strides[2] + + static_cast(cur_h) * in_strides[1] + + static_cast(cur_w) * in_strides[0]; + + size_t f_idx = static_cast(ic) * wei_strides[3] + + static_cast(iz) * wei_strides[2] + + static_cast(iy) * wei_strides[1] + + static_cast(ix) * wei_strides[0]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } } } } } } - size_t o_idx = static_cast(ido) * ho * wo + static_cast(iho) * wo + - static_cast(iwo); - p_out[o_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(ido) * ho * wo + static_cast(iho) * wo + + static_cast(iwo); + + p_out[o_idx] = cast_to(value); + } + else + { + size_t o_idx = static_cast(ido) * out_strides[2] + + static_cast(iho) * out_strides[1] + + static_cast(iwo) * out_strides[0]; + + p_out[o_idx] = cast_to(value); + } } } -template +template inline __device__ void naive_conv_bwd_ncdhw(dst_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides6D in_strides, + Strides6D wei_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -506,21 +730,37 @@ inline __device__ void naive_conv_bwd_ncdhw(dst_data_t* __restrict__ p_in, int in = (bid / c_per_group) % n; int ig = bid / (n * c_per_group); - p_in += static_cast(in) * c * di * hi * wi + - static_cast(ig) * c_per_group * di * hi * wi + - static_cast(ic) * di * hi * wi; - p_wei += static_cast(ig) * k_per_group * c_per_group * fz * fy * fx + - static_cast(ic) * fz * fy * fx; - p_out += static_cast(in) * k * do_ * ho * wo + - static_cast(ig) * k_per_group * do_ * ho * wo; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(in) * c * di * hi * wi + + static_cast(ig) * c_per_group * di * hi * wi + + static_cast(ic) * di * hi * wi; + + p_wei += static_cast(ig) * k_per_group * c_per_group * fz * fy * fx + + static_cast(ic) * fz * fy * fx; + + p_out += static_cast(in) * k * do_ * ho * wo + + static_cast(ig) * k_per_group * do_ * ho * wo; + } + else + { + p_in += static_cast(in) * in_strides[5] + static_cast(ig) * in_strides[4] + + static_cast(ic) * in_strides[3]; + + p_wei += + static_cast(ig) * wei_strides[5] + static_cast(ic) * wei_strides[3]; + + p_out += + static_cast(in) * out_strides[5] + static_cast(ig) * out_strides[4]; + } - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int iwi = tid % wi; int ihi = (tid / wi) % hi; int idi = tid / (hi * wi); - double value = .0f; + acc_data_t value = 0; for(int ik = 0; ik < k_per_group; ik++) { @@ -554,30 +794,67 @@ inline __device__ void naive_conv_bwd_ncdhw(dst_data_t* __restrict__ p_in, if(valid_d & valid_h & valid_w) { - size_t o_idx = static_cast(ik) * do_ * ho * wo + - static_cast(cur_do) * ho * wo + - static_cast(cur_ho) * wo + - static_cast(cur_wo); - size_t f_idx = static_cast(ik) * c_per_group * fz * fy * fx + - static_cast(iz) * fy * fx + - static_cast(iy) * fx + static_cast(ix); - value += cast_to(p_out[o_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(ik) * do_ * ho * wo + + static_cast(cur_do) * ho * wo + + static_cast(cur_ho) * wo + + static_cast(cur_wo); + + size_t f_idx = + static_cast(ik) * c_per_group * fz * fy * fx + + static_cast(iz) * fy * fx + + static_cast(iy) * fx + static_cast(ix); + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t o_idx = static_cast(ik) * out_strides[3] + + static_cast(cur_do) * out_strides[2] + + static_cast(cur_ho) * out_strides[1] + + static_cast(cur_wo) * out_strides[0]; + + size_t f_idx = static_cast(ik) * wei_strides[4] + + static_cast(iz) * wei_strides[2] + + static_cast(iy) * wei_strides[1] + + static_cast(ix) * wei_strides[0]; + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } } } } } } - size_t i_idx = static_cast(idi) * hi * wi + static_cast(ihi) * wi + - static_cast(iwi); - p_in[i_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(idi) * hi * wi + static_cast(ihi) * wi + + static_cast(iwi); + + p_in[i_idx] = cast_to(value); + } + else + { + size_t i_idx = static_cast(idi) * in_strides[2] + + static_cast(ihi) * in_strides[1] + + static_cast(iwi) * in_strides[0]; + + p_in[i_idx] = cast_to(value); + } } } -template +template inline __device__ void naive_conv_wrw_ncdhw(const src_data_t* __restrict__ p_in, dst_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides6D in_strides, + Strides6D wei_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -615,20 +892,35 @@ inline __device__ void naive_conv_wrw_ncdhw(const src_data_t* __restrict__ p_in, int ik = bid % k_per_group; int ig = bid / k_per_group; - p_in += static_cast(ig) * c_per_group * di * hi * wi; - p_wei += static_cast(ig) * k_per_group * c_per_group * fz * fy * fx + - static_cast(ik) * c_per_group * fz * fy * fx; - p_out += static_cast(ig) * k_per_group * do_ * ho * wo + - static_cast(ik) * do_ * ho * wo; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(ig) * c_per_group * di * hi * wi; + + p_wei += static_cast(ig) * k_per_group * c_per_group * fz * fy * fx + + static_cast(ik) * c_per_group * fz * fy * fx; + + p_out += static_cast(ig) * k_per_group * do_ * ho * wo + + static_cast(ik) * do_ * ho * wo; + } + else + { + p_in += static_cast(ig) * in_strides[4]; + + p_wei += + static_cast(ig) * wei_strides[5] + static_cast(ik) * wei_strides[4]; + + p_out += + static_cast(ig) * out_strides[4] + static_cast(ik) * out_strides[3]; + } - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int ix = tid % fx; int iy = (tid / fx) % fy; int iz = (tid / (fx * fy)) % fz; int ic = tid / (fx * fy * fz); - double value = .0f; + acc_data_t value = 0; for(int in = 0; in < n; in++) { @@ -653,33 +945,73 @@ inline __device__ void naive_conv_wrw_ncdhw(const src_data_t* __restrict__ p_in, if(valid_d & valid_h & valid_w) { - size_t i_idx = static_cast(in) * c * di * hi * wi + - static_cast(ic) * di * hi * wi + - static_cast(cur_d) * hi * wi + - static_cast(cur_h) * wi + - static_cast(cur_w); - size_t o_idx = static_cast(in) * k * do_ * ho * wo + - static_cast(ido) * ho * wo + - static_cast(iho) * wo + static_cast(iwo); - value += cast_to(p_in[i_idx]) * - cast_to(p_out[o_idx]); + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(in) * c * di * hi * wi + + static_cast(ic) * di * hi * wi + + static_cast(cur_d) * hi * wi + + static_cast(cur_h) * wi + + static_cast(cur_w); + + size_t o_idx = static_cast(in) * k * do_ * ho * wo + + static_cast(ido) * ho * wo + + static_cast(iho) * wo + + static_cast(iwo); + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } + else + { + size_t i_idx = static_cast(in) * in_strides[5] + + static_cast(ic) * in_strides[3] + + static_cast(cur_d) * in_strides[2] + + static_cast(cur_h) * in_strides[1] + + static_cast(cur_w) * in_strides[0]; + + size_t o_idx = static_cast(in) * out_strides[5] + + static_cast(ido) * out_strides[2] + + static_cast(iho) * out_strides[1] + + static_cast(iwo) * out_strides[0]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } } } } } } - size_t f_idx = static_cast(ic) * fz * fy * fx + static_cast(iz) * fy * fx + - static_cast(iy) * fx + static_cast(ix); - p_wei[f_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t f_idx = static_cast(ic) * fz * fy * fx + + static_cast(iz) * fy * fx + static_cast(iy) * fx + + static_cast(ix); + + p_wei[f_idx] = cast_to(value); + } + else + { + size_t f_idx = static_cast(ic) * wei_strides[3] + + static_cast(iz) * wei_strides[2] + + static_cast(iy) * wei_strides[1] + + static_cast(ix) * wei_strides[0]; + + p_wei[f_idx] = cast_to(value); + } } } /***************************** nhwc *****************************/ // design block_size 256 -template +template inline __device__ void naive_conv_fwd_nhwc(const src_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, dst_data_t* __restrict__ p_out, + Strides5D in_strides, + Strides5D wei_strides, + Strides5D out_strides, int hi, int wi, int n, @@ -711,17 +1043,32 @@ inline __device__ void naive_conv_fwd_nhwc(const src_data_t* __restrict__ p_in, int in = (bid / ho) % n; int ig = bid / (n * ho); - p_in += static_cast(in) * hi * wi * c + static_cast(ig) * c_per_group; - p_wei += static_cast(ig) * k_per_group * fy * fx * c_per_group; - p_out += static_cast(in) * ho * wo * k + static_cast(ig) * k_per_group + - static_cast(iho) * wo * k; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(in) * hi * wi * c + static_cast(ig) * c_per_group; + + p_wei += static_cast(ig) * k_per_group * fy * fx * c_per_group; + + p_out += static_cast(in) * ho * wo * k + static_cast(iho) * wo * k + + static_cast(ig) * k_per_group; + } + else + { + p_in += static_cast(in) * in_strides[4] + static_cast(ig) * in_strides[1]; + + p_wei += static_cast(ig) * wei_strides[4]; + + p_out += static_cast(in) * out_strides[4] + + static_cast(iho) * out_strides[3] + + static_cast(ig) * out_strides[1]; + } - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int iwo = tid / k_per_group; int ik = tid % k_per_group; - double value = .0f; + acc_data_t value = 0; for(int iy = 0; iy < fy; iy++) { @@ -740,27 +1087,61 @@ inline __device__ void naive_conv_fwd_nhwc(const src_data_t* __restrict__ p_in, if(valid_w & valid_h) { - size_t i_idx = static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + static_cast(ic); - size_t f_idx = static_cast(ik) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); - value += cast_to(p_in[i_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + static_cast(ic); + + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t i_idx = static_cast(cur_h) * in_strides[3] + + static_cast(cur_w) * in_strides[2] + + static_cast(ic) * in_strides[0]; + + size_t f_idx = static_cast(ik) * wei_strides[3] + + static_cast(iy) * wei_strides[2] + + static_cast(ix) * wei_strides[1] + + static_cast(ic) * wei_strides[0]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } } } } } - size_t o_idx = static_cast(iwo) * k + static_cast(ik); - p_out[o_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(iwo) * k + static_cast(ik); + + p_out[o_idx] = cast_to(value); + } + else + { + size_t o_idx = static_cast(iwo) * out_strides[2] + + static_cast(ik) * out_strides[0]; + + p_out[o_idx] = cast_to(value); + } } } -template +template inline __device__ void naive_conv_bwd_nhwc(dst_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides5D in_strides, + Strides5D wei_strides, + Strides5D out_strides, int hi, int wi, int n, @@ -792,17 +1173,32 @@ inline __device__ void naive_conv_bwd_nhwc(dst_data_t* __restrict__ p_in, int in = (bid / hi) % n; int ig = bid / (n * hi); - p_in += static_cast(in) * hi * wi * c + static_cast(ihi) * wi * c + - static_cast(ig) * c_per_group; - p_wei += static_cast(ig) * k_per_group * fy * fx * c_per_group; - p_out += static_cast(in) * ho * wo * k + static_cast(ig) * k_per_group; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(in) * hi * wi * c + static_cast(ihi) * wi * c + + static_cast(ig) * c_per_group; + + p_wei += static_cast(ig) * k_per_group * fy * fx * c_per_group; + + p_out += static_cast(in) * ho * wo * k + static_cast(ig) * k_per_group; + } + else + { + p_in += static_cast(in) * in_strides[4] + static_cast(ihi) * in_strides[3] + + static_cast(ig) * in_strides[1]; + + p_wei += static_cast(ig) * wei_strides[4]; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += + static_cast(in) * out_strides[4] + static_cast(ig) * out_strides[1]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int iwi = tid / c_per_group; int ic = tid % c_per_group; - double value = .0f; + acc_data_t value = 0; for(int iy = 0; iy < fy; iy++) { @@ -827,27 +1223,61 @@ inline __device__ void naive_conv_bwd_nhwc(dst_data_t* __restrict__ p_in, if(valid_h & valid_w) { - size_t o_idx = static_cast(cur_ho) * wo * k + - static_cast(cur_wo) * k + static_cast(ik); - size_t f_idx = static_cast(ik) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); - value += cast_to(p_out[o_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(cur_ho) * wo * k + + static_cast(cur_wo) * k + + static_cast(ik); + + size_t f_idx = static_cast(ik) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + + static_cast(ic); + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t o_idx = static_cast(cur_ho) * out_strides[3] + + static_cast(cur_wo) * out_strides[2] + + static_cast(ik) * out_strides[0]; + + size_t f_idx = static_cast(ik) * wei_strides[3] + + static_cast(iy) * wei_strides[2] + + static_cast(ix) * wei_strides[1] + + static_cast(ic) * wei_strides[0]; + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } } } } } - size_t i_idx = static_cast(iwi) * c + static_cast(ic); - p_in[i_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(iwi) * c + static_cast(ic); + + p_in[i_idx] = cast_to(value); + } + else + { + size_t i_idx = + static_cast(iwi) * in_strides[2] + static_cast(ic) * in_strides[0]; + p_in[i_idx] = cast_to(value); + } } } -template +template inline __device__ void naive_conv_wrw_nhwc(const src_data_t* __restrict__ p_in, dst_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides5D in_strides, + Strides5D wei_strides, + Strides5D out_strides, int hi, int wi, int n, @@ -879,18 +1309,33 @@ inline __device__ void naive_conv_wrw_nhwc(const src_data_t* __restrict__ p_in, int ik = bid % k_per_group; int ig = bid / k_per_group; - p_in += static_cast(ig) * c_per_group; - p_wei += static_cast(ig) * k_per_group * fy * fx * c_per_group + - static_cast(ik) * fy * fx * c_per_group; - p_out += static_cast(ig) * k_per_group + static_cast(ik); + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(ig) * c_per_group; + + p_wei += static_cast(ig) * k_per_group * fy * fx * c_per_group + + static_cast(ik) * fy * fx * c_per_group; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += static_cast(ig) * k_per_group + static_cast(ik); + } + else + { + p_in += static_cast(ig) * in_strides[1]; + + p_wei += + static_cast(ig) * wei_strides[4] + static_cast(ik) * wei_strides[3]; + + p_out += + static_cast(ig) * out_strides[1] + static_cast(ik) * out_strides[0]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int ic = tid % c_per_group; int ix = (tid / c_per_group) % fx; int iy = tid / (c_per_group * fx); - double value = .0f; + acc_data_t value = 0; for(int in = 0; in < n; in++) { @@ -909,29 +1354,65 @@ inline __device__ void naive_conv_wrw_nhwc(const src_data_t* __restrict__ p_in, if(valid_h & valid_w) { - size_t i_idx = static_cast(in) * hi * wi * c + - static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + static_cast(ic); - size_t o_idx = static_cast(in) * ho * wo * k + - static_cast(iho) * wo * k + - static_cast(iwo) * k; - value += cast_to(p_in[i_idx]) * - cast_to(p_out[o_idx]); + + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(in) * hi * wi * c + + static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + static_cast(ic); + + size_t o_idx = static_cast(in) * ho * wo * k + + static_cast(iho) * wo * k + + static_cast(iwo) * k; + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } + else + { + size_t i_idx = static_cast(in) * in_strides[4] + + static_cast(cur_h) * in_strides[3] + + static_cast(cur_w) * in_strides[2] + + static_cast(ic) * in_strides[0]; + + size_t o_idx = static_cast(in) * out_strides[4] + + static_cast(iho) * out_strides[3] + + static_cast(iwo) * out_strides[2]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } } } } } - size_t f_idx = static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + static_cast(ic); - p_wei[f_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t f_idx = static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + static_cast(ic); + + p_wei[f_idx] = cast_to(value); + } + else + { + size_t f_idx = static_cast(iy) * wei_strides[2] + + static_cast(ix) * wei_strides[1] + + static_cast(ic) * wei_strides[0]; + + p_wei[f_idx] = cast_to(value); + } } } // design block_size 256 -template +template inline __device__ void naive_conv_fwd_ndhwc(const src_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, dst_data_t* __restrict__ p_out, + Strides6D in_strides, + Strides6D wei_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -970,18 +1451,36 @@ inline __device__ void naive_conv_fwd_ndhwc(const src_data_t* __restrict__ p_in, int in = (bid / do_) % n; int ig = bid / (n * do_); - p_in += static_cast(in) * di * hi * wi * c + static_cast(ig) * c_per_group; - p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group; - p_out += static_cast(in) * do_ * ho * wo * k + static_cast(ido) * ho * wo * k + - static_cast(ig) * k_per_group; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(in) * di * hi * wi * c + static_cast(ig) * c_per_group; + + p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += static_cast(in) * do_ * ho * wo * k + + static_cast(ido) * ho * wo * k + static_cast(ig) * k_per_group; + } + else + { + // dim order NDHWGC + // replace C and K with G * C_per_G and G * K_per_G + p_in += static_cast(in) * in_strides[5] + static_cast(ig) * in_strides[1]; + + // Assumes that group G is the highest dimension in the layout + p_wei += static_cast(ig) * wei_strides[5]; + + p_out += static_cast(in) * out_strides[5] + + static_cast(ido) * out_strides[4] + + static_cast(ig) * out_strides[1]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int ik = tid % k_per_group; int iwo = (tid / k_per_group) % wo; int iho = tid / (k_per_group * wo); - double value = .0f; + acc_data_t value = 0; for(int iz = 0; iz < fz; iz++) { @@ -1005,30 +1504,69 @@ inline __device__ void naive_conv_fwd_ndhwc(const src_data_t* __restrict__ p_in, { if(valid_d & valid_w & valid_h) { - size_t i_idx = static_cast(cur_d) * hi * wi * c + - static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + static_cast(ic); - size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + - static_cast(iz) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); - value += cast_to(p_in[i_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(cur_d) * hi * wi * c + + static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + + size_t f_idx = + static_cast(ik) * fz * fy * fx * c_per_group + + static_cast(iz) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + static_cast(ic); + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t i_idx = static_cast(cur_d) * in_strides[4] + + static_cast(cur_h) * in_strides[3] + + static_cast(cur_w) * in_strides[2] + + static_cast(ic) * in_strides[0]; + + size_t f_idx = static_cast(ik) * wei_strides[4] + + static_cast(iz) * wei_strides[3] + + static_cast(iy) * wei_strides[2] + + static_cast(ix) * wei_strides[1] + + static_cast(ic) * wei_strides[0]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_wei[f_idx]); + } } } } } } - size_t o_idx = static_cast(iho) * wo * k + static_cast(iwo) * k + - static_cast(ik); - p_out[o_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(iho) * wo * k + static_cast(iwo) * k + + static_cast(ik); + + p_out[o_idx] = cast_to(value); + } + else + { + size_t o_idx = static_cast(iho) * out_strides[3] + + static_cast(iwo) * out_strides[2] + + static_cast(ik) * out_strides[0]; + + p_out[o_idx] = cast_to(value); + } } } -template + +template inline __device__ void naive_conv_bwd_ndhwc(dst_data_t* __restrict__ p_in, const src_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides6D in_strides, + Strides6D wei_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -1052,6 +1590,7 @@ inline __device__ void naive_conv_bwd_ndhwc(dst_data_t* __restrict__ p_in, int fx, int group) { + /* * need to compute total input pixel: `group * n * di * hi * wi * * c_per_group`. @@ -1067,18 +1606,34 @@ inline __device__ void naive_conv_bwd_ndhwc(dst_data_t* __restrict__ p_in, int in = (bid / di) % n; int ig = bid / (n * di); - p_in += static_cast(in) * di * hi * wi * c + static_cast(idi) * hi * wi * c + - static_cast(ig) * c_per_group; - p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group; - p_out += static_cast(in) * do_ * ho * wo * k + static_cast(ig) * k_per_group; + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(in) * di * hi * wi * c + + static_cast(idi) * hi * wi * c + static_cast(ig) * c_per_group; + + p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group; + + p_out += + static_cast(in) * do_ * ho * wo * k + static_cast(ig) * k_per_group; + } + else + { + p_in += static_cast(in) * in_strides[5] + static_cast(idi) * in_strides[4] + + static_cast(ig) * in_strides[1]; + + p_wei += static_cast(ig) * wei_strides[5]; + + p_out += + static_cast(in) * out_strides[5] + static_cast(ig) * out_strides[1]; + } - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int ic = tid % c_per_group; int iwi = (tid / c_per_group) % wi; int ihi = (tid / (c_per_group * wi)); - double value = .0f; + acc_data_t value = 0; for(int iz = 0; iz < fz; iz++) { @@ -1111,32 +1666,69 @@ inline __device__ void naive_conv_bwd_ndhwc(dst_data_t* __restrict__ p_in, { if(valid_d & valid_h & valid_w) { - size_t o_idx = static_cast(cur_do) * ho * wo * k + - static_cast(cur_ho) * wo * k + - static_cast(cur_wo) * k + - static_cast(ik); - size_t f_idx = static_cast(ik) * fz * fy * fx * c_per_group + - static_cast(iz) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + - static_cast(ic); - value += cast_to(p_out[o_idx]) * - cast_to(p_wei[f_idx]); + if constexpr(ASSUME_PACKED) + { + size_t o_idx = static_cast(cur_do) * ho * wo * k + + static_cast(cur_ho) * wo * k + + static_cast(cur_wo) * k + + static_cast(ik); + + size_t f_idx = + static_cast(ik) * fz * fy * fx * c_per_group + + static_cast(iz) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + static_cast(ic); + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } + else + { + size_t o_idx = static_cast(cur_do) * out_strides[4] + + static_cast(cur_ho) * out_strides[3] + + static_cast(cur_wo) * out_strides[2] + + static_cast(ik) * out_strides[0]; + + size_t f_idx = static_cast(ik) * wei_strides[4] + + static_cast(iz) * wei_strides[3] + + static_cast(iy) * wei_strides[2] + + static_cast(ix) * wei_strides[1] + + static_cast(ic) * wei_strides[0]; + + value += cast_to(p_out[o_idx]) * + cast_to(p_wei[f_idx]); + } } } } } } - size_t i_idx = static_cast(ihi) * wi * c + static_cast(iwi) * c + - static_cast(ic); - p_in[i_idx] = cast_to(value); + + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(ihi) * wi * c + static_cast(iwi) * c + + static_cast(ic); + + p_in[i_idx] = cast_to(value); + } + else + { + size_t i_idx = static_cast(ihi) * in_strides[3] + + static_cast(iwi) * in_strides[2] + + static_cast(ic) * in_strides[0]; + + p_in[i_idx] = cast_to(value); + } } } -template +template inline __device__ void naive_conv_wrw_ndhwc(const src_data_t* __restrict__ p_in, dst_data_t* __restrict__ p_wei, const src_data_t* __restrict__ p_out, + Strides6D in_strides, + Strides6D wei_strides, + Strides6D out_strides, int di, int hi, int wi, @@ -1174,19 +1766,34 @@ inline __device__ void naive_conv_wrw_ndhwc(const src_data_t* __restrict__ p_in, int ik = bid % k_per_group; int ig = bid / k_per_group; - p_in += static_cast(ig) * c_per_group; - p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group + - static_cast(ik) * fz * fy * fx * c_per_group; - p_out += static_cast(ig) * k_per_group + static_cast(ik); + if constexpr(ASSUME_PACKED) + { + p_in += static_cast(ig) * c_per_group; + + p_wei += static_cast(ig) * k_per_group * fz * fy * fx * c_per_group + + static_cast(ik) * fz * fy * fx * c_per_group; - for(int tid = threadIdx.x; tid < thread_length; tid += 256) + p_out += static_cast(ig) * k_per_group + static_cast(ik); + } + else + { + p_in += static_cast(ig) * in_strides[1]; + + p_wei += + static_cast(ig) * wei_strides[5] + static_cast(ik) * wei_strides[4]; + + p_out += + static_cast(ig) * out_strides[1] + static_cast(ik) * out_strides[0]; + } + + for(int tid = threadIdx.x; tid < thread_length; tid += blockDim.x) { int ic = tid % c_per_group; int ix = (tid / c_per_group) % fx; int iy = (tid / (c_per_group * fx)) % fy; int iz = (tid / (c_per_group * fx * fy)); - double value = .0f; + acc_data_t value = 0; for(int in = 0; in < n; in++) { @@ -1211,374 +1818,340 @@ inline __device__ void naive_conv_wrw_ndhwc(const src_data_t* __restrict__ p_in, if(valid_d & valid_h & valid_w) { - size_t i_idx = static_cast(in) * di * hi * wi * c + - static_cast(cur_d) * hi * wi * c + - static_cast(cur_h) * wi * c + - static_cast(cur_w) * c + static_cast(ic); - size_t o_idx = static_cast(in) * do_ * ho * wo * k + - static_cast(ido) * ho * wo * k + - static_cast(iho) * wo * k + - static_cast(iwo) * k; - value += cast_to(p_in[i_idx]) * - cast_to(p_out[o_idx]); + + if constexpr(ASSUME_PACKED) + { + size_t i_idx = static_cast(in) * di * hi * wi * c + + static_cast(cur_d) * hi * wi * c + + static_cast(cur_h) * wi * c + + static_cast(cur_w) * c + + static_cast(ic); + + size_t o_idx = static_cast(in) * do_ * ho * wo * k + + static_cast(ido) * ho * wo * k + + static_cast(iho) * wo * k + + static_cast(iwo) * k; + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } + else + { + + size_t i_idx = static_cast(in) * in_strides[5] + + static_cast(cur_d) * in_strides[4] + + static_cast(cur_h) * in_strides[3] + + static_cast(cur_w) * in_strides[2] + + static_cast(ic) * in_strides[0]; + + size_t o_idx = static_cast(in) * out_strides[5] + + static_cast(ido) * out_strides[4] + + static_cast(iho) * out_strides[3] + + static_cast(iwo) * out_strides[2]; + + value += cast_to(p_in[i_idx]) * + cast_to(p_out[o_idx]); + } } } } } } - size_t f_idx = static_cast(iz) * fy * fx * c_per_group + - static_cast(iy) * fx * c_per_group + - static_cast(ix) * c_per_group + static_cast(ic); - p_wei[f_idx] = cast_to(value); - } -} - -#define DEFINE_2D_NAIVE_FWD_CONV_KERNEL(tensor_layout, src_data_t, acc_data_t, dst_data_t) \ - extern "C" __global__ void \ - naive_conv_fwd_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ - src_data_t* __restrict__ p_in, \ - src_data_t* __restrict__ p_wei, \ - dst_data_t* __restrict__ p_out, \ - int hi, \ - int wi, \ - int n, \ - int k_per_group, \ - int c_per_group, \ - int ho, \ - int wo, \ - int sy, \ - int sx, \ - int dy, \ - int dx, \ - int py, \ - int px, \ - int fy, \ - int fx, \ - int group) \ - { \ - naive_conv_fwd_##tensor_layout(p_in, \ - p_wei, \ - p_out, \ - hi, \ - wi, \ - n, \ - k_per_group, \ - c_per_group, \ - ho, \ - wo, \ - sy, \ - sx, \ - dy, \ - dx, \ - py, \ - px, \ - fy, \ - fx, \ - group); \ - } -#define DEFINE_2D_NAIVE_BWD_CONV_KERNEL(tensor_layout, src_data_t, acc_data_t, dst_data_t) \ - extern "C" __global__ void \ - naive_conv_bwd_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ - dst_data_t* __restrict__ p_in, \ - src_data_t* __restrict__ p_wei, \ - src_data_t* __restrict__ p_out, \ - int hi, \ - int wi, \ - int n, \ - int k_per_group, \ - int c_per_group, \ - int ho, \ - int wo, \ - int sy, \ - int sx, \ - int dy, \ - int dx, \ - int py, \ - int px, \ - int fy, \ - int fx, \ - int group) \ - { \ - naive_conv_bwd_##tensor_layout(p_in, \ - p_wei, \ - p_out, \ - hi, \ - wi, \ - n, \ - k_per_group, \ - c_per_group, \ - ho, \ - wo, \ - sy, \ - sx, \ - dy, \ - dx, \ - py, \ - px, \ - fy, \ - fx, \ - group); \ - } + if constexpr(ASSUME_PACKED) + { + size_t f_idx = static_cast(iz) * fy * fx * c_per_group + + static_cast(iy) * fx * c_per_group + + static_cast(ix) * c_per_group + static_cast(ic); -#define DEFINE_2D_NAIVE_WRW_CONV_KERNEL(tensor_layout, src_data_t, acc_data_t, dst_data_t) \ - extern "C" __global__ void \ - naive_conv_wrw_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ - src_data_t* __restrict__ p_in, \ - dst_data_t* __restrict__ p_wei, \ - src_data_t* __restrict__ p_out, \ - int hi, \ - int wi, \ - int n, \ - int k_per_group, \ - int c_per_group, \ - int ho, \ - int wo, \ - int sy, \ - int sx, \ - int dy, \ - int dx, \ - int py, \ - int px, \ - int fy, \ - int fx, \ - int group) \ - { \ - naive_conv_wrw_##tensor_layout(p_in, \ - p_wei, \ - p_out, \ - hi, \ - wi, \ - n, \ - k_per_group, \ - c_per_group, \ - ho, \ - wo, \ - sy, \ - sx, \ - dy, \ - dx, \ - py, \ - px, \ - fy, \ - fx, \ - group); \ - } + p_wei[f_idx] = cast_to(value); + } + else + { + size_t f_idx = static_cast(iz) * wei_strides[3] + + static_cast(iy) * wei_strides[2] + + static_cast(ix) * wei_strides[1] + + static_cast(ic) * wei_strides[0]; -#define DEFINE_3D_NAIVE_FWD_CONV_KERNEL(tensor_layout, src_data_t, acc_data_t, dst_data_t) \ - extern "C" __global__ void \ - naive_conv_fwd_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ - src_data_t* __restrict__ p_in, \ - src_data_t* __restrict__ p_wei, \ - dst_data_t* __restrict__ p_out, \ - int di, \ - int hi, \ - int wi, \ - int n, \ - int k_per_group, \ - int c_per_group, \ - int do_, \ - int ho, \ - int wo, \ - int sz, \ - int sy, \ - int sx, \ - int dz, \ - int dy, \ - int dx, \ - int pz, \ - int py, \ - int px, \ - int fz, \ - int fy, \ - int fx, \ - int group) \ - { \ - naive_conv_fwd_##tensor_layout(p_in, \ - p_wei, \ - p_out, \ - di, \ - hi, \ - wi, \ - n, \ - k_per_group, \ - c_per_group, \ - do_, \ - ho, \ - wo, \ - sz, \ - sy, \ - sx, \ - dz, \ - dy, \ - dx, \ - pz, \ - py, \ - px, \ - fz, \ - fy, \ - fx, \ - group); \ + p_wei[f_idx] = cast_to(value); + } } +} -#define DEFINE_3D_NAIVE_BWD_CONV_KERNEL(tensor_layout, src_data_t, acc_data_t, dst_data_t) \ - extern "C" __global__ void \ - naive_conv_bwd_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ - dst_data_t* __restrict__ p_in, \ - src_data_t* __restrict__ p_wei, \ - src_data_t* __restrict__ p_out, \ - int di, \ - int hi, \ - int wi, \ - int n, \ - int k_per_group, \ - int c_per_group, \ - int do_, \ - int ho, \ - int wo, \ - int sz, \ - int sy, \ - int sx, \ - int dz, \ - int dy, \ - int dx, \ - int pz, \ - int py, \ - int px, \ - int fz, \ - int fy, \ - int fx, \ - int group) \ - { \ - naive_conv_bwd_##tensor_layout(p_in, \ - p_wei, \ - p_out, \ - di, \ - hi, \ - wi, \ - n, \ - k_per_group, \ - c_per_group, \ - do_, \ - ho, \ - wo, \ - sz, \ - sy, \ - sx, \ - dz, \ - dy, \ - dx, \ - pz, \ - py, \ - px, \ - fz, \ - fy, \ - fx, \ - group); \ +#define DEFINE_2D_NAIVE_CONV_KERNEL(direction, tensor_layout, src_data_t, acc_data_t, dst_data_t) \ + extern "C" __global__ void \ + naive_conv_packed_##direction##_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ + src_data_t* __restrict__ p_in, \ + src_data_t* __restrict__ p_wei, \ + dst_data_t* __restrict__ p_out, \ + Strides5D in_strides, \ + Strides5D wei_strides, \ + Strides5D out_strides, \ + int hi, \ + int wi, \ + int n, \ + int k_per_group, \ + int c_per_group, \ + int ho, \ + int wo, \ + int sy, \ + int sx, \ + int dy, \ + int dx, \ + int py, \ + int px, \ + int fy, \ + int fx, \ + int group) \ + { \ + naive_conv_##direction##_##tensor_layout( \ + p_in, \ + p_wei, \ + p_out, \ + in_strides, \ + wei_strides, \ + out_strides, \ + hi, \ + wi, \ + n, \ + k_per_group, \ + c_per_group, \ + ho, \ + wo, \ + sy, \ + sx, \ + dy, \ + dx, \ + py, \ + px, \ + fy, \ + fx, \ + group); \ + } \ + extern "C" __global__ void \ + naive_conv_nonpacked_##direction##_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ + src_data_t* __restrict__ p_in, \ + src_data_t* __restrict__ p_wei, \ + dst_data_t* __restrict__ p_out, \ + Strides5D in_strides, \ + Strides5D wei_strides, \ + Strides5D out_strides, \ + int hi, \ + int wi, \ + int n, \ + int k_per_group, \ + int c_per_group, \ + int ho, \ + int wo, \ + int sy, \ + int sx, \ + int dy, \ + int dx, \ + int py, \ + int px, \ + int fy, \ + int fx, \ + int group) \ + { \ + naive_conv_##direction##_##tensor_layout( \ + p_in, \ + p_wei, \ + p_out, \ + in_strides, \ + wei_strides, \ + out_strides, \ + hi, \ + wi, \ + n, \ + k_per_group, \ + c_per_group, \ + ho, \ + wo, \ + sy, \ + sx, \ + dy, \ + dx, \ + py, \ + px, \ + fy, \ + fx, \ + group); \ } -#define DEFINE_3D_NAIVE_WRW_CONV_KERNEL(tensor_layout, src_data_t, acc_data_t, dst_data_t) \ - extern "C" __global__ void \ - naive_conv_wrw_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ - src_data_t* __restrict__ p_in, \ - dst_data_t* __restrict__ p_wei, \ - src_data_t* __restrict__ p_out, \ - int di, \ - int hi, \ - int wi, \ - int n, \ - int k_per_group, \ - int c_per_group, \ - int do_, \ - int ho, \ - int wo, \ - int sz, \ - int sy, \ - int sx, \ - int dz, \ - int dy, \ - int dx, \ - int pz, \ - int py, \ - int px, \ - int fz, \ - int fy, \ - int fx, \ - int group) \ - { \ - naive_conv_wrw_##tensor_layout(p_in, \ - p_wei, \ - p_out, \ - di, \ - hi, \ - wi, \ - n, \ - k_per_group, \ - c_per_group, \ - do_, \ - ho, \ - wo, \ - sz, \ - sy, \ - sx, \ - dz, \ - dy, \ - dx, \ - pz, \ - py, \ - px, \ - fz, \ - fy, \ - fx, \ - group); \ +#define DEFINE_3D_NAIVE_CONV_KERNEL(direction, tensor_layout, src_data_t, acc_data_t, dst_data_t) \ + extern "C" __global__ void \ + naive_conv_packed_##direction##_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ + src_data_t* __restrict__ p_in, \ + src_data_t* __restrict__ p_wei, \ + dst_data_t* __restrict__ p_out, \ + Strides6D in_strides, \ + Strides6D wei_strides, \ + Strides6D out_strides, \ + int di, \ + int hi, \ + int wi, \ + int n, \ + int k_per_group, \ + int c_per_group, \ + int do_, \ + int ho, \ + int wo, \ + int sz, \ + int sy, \ + int sx, \ + int dz, \ + int dy, \ + int dx, \ + int pz, \ + int py, \ + int px, \ + int fz, \ + int fy, \ + int fx, \ + int group) \ + { \ + naive_conv_##direction##_##tensor_layout( \ + p_in, \ + p_wei, \ + p_out, \ + in_strides, \ + wei_strides, \ + out_strides, \ + di, \ + hi, \ + wi, \ + n, \ + k_per_group, \ + c_per_group, \ + do_, \ + ho, \ + wo, \ + sz, \ + sy, \ + sx, \ + dz, \ + dy, \ + dx, \ + pz, \ + py, \ + px, \ + fz, \ + fy, \ + fx, \ + group); \ + } \ + extern "C" __global__ void \ + naive_conv_nonpacked_##direction##_##tensor_layout##_##src_data_t##_##acc_data_t##_##dst_data_t( \ + src_data_t* __restrict__ p_in, \ + src_data_t* __restrict__ p_wei, \ + dst_data_t* __restrict__ p_out, \ + Strides6D in_strides, \ + Strides6D wei_strides, \ + Strides6D out_strides, \ + int di, \ + int hi, \ + int wi, \ + int n, \ + int k_per_group, \ + int c_per_group, \ + int do_, \ + int ho, \ + int wo, \ + int sz, \ + int sy, \ + int sx, \ + int dz, \ + int dy, \ + int dx, \ + int pz, \ + int py, \ + int px, \ + int fz, \ + int fy, \ + int fx, \ + int group) \ + { \ + naive_conv_##direction##_##tensor_layout( \ + p_in, \ + p_wei, \ + p_out, \ + in_strides, \ + wei_strides, \ + out_strides, \ + di, \ + hi, \ + wi, \ + n, \ + k_per_group, \ + c_per_group, \ + do_, \ + ho, \ + wo, \ + sz, \ + sy, \ + sx, \ + dz, \ + dy, \ + dx, \ + pz, \ + py, \ + px, \ + fz, \ + fy, \ + fx, \ + group); \ } -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nchw, float, double, float) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nchw, half, double, half) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nchw, ushort, double, ushort) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nchw, int8_t, int32_t, int32_t) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nchw, int8_t, int32_t, float) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nhwc, float, double, float) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nhwc, half, double, half) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nhwc, ushort, double, ushort) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nhwc, int8_t, int32_t, int32_t) -DEFINE_2D_NAIVE_FWD_CONV_KERNEL(nhwc, int8_t, int32_t, float) - -DEFINE_2D_NAIVE_BWD_CONV_KERNEL(nchw, float, double, float) -DEFINE_2D_NAIVE_BWD_CONV_KERNEL(nchw, half, double, half) -DEFINE_2D_NAIVE_BWD_CONV_KERNEL(nchw, ushort, double, ushort) -DEFINE_2D_NAIVE_BWD_CONV_KERNEL(nhwc, float, double, float) -DEFINE_2D_NAIVE_BWD_CONV_KERNEL(nhwc, half, double, half) -DEFINE_2D_NAIVE_BWD_CONV_KERNEL(nhwc, ushort, double, ushort) - -DEFINE_2D_NAIVE_WRW_CONV_KERNEL(nchw, float, double, float) -DEFINE_2D_NAIVE_WRW_CONV_KERNEL(nchw, half, double, half) -DEFINE_2D_NAIVE_WRW_CONV_KERNEL(nchw, ushort, double, ushort) -DEFINE_2D_NAIVE_WRW_CONV_KERNEL(nhwc, float, double, float) -DEFINE_2D_NAIVE_WRW_CONV_KERNEL(nhwc, half, double, half) -DEFINE_2D_NAIVE_WRW_CONV_KERNEL(nhwc, ushort, double, ushort) - -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ncdhw, float, double, float) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ncdhw, half, double, half) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ncdhw, ushort, double, ushort) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ncdhw, int8_t, int32_t, int32_t) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ncdhw, int8_t, int32_t, float) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ndhwc, float, double, float) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ndhwc, half, double, half) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ndhwc, ushort, double, ushort) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ndhwc, int8_t, int32_t, int32_t) -DEFINE_3D_NAIVE_FWD_CONV_KERNEL(ndhwc, int8_t, int32_t, float) - -DEFINE_3D_NAIVE_BWD_CONV_KERNEL(ncdhw, float, double, float) -DEFINE_3D_NAIVE_BWD_CONV_KERNEL(ncdhw, half, double, half) -DEFINE_3D_NAIVE_BWD_CONV_KERNEL(ncdhw, ushort, double, ushort) -DEFINE_3D_NAIVE_BWD_CONV_KERNEL(ndhwc, float, double, float) -DEFINE_3D_NAIVE_BWD_CONV_KERNEL(ndhwc, half, double, half) -DEFINE_3D_NAIVE_BWD_CONV_KERNEL(ndhwc, ushort, double, ushort) - -DEFINE_3D_NAIVE_WRW_CONV_KERNEL(ncdhw, float, double, float) -DEFINE_3D_NAIVE_WRW_CONV_KERNEL(ncdhw, half, double, half) -DEFINE_3D_NAIVE_WRW_CONV_KERNEL(ncdhw, ushort, double, ushort) -DEFINE_3D_NAIVE_WRW_CONV_KERNEL(ndhwc, float, double, float) -DEFINE_3D_NAIVE_WRW_CONV_KERNEL(ndhwc, half, double, half) -DEFINE_3D_NAIVE_WRW_CONV_KERNEL(ndhwc, ushort, double, ushort) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, float, double, float) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, half, double, half) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, ushort, double, ushort) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, int8_t, int32_t, int32_t) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nchw, int8_t, int32_t, float) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, float, double, float) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, half, double, half) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, ushort, double, ushort) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, int8_t, int32_t, int32_t) +DEFINE_2D_NAIVE_CONV_KERNEL(fwd, nhwc, int8_t, int32_t, float) + +DEFINE_2D_NAIVE_CONV_KERNEL(bwd, nchw, float, double, float) +DEFINE_2D_NAIVE_CONV_KERNEL(bwd, nchw, half, double, half) +DEFINE_2D_NAIVE_CONV_KERNEL(bwd, nchw, ushort, double, ushort) +DEFINE_2D_NAIVE_CONV_KERNEL(bwd, nhwc, float, double, float) +DEFINE_2D_NAIVE_CONV_KERNEL(bwd, nhwc, half, double, half) +DEFINE_2D_NAIVE_CONV_KERNEL(bwd, nhwc, ushort, double, ushort) + +DEFINE_2D_NAIVE_CONV_KERNEL(wrw, nchw, float, double, float) +DEFINE_2D_NAIVE_CONV_KERNEL(wrw, nchw, half, double, half) +DEFINE_2D_NAIVE_CONV_KERNEL(wrw, nchw, ushort, double, ushort) +DEFINE_2D_NAIVE_CONV_KERNEL(wrw, nhwc, float, double, float) +DEFINE_2D_NAIVE_CONV_KERNEL(wrw, nhwc, half, double, half) +DEFINE_2D_NAIVE_CONV_KERNEL(wrw, nhwc, ushort, double, ushort) + +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ncdhw, float, double, float) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ncdhw, half, double, half) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ncdhw, ushort, double, ushort) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ncdhw, int8_t, int32_t, int32_t) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ncdhw, int8_t, int32_t, float) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ndhwc, float, double, float) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ndhwc, half, double, half) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ndhwc, ushort, double, ushort) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ndhwc, int8_t, int32_t, int32_t) +DEFINE_3D_NAIVE_CONV_KERNEL(fwd, ndhwc, int8_t, int32_t, float) + +DEFINE_3D_NAIVE_CONV_KERNEL(bwd, ncdhw, float, double, float) +DEFINE_3D_NAIVE_CONV_KERNEL(bwd, ncdhw, half, double, half) +DEFINE_3D_NAIVE_CONV_KERNEL(bwd, ncdhw, ushort, double, ushort) +DEFINE_3D_NAIVE_CONV_KERNEL(bwd, ndhwc, float, double, float) +DEFINE_3D_NAIVE_CONV_KERNEL(bwd, ndhwc, half, double, half) +DEFINE_3D_NAIVE_CONV_KERNEL(bwd, ndhwc, ushort, double, ushort) + +DEFINE_3D_NAIVE_CONV_KERNEL(wrw, ncdhw, float, double, float) +DEFINE_3D_NAIVE_CONV_KERNEL(wrw, ncdhw, half, double, half) +DEFINE_3D_NAIVE_CONV_KERNEL(wrw, ncdhw, ushort, double, ushort) +DEFINE_3D_NAIVE_CONV_KERNEL(wrw, ndhwc, float, double, float) +DEFINE_3D_NAIVE_CONV_KERNEL(wrw, ndhwc, half, double, half) +DEFINE_3D_NAIVE_CONV_KERNEL(wrw, ndhwc, ushort, double, ushort) + +/// \todo discuss whether we should split the kernels into separate files, or +/// figure out a mechanism to compile each kernel separately to reduce hipRTC +/// compilation times. --amberhassaan diff --git a/src/kernels/stride_array.hpp b/src/kernels/stride_array.hpp new file mode 100644 index 0000000000..32cb1f85b6 --- /dev/null +++ b/src/kernels/stride_array.hpp @@ -0,0 +1,86 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#ifdef __HIPCC_RTC__ +#ifndef WORKAROUND_ISSUE_HIPRTC_TRUE_TYPE +#include +#endif +#endif // __HIPCC_RTC__ + +/// \todo Uncomment when hip RTC accepts std::array -- amberhassaan +// #include +// using StrideIndexType = int; +// using Strides3D = std::array; +// using Strides4D = std::array; +// using Strides5D = std::array; +// using Strides6D = std::array; +template +class MyArray +{ + T data_[N] = {}; + +public: + constexpr static const unsigned SIZE = N; + __host__ __device__ constexpr unsigned size() const { return N; } + + __host__ __device__ const T& operator[](unsigned i) const { return data_[i]; } + + __host__ T& operator[](unsigned i) { return data_[i]; } + + __host__ __device__ MyArray() = default; + __host__ __device__ MyArray(const MyArray&) = default; + __host__ __device__ MyArray(MyArray&&) noexcept = default; + __host__ __device__ MyArray& operator=(const MyArray&) = default; + __host__ __device__ MyArray& operator=(MyArray&&) noexcept = default; + __host__ __device__ ~MyArray() = default; +}; + +using StrideIndexType = size_t; +using Strides5D = MyArray; +using Strides6D = MyArray; + +template +__host__ __device__ void printStrideArray(const char* name, const StrideArray& sarr) +{ + printf("%s = [", name); + for(int i = 0; i < StrideArray::SIZE; ++i) + { + printf("%zu,", sarr[i]); + } + printf("]\n"); +} + +template +__host__ __device__ void printStrideArrays(const StrideArray& in_strides, + const StrideArray& wei_strides, + const StrideArray& out_strides) +{ + + printStrideArray("in_strides", in_strides); + printStrideArray("wei_strides", wei_strides); + printStrideArray("out_strides", out_strides); +} diff --git a/src/solver/conv_direct_naive_conv.cpp b/src/solver/conv_direct_naive_conv.cpp index 5c468768fa..86a8a4161e 100644 --- a/src/solver/conv_direct_naive_conv.cpp +++ b/src/solver/conv_direct_naive_conv.cpp @@ -24,6 +24,7 @@ * *******************************************************************************/ +#include "miopen/env.hpp" #include #include #include @@ -105,10 +106,20 @@ bool IsOutputInt32(const ProblemDescription& problem) problem.GetOutDataType() == miopenInt32; } +MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_USE_PACKED_KERNELS); + std::string ConvDirectNaiveConvKernelName(const ProblemDescription& problem) { std::ostringstream kernel_name; - kernel_name << "naive_conv_"; + if(miopen::IsEnabled(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_USE_PACKED_KERNELS())) + { + kernel_name << "naive_conv_packed_"; + } + else + { + kernel_name << "naive_conv_nonpacked_"; + } + if(problem.direction.IsForward()) kernel_name << "fwd_"; else if(problem.direction.IsBackwardData()) @@ -244,5 +255,49 @@ bool ConvDirectNaiveConvIsApplicableByKernelType(const ExecutionContext& ctx, return true; } +/// Figure out the index of C (channel) stride so we can expand it into +/// (G, C_per_group). Return value G_stride_idx is the position of G stride +/// in the stride vector, such that the (G_stride_idx - 1) is the index that +/// contains C's stride as a multiplying factor +int conv_internal::GetGroupStrideIndex(const ProblemDescription& problem) +{ + int G_stride_idx = -1; + if(problem.IsLayoutDefault()) + { + G_stride_idx = 1; + } + else + { + assert(problem.IsLayoutNHWC()); + assert(problem.Is2d() || problem.Is3d()); + // + // G_stride_idx = problem.Is2d() ? 3 : 4; + // For NHWC, MIOpen stores strides in NCHW order, so we are interested in 1 + W's + // stride as that will be the value of G_stride_idx; + G_stride_idx = problem.Is2d() ? 4 : 5; + } + assert(G_stride_idx != -1); + return G_stride_idx; +} + +void conv_internal::DebugPrintTensorStrides(const TensorDescriptor& inDesc, + const TensorDescriptor& wDesc, + const TensorDescriptor& outDesc) +{ + + auto printOneStrideVec = [](const char* name, const auto& vec) { + MIOPEN_LOG_I(name << " = ["); + for(const size_t v : vec) + { + MIOPEN_LOG_I(v << ","); + } + MIOPEN_LOG_I("]\n"); + }; + + printOneStrideVec("inDesc = ", inDesc.GetStrides()); + printOneStrideVec("wDesc = ", wDesc.GetStrides()); + printOneStrideVec("outDesc = ", outDesc.GetStrides()); +} + } // namespace solver } // namespace miopen diff --git a/src/solver/conv_direct_naive_conv_bwd.cpp b/src/solver/conv_direct_naive_conv_bwd.cpp index f8af0ec2d1..1a28f8aae6 100644 --- a/src/solver/conv_direct_naive_conv_bwd.cpp +++ b/src/solver/conv_direct_naive_conv_bwd.cpp @@ -142,14 +142,27 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, }(); kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem); + int G_stride_idx = conv_internal::GetGroupStrideIndex(problem); + if(problem.Is2d()) + { result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { decltype(auto) data_ctx = primitive_parameters.CastTo(); const auto& tensors = data_ctx.tensors; float elapsed = 0; + auto in_strides = conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( + group, tensors.inDesc.GetStrides(), G_stride_idx)); + // For weights, we split K to (G, K_per_group), which is always index 0 + auto wei_strides = conv_internal::MakeStrideArray<5>( + conv_internal::SplitWeiStrideKtoGK(k_per_group, tensors.wDesc.GetStrides())); + auto out_strides = + conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( + group, tensors.outDesc.GetStrides(), G_stride_idx)); + /// \ref backward_tensors_reversed_why if(is_f8) + { handle.Run(kern)(tensors.out, tensors.w, tensors.in, @@ -172,10 +185,15 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, problem.GetConv().attribute.fp8rounding_mode.Get() == miopenF8RoundingModeStochastic, problem.GetConv().attribute.fp8rounding_mode.GetSeed()); + } else + { handle.Run(kern)(tensors.out, tensors.w, tensors.in, + out_strides, + wei_strides, + in_strides, hi, wi, n, @@ -192,6 +210,7 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, fy, fx, group); + } if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); @@ -202,7 +221,9 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, } }; }; + } else + { result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { @@ -210,9 +231,26 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, const auto& tensors = data_ctx.tensors; float elapsed = 0; + auto in_strides = conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( + group, tensors.inDesc.GetStrides(), G_stride_idx)); + // For weights, we split K to (G, K_per_group), which is always index 0 + auto wei_strides = conv_internal::MakeStrideArray<6>( + conv_internal::SplitWeiStrideKtoGK(k_per_group, tensors.wDesc.GetStrides())); + auto out_strides = + conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( + group, tensors.outDesc.GetStrides(), G_stride_idx)); + + /// \anchor backward_tensors_reversed_why + /// \todo Someone made the silly decision of swapping in and + /// out pointers in ConvTensors for backward pass, so now I have to + /// pass out in place of in, out_strides in place of in_strides and + /// vice-versa --amberhassaan handle.Run(kern)(tensors.out, tensors.w, tensors.in, + out_strides, + wei_strides, + in_strides, di, hi, wi, @@ -245,6 +283,7 @@ ConvSolution ConvDirectNaiveConvBwd::GetSolution(const ExecutionContext& ctx, } }; }; + } result.construction_params.push_back(kernel); return result; } diff --git a/src/solver/conv_direct_naive_conv_fwd.cpp b/src/solver/conv_direct_naive_conv_fwd.cpp index 90d8feee31..a4656d929a 100644 --- a/src/solver/conv_direct_naive_conv_fwd.cpp +++ b/src/solver/conv_direct_naive_conv_fwd.cpp @@ -27,7 +27,6 @@ #include #include #include -#include MIOPEN_DECLARE_ENV_VAR(MIOPEN_DEBUG_CONV_DIRECT_NAIVE_CONV_FWD) @@ -142,13 +141,26 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, kernel.comp_options = ConvDirectNaiveConvCompileOption(ctx, problem); + int G_stride_idx = conv_internal::GetGroupStrideIndex(problem); + if(problem.Is2d()) + { result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { decltype(auto) data_ctx = primitive_parameters.CastTo(); const auto& tensors = data_ctx.tensors; float elapsed = 0; + + auto in_strides = conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( + group, tensors.inDesc.GetStrides(), G_stride_idx)); + // For weights, we split K to (G, K_per_group), which is always index 0 + auto wei_strides = conv_internal::MakeStrideArray<5>( + conv_internal::SplitWeiStrideKtoGK(k_per_group, tensors.wDesc.GetStrides())); + auto out_strides = + conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( + group, tensors.outDesc.GetStrides(), G_stride_idx)); + if(is_f8) { handle.Run(kern)(tensors.in, @@ -179,6 +191,9 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, handle.Run(kern)(tensors.in, tensors.w, tensors.out, + in_strides, + wei_strides, + out_strides, hi, wi, n, @@ -206,7 +221,9 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, } }; }; + } else + { result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { @@ -214,9 +231,20 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, const auto& tensors = data_ctx.tensors; float elapsed = 0; + auto in_strides = conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( + group, tensors.inDesc.GetStrides(), G_stride_idx)); + // For weights, we split K to (G, K_per_group), which is always index 0 + auto wei_strides = conv_internal::MakeStrideArray<6>( + conv_internal::SplitWeiStrideKtoGK(k_per_group, tensors.wDesc.GetStrides())); + auto out_strides = + conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( + group, tensors.outDesc.GetStrides(), G_stride_idx)); handle.Run(kern)(tensors.in, tensors.w, tensors.out, + in_strides, + wei_strides, + out_strides, di, hi, wi, @@ -249,6 +277,7 @@ ConvSolution ConvDirectNaiveConvFwd::GetSolution(const ExecutionContext& ctx, } }; }; + } result.construction_params.push_back(kernel); return result; } diff --git a/src/solver/conv_direct_naive_conv_wrw.cpp b/src/solver/conv_direct_naive_conv_wrw.cpp index 6fcf2f71d0..dfe1c342b0 100644 --- a/src/solver/conv_direct_naive_conv_wrw.cpp +++ b/src/solver/conv_direct_naive_conv_wrw.cpp @@ -129,14 +129,28 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, return false; }(); + int G_stride_idx = conv_internal::GetGroupStrideIndex(problem); + if(problem.Is2d()) + { result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { decltype(auto) data_ctx = primitive_parameters.CastTo(); const auto& tensors = data_ctx.tensors; float elapsed = 0; + + auto in_strides = conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( + group, tensors.xDesc.GetStrides(), G_stride_idx)); + // For weights, we split K to (G, K_per_group), which is always index 0 + auto wei_strides = conv_internal::MakeStrideArray<5>( + conv_internal::SplitWeiStrideKtoGK(k_per_group, tensors.dwDesc.GetStrides())); + auto out_strides = + conv_internal::MakeStrideArray<5>(conv_internal::SplitStrideCtoGC( + group, tensors.dyDesc.GetStrides(), G_stride_idx)); + if(is_f8) + { handle.Run(kern)(tensors.x, tensors.dw, tensors.dy, @@ -159,10 +173,15 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, problem.GetConv().attribute.fp8rounding_mode.Get() == miopenF8RoundingModeStochastic, problem.GetConv().attribute.fp8rounding_mode.GetSeed()); + } else + { handle.Run(kern)(tensors.x, tensors.dw, tensors.dy, + in_strides, + wei_strides, + out_strides, hi, wi, n, @@ -179,6 +198,7 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, fy, fx, group); + } if(handle.IsProfilingEnabled()) elapsed += handle.GetKernelTime(); @@ -189,7 +209,9 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, } }; }; + } else + { result.invoker_factory = [=](const std::vector& kernels) { const auto kern = kernels[0]; return [=](const Handle& handle, const AnyInvokeParams& primitive_parameters) { @@ -197,9 +219,21 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, const auto& tensors = data_ctx.tensors; float elapsed = 0; + auto in_strides = conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( + group, tensors.xDesc.GetStrides(), G_stride_idx)); + // For weights, we split K to (G, K_per_group), which is always index 0 + auto wei_strides = conv_internal::MakeStrideArray<6>( + conv_internal::SplitWeiStrideKtoGK(k_per_group, tensors.dwDesc.GetStrides())); + auto out_strides = + conv_internal::MakeStrideArray<6>(conv_internal::SplitStrideCtoGC( + group, tensors.dyDesc.GetStrides(), G_stride_idx)); + handle.Run(kern)(tensors.x, tensors.dw, tensors.dy, + in_strides, + wei_strides, + out_strides, di, hi, wi, @@ -232,6 +266,7 @@ ConvSolution ConvDirectNaiveConvWrw::GetSolution(const ExecutionContext& ctx, } }; }; + } result.construction_params.push_back(kernel); return result; } diff --git a/test/gpu_reference_kernel.cpp b/test/gpu_reference_kernel.cpp index c3b26a80a9..aa3dda788d 100644 --- a/test/gpu_reference_kernel.cpp +++ b/test/gpu_reference_kernel.cpp @@ -95,7 +95,8 @@ struct gpu_reference_kernel_base static std::vector get_image_size() { return {9, 14}; } - static std::vector get_channel_size() { return {3, 8}; } + // Warning: Channel size must be multiple of group size + static std::vector get_channel_size() { return {4, 8}; } static std::vector get_filter_depth() { return {1, 3}; } diff --git a/test/gtest/conv3d_test_case.hpp b/test/gtest/conv3d_test_case.hpp new file mode 100644 index 0000000000..242615077f --- /dev/null +++ b/test/gtest/conv3d_test_case.hpp @@ -0,0 +1,112 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include + +#include "get_handle.hpp" +#include + +#include "../driver/tensor_driver.hpp" +#include "conv_common.hpp" + +template +miopenDataType_t GetDataType(); + +template <> +miopenDataType_t GetDataType() +{ + return miopenFloat; +} + +template <> +miopenDataType_t GetDataType() +{ + return miopenHalf; +} + +template <> +miopenDataType_t GetDataType() +{ + return miopenInt8; +} + +struct Conv3DTestCase +{ + size_t G; + size_t N; + size_t C; + size_t D; + size_t H; + size_t W; + size_t k; + size_t z; + size_t y; + size_t x; + size_t pad_x; + size_t pad_y; + size_t pad_z; + size_t stride_x; + size_t stride_y; + size_t stride_z; + size_t dilation_x; + size_t dilation_y; + size_t dilation_z; + miopenConvolutionMode_t conv_mode; + friend std::ostream& operator<<(std::ostream& os, const Conv3DTestCase& tc) + { + return os << " G:" << tc.G << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D + << " H:" << tc.H << " W:" << tc.W << " k:" << tc.k << " z:" << tc.z + << " y:" << tc.y << " x:" << tc.x << " pad_z:" << tc.pad_z + << " pad_y:" << tc.pad_y << " pad_x:" << tc.pad_x << " stride_z:" << tc.stride_z + << " stride_y:" << tc.stride_y << " stride_x:" << tc.stride_x + << " dilation_z:" << tc.dilation_z << " dilation_y:" << tc.dilation_y + << " dilation_x:" << tc.dilation_x << " conv_mode:" << tc.conv_mode; + } + + std::vector GetInput() { return {N, C, D, H, W}; } + std::vector GetWeights() + { + EXPECT_EQUAL(C % G, 0); + return {k, C / G, z, y, x}; + } + + miopen::ConvolutionDescriptor GetConv() + { + return miopen::ConvolutionDescriptor{ + 3, + miopenConvolution, + miopenPaddingDefault, + {static_cast(pad_z), static_cast(pad_y), static_cast(pad_x)}, + {static_cast(stride_z), static_cast(stride_y), static_cast(stride_x)}, + {static_cast(dilation_z), + static_cast(dilation_y), + static_cast(dilation_x)}, + {0, 0, 0}, + static_cast(G), + 1.0}; + } +}; diff --git a/test/gtest/group_conv3d_bwd.cpp b/test/gtest/group_conv3d_bwd.cpp index e53a690021..a9bffceff1 100644 --- a/test/gtest/group_conv3d_bwd.cpp +++ b/test/gtest/group_conv3d_bwd.cpp @@ -44,7 +44,7 @@ void SolverBwd(const miopen::TensorDescriptor& inputDesc, const miopen::TensorDescriptor& outputDesc, ConstData_t output, const miopen::ConvolutionDescriptor& convDesc, - const ConvTestCase& conv_config, + const Conv3DTestCase& conv_config, bool& test_skipped) { auto&& handle = get_handle(); diff --git a/test/gtest/group_conv3d_bwd.hpp b/test/gtest/group_conv3d_bwd.hpp index 410d71e6d0..71702c5808 100644 --- a/test/gtest/group_conv3d_bwd.hpp +++ b/test/gtest/group_conv3d_bwd.hpp @@ -25,89 +25,9 @@ *******************************************************************************/ #pragma once -#include +#include "conv3d_test_case.hpp" -#include "get_handle.hpp" -#include - -#include "../driver/tensor_driver.hpp" -#include "conv_common.hpp" - -template -miopenDataType_t GetDataType(); - -template <> -miopenDataType_t GetDataType() -{ - return miopenFloat; -} - -template <> -miopenDataType_t GetDataType() -{ - return miopenHalf; -} - -template <> -miopenDataType_t GetDataType() -{ - return miopenInt8; -} - -struct ConvTestCase -{ - size_t G; - size_t N; - size_t C; - size_t D; - size_t H; - size_t W; - size_t k; - size_t z; - size_t y; - size_t x; - size_t pad_x; - size_t pad_y; - size_t pad_z; - size_t stride_x; - size_t stride_y; - size_t stride_z; - size_t dilation_x; - size_t dilation_y; - size_t dilation_z; - miopenConvolutionMode_t conv_mode; - friend std::ostream& operator<<(std::ostream& os, const ConvTestCase& tc) - { - return os << " G:" << tc.G << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D - << " H:" << tc.H << " W:" << tc.W << " k:" << tc.k << " z:" << tc.z - << " y:" << tc.y << " x:" << tc.x << " pad_z:" << tc.pad_z - << " pad_y:" << tc.pad_y << " pad_x:" << tc.pad_x << " stride_z:" << tc.stride_z - << " stride_y:" << tc.stride_y << " stride_x:" << tc.stride_x - << " dilation_z:" << tc.dilation_z << " dilation_y:" << tc.dilation_y - << " dilation_x:" << tc.dilation_x << " conv_mode:" << tc.conv_mode; - } - - std::vector GetInput() { return {N, C, D, H, W}; } - std::vector GetWeights() { return {k, C, z, y, x}; } - - miopen::ConvolutionDescriptor GetConv() - { - return miopen::ConvolutionDescriptor{ - 3, - miopenConvolution, - miopenPaddingDefault, - {static_cast(pad_z), static_cast(pad_y), static_cast(pad_x)}, - {static_cast(stride_z), static_cast(stride_y), static_cast(stride_x)}, - {static_cast(dilation_z), - static_cast(dilation_y), - static_cast(dilation_x)}, - {0, 0, 0}, - static_cast(G), - 1.0}; - } -}; - -std::vector ConvTestConfigs() +std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z return {{1, 128, 64, 14, 28, 28, 64, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, {1, 64, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, @@ -133,7 +53,7 @@ inline int SetTensorLayout(miopen::TensorDescriptor& desc) template struct ConvBwdSolverTest : public ::testing::TestWithParam< - std::tuple> + std::tuple> { protected: void SetUp() override @@ -188,7 +108,7 @@ struct ConvBwdSolverTest EXPECT_TRUE(error < threshold) << "Error beyond tolerance Error:" << error << ", Threshold: " << threshold; } - ConvTestCase conv_config; + Conv3DTestCase conv_config; miopen::ConvolutionDescriptor conv_desc; tensor input; tensor weights; diff --git a/test/gtest/group_conv3d_fwd.cpp b/test/gtest/group_conv3d_fwd.cpp index 2b52a1b43a..18d54355e8 100644 --- a/test/gtest/group_conv3d_fwd.cpp +++ b/test/gtest/group_conv3d_fwd.cpp @@ -44,7 +44,7 @@ void SolverFwd(const miopen::TensorDescriptor& inputDesc, const miopen::TensorDescriptor& outputDesc, Data_t output, const miopen::ConvolutionDescriptor& convDesc, - const ConvTestCase& conv_config, + const Conv3DTestCase& conv_config, bool& test_skipped) { auto&& handle = get_handle(); diff --git a/test/gtest/group_conv3d_fwd.hpp b/test/gtest/group_conv3d_fwd.hpp index 983f897d78..c8767399a7 100644 --- a/test/gtest/group_conv3d_fwd.hpp +++ b/test/gtest/group_conv3d_fwd.hpp @@ -25,89 +25,9 @@ *******************************************************************************/ #pragma once -#include +#include "conv3d_test_case.hpp" -#include "get_handle.hpp" -#include - -#include "../driver/tensor_driver.hpp" -#include "conv_common.hpp" - -template -miopenDataType_t GetDataType(); - -template <> -miopenDataType_t GetDataType() -{ - return miopenFloat; -} - -template <> -miopenDataType_t GetDataType() -{ - return miopenHalf; -} - -template <> -miopenDataType_t GetDataType() -{ - return miopenInt8; -} - -struct ConvTestCase -{ - size_t G; - size_t N; - size_t C; - size_t D; - size_t H; - size_t W; - size_t k; - size_t z; - size_t y; - size_t x; - size_t pad_x; - size_t pad_y; - size_t pad_z; - size_t stride_x; - size_t stride_y; - size_t stride_z; - size_t dilation_x; - size_t dilation_y; - size_t dilation_z; - miopenConvolutionMode_t conv_mode; - friend std::ostream& operator<<(std::ostream& os, const ConvTestCase& tc) - { - return os << " G:" << tc.G << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D - << " H:" << tc.H << " W:" << tc.W << " k:" << tc.k << " z:" << tc.z - << " y:" << tc.y << " x:" << tc.x << " pad_z:" << tc.pad_z - << " pad_y:" << tc.pad_y << " pad_x:" << tc.pad_x << " stride_z:" << tc.stride_z - << " stride_y:" << tc.stride_y << " stride_x:" << tc.stride_x - << " dilation_z:" << tc.dilation_z << " dilation_y:" << tc.dilation_y - << " dilation_x:" << tc.dilation_x << " conv_mode:" << tc.conv_mode; - } - - std::vector GetInput() { return {N, C, D, H, W}; } - std::vector GetWeights() { return {k, C, z, y, x}; } - - miopen::ConvolutionDescriptor GetConv() - { - return miopen::ConvolutionDescriptor{ - 3, - miopenConvolution, - miopenPaddingDefault, - {static_cast(pad_z), static_cast(pad_y), static_cast(pad_x)}, - {static_cast(stride_z), static_cast(stride_y), static_cast(stride_x)}, - {static_cast(dilation_z), - static_cast(dilation_y), - static_cast(dilation_x)}, - {0, 0, 0}, - static_cast(G), - 1.0}; - } -}; - -std::vector ConvTestConfigs() +std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z return {{1, 128, 64, 14, 28, 28, 64, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, {1, 64, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, @@ -139,7 +59,7 @@ inline int SetTensorLayout(miopen::TensorDescriptor& desc) template struct ConvFwdSolverTest : public ::testing::TestWithParam< - std::tuple> + std::tuple> { protected: void SetUp() override @@ -195,7 +115,7 @@ struct ConvFwdSolverTest EXPECT_TRUE(error < threshold) << "Error beyond tolerance Error:" << error << ", Threshold: " << threshold; } - ConvTestCase conv_config; + Conv3DTestCase conv_config; miopen::ConvolutionDescriptor conv_desc; tensor input; tensor weights; diff --git a/test/gtest/group_conv3d_wrw.cpp b/test/gtest/group_conv3d_wrw.cpp index 13e88da5ad..977a06220a 100644 --- a/test/gtest/group_conv3d_wrw.cpp +++ b/test/gtest/group_conv3d_wrw.cpp @@ -44,7 +44,7 @@ void SolverWrw(const miopen::TensorDescriptor& inputDesc, const miopen::TensorDescriptor& outputDesc, ConstData_t output, // dy const miopen::ConvolutionDescriptor& convDesc, - const ConvTestCase& conv_config, + const Conv3DTestCase& conv_config, bool& test_skipped) { diff --git a/test/gtest/group_conv3d_wrw.hpp b/test/gtest/group_conv3d_wrw.hpp index 76d8ae5d90..bf5824b4fa 100644 --- a/test/gtest/group_conv3d_wrw.hpp +++ b/test/gtest/group_conv3d_wrw.hpp @@ -25,89 +25,9 @@ *******************************************************************************/ #pragma once -#include +#include "conv3d_test_case.hpp" -#include "get_handle.hpp" -#include - -#include "../driver/tensor_driver.hpp" -#include "conv_common.hpp" - -template -miopenDataType_t GetDataType(); - -template <> -miopenDataType_t GetDataType() -{ - return miopenFloat; -} - -template <> -miopenDataType_t GetDataType() -{ - return miopenHalf; -} - -template <> -miopenDataType_t GetDataType() -{ - return miopenInt8; -} - -struct ConvTestCase -{ - size_t G; - size_t N; - size_t C; - size_t D; - size_t H; - size_t W; - size_t k; - size_t z; - size_t y; - size_t x; - size_t pad_x; - size_t pad_y; - size_t pad_z; - size_t stride_x; - size_t stride_y; - size_t stride_z; - size_t dilation_x; - size_t dilation_y; - size_t dilation_z; - miopenConvolutionMode_t conv_mode; - friend std::ostream& operator<<(std::ostream& os, const ConvTestCase& tc) - { - return os << " G:" << tc.G << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D - << " H:" << tc.H << " W:" << tc.W << " k:" << tc.k << " z:" << tc.z - << " y:" << tc.y << " x:" << tc.x << " pad_z:" << tc.pad_z - << " pad_y:" << tc.pad_y << " pad_x:" << tc.pad_x << " stride_z:" << tc.stride_z - << " stride_y:" << tc.stride_y << " stride_x:" << tc.stride_x - << " dilation_z:" << tc.dilation_z << " dilation_y:" << tc.dilation_y - << " dilation_x:" << tc.dilation_x << " conv_mode:" << tc.conv_mode; - } - - std::vector GetInput() { return {N, C, D, H, W}; } - std::vector GetWeights() { return {k, C, z, y, x}; } - - miopen::ConvolutionDescriptor GetConv() - { - return miopen::ConvolutionDescriptor{ - 3, - miopenConvolution, - miopenPaddingDefault, - {static_cast(pad_z), static_cast(pad_y), static_cast(pad_x)}, - {static_cast(stride_z), static_cast(stride_y), static_cast(stride_x)}, - {static_cast(dilation_z), - static_cast(dilation_y), - static_cast(dilation_x)}, - {0, 0, 0}, - static_cast(G), - 1.0}; - } -}; - -std::vector ConvTestConfigs() +std::vector ConvTestConfigs() { // g n c d h w k z y x pad_x pad_y pad_z stri_x stri_y stri_z dia_x dia_y dia_z return {{1, 128, 64, 14, 28, 28, 64, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, {1, 64, 32, 28, 28, 28, 32, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, miopenConvolution}, @@ -135,7 +55,7 @@ inline int SetTensorLayout(miopen::TensorDescriptor& desc) template struct ConvWrwSolverTest : public ::testing::TestWithParam< - std::tuple> + std::tuple> { protected: void SetUp() override @@ -191,7 +111,7 @@ struct ConvWrwSolverTest EXPECT_TRUE(error < threshold) << "Error beyond tolerance Error:" << error << ", Threshold: " << threshold; } - ConvTestCase conv_config; + Conv3DTestCase conv_config; miopen::ConvolutionDescriptor conv_desc; tensor input; tensor weights; diff --git a/test/gtest/group_solver.hpp b/test/gtest/group_solver.hpp index 6fe02e00da..3d9ebddca3 100644 --- a/test/gtest/group_solver.hpp +++ b/test/gtest/group_solver.hpp @@ -80,7 +80,11 @@ struct ConvTestCase } std::vector GetInput() { return {N, C, H, W}; } - std::vector GetWeights() { return {k, C, y, x}; } + std::vector GetWeights() + { + EXPECT_EQUAL(C % G, 0); + return {k, C / G, y, x}; + } miopen::ConvolutionDescriptor GetConv() {