You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On my system using matx, step 2 takes about 40us and step 3 takes 380us.
A custom CUDA kernel for steps 2&3 is taking 30us.
To Reproduce
Allocate complex<float> input tensor of size 164750
Create a hamming window of size 500
Call matx::pwelch(input, window, 500, 250, 65536)
Expected Behavior
Better performance
Code Snippets
The custom CUDA kernel in this case is:
template<typename ComplexT, typename RealT>
__global__ void ComputePowerSpectrumMultiFFT(const ComplexT* fft_buff, RealT* power_spectrum,
int num_ffts, int num_bins)
{
const int id = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
for (int bin = id; bin < num_bins; bin += num_threads)
{
RealT pwr = 0;
for (int fft = 0; fft < num_ffts; fft++)
{
pwr += norm(fft_buff[fft * num_bins + bin]);
}
pwr /= num_ffts;
// our kernel is outputting in dB, but that difference shouldn't matter.
// Although, it would be nice if MatX allowed the `10 * log10` to be fused
// with the rest of pwelch somehow.
if (pwr <= 0)
power_spectrum[bin] = 0;
else
power_spectrum[bin] = 10 * log10(pwr);
}
}
This kernel is not doing the FFT portion of pwelch, but I am not reporting any issue with the FFT portion of pwelch, only the reduction stages.
System Details (please complete the following information):
OS: Rocky 9
CUDA version:CUDA 12.3
g++ version: 11.4.1
Additional Context
magnitude squared calculation info from nsys profile:
hi @deanljohnson, just to ping the issue, we haven't forgot about these. there's a bit of prep we are doing for GTC and haven't had time to look at them yet. we should be able to take a look soon.
Describe the Bug
pwelch
reduction stage performs about 10x worse than a similar hand-rolled CUDA implementation.The relevant parameters in this case are:
From looking at the
pwelch_impl
source, it looks like it is effectively 3 operations:X_with_overlaps = conj(X_with_overlaps) * X_with_overlaps
Pxx = sum(mag_sq_X_with_overlaps, {0}) * norm_factor
On my system using matx, step 2 takes about 40us and step 3 takes 380us.
A custom CUDA kernel for steps 2&3 is taking 30us.
To Reproduce
complex<float>
input tensor of size 164750matx::pwelch(input, window, 500, 250, 65536)
Expected Behavior
Better performance
Code Snippets
The custom CUDA kernel in this case is:
This kernel is not doing the FFT portion of
pwelch
, but I am not reporting any issue with the FFT portion ofpwelch
, only the reduction stages.System Details (please complete the following information):
Additional Context
magnitude squared calculation info from nsys profile:
Reduction+normalization (using
matx
) info from nsys profile:Reduction+normalization (using custom kernel) info from nsys profile:
The text was updated successfully, but these errors were encountered: