diff --git a/src/cuda/cufinufft.cu b/src/cuda/cufinufft.cu index 534fa5358..fa4990285 100644 --- a/src/cuda/cufinufft.cu +++ b/src/cuda/cufinufft.cu @@ -6,7 +6,13 @@ #include #include -inline bool is_invalid_mode_array(int dim, const int64_t *modes64, int32_t modes32[3]) { +inline bool is_invalid_mode_array(int type, int dim, const int64_t *modes64, + int32_t modes32[3]) { + if (type == 3) { + modes32[0] = modes32[1] = modes32[2] = 1; + return false; + } + int64_t tot_size = 1; for (int i = 0; i < dim; ++i) { if (modes64[i] > std::numeric_limits::max()) return true; @@ -28,7 +34,9 @@ int cufinufftf_makeplan(int type, int dim, const int64_t *nmodes, int iflag, int } int nmodes32[3]; - if (is_invalid_mode_array(dim, nmodes, nmodes32)) return FINUFFT_ERR_NDATA_NOTVALID; + if (is_invalid_mode_array(type, dim, nmodes, nmodes32)) { + return FINUFFT_ERR_NDATA_NOTVALID; + } return cufinufft_makeplan_impl(type, dim, nmodes32, iflag, ntransf, tol, (cufinufft_plan_t **)d_plan_ptr, opts); @@ -42,7 +50,9 @@ int cufinufft_makeplan(int type, int dim, const int64_t *nmodes, int iflag, int } int nmodes32[3]; - if (is_invalid_mode_array(dim, nmodes, nmodes32)) return FINUFFT_ERR_NDATA_NOTVALID; + if (is_invalid_mode_array(type, dim, nmodes, nmodes32)) { + return FINUFFT_ERR_NDATA_NOTVALID; + } return cufinufft_makeplan_impl(type, dim, nmodes32, iflag, ntransf, tol, (cufinufft_plan_t **)d_plan_ptr, opts);