Skip to content

Commit

Permalink
adding madd function to improve convolution throughput
Browse files Browse the repository at this point in the history
added half benchmark
  • Loading branch information
luitjens committed Aug 29, 2022
1 parent beaa582 commit 7fc4ef0
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 7 deletions.
3 changes: 2 additions & 1 deletion bench/00_transform/conv.cu
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#include "matx.h"
#include <nvbench/nvbench.cuh>
#include "matx/core/half_complex.h"

using namespace matx;

using conv_types =
nvbench::type_list<cuda::std::complex<float>, cuda::std::complex<double>, float, double>;
nvbench::type_list<matxFp16Complex, cuda::std::complex<float>, cuda::std::complex<double>, float, double>;

/* FFT benchmarks */
template <typename ValueType>
Expand Down
64 changes: 64 additions & 0 deletions include/matx/core/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
/////////////////////////////////////////////////////////////////////////////////

#pragma once

#include <type_traits>
#include <cuda_fp16.h>

#include "matx/core/defines.h"
#include "matx/core/error.h"

Expand All @@ -57,5 +61,65 @@ __MATX_INLINE__ bool IsAmpereOrAbove() {
return GetComputeCapabilityMajor() >= AMPERE_CC;
}

template <typename T1, typename T2, typename T3>
__MATX_HOST__ __MATX_DEVICE__ __MATX_INLINE__ auto madd( const T1 &x, const T2 &y, const T3 &z) {
typedef decltype(x*y+z) T4;
if constexpr (is_complex_v<T4> && !is_complex_half_v<T4>) {

typedef typename T4::value_type value_type;

value_type xr, xi;
value_type yr, yi;
value_type zr, zi;

if constexpr (is_complex_v<T1>) {
xr = x.real();
xi = x.imag();
} else {
xr = x;
xi = value_type(0);
}

if constexpr (is_complex_v<T2>) {
yr = y.real();
yi = y.imag();
} else {
yr = y;
yi = value_type(0);
}

if constexpr (is_complex_v<T3>) {
zr = z.real();
zi = z.imag();
} else {
zr = z;
zi = value_type(0);
}

T4 Z(zr,zi);

Z.real(Z.real() + xr*yr);
Z.real(Z.real() - xi*yi);

Z.imag(Z.imag() + xi*yr);
Z.imag(Z.imag() + xr*yi);

return Z;
} else if constexpr (std::is_same_v<T4, matxFp16Complex>) {
//__half2 X = make_half2(x.real(), x.imag());
//__half2 Y = make_half2(y.real(), y.imag());
//__half2 Z = make_half2(z.real(), z.imag());

const __half2 &X = *reinterpret_cast<const __half2*>(&x);
const __half2 &Y = *reinterpret_cast<const __half2*>(&y);
const __half2 &Z = *reinterpret_cast<const __half2*>(&z);

auto v = __hcmadd(X,Y,Z);
return T4(v.x, v.y);
} else {
return x*y+z;
}
}

};
};
17 changes: 11 additions & 6 deletions include/matx/kernels/conv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <vector>

#include "cuComplex.h"
#include "matx/core/utils.h"
#include "matx/core/type_utils.h"
#include "matx/core/tensor_utils.h"

Expand Down Expand Up @@ -144,7 +145,11 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
#else
s_data += threadIdx.x + filter_len - 1;
for (int32_t r = 0; r < filter_len; r++) {
#if 0
val = val + s_filter[0] * s_data[0];
#else
val = detail::madd(s_filter[0], s_data[0], val);
#endif
s_data--;
s_filter++;
}
Expand Down Expand Up @@ -242,19 +247,19 @@ __global__ void Conv2D(OutType d_out, InType d_in, FilterType d_filter,
}

if constexpr (d_in.Rank() == 4) {
val += s_filter[y * d_filter.Size(1) + x] *
val = detail::madd(s_filter[y * d_filter.Size(1) + x],
d_in(bdims[0], bdims[1], tid_y - d_filter.Size(0) + 1 + y,
tid_x - d_filter.Size(1) + 1 + x);
tid_x - d_filter.Size(1) + 1 + x), val);
}
else if constexpr (d_in.Rank() == 3) {
val += s_filter[y * d_filter.Size(1) + x] *
val = detail::madd(s_filter[y * d_filter.Size(1) + x],
d_in(blockIdx.z, tid_y - d_filter.Size(0) + 1 + y,
tid_x - d_filter.Size(1) + 1 + x);
tid_x - d_filter.Size(1) + 1 + x), val);
}
else if constexpr (d_in.Rank() == 2) {
val += s_filter[y * d_filter.Size(1) + x] *
val = detail::madd(s_filter[y * d_filter.Size(1) + x],
d_in(tid_y - d_filter.Size(0) + 1 + y,
tid_x - d_filter.Size(1) + 1 + x);
tid_x - d_filter.Size(1) + 1 + x), val);
}
}
}
Expand Down

0 comments on commit 7fc4ef0

Please sign in to comment.