Skip to content

Commit

Permalink
Fixes, update API
Browse files Browse the repository at this point in the history
  • Loading branch information
chaithyagr committed Dec 9, 2024
1 parent d0d60fe commit 305482b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
1 change: 1 addition & 0 deletions include/finufft/finufft_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include <finufft_errors.h>
#include <memory>
#include <xsimd/xsimd.hpp>
#include <span>

// All indexing in library that potentially can exceed 2^31 uses 64-bit signed.
// This includes all calling arguments (eg M,N) that could be huge someday.
Expand Down
20 changes: 5 additions & 15 deletions src/finufft_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,7 @@ static void deconvolveshuffle3d(int dir, T prefac, std::vector<T> &ker1,
// --------- batch helper functions for t1,2 exec: ---------------------------

template<typename T>
static int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN_T<T> *p,
std::complex<T> *cBatch)
static int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN_T<T> *p, std::complex<T> *fwBatch, std::complex<T> *cBatch)
/*
Spreads (or interpolates) a batch of batchSize strength vectors in cBatch
to (or from) the batch of fine working grids p->fwBatch, using the same set of
Expand All @@ -447,7 +446,7 @@ static int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN_T<T> *p,
#endif
#pragma omp parallel for num_threads(nthr_outer)
for (int i = 0; i < batchSize; i++) {
std::complex<T> *fwi = p->fwBatch.data() + i * p->nf; // start of i'th fw array in
std::complex<T> *fwi = fwBatch + i * p->nf; // start of i'th fw array in
// wkspace
std::complex<T> *ci = cBatch + i * p->nj; // start of i'th c array in cBatch
spreadinterpSorted(p->sortIndices, p->nf1, p->nf2, p->nf3, (T *)fwi, p->nj, p->X,
Expand Down Expand Up @@ -1049,16 +1048,11 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk) {
// STEP 1: (varies by type)
timer.restart();
if (type == 1) { // type 1: spread NU pts X, weights cj, to fw grid
if (opts.spreadinterponly)
wrapArrayInVector(fkb, thisBatchSize*N, this->fwBatch);
spreadinterpSortedBatch<TF>(thisBatchSize, this, cjb);
spreadinterpSortedBatch<TF>(thisBatchSize, this, opts.spreadinterponly? fkb: this->fwBatch.data(), cjb);
t_sprint += timer.elapsedsec();
// Stop here if it is spread interp only.
if (opts.spreadinterponly)
{
releaseVectorWrapper(this->fwBatch);
continue;
}
} else if(!opts.spreadinterponly) { // type 2: amplify Fourier coeffs fk into 0-padded fw, but dont do it if it is spread interp only.
deconvolveBatch<TF>(thisBatchSize, this, fkb);
t_deconv += timer.elapsedsec();
Expand All @@ -1071,20 +1065,16 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk) {
t_fft += timer.elapsedsec();
if (opts.debug > 1) printf("\tFFT exec:\t\t%.3g s\n", timer.elapsedsec());
}
else
wrapArrayInVector(fkb, thisBatchSize*N, this->fwBatch);
// STEP 3: (varies by type)
timer.restart();
if (type == 1) { // type 1: deconvolve (amplify) fw and shuffle to fk
deconvolveBatch<TF>(thisBatchSize, this, fkb);
t_deconv += timer.elapsedsec();
} else { // type 2: interpolate unif fw grid to NU target pts
spreadinterpSortedBatch<TF>(thisBatchSize, this, cjb);
spreadinterpSortedBatch<TF>(thisBatchSize, this, opts.spreadinterponly? fkb: this->fwBatch.data(), cjb);
t_sprint += timer.elapsedsec();
}
// Release the fwBatch vector to prevent double freeing of memory.
if(opts.spreadinterponly)
releaseVectorWrapper(this->fwBatch);
} // ........end b loop

if (opts.debug) { // report total times in their natural order...
Expand Down Expand Up @@ -1135,7 +1125,7 @@ int FINUFFT_PLAN_T<TF>::execute(std::complex<TF> *cj, std::complex<TF> *fk) {
// STEP 1: spread c'_j batch (x'_j NU pts) into fw batch grid...
timer.restart();
spopts.spread_direction = 1; // spread
spreadinterpSortedBatch<TF>(thisBatchSize, this, CpBatch.data()); // X are primed
spreadinterpSortedBatch<TF>(thisBatchSize, this, this->fwBatch.data(), CpBatch.data()); // X are primed
t_spr += timer.elapsedsec();

// STEP 2: type 2 NUFFT from fw batch to user output fk array batch...
Expand Down

0 comments on commit 305482b

Please sign in to comment.