diff --git a/include/finufft/finufft_core.h b/include/finufft/finufft_core.h index 92237b707..f3cb14146 100644 --- a/include/finufft/finufft_core.h +++ b/include/finufft/finufft_core.h @@ -132,9 +132,7 @@ static inline void MY_OMP_SET_NUM_THREADS [[maybe_unused]] (int) {} // group together a bunch of type 3 rescaling/centering/phasing parameters: template struct type3params { - T X1, C1, D1, h1, gam1; // x dim: X=halfwid C=center D=freqcen h,gam=rescale - T X2, C2, D2, h2, gam2; // y - T X3, C3, D3, h3, gam3; // z + std::array X, C, D, h, gam; // x dim: X=halfwid C=center D=freqcen h,gam=rescale }; template struct FINUFFT_PLAN_T { // the main plan class, fully C++ @@ -151,30 +149,26 @@ template struct FINUFFT_PLAN_T { // the main plan class, fully C++ FINUFFT_PLAN_T &operator=(const FINUFFT_PLAN_T &) = delete; ~FINUFFT_PLAN_T(); - int type; // transform type (Rokhlin naming): 1,2 or 3 - int dim; // overall dimension: 1,2 or 3 - int ntrans; // how many transforms to do at once (vector or "many" mode) - BIGINT nj; // num of NU pts in type 1,2 (for type 3, num input x pts) - BIGINT nk; // number of NU freq pts (type 3 only) - TF tol; // relative user tolerance - int batchSize; // # strength vectors to group together for FFTW, etc - int nbatch; // how many batches done to cover all ntrans vectors + int type; // transform type (Rokhlin naming): 1,2 or 3 + int dim; // overall dimension: 1,2 or 3 + int ntrans; // how many transforms to do at once (vector or "many" mode) + BIGINT nj; // num of NU pts in type 1,2 (for type 3, num input x pts) + BIGINT nk; // number of NU freq pts (type 3 only) + TF tol; // relative user tolerance + int batchSize; // # strength vectors to group together for FFTW, etc + int nbatch; // how many batches done to cover all ntrans vectors - BIGINT ms; // number of modes in x (1) dir (historical CMCL name) = N1 - BIGINT mt; // number of modes in y (2) direction = N2 - BIGINT mu; // number of modes in z (3) direction = N3 - BIGINT N; // total # modes (prod of above three) + std::array mstu; // number of modes in x/y/z dir (historical CMCL name) = + // N1/N2/N3 + UBIGINT N; // total # modes (prod of above three) - BIGINT nf1 = 1; // size of internal fine grid in x (1) direction - BIGINT nf2 = 1; // " y (2) - BIGINT nf3 = 1; // " z (3) - BIGINT nf = 1; // total # fine grid points (product of the above three) + std::array nf123 = {1, 1, 1}; // size of internal fine grid in x/y/z + // direction + UBIGINT nf = 1; // total # fine grid points (product of the above three) - int fftSign; // sign in exponential for NUFFT defn, guaranteed to be +-1 + int fftSign; // sign in exponential for NUFFT defn, guaranteed to be +-1 - std::vector phiHat1; // FT of kernel in t1,2, on x-axis mode grid - std::vector phiHat2; // " y-axis. - std::vector phiHat3; // " z-axis. + std::array, 3> phiHat; // FT of kernel in t1,2, on x/y/z-axis mode grid // fwBatch: (batches of) fine working grid(s) for the FFT to plan & act on. // Usually the largest internal array. Its allocator is 64-byte (cache-line) aligned: @@ -185,17 +179,17 @@ template struct FINUFFT_PLAN_T { // the main plan class, fully C++ // for t1,2: ptr to user-supplied NU pts (no new allocs). // for t3: will become ptr to internally allocated "primed" (scaled) Xp, Yp, Zp vecs. - TF *X = nullptr, *Y = nullptr, *Z = nullptr; + std::array XYZ = {nullptr, nullptr, nullptr}; // type 3 specific - TF *S = nullptr, *T = nullptr, *U = nullptr; // ptrs to user's target NU-point arrays - // (no new allocs) - std::vector prephase; // pre-phase, for all input NU pts - std::vector deconv; // reciprocal of kernel FT, phase, all output NU pts - std::vector CpBatch; // working array of prephased strengths - std::vector Xp, Yp, Zp; // internal primed NU points (x'_j, etc) - std::vector Sp, Tp, Up; // internal primed targs (s'_k, etc) - type3params t3P; // groups together type 3 shift, scale, phase, parameters + std::array STU = {nullptr, nullptr, nullptr}; // ptrs to user's target NU-point + // arrays (no new allocs) + std::vector prephase; // pre-phase, for all input NU pts + std::vector deconv; // reciprocal of kernel FT, phase, all output NU pts + std::vector CpBatch; // working array of prephased strengths + std::array, 3> XYZp; // internal primed NU points (x'_j, etc) + std::array, 3> STUp; // internal primed targs (s'_k, etc) + type3params t3P; // groups together type 3 shift, scale, phase, parameters std::unique_ptr> innerT2plan; // ptr used for type 2 in step 2 of // type 3 diff --git a/perftest/guru_timing_test.cpp b/perftest/guru_timing_test.cpp index 97b8da053..a291a269b 100644 --- a/perftest/guru_timing_test.cpp +++ b/perftest/guru_timing_test.cpp @@ -271,8 +271,8 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 1: timer.restart(); - ier = FINUFFT1D1(plan->nj, x, cStart, plan->fftSign, plan->tol, plan->ms, fStart, - popts); + ier = FINUFFT1D1(plan->nj, x, cStart, plan->fftSign, plan->tol, plan->mstu[0], + fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -281,8 +281,8 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 2: timer.restart(); - ier = FINUFFT1D2(plan->nj, x, cStart, plan->fftSign, plan->tol, plan->ms, fStart, - popts); + ier = FINUFFT1D2(plan->nj, x, cStart, plan->fftSign, plan->tol, plan->mstu[0], + fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -291,8 +291,8 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 3: timer.restart(); - ier = FINUFFT1D3(plan->nj, x, cStart, plan->fftSign, plan->tol, plan->nk, plan->S, - fStart, popts); + ier = FINUFFT1D3(plan->nj, x, cStart, plan->fftSign, plan->tol, plan->nk, + plan->STU[0], fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -308,8 +308,8 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 1: timer.restart(); - ier = FINUFFT2D1(plan->nj, x, y, cStart, plan->fftSign, plan->tol, plan->ms, - plan->mt, fStart, popts); + ier = FINUFFT2D1(plan->nj, x, y, cStart, plan->fftSign, plan->tol, plan->mstu[0], + plan->mstu[1], fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -318,8 +318,8 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 2: timer.restart(); - ier = FINUFFT2D2(plan->nj, x, y, cStart, plan->fftSign, plan->tol, plan->ms, - plan->mt, fStart, popts); + ier = FINUFFT2D2(plan->nj, x, y, cStart, plan->fftSign, plan->tol, plan->mstu[0], + plan->mstu[1], fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -329,7 +329,7 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 3: timer.restart(); ier = FINUFFT2D3(plan->nj, x, y, cStart, plan->fftSign, plan->tol, plan->nk, - plan->S, plan->T, fStart, popts); + plan->STU[0], plan->STU[1], fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -345,8 +345,8 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 1: timer.restart(); - ier = FINUFFT3D1(plan->nj, x, y, z, cStart, plan->fftSign, plan->tol, plan->ms, - plan->mt, plan->mu, fStart, popts); + ier = FINUFFT3D1(plan->nj, x, y, z, cStart, plan->fftSign, plan->tol, plan->mstu[0], + plan->mstu[1], plan->mstu[2], fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -355,8 +355,8 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 2: timer.restart(); - ier = FINUFFT3D2(plan->nj, x, y, z, cStart, plan->fftSign, plan->tol, plan->ms, - plan->mt, plan->mu, fStart, popts); + ier = FINUFFT3D2(plan->nj, x, y, z, cStart, plan->fftSign, plan->tol, plan->mstu[0], + plan->mstu[1], plan->mstu[2], fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -366,7 +366,7 @@ double finufftFunnel(CPX *cStart, CPX *fStart, FLT *x, FLT *y, FLT *z, FINUFFT_P case 3: timer.restart(); ier = FINUFFT3D3(plan->nj, x, y, z, cStart, plan->fftSign, plan->tol, plan->nk, - plan->S, plan->T, plan->U, fStart, popts); + plan->STU[0], plan->STU[1], plan->STU[2], fStart, popts); t = timer.elapsedsec(); if (ier) return fail; @@ -399,7 +399,7 @@ double many_simple_calls(CPX *c, CPX *F, FLT *x, FLT *y, FLT *z, FINUFFT_PLAN pl for (int k = 0; k < plan->ntrans; k++) { cStart = c + plan->nj * k; - fStart = F + plan->ms * plan->mt * plan->mu * k; + fStart = F + plan->mstu[0] * plan->mstu[1] * plan->mstu[2] * k; // printf("k=%d, debug=%d.................\n",k, plan->opts.debug); if (k != 0) { // prevent massive debug output diff --git a/src/fft.cpp b/src/fft.cpp index 958d62953..a1138d346 100644 --- a/src/fft.cpp +++ b/src/fft.cpp @@ -10,10 +10,10 @@ using namespace std; template std::vector gridsize_for_fft(FINUFFT_PLAN_T *p) { // local helper func returns a new int array of length dim, extracted from // the finufft plan, that fftw_plan_many_dft needs as its 2nd argument. - if (p->dim == 1) return {(int)p->nf1}; - if (p->dim == 2) return {(int)p->nf2, (int)p->nf1}; + if (p->dim == 1) return {(int)p->nf123[0]}; + if (p->dim == 2) return {(int)p->nf123[1], (int)p->nf123[0]}; // if (p->dim == 3) - return {(int)p->nf3, (int)p->nf2, (int)p->nf1}; + return {(int)p->nf123[2], (int)p->nf123[1], (int)p->nf123[0]}; } template std::vector gridsize_for_fft(FINUFFT_PLAN_T *p); template std::vector gridsize_for_fft(FINUFFT_PLAN_T *p); @@ -49,11 +49,11 @@ template void do_fft(FINUFFT_PLAN_T *p) { if (p->dim == 1) // 1D: no chance for FFT shortcuts ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads); else if (p->dim == 2) { // 2D: do partial FFTs - if (p->ms < 2) // something is weird, do standard FFT + if (p->mstu[0] < 2) // something is weird, do standard FFT ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads); else { - size_t y_lo = size_t((p->ms + 1) / 2); - size_t y_hi = size_t(ns[1] - p->ms / 2); + size_t y_lo = size_t((p->mstu[0] + 1) / 2); + size_t y_hi = size_t(ns[1] - p->mstu[0] / 2); // the next line is analogous to the Python statement "sub1 = data[:, :, :y_lo]" auto sub1 = ducc0::subarray(data, {{}, {}, {0, y_lo}}); // the next line is analogous to the Python statement "sub2 = data[:, :, y_hi:]" @@ -68,14 +68,14 @@ template void do_fft(FINUFFT_PLAN_T *p) { // do axis 2 in full ducc0::c2c(data, data, {2}, p->fftSign < 0, TF(1), nthreads); } - } else { // 3D - if ((p->ms < 2) || (p->mt < 2)) // something is weird, do standard FFT + } else { // 3D + if ((p->mstu[0] < 2) || (p->mstu[1] < 2)) // something is weird, do standard FFT ducc0::c2c(data, data, axes, p->fftSign < 0, TF(1), nthreads); else { - size_t z_lo = size_t((p->ms + 1) / 2); - size_t z_hi = size_t(ns[2] - p->ms / 2); - size_t y_lo = size_t((p->mt + 1) / 2); - size_t y_hi = size_t(ns[1] - p->mt / 2); + size_t z_lo = size_t((p->mstu[0] + 1) / 2); + size_t z_hi = size_t(ns[2] - p->mstu[0] / 2); + size_t y_lo = size_t((p->mstu[1] + 1) / 2); + size_t y_hi = size_t(ns[1] - p->mstu[1] / 2); auto sub1 = ducc0::subarray(data, {{}, {}, {}, {0, z_lo}}); auto sub2 = ducc0::subarray(data, {{}, {}, {}, {z_hi, ducc0::MAXIDX}}); auto sub3 = ducc0::subarray(sub1, {{}, {}, {0, y_lo}, {}}); diff --git a/src/finufft_core.cpp b/src/finufft_core.cpp index 9cc45e7e4..d6522a974 100644 --- a/src/finufft_core.cpp +++ b/src/finufft_core.cpp @@ -77,16 +77,17 @@ Design notes for guru interface implementation: namespace finufft { namespace common { -static int set_nf_type12(BIGINT ms, const finufft_opts &opts, - const finufft_spread_opts &spopts, BIGINT *nf) +static int set_nf_type12(UBIGINT ms, const finufft_opts &opts, + const finufft_spread_opts &spopts, UBIGINT *nf) // Type 1 & 2 recipe for how to set 1d size of upsampled array, nf, given opts // and requested number of Fourier modes ms. Returns 0 if success, else an // error code if nf was unreasonably big (& tell the world). { *nf = BIGINT(opts.upsampfac * double(ms)); // manner of rounding not crucial - if (*nf < 2 * spopts.nspread) *nf = 2 * spopts.nspread; // otherwise spread fails + if (*nf < UBIGINT(2 * spopts.nspread)) + *nf = UBIGINT(2 * spopts.nspread); // otherwise spread fails if (*nf < MAX_NF) { - *nf = next235even(*nf); // expensive at huge nf + *nf = next235even(*nf); // expensive at huge nf return 0; } else { fprintf(stderr, @@ -125,7 +126,7 @@ static int setup_spreader_for_nufft(finufft_spread_opts &spopts, T eps, template static void set_nhg_type3(T S, T X, const finufft_opts &opts, - const finufft_spread_opts &spopts, BIGINT *nf, T *h, T *gam) + const finufft_spread_opts &spopts, UBIGINT *nf, T *h, T *gam) /* sets nf, h (upsampled grid spacing), and gamma (x_j rescaling factor), for type 3 only. Inputs: @@ -155,7 +156,7 @@ static void set_nhg_type3(T S, T X, const finufft_opts &opts, *nf = (BIGINT)nfd; // printf("initial nf=%lld, ns=%d\n",*nf,spopts.nspread); // catch too small nf, and nan or +-inf, otherwise spread fails... - if (*nf < 2 * spopts.nspread) *nf = 2 * spopts.nspread; + if (*nf < UBIGINT(2 * spopts.nspread)) *nf = UBIGINT(2 * spopts.nspread); if (*nf < MAX_NF) // otherwise will fail anyway *nf = next235even(*nf); // expensive at huge nf *h = T(2.0 * PI / *nf); // upsampled grid spacing @@ -450,8 +451,9 @@ static int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN_T *p, std::complex *fwi = p->fwBatch.data() + i * p->nf; // start of i'th fw array in // wkspace std::complex *ci = cBatch + i * p->nj; // start of i'th c array in cBatch - spreadinterpSorted(p->sortIndices, p->nf1, p->nf2, p->nf3, (T *)fwi, p->nj, p->X, - p->Y, p->Z, (T *)ci, p->spopts, p->didSort); + spreadinterpSorted(p->sortIndices, p->nf123[0], p->nf123[1], p->nf123[2], (T *)fwi, + p->nj, p->XYZ[0], p->XYZ[1], p->XYZ[2], (T *)ci, p->spopts, + p->didSort); } return 0; } @@ -478,15 +480,16 @@ static int deconvolveBatch(int batchSize, FINUFFT_PLAN_T *p, std::complex // Call routine from common.cpp for the dim; prefactors hardcoded to 1.0... if (p->dim == 1) - deconvolveshuffle1d(p->spopts.spread_direction, T(1), p->phiHat1, p->ms, (T *)fki, - p->nf1, fwi, p->opts.modeord); + deconvolveshuffle1d(p->spopts.spread_direction, T(1), p->phiHat[0], p->mstu[0], + (T *)fki, p->nf123[0], fwi, p->opts.modeord); else if (p->dim == 2) - deconvolveshuffle2d(p->spopts.spread_direction, T(1), p->phiHat1, p->phiHat2, p->ms, - p->mt, (T *)fki, p->nf1, p->nf2, fwi, p->opts.modeord); + deconvolveshuffle2d(p->spopts.spread_direction, T(1), p->phiHat[0], p->phiHat[1], + p->mstu[0], p->mstu[1], (T *)fki, p->nf123[0], p->nf123[1], fwi, + p->opts.modeord); else - deconvolveshuffle3d(p->spopts.spread_direction, T(1), p->phiHat1, p->phiHat2, - p->phiHat3, p->ms, p->mt, p->mu, (T *)fki, p->nf1, p->nf2, - p->nf3, fwi, p->opts.modeord); + deconvolveshuffle3d(p->spopts.spread_direction, T(1), p->phiHat[0], p->phiHat[1], + p->phiHat[2], p->mstu[0], p->mstu[1], p->mstu[2], (T *)fki, + p->nf123[0], p->nf123[1], p->nf123[2], fwi, p->opts.modeord); } return 0; } @@ -621,11 +624,12 @@ FINUFFT_PLAN_T::FINUFFT_PLAN_T(int type_, int dim_, const BIGINT *n_modes, i throw int(FINUFFT_ERR_SPREAD_THREAD_NOTVALID); } - if (type != 3) { // read in user Fourier mode array sizes... - ms = n_modes[0]; - mt = (dim > 1) ? n_modes[1] : 1; // leave as 1 for unused dims - mu = (dim > 2) ? n_modes[2] : 1; - N = ms * mt * mu; // N = total # modes + if (type != 3) { // read in user Fourier mode array sizes... + N = 1; + for (int idim = 0; idim < 3; ++idim) { + mstu[idim] = (idim < dim) ? n_modes[idim] : 1; + N *= mstu[idim]; + } } // heuristic to choose default upsampfac... (currently two poss) @@ -657,38 +661,26 @@ FINUFFT_PLAN_T::FINUFFT_PLAN_T(int type_, int dim_, const BIGINT *n_modes, i constexpr TF EPSILON = std::numeric_limits::epsilon(); if (opts.showwarn) { // user warn round-off error... - if (EPSILON * ms > 1.0) - fprintf(stderr, "%s warning: rounding err predicted eps_mach*N1 = %.3g > 1 !\n", - __func__, (double)(EPSILON * ms)); - if (EPSILON * mt > 1.0) - fprintf(stderr, "%s warning: rounding err predicted eps_mach*N2 = %.3g > 1 !\n", - __func__, (double)(EPSILON * mt)); - if (EPSILON * mu > 1.0) - fprintf(stderr, "%s warning: rounding err predicted eps_mach*N3 = %.3g > 1 !\n", - __func__, (double)(EPSILON * mu)); + for (int idim = 0; idim < dim; ++idim) + if (EPSILON * mstu[idim] > 1.0) + fprintf(stderr, "%s warning: rounding err predicted eps_mach*N1 = %.3g > 1 !\n", + __func__, (double)(EPSILON * mstu[idim])); } // determine fine grid sizes, sanity check.. - int nfier = set_nf_type12(ms, opts, spopts, &nf1); - if (nfier) throw nfier; // nf too big; we're done - phiHat1.resize(nf1 / 2 + 1); - if (dim > 1) { - nfier = set_nf_type12(mt, opts, spopts, &nf2); - if (nfier) throw nfier; - phiHat2.resize(nf2 / 2 + 1); - } - if (dim > 2) { - nfier = set_nf_type12(mu, opts, spopts, &nf3); - if (nfier) throw nfier; - phiHat3.resize(nf3 / 2 + 1); + for (int idim = 0; idim < dim; ++idim) { + int nfier = set_nf_type12(mstu[idim], opts, spopts, &nf123[idim]); + if (nfier) throw nfier; // nf too big; we're done + phiHat[idim].resize(nf123[idim] / 2 + 1); } if (opts.debug) { // "long long" here is to avoid warnings with printf... printf("[%s] %dd%d: (ms,mt,mu)=(%lld,%lld,%lld) " "(nf1,nf2,nf3)=(%lld,%lld,%lld)\n ntrans=%d nthr=%d " "batchSize=%d ", - __func__, dim, type, (long long)ms, (long long)mt, (long long)mu, - (long long)nf1, (long long)nf2, (long long)nf3, ntrans, nthr, batchSize); + __func__, dim, type, (long long)mstu[0], (long long)mstu[1], + (long long)mstu[2], (long long)nf123[0], (long long)nf123[1], + (long long)nf123[2], ntrans, nthr, batchSize); if (batchSize == 1) // spread_thread has no effect in this case printf("\n"); else @@ -698,14 +690,13 @@ FINUFFT_PLAN_T::FINUFFT_PLAN_T(int type_, int dim_, const BIGINT *n_modes, i // STEP 0: get Fourier coeffs of spreading kernel along each fine grid dim CNTime timer; timer.start(); - onedim_fseries_kernel(nf1, phiHat1, spopts); - if (dim > 1) onedim_fseries_kernel(nf2, phiHat2, spopts); - if (dim > 2) onedim_fseries_kernel(nf3, phiHat3, spopts); + for (int idim = 0; idim < dim; ++idim) + onedim_fseries_kernel(nf123[idim], phiHat[idim], spopts); if (opts.debug) printf("[%s] kernel fser (ns=%d):\t\t%.3g s\n", __func__, spopts.nspread, timer.elapsedsec()); - nf = nf1 * nf2 * nf3; // fine grid total number of points + nf = nf123[0] * nf123[1] * nf123[2]; // fine grid total number of points if (nf * batchSize > MAX_NF) { fprintf( stderr, @@ -785,12 +776,10 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF return FINUFFT_ERR_NUM_NU_PTS_INVALID; } - if (type != 3) { // ------------------ TYPE 1,2 SETPTS ------------------- - // (all we can do is check and maybe bin-sort the NU pts) - X = xj; // plan must keep pointers to user's fixed NU pts - Y = yj; - Z = zj; - int ier = spreadcheck(nf1, nf2, nf3, nj, xj, yj, zj, spopts); + if (type != 3) { // ------------------ TYPE 1,2 SETPTS ------------------- + // (all we can do is check and maybe bin-sort the NU pts) + XYZ = {xj, yj, zj}; // plan must keep pointers to user's fixed NU pts + int ier = spreadcheck(nf123[0], nf123[1], nf123[2], nj, xj, yj, zj, spopts); if (opts.debug > 1) printf("[%s] spreadcheck (%d):\t%.3g s\n", __func__, spopts.chkbnds, timer.elapsedsec()); @@ -798,7 +787,8 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF return ier; timer.restart(); sortIndices.resize(nj); - didSort = indexSort(sortIndices, nf1, nf2, nf3, nj, xj, yj, zj, spopts); + didSort = + indexSort(sortIndices, nf123[0], nf123[1], nf123[2], nj, xj, yj, zj, spopts); if (opts.debug) printf("[%s] sort (didSort=%d):\t\t%.3g s\n", __func__, didSort, timer.elapsedsec()); @@ -806,6 +796,8 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF } else { // ------------------------- TYPE 3 SETPTS ----------------------- // (here we can precompute pre/post-phase factors and plan the t2) + std::array XYZ_in{xj, yj, zj}; + std::array STU_in{s, t, u}; if (nk < 0) { fprintf(stderr, "[%s] nk (%lld) cannot be negative!\n", __func__, (long long)nk); return FINUFFT_ERR_NUM_NU_PTS_INVALID; @@ -814,43 +806,26 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF return FINUFFT_ERR_NUM_NU_PTS_INVALID; } this->nk = nk; // user set # targ freq pts - S = s; // keep pointers to user's input target pts - T = t; - U = u; + STU = {s, t, u}; // pick x, s intervals & shifts & # fine grid pts (nf) in each dim... - TF S1 = 0, S2 = 0, S3 = 0; // get half-width X, center C, which contains {x_j}... - arraywidcen(nj, xj, &(t3P.X1), &(t3P.C1)); - arraywidcen(nk, s, &S1, &(t3P.D1)); // same D, S, but for {s_k} - set_nhg_type3(S1, t3P.X1, opts, spopts, &(nf1), &(t3P.h1), - &(t3P.gam1)); // applies twist i) - t3P.C2 = 0.0; // their defaults if dim 2 unused, etc - t3P.D2 = 0.0; - if (d > 1) { - arraywidcen(nj, yj, &(t3P.X2), &(t3P.C2)); // {y_j} - arraywidcen(nk, t, &S2, &(t3P.D2)); // {t_k} - set_nhg_type3(S2, t3P.X2, opts, spopts, &(nf2), &(t3P.h2), &(t3P.gam2)); - } - t3P.C3 = 0.0; - t3P.D3 = 0.0; - if (d > 2) { - arraywidcen(nj, zj, &(t3P.X3), &(t3P.C3)); // {z_j} - arraywidcen(nk, u, &S3, &(t3P.D3)); // {u_k} - set_nhg_type3(S3, t3P.X3, opts, spopts, &(nf3), &(t3P.h3), &(t3P.gam3)); + std::array S = {0, 0, 0}; + for (int idim = 0; idim < dim; ++idim) { + arraywidcen(nj, XYZ_in[idim], &(t3P.X[idim]), &(t3P.C[idim])); + arraywidcen(nk, STU_in[idim], &S[idim], &(t3P.D[idim])); // same D, S, but for {s_k} + set_nhg_type3(S[idim], t3P.X[idim], opts, spopts, &(nf123[idim]), &(t3P.h[idim]), + &(t3P.gam[idim])); // applies twist i) + if (opts.debug) { // report on choices of shifts, centers, etc... + printf("\tM=%lld N=%lld\n", (long long)nj, (long long)nk); + printf("\tX1=%.3g C1=%.3g S1=%.3g D1=%.3g gam1=%g nf1=%lld h1=%.3g\t\n", + t3P.X[idim], t3P.C[idim], S[idim], t3P.D[idim], t3P.gam[idim], + (long long)nf123[idim], t3P.h[idim]); + } } + for (int idim = dim; idim < 3; ++idim) + t3P.C[idim] = t3P.D[idim] = 0.0; // their defaults if dim 2 unused, etc - if (opts.debug) { // report on choices of shifts, centers, etc... - printf("\tM=%lld N=%lld\n", (long long)nj, (long long)nk); - printf("\tX1=%.3g C1=%.3g S1=%.3g D1=%.3g gam1=%g nf1=%lld h1=%.3g\t\n", t3P.X1, - t3P.C1, S1, t3P.D1, t3P.gam1, (long long)nf1, t3P.h1); - if (d > 1) - printf("\tX2=%.3g C2=%.3g S2=%.3g D2=%.3g gam2=%g nf2=%lld h2=%.3g\n", t3P.X2, - t3P.C2, S2, t3P.D2, t3P.gam2, (long long)nf2, t3P.h2); - if (d > 2) - printf("\tX3=%.3g C3=%.3g S3=%.3g D3=%.3g gam3=%g nf3=%lld h3=%.3g\n", t3P.X3, - t3P.C3, S3, t3P.D3, t3P.gam3, (long long)nf3, t3P.h3); - } - nf = nf1 * nf2 * nf3; // fine grid total number of points + nf = nf123[0] * nf123[1] * nf123[2]; // fine grid total number of points if (nf * batchSize > MAX_NF) { fprintf(stderr, "[%s t3] fwBatch would be bigger than MAX_NF, not attempting memory " @@ -871,44 +846,30 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF // alloc rescaled NU src pts x'_j (in X etc), rescaled NU targ pts s'_k ... // We do this by resizing Xp, Yp, and Zp, and pointing X, Y, Z to their data; // this avoids any need for explicit cleanup. - Xp.resize(nj); - X = Xp.data(); - Sp.resize(nk); - if (d > 1) { - Yp.resize(nj); - Y = Yp.data(); - Tp.resize(nk); - } - if (d > 2) { - Zp.resize(nj); - Z = Zp.data(); - Up.resize(nk); + for (int idim = 0; idim < dim; ++idim) { + XYZp[idim].resize(nj); + XYZ[idim] = XYZp[idim].data(); + STUp[idim].resize(nk); } // always shift as use gam to rescale x_j to x'_j, etc (twist iii)... - TF ig1 = 1.0 / t3P.gam1, ig2 = 0.0, ig3 = 0.0; // "reciprocal-math" optim - if (d > 1) ig2 = 1.0 / t3P.gam2; - if (d > 2) ig3 = 1.0 / t3P.gam3; + std::array ig = {0, 0, 0}; + for (int idim = 0; idim < dim; ++idim) ig[idim] = 1.0 / t3P.gam[idim]; #pragma omp parallel for num_threads(opts.nthreads) schedule(static) for (BIGINT j = 0; j < nj; ++j) { - X[j] = (xj[j] - t3P.C1) * ig1; // rescale x_j - if (d > 1) // (ok to do inside loop because of branch predict) - Y[j] = (yj[j] - t3P.C2) * ig2; // rescale y_j - if (d > 2) Z[j] = (zj[j] - t3P.C3) * ig3; // rescale z_j + for (int idim = 0; idim < dim; ++idim) + XYZ[idim][j] = (XYZ_in[idim][j] - t3P.C[idim]) * ig[idim]; // rescale x_j } // set up prephase array... - std::complex imasign = - (fftSign >= 0) ? std::complex(0, 1) : std::complex(0, -1); // +-i + TF isign = (fftSign >= 0) ? 1 : -1; prephase.resize(nj); - if (t3P.D1 != 0.0 || t3P.D2 != 0.0 || t3P.D3 != 0.0) { + if (t3P.D[0] != 0.0 || t3P.D[1] != 0.0 || t3P.D[2] != 0.0) { #pragma omp parallel for num_threads(opts.nthreads) schedule(static) for (BIGINT j = 0; j < nj; ++j) { // ... loop over src NU locs - TF phase = t3P.D1 * xj[j]; - if (d > 1) phase += t3P.D2 * yj[j]; - if (d > 2) phase += t3P.D3 * zj[j]; - prephase[j] = std::cos(phase) + imasign * std::sin(phase); // Euler - // e^{+-i.phase} + TF phase = 0; + for (int idim = 0; idim < dim; ++idim) phase += t3P.D[idim] * XYZ_in[idim][j]; + prephase[j] = std::polar(TF(1), isign * phase); // Euler } } else for (BIGINT j = 0; j < nj; ++j) @@ -917,42 +878,33 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF // rescale the target s_k etc to s'_k etc... #pragma omp parallel for num_threads(opts.nthreads) schedule(static) for (BIGINT k = 0; k < nk; ++k) { - Sp[k] = t3P.h1 * t3P.gam1 * (s[k] - t3P.D1); // so |s'_k| < pi/R - if (d > 1) - Tp[k] = t3P.h2 * t3P.gam2 * (t[k] - t3P.D2); // so |t'_k| < - // pi/R - if (d > 2) - Up[k] = t3P.h3 * t3P.gam3 * (u[k] - t3P.D3); // so |u'_k| < - // pi/R + for (int idim = 0; idim < dim; ++idim) + STUp[idim][k] = + t3P.h[idim] * t3P.gam[idim] * (STU_in[idim][k] - t3P.D[idim]); // so |s'_k| < + // pi/R } // (old STEP 3a) Compute deconvolution post-factors array (per targ pt)... // (exploits that FT separates because kernel is prod of 1D funcs) deconv.resize(nk); - std::vector phiHatk1(nk); // don't confuse w/ phiHat - onedim_nuft_kernel(nk, Sp, phiHatk1, spopts); // fill phiHat1 - std::vector phiHatk2, phiHatk3; - if (d > 1) { - phiHatk2.resize(nk); - onedim_nuft_kernel(nk, Tp, phiHatk2, spopts); // fill phiHat2 - } - if (d > 2) { - phiHatk3.resize(nk); - onedim_nuft_kernel(nk, Up, phiHatk3, spopts); // fill phiHat3 + std::array, 3> phiHatk; + for (int idim = 0; idim < dim; ++idim) { + phiHatk[idim].resize(nk); + onedim_nuft_kernel(nk, STUp[idim], phiHatk[idim], spopts); // fill phiHat1 } // C can be nan or inf if M=0, no input NU pts - int Cfinite = std::isfinite(t3P.C1) && std::isfinite(t3P.C2) && std::isfinite(t3P.C3); - int Cnonzero = t3P.C1 != 0.0 || t3P.C2 != 0.0 || t3P.C3 != 0.0; // cen + int Cfinite = + std::isfinite(t3P.C[0]) && std::isfinite(t3P.C[1]) && std::isfinite(t3P.C[2]); + int Cnonzero = t3P.C[0] != 0.0 || t3P.C[1] != 0.0 || t3P.C[2] != 0.0; // cen #pragma omp parallel for num_threads(opts.nthreads) schedule(static) for (BIGINT k = 0; k < nk; ++k) { // .... loop over NU targ freqs - TF phiHat = phiHatk1[k]; - if (d > 1) phiHat *= phiHatk2[k]; - if (d > 2) phiHat *= phiHatk3[k]; + TF phiHat = 1; + for (int idim = 0; idim < dim; ++idim) phiHat *= phiHatk[idim][k]; deconv[k] = (std::complex)(1.0 / phiHat); if (Cfinite && Cnonzero) { - TF phase = (s[k] - t3P.D1) * t3P.C1; - if (d > 1) phase += (t[k] - t3P.D2) * t3P.C2; - if (d > 2) phase += (u[k] - t3P.D3) * t3P.C3; - deconv[k] *= std::cos(phase) + imasign * std::sin(phase); // Euler e^{+-i.phase} + TF phase = 0; + for (int idim = 0; idim < dim; ++idim) + phase += (STU_in[idim][k] - t3P.D[idim]) * t3P.C[idim]; + deconv[k] *= std::polar(TF(1), isign * phase); // Euler e^{+-i.phase} } } if (opts.debug) @@ -961,14 +913,20 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF // Set up sort for spreading Cp (from primed NU src pts X, Y, Z) to fw... timer.restart(); sortIndices.resize(nj); - didSort = indexSort(sortIndices, nf1, nf2, nf3, nj, X, Y, Z, spopts); + didSort = indexSort(sortIndices, nf123[0], nf123[1], nf123[2], nj, XYZ[0], XYZ[1], + XYZ[2], spopts); if (opts.debug) printf("[%s t3] sort (didSort=%d):\t\t%.3g s\n", __func__, didSort, timer.elapsedsec()); // Plan and setpts once, for the (repeated) inner type 2 finufft call... timer.restart(); - BIGINT t2nmodes[] = {nf1, nf2, nf3}; // t2 input is actually fw + BIGINT t2nmodes[] = {BIGINT(nf123[0]), BIGINT(nf123[1]), + BIGINT(nf123[2])}; // t2 + // input + // is + // actually + // fw finufft_opts t2opts = opts; // deep copy, since not ptrs t2opts.modeord = 0; // needed for correct t3! t2opts.debug = std::max(0, opts.debug - 1); // don't print as much detail @@ -985,7 +943,8 @@ int FINUFFT_PLAN_T::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF __func__, ier); return ier; } - ier = innerT2plan->setpts(nk, Sp.data(), Tp.data(), Up.data(), 0, nullptr, nullptr, + ier = innerT2plan->setpts(nk, STUp[0].data(), STUp[1].data(), STUp[2].data(), 0, + nullptr, nullptr, nullptr); // note nk = # output points (not nj) if (ier > 1) { fprintf(stderr, "[%s t3]: inner type 2 setpts failed, ier=%d!\n", __func__, ier);