From e9c459c6096b1dd4c26f3248b4b61501b6c0d403 Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Thu, 2 Nov 2023 10:14:02 -0700 Subject: [PATCH 1/3] Optimize resample_poly with a multi-kernel approach Add additional resample_poly kernels that are more optimal in specific cases (e.g., in the case of a pure downsampler). Currently, there are three kernels, but only two of them are in use. There are some cases for which the third is fastest, but a pattern for when that will be the case is not yet clear. --- examples/resample_poly_bench.cu | 55 ++++- include/matx/kernels/resample_poly.cuh | 290 ++++++++++++++++++++++-- include/matx/transforms/resample_poly.h | 127 +++++++++-- test/00_transform/ResamplePoly.cu | 9 +- 4 files changed, 434 insertions(+), 47 deletions(-) diff --git a/examples/resample_poly_bench.cu b/examples/resample_poly_bench.cu index 8301f4c3..124bc412 100644 --- a/examples/resample_poly_bench.cu +++ b/examples/resample_poly_bench.cu @@ -46,7 +46,7 @@ using namespace matx; // length, input signal length, up/down factors) will be adjusted to a range of interest // and the benchmark will be run with and without the proposed kernel changes. -constexpr int NUM_WARMUP_ITERATIONS = 2; +constexpr int NUM_WARMUP_ITERATIONS = 10; // Number of iterations per timed test. Iteration times are averaged in the report. constexpr int NUM_ITERATIONS = 20; @@ -61,29 +61,60 @@ void ResamplePolyBench() matx::index_t down; } test_cases[] = { { 1, 256, 384, 3125 }, + { 1, 256, 384, 175 }, { 1, 256, 4, 5 }, { 1, 256, 1, 4 }, { 1, 256, 1, 16 }, { 1, 3000, 384, 3125 }, + { 1, 3000, 384, 175 }, { 1, 3000, 4, 5 }, { 1, 3000, 1, 4 }, { 1, 3000, 1, 16 }, { 1, 31000, 384, 3125 }, + { 1, 31000, 384, 175 }, { 1, 31000, 4, 5 }, { 1, 31000, 1, 4 }, { 1, 31000, 1, 16 }, { 1, 256000, 384, 3125 }, + { 1, 256000, 384, 175 }, + { 1, 256000, 64, 37 }, + { 1, 256000, 2, 3 }, { 1, 256000, 4, 5 }, + { 1, 256000, 7, 64 }, + { 1, 256000, 7, 128 }, { 1, 256000, 1, 4 }, + { 1, 256000, 1, 8 }, { 1, 256000, 1, 16 }, + { 1, 256000, 1, 64 }, + { 1, 256000, 4, 1 }, + { 1, 256000, 8, 1 }, + { 1, 256000, 16, 1 }, + { 1, 256000, 64, 1 }, + { 1, 256000, 4, 1 }, + { 1, 256000, 8, 1 }, + { 1, 256000, 16, 1 }, { 42, 256000, 384, 3125 }, + { 42, 256000, 384, 175 }, + { 42, 256000, 2, 3 }, { 42, 256000, 4, 5 }, { 42, 256000, 1, 4 }, + { 42, 256000, 1, 8 }, { 42, 256000, 1, 16 }, - { 128, 256000, 384, 3125 }, - { 128, 256000, 4, 5 }, - { 128, 256000, 1, 4 }, - { 128, 256000, 1, 16 }, + { 42, 256000, 1, 64 }, + { 1, 100000000, 384, 3125 }, + { 1, 100000000, 384, 175 }, + { 1, 100000000, 2, 3 }, + { 1, 100000000, 4, 5 }, + { 1, 100000000, 7, 64 }, + { 1, 100000000, 7, 128 }, + { 1, 100000000, 1, 2 }, + { 1, 100000000, 1, 4 }, + { 1, 100000000, 1, 8 }, + { 1, 100000000, 1, 16 }, + { 1, 100000000, 1, 192 }, + { 1, 100000000, 4, 1 }, + { 1, 100000000, 8, 1 }, + { 1, 100000000, 16, 1 }, }; cudaStream_t stream; @@ -104,9 +135,14 @@ void ResamplePolyBench() const index_t up_len = input_len * up; const index_t output_len = up_len / down + ((up_len % down) ? 1 : 0); - auto input = matx::make_tensor({num_batches, input_len}); - auto filter = matx::make_tensor({filter_len}); - auto output = matx::make_tensor({num_batches, output_len}); + auto input = matx::make_tensor({num_batches, input_len}, MATX_DEVICE_MEMORY); + auto filter = matx::make_tensor({filter_len}, MATX_DEVICE_MEMORY); + auto output = matx::make_tensor({num_batches, output_len}, MATX_DEVICE_MEMORY); + + (input = static_cast(1.0)).run(stream); + (filter = static_cast(1.0)).run(stream); + + cudaStreamSynchronize(stream); for (int k = 0; k < NUM_WARMUP_ITERATIONS; k++) { (output = matx::resample_poly(input, filter, up, down)).run(stream); @@ -126,7 +162,8 @@ void ResamplePolyBench() const double gflops = static_cast(num_batches*(2*filter_len_per_phase-1)*output_len) / 1.0e9; const double avg_elapsed_us = (static_cast(elapsed_ms)/NUM_ITERATIONS)*1.0e3; - printf("Batches: %5lld FilterLen: %5lld InputLen: %7lld OutputLen: %7lld Up/Down: %4lld/%4lld Elapsed Usecs: %12.1f GFLOPS: %10.3f\n", + printf("Batches: %5" INDEX_T_FMT " FilterLen: %5" INDEX_T_FMT " InputLen: %9" INDEX_T_FMT " OutputLen: %8" INDEX_T_FMT + " Up/Down: %4" INDEX_T_FMT "/%4" INDEX_T_FMT " Elapsed Usecs: %12.1f GFLOPS: %10.3f\n", num_batches, filter_len, input_len, output_len, up, down, avg_elapsed_us, gflops/(avg_elapsed_us/1.0e6)); } diff --git a/include/matx/kernels/resample_poly.cuh b/include/matx/kernels/resample_poly.cuh index 326e8b97..69bedfd4 100644 --- a/include/matx/kernels/resample_poly.cuh +++ b/include/matx/kernels/resample_poly.cuh @@ -40,6 +40,11 @@ #include #include +#include +#include + +namespace cg = cooperative_groups; + #include "cuComplex.h" #include "matx/core/utils.h" #include "matx/core/type_utils.h" @@ -47,17 +52,26 @@ namespace matx { +// Use for __launch_bounds__ to allow the compiler to tune register usage +static constexpr int MATX_RESAMPLE_POLY_MAX_NUM_THREADS = 256; + +// We use a static 11 KiB buffer to potentially store the filter. If it fits, we will load it +// into smem. If not, then we will load it from global memory at the time of use. We choose +// 11 KiB so that we can definitely fit four blocks in 48 KiB, leaving 1 KiB per block +// for the driver. +static constexpr size_t MATX_RESAMPLE_POLY_MAX_SMEM_BYTES = 11*1024; + #ifdef __CUDACC__ -template -__launch_bounds__(THREADS) -__global__ void ResamplePoly1D(OutType output, InType input, FilterType filter, +template +__launch_bounds__(MATX_RESAMPLE_POLY_MAX_NUM_THREADS) +__global__ void ResamplePoly1D_PhaseBlock(OutType output, InType input, FilterType filter, index_t up, index_t down, index_t elems_per_thread) { using output_t = typename OutType::scalar_type; using input_t = typename InType::scalar_type; using filter_t = typename FilterType::scalar_type; - extern __shared__ __align__(alignof(double4)) uint8_t smem_filter[]; + extern __shared__ uint8_t smem_filter[]; filter_t *s_filter = reinterpret_cast(smem_filter); constexpr int Rank = OutType::Rank(); @@ -159,9 +173,6 @@ __global__ void ResamplePoly1D(OutType output, InType input, FilterType filter, } } - - __syncthreads(); - // left_h_ind is the index in s_filter that contains the filter tap that will be applied to the // last input signal value not to the right of the output index in the virtual upsampled array. // If the filter has odd length and a given output value aligns with an input value, then @@ -171,19 +182,24 @@ __global__ void ResamplePoly1D(OutType output, InType input, FilterType filter, const index_t left_h_ind = (last_filter_ind - left_filter_ind)/up; const index_t max_h_epilogue = this_phase_len - left_h_ind - 1; - const index_t max_input_ind = static_cast(input_len) - 1; + const index_t max_input_ind = input_len - 1; const index_t start_ind = phase_ind + up * (tid + elem_block * elems_per_thread * THREADS); const index_t last_ind = std::min(output_len - 1, start_ind + elems_per_thread * THREADS * up); + __syncthreads(); for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS * up) { - // out_ind is the index in the output array and up_ind is the corresponding - // index in the upsampled array + // out_ind is the index in the output array and up_ind = out_ind * down is the + // corresponding index in the upsampled array const index_t up_ind = out_ind * down; // input_ind is the largest index in the input array that is not greater than - // (to the right of, in the previous figure earlier) up_ind. + // (to the right of, in the previous figure earlier) up_ind. This is equivalent + // to up_ind / up where up_ind = out_ind * down, but we increment rather than + // divide to avoid integer divisions. out_ind increments by THREADS * up each + // iteration, so adding THREADS * up * down and dividing by up is equivalent + // to adding THREADS * down. + // input_ind += THREADS * down; const index_t input_ind = up_ind / up; - // We want x_ind and h_ind to be the first aligned input and filter samples // of the convolution and n to be the number of taps. prologue is the number // of valid samples before input_ind. In the case that the filter is not @@ -217,6 +233,254 @@ __global__ void ResamplePoly1D(OutType output, InType input, FilterType filter, } } +template +__device__ inline void ResamplePoly1D_LoadFilter(typename FilterType::scalar_type *s_filter, const FilterType &filter) +{ + const index_t filter_len = filter.Size(0); + const int tid = threadIdx.x; + if (filter_len % 2 == 0) { + for (int t = tid; t < filter_len; t += THREADS) { + s_filter[t+1] = filter.operator()(t); + } + if (tid == 0) { + s_filter[0] = static_cast(0); + } + } else { + for (int t = tid; t < filter_len; t += THREADS) { + s_filter[t] = filter.operator()(t); + } + } + __syncthreads(); +} + +template +__launch_bounds__(MATX_RESAMPLE_POLY_MAX_NUM_THREADS) +__global__ void ResamplePoly1D_ElemBlock(OutType output, InType input, FilterType filter, + index_t up, index_t down, index_t elems_per_thread) +{ + using output_t = typename OutType::scalar_type; + using input_t = typename InType::scalar_type; + using filter_t = typename FilterType::scalar_type; + + extern __shared__ uint8_t smem_filter[]; + filter_t *s_filter = reinterpret_cast(smem_filter); + + constexpr int Rank = OutType::Rank(); + const index_t output_len = output.Size(Rank-1); + index_t filter_len = filter.Size(0); + const index_t input_len = input.Size(Rank-1); + + const size_t filter_sz_bytes = (filter_len % 2 == 0) ? sizeof(filter_t)*(filter_len+1) : sizeof(filter_t)*filter_len; + const bool load_filter_to_smem = (filter_sz_bytes <= MATX_RESAMPLE_POLY_MAX_SMEM_BYTES); + + const int elem_block = blockIdx.z; + const int tid = threadIdx.x; + // const int THREADS = blockDim.x; + + if (load_filter_to_smem) { + ResamplePoly1D_LoadFilter(s_filter, filter); + if (filter_len % 2 == 0) { + filter_len++; + } + } + + // All but the last dim are batch indices + const int batch_idx = blockIdx.x; + auto bdims = BlockToIdx(output, batch_idx, 1); + + // Scale the filter coefficients by up to match scipy's convention + const filter_t scale = static_cast(up); + const index_t max_input_ind = input_len - 1; + + const index_t filter_len_half = filter_len/2; + const index_t filter_central_tap = (filter_len-1)/2; + const index_t start_ind = elem_block * elems_per_thread * THREADS + tid; + const index_t last_ind = std::min(output_len - 1, start_ind + (elems_per_thread-1) * THREADS); + if (load_filter_to_smem) { + for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS) { + const index_t up_ind = out_ind * down; + const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t x_start = (up_start + up - 1) / up; + index_t x_end = up_end / up; + // Since the filter is in shared memory, we can narrow the index type to 32 bits + int h_ind = static_cast(filter_central_tap + (up_ind - up*x_start)); + + output_t accum {}; + input_t in_val; + for (index_t i = x_start; i <= x_end; i++) { + bdims[Rank - 1] = i; + detail::mapply([&in_val, &input](auto &&...args) { + in_val = input.operator()(args...); + }, bdims); + accum += in_val * s_filter[h_ind]; + h_ind -= up; + } + + accum *= scale; + bdims[Rank - 1] = out_ind; + detail::mapply([&accum, &output](auto &&...args) { + output.operator()(args...) = accum; + }, bdims); + } + } else { + for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS) { + const index_t up_ind = out_ind * down; + const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t x_start = (up_start + up - 1) / up; + index_t x_end = up_end / up; + index_t h_ind = filter_central_tap + (up_ind - up*x_start); + if (h_ind - up*(x_end-x_start) < 0) { + x_end--; + } + + output_t accum {}; + input_t in_val; + for (index_t i = x_start; i <= x_end; i++) { + bdims[Rank - 1] = i; + detail::mapply([&in_val, &input](auto &&...args) { + in_val = input.operator()(args...); + }, bdims); + accum += in_val * filter.operator()(h_ind); + h_ind -= up; + } + + accum *= scale; + bdims[Rank - 1] = out_ind; + detail::mapply([&accum, &output](auto &&...args) { + output.operator()(args...) = accum; + }, bdims); + } + } + +} + +template +__launch_bounds__(MATX_RESAMPLE_POLY_MAX_NUM_THREADS) +__global__ void ResamplePoly1D_WarpCentric(OutType output, InType input, FilterType filter, + index_t up, index_t down, index_t elems_per_warp) +{ + using output_t = typename OutType::scalar_type; + using input_t = typename InType::scalar_type; + using filter_t = typename FilterType::scalar_type; + + auto block = cg::this_thread_block(); + auto tile = cg::tiled_partition(block); + const int warp_id = tile.meta_group_rank(); + const int NUM_WARPS = THREADS / WARP_SIZE; + const int lane_id = tile.thread_rank(); + + extern __shared__ uint8_t smem_filter[]; + filter_t *s_filter = reinterpret_cast(smem_filter); + + constexpr int Rank = OutType::Rank(); + const index_t output_len = output.Size(Rank-1); + index_t filter_len = filter.Size(0); + const index_t input_len = input.Size(Rank-1); + + const size_t filter_sz_bytes = (filter_len % 2 == 0) ? sizeof(filter_t)*(filter_len+1) : sizeof(filter_t)*filter_len; + const bool load_filter_to_smem = (filter_sz_bytes <= MATX_RESAMPLE_POLY_MAX_SMEM_BYTES); + + const int elem_block = blockIdx.z; + + if (load_filter_to_smem) { + ResamplePoly1D_LoadFilter(s_filter, filter); + if (filter_len % 2 == 0) { + filter_len++; + } + } + + // All but the last dim are batch indices + const int batch_idx = blockIdx.x; + auto bdims = BlockToIdx(output, batch_idx, 1); + + // Scale the filter coefficients by up to match scipy's convention + const filter_t scale = static_cast(up); + const index_t max_input_ind = input_len - 1; + + const index_t filter_len_half = filter_len/2; + const index_t filter_central_tap = (filter_len-1)/2; + const index_t start_ind = elem_block * elems_per_warp * NUM_WARPS; + const index_t last_ind = std::min(output_len - 1, start_ind + elems_per_warp * NUM_WARPS - 1); + if (load_filter_to_smem) { + for (index_t out_ind = start_ind+warp_id; out_ind <= last_ind; out_ind += NUM_WARPS) { + const index_t up_ind = out_ind * down; + const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t x_start = (up_start + up - 1) / up; + index_t x_end = up_end / up; + // Since the filter is in shared memory, we can narrow the index type to 32 bits + int h_ind = static_cast(filter_central_tap + (up_ind - up*x_start)) - lane_id*up; + + output_t accum {}; + input_t in_val; + for (index_t i = x_start+lane_id; i <= x_end; i += WARP_SIZE) { + bdims[Rank - 1] = i; + detail::mapply([&in_val, &input](auto &&...args) { + in_val = input.operator()(args...); + }, bdims); + accum += in_val * s_filter[h_ind]; + h_ind -= up * WARP_SIZE; + } + + accum *= scale; + if constexpr (is_complex_v) { + using inner_type = typename inner_op_type_t::type; + accum.real(cg::reduce(tile, accum.real(), cg::plus())); + accum.imag(cg::reduce(tile, accum.imag(), cg::plus())); + } else { + accum = cg::reduce(tile, accum, cg::plus()); + } + if (lane_id == 0) { + bdims[Rank - 1] = out_ind; + detail::mapply([&accum, &output](auto &&...args) { + output.operator()(args...) = accum; + }, bdims); + } + } + } else { + for (index_t out_ind = start_ind+warp_id; out_ind <= last_ind; out_ind += NUM_WARPS) { + const index_t up_ind = out_ind * down; + const index_t up_start = std::max(static_cast(0), up_ind - filter_len_half); + const index_t up_end = std::min(max_input_ind * up, up_ind + filter_len_half); + const index_t x_start = (up_start + up - 1) / up; + index_t x_end = up_end / up; + index_t h_ind = filter_central_tap + (up_ind - up*x_start); + if (h_ind - up*(x_end-x_start) < 0) { + x_end--; + } + h_ind -= lane_id*up; + + output_t accum {}; + input_t in_val; + for (index_t i = x_start+lane_id; i <= x_end; i += WARP_SIZE) { + bdims[Rank - 1] = i; + detail::mapply([&in_val, &input](auto &&...args) { + in_val = input.operator()(args...); + }, bdims); + accum += in_val * filter.operator()(h_ind); + h_ind -= up * WARP_SIZE; + } + + accum *= scale; + if constexpr (is_complex_v) { + using inner_type = typename inner_op_type_t::type; + accum.real(cg::reduce(tile, accum.real(), cg::plus())); + accum.imag(cg::reduce(tile, accum.imag(), cg::plus())); + } else { + accum = cg::reduce(tile, accum, cg::plus()); + } + if (lane_id == 0) { + bdims[Rank - 1] = out_ind; + detail::mapply([&accum, &output](auto &&...args) { + output.operator()(args...) = accum; + }, bdims); + } + } + } +} + #endif // __CUDACC__ -}; // namespace matx \ No newline at end of file +}; // namespace matx diff --git a/include/matx/transforms/resample_poly.h b/include/matx/transforms/resample_poly.h index 5e63c26b..0d6dc44a 100644 --- a/include/matx/transforms/resample_poly.h +++ b/include/matx/transforms/resample_poly.h @@ -55,36 +55,115 @@ inline void matxResamplePoly1DInternal(OutType &o, const InType &i, using input_t = typename InType::scalar_type; using filter_t = typename FilterType::scalar_type; + using output_t = typename OutType::scalar_type; using shape_type = typename OutType::shape_type; - - shape_type filter_len = filter.Size(FilterType::Rank()-1); // Even-length filters will be prepended with a single 0 to make them odd-length - const int max_phase_len = (filter_len % 2 == 0) ? - static_cast((filter_len + 1 + up - 1) / up) : - static_cast((filter_len + up - 1) / up); - const size_t filter_shm = sizeof(filter_t) * max_phase_len; + const shape_type filter_len = filter.Size(FilterType::Rank()-1); + const index_t max_phase_len = (filter_len % 2 == 0) ? + ((filter_len + 1 + up - 1) / up) : + ((filter_len + up - 1) / up); + + auto downcast_to_32b_index = [&i, filter_len, up, down]() -> bool { + if constexpr (sizeof(index_t) == 4) { + // The index is already 32 bits + return false; + } else { + return + // + 1 because we may include a zero padded after the last input element + (i.Size(i.Rank() - 1)+1) * up <= std::numeric_limits::max() && + (filter_len+1) <= std::numeric_limits::max() && + down <= std::numeric_limits::max(); + } + }; + const index_t output_len = o.Size(OutType::Rank()-1); - const index_t max_output_len_per_phase = (output_len + up - 1) / up; - const int num_phases = static_cast(up); + + // We default to the ElemBlock kernel as it tends to work well for general problems. + enum class ResampleKernel { + PhaseBlock, + ElemBlock, + WarpCentric, + } kernel = ResampleKernel::ElemBlock; + + // The WarpCentric kernel currently uses cg::reduce(), which requires trivially-copyable types. + if constexpr (std::is_trivially_copyable_v) { + // There are a couple cases where a warp-centric resampler tends to be faster: + // 1. When we have a small number of output points, handling one or a few points per warp is an effective + // way to achieve higher occupancy. + // 2. When we have many filter taps per output point, each thread in the warp will be able to read + // multiple elements and the warp will tend to achieve coalesced reads. This helps to prevent loop + // overhead and barrier stalls from dominating. + if (output_len <= 2048 || max_phase_len > 256) { + kernel = ResampleKernel::WarpCentric; + } + } + + // Currently, we select only ElemBlock or WarpCentric to keep things simpler. However, there are some + // cases where PhaseBlock is the fastest kernel. If there are specific parameter sets of interest, then + // we can benchmark the PhaseBlock method and, if it proves fastest, use that method in those cases. + + // Desired number of blocks to reach high occupancy + constexpr index_t DESIRED_MIN_GRID_SIZE = 8192; const int num_batches = static_cast(TotalSize(i)/i.Size(i.Rank() - 1)); - dim3 grid(num_batches, num_phases); - constexpr int THREADS = 128; - constexpr index_t DESIRED_MIN_GRID_SIZE = 512; - // If we do not have enough batches and phases to create a large grid, then - // we try to reduce the number of output elements generated per thread to - // yield a large-enough grid to saturate the GPU. However, since the filter - // taps are stored in shared memory, we do not want to process fewer elements - // per thread than is necessary to saturate the GPU. - if (num_batches * num_phases < DESIRED_MIN_GRID_SIZE) { - const index_t desired_elem_blocks = (DESIRED_MIN_GRID_SIZE + num_batches * num_phases - 1) / - (num_batches * num_phases); - const index_t max_output_len_per_thread = (max_output_len_per_phase + THREADS - 1) / THREADS; - grid.z = static_cast(std::min(desired_elem_blocks, max_output_len_per_thread)); + dim3 grid(num_batches, 1, 1); + // comp_unit is either a thread or a warp, depending on the kernel. It is the size of the computational + // unit that collectively computes a single output value. + auto compute_elems_per_comp_unit = [&grid, DESIRED_MIN_GRID_SIZE](index_t max_outlen_per_cta, int cta_comp_unit_count) -> index_t { + const int start_batch_size = grid.x * grid.y; + const index_t desired_extra_batches = (DESIRED_MIN_GRID_SIZE + start_batch_size - 1) / + start_batch_size; + const index_t max_outlen_per_comp_unit = (max_outlen_per_cta + cta_comp_unit_count - 1) / + cta_comp_unit_count; + grid.z = static_cast(std::min(desired_extra_batches, max_outlen_per_comp_unit)); + return (max_outlen_per_cta + cta_comp_unit_count * grid.z - 1) / (cta_comp_unit_count * grid.z); + }; + + constexpr int THREADS = MATX_RESAMPLE_POLY_MAX_NUM_THREADS; + if (kernel == ResampleKernel::PhaseBlock) { + const size_t smemBytes = (sizeof(filter_t) * max_phase_len <= MATX_RESAMPLE_POLY_MAX_SMEM_BYTES) ? + sizeof(filter_t) * max_phase_len : 0; + const index_t max_output_len_per_phase = (output_len + up - 1) / up; + grid.y = static_cast(up); + const index_t elems_per_thread = compute_elems_per_comp_unit(max_output_len_per_phase, THREADS); + if (downcast_to_32b_index()) { + ResamplePoly1D_PhaseBlock<<>>( + o, i, filter, static_cast(up), static_cast(down), + static_cast(elems_per_thread)); + } else { + ResamplePoly1D_PhaseBlock<<>>( + o, i, filter, up, down, elems_per_thread); + } + } else if (kernel == ResampleKernel::ElemBlock) { + const size_t filter_sz_bytes = (filter_len % 2 == 0) ? sizeof(filter_t)*(filter_len+1) : sizeof(filter_t)*filter_len; + const size_t smemBytes = (filter_sz_bytes <= MATX_RESAMPLE_POLY_MAX_SMEM_BYTES) ? filter_sz_bytes : 0; + const index_t elems_per_thread = compute_elems_per_comp_unit(output_len, THREADS); + if (downcast_to_32b_index()) { + ResamplePoly1D_ElemBlock<<>>( + o, i, filter, static_cast(up), static_cast(down), + static_cast(elems_per_thread)); + } else { + ResamplePoly1D_ElemBlock<<>>( + o, i, filter, up, down, elems_per_thread); + } + } else { + // We only select the WarpCentric kernel for trivially copyable types, but we need this + // constexpr if to avoid instantiating the kernel with inappropriate types. + if constexpr (std::is_trivially_copyable_v) { + const size_t filter_sz_bytes = (filter_len % 2 == 0) ? sizeof(filter_t)*(filter_len+1) : sizeof(filter_t)*filter_len; + const size_t smemBytes = (filter_sz_bytes <= MATX_RESAMPLE_POLY_MAX_SMEM_BYTES) ? filter_sz_bytes : 0; + static_assert(THREADS % WARP_SIZE == 0); + const index_t elems_per_warp = compute_elems_per_comp_unit(output_len, THREADS/WARP_SIZE); + if (downcast_to_32b_index()) { + ResamplePoly1D_WarpCentric<<>>( + o, i, filter, static_cast(up), static_cast(down), + static_cast(elems_per_warp)); + } else { + ResamplePoly1D_WarpCentric<<>>( + o, i, filter, up, down, elems_per_warp); + } + } } - const index_t elems_per_thread = (max_output_len_per_phase + THREADS * grid.z - 1) / (THREADS * grid.z); - ResamplePoly1D<<>>( - o, i, filter, up, down, elems_per_thread); #endif } diff --git a/test/00_transform/ResamplePoly.cu b/test/00_transform/ResamplePoly.cu index 72317ecb..84beeb0a 100644 --- a/test/00_transform/ResamplePoly.cu +++ b/test/00_transform/ResamplePoly.cu @@ -268,6 +268,9 @@ TYPED_TEST(ResamplePolyTestFloatTypes, DefaultFilter) { 350, 1, 7 }, { 351, 7, 1 }, { 351, 1, 7 }, + { 1000000, 5, 1 }, + { 1000000, 1, 5 }, + { 1000000, 2, 3 }, }; for (size_t i = 0; i < sizeof(test_cases)/sizeof(test_cases[0]); i++) { @@ -509,7 +512,11 @@ TYPED_TEST(ResamplePolyTestNonHalfFloatTypes, Upsample) aj = (j % up == 0) ? a(j/up) : 0; bj = b(j); } - ASSERT_NEAR(aj, bj, 1.0e-16); + if (j % up == 0) { + ASSERT_NEAR(aj, bj, this->thresh); + } else { + ASSERT_EQ(bj, 0.0); + } } } From c27b45c74a6d9340edb8f78d3e0210b697dfc38a Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Thu, 2 Nov 2023 10:21:38 -0700 Subject: [PATCH 2/3] Remove outdated comment --- include/matx/kernels/resample_poly.cuh | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/include/matx/kernels/resample_poly.cuh b/include/matx/kernels/resample_poly.cuh index 69bedfd4..1d75b0a2 100644 --- a/include/matx/kernels/resample_poly.cuh +++ b/include/matx/kernels/resample_poly.cuh @@ -173,6 +173,8 @@ __global__ void ResamplePoly1D_PhaseBlock(OutType output, InType input, FilterTy } } + __syncthreads(); + // left_h_ind is the index in s_filter that contains the filter tap that will be applied to the // last input signal value not to the right of the output index in the virtual upsampled array. // If the filter has odd length and a given output value aligns with an input value, then @@ -186,19 +188,13 @@ __global__ void ResamplePoly1D_PhaseBlock(OutType output, InType input, FilterTy const index_t start_ind = phase_ind + up * (tid + elem_block * elems_per_thread * THREADS); const index_t last_ind = std::min(output_len - 1, start_ind + elems_per_thread * THREADS * up); - __syncthreads(); for (index_t out_ind = start_ind; out_ind <= last_ind; out_ind += THREADS * up) { // out_ind is the index in the output array and up_ind = out_ind * down is the // corresponding index in the upsampled array const index_t up_ind = out_ind * down; // input_ind is the largest index in the input array that is not greater than - // (to the right of, in the previous figure earlier) up_ind. This is equivalent - // to up_ind / up where up_ind = out_ind * down, but we increment rather than - // divide to avoid integer divisions. out_ind increments by THREADS * up each - // iteration, so adding THREADS * up * down and dividing by up is equivalent - // to adding THREADS * down. - // input_ind += THREADS * down; + // (to the right of, in the previous figure earlier) up_ind. const index_t input_ind = up_ind / up; // We want x_ind and h_ind to be the first aligned input and filter samples // of the convolution and n to be the number of taps. prologue is the number From 0bd065862f8645322a63b3c1b370c884e56391a6 Mon Sep 17 00:00:00 2001 From: Thomas Benson Date: Thu, 2 Nov 2023 12:15:42 -0700 Subject: [PATCH 3/3] Add comment clarifying filter index calculations --- include/matx/kernels/resample_poly.cuh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/matx/kernels/resample_poly.cuh b/include/matx/kernels/resample_poly.cuh index 1d75b0a2..6fe00578 100644 --- a/include/matx/kernels/resample_poly.cuh +++ b/include/matx/kernels/resample_poly.cuh @@ -289,6 +289,15 @@ __global__ void ResamplePoly1D_ElemBlock(OutType output, InType input, FilterTyp const index_t max_input_ind = input_len - 1; const index_t filter_len_half = filter_len/2; + // The loops below assume odd-length filters with a central tap. In the case of storing an + // even-length filter to smem, a zero is pre-pended to the filter (prior to flipping for convolution) + // so that the stored filter length is always odd-length. + // Thus, for a stored filter, both filter_len/2 and (filter_len-1)/2 reference the central tap. + // In the case of an originally even-length filter, the index of the central tap in the filter + // tensor is filter_len/2 - 1. When not storing the filter to smem, we want the same central + // tap, so we compute the index as (filter_len-1)/2. This will return the same result for + // natively odd-length filters, but for even-length filters will reference the same coefficient + // whether or not the filter has been loaded to shared memory. const index_t filter_central_tap = (filter_len-1)/2; const index_t start_ind = elem_block * elems_per_thread * THREADS + tid; const index_t last_ind = std::min(output_len - 1, start_ind + (elems_per_thread-1) * THREADS);