From f021d5ce6c2ceee4877ca8edc4b8985bfc154f70 Mon Sep 17 00:00:00 2001 From: Justin Luitjens Date: Mon, 29 Aug 2022 13:09:52 -0600 Subject: [PATCH] adding madd function to improve convolution throughput (#258) added half benchmark Co-authored-by: jluitjens --- bench/00_transform/conv.cu | 3 +- include/matx/core/utils.h | 64 +++++++++++++++++++++++++++++++++++ include/matx/kernels/conv.cuh | 17 ++++++---- 3 files changed, 77 insertions(+), 7 deletions(-) diff --git a/bench/00_transform/conv.cu b/bench/00_transform/conv.cu index fcf6a42b..0a2984f8 100644 --- a/bench/00_transform/conv.cu +++ b/bench/00_transform/conv.cu @@ -1,10 +1,11 @@ #include "matx.h" #include +#include "matx/core/half_complex.h" using namespace matx; using conv_types = - nvbench::type_list, cuda::std::complex, float, double>; + nvbench::type_list, cuda::std::complex, float, double>; /* FFT benchmarks */ template diff --git a/include/matx/core/utils.h b/include/matx/core/utils.h index d433f730..f6352fe4 100644 --- a/include/matx/core/utils.h +++ b/include/matx/core/utils.h @@ -31,6 +31,10 @@ ///////////////////////////////////////////////////////////////////////////////// #pragma once + +#include +#include + #include "matx/core/defines.h" #include "matx/core/error.h" @@ -57,5 +61,65 @@ __MATX_INLINE__ bool IsAmpereOrAbove() { return GetComputeCapabilityMajor() >= AMPERE_CC; } +template +__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ auto madd( const T1 &x, const T2 &y, const T3 &z) { + using T4 = decltype(x*y+z); + if constexpr (is_complex_v && !is_complex_half_v) { + + using value_type = typename T4::value_type; + + value_type xr, xi; + value_type yr, yi; + value_type zr, zi; + + if constexpr (is_complex_v) { + xr = x.real(); + xi = x.imag(); + } else { + xr = x; + xi = value_type(0); + } + + if constexpr (is_complex_v) { + yr = y.real(); + yi = y.imag(); + } else { + yr = y; + yi = value_type(0); + } + + if constexpr (is_complex_v) { + zr = z.real(); + zi = z.imag(); + } else { + zr = z; + zi = value_type(0); + } + + T4 Z(zr,zi); + + Z.real(Z.real() + xr*yr); + Z.real(Z.real() - xi*yi); + + Z.imag(Z.imag() + xi*yr); + Z.imag(Z.imag() + xr*yi); + + return Z; + } else if constexpr (std::is_same_v) { + //__half2 X = make_half2(x.real(), x.imag()); + //__half2 Y = make_half2(y.real(), y.imag()); + //__half2 Z = make_half2(z.real(), z.imag()); + + const __half2 &X = *reinterpret_cast(&x); + const __half2 &Y = *reinterpret_cast(&y); + const __half2 &Z = *reinterpret_cast(&z); + + auto v = __hcmadd(X,Y,Z); + return T4(v.x, v.y); + } else { + return x*y+z; + } +} + }; }; diff --git a/include/matx/kernels/conv.cuh b/include/matx/kernels/conv.cuh index e1e1a19b..c1a3495c 100644 --- a/include/matx/kernels/conv.cuh +++ b/include/matx/kernels/conv.cuh @@ -9,6 +9,7 @@ #include #include "cuComplex.h" +#include "matx/core/utils.h" #include "matx/core/type_utils.h" #include "matx/core/tensor_utils.h" @@ -144,7 +145,11 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter, #else s_data += threadIdx.x + filter_len - 1; for (int32_t r = 0; r < filter_len; r++) { +#if 0 val = val + s_filter[0] * s_data[0]; +#else + val = detail::madd(s_filter[0], s_data[0], val); +#endif s_data--; s_filter++; } @@ -242,19 +247,19 @@ __global__ void Conv2D(OutType d_out, InType d_in, FilterType d_filter, } if constexpr (d_in.Rank() == 4) { - val += s_filter[y * d_filter.Size(1) + x] * + val = detail::madd(s_filter[y * d_filter.Size(1) + x], d_in(bdims[0], bdims[1], tid_y - d_filter.Size(0) + 1 + y, - tid_x - d_filter.Size(1) + 1 + x); + tid_x - d_filter.Size(1) + 1 + x), val); } else if constexpr (d_in.Rank() == 3) { - val += s_filter[y * d_filter.Size(1) + x] * + val = detail::madd(s_filter[y * d_filter.Size(1) + x], d_in(blockIdx.z, tid_y - d_filter.Size(0) + 1 + y, - tid_x - d_filter.Size(1) + 1 + x); + tid_x - d_filter.Size(1) + 1 + x), val); } else if constexpr (d_in.Rank() == 2) { - val += s_filter[y * d_filter.Size(1) + x] * + val = detail::madd(s_filter[y * d_filter.Size(1) + x], d_in(tid_y - d_filter.Size(0) + 1 + y, - tid_x - d_filter.Size(1) + 1 + x); + tid_x - d_filter.Size(1) + 1 + x), val); } } }