Skip to content

Commit

Permalink
cu+py: access private attributes through properties
Browse files Browse the repository at this point in the history
  • Loading branch information
janden committed Jan 27, 2025
1 parent e6cfc03 commit 67e11ee
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions python/cufinufft/cufinufft/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,14 @@ def _init_plan(self):
# We extend the mode tuple to 3D as needed,
# and reorder from C/python ndarray.shape style input (nZ, nY, nX)
# to the (F) order expected by the low level library (nX, nY, nZ).
_n_modes = self._n_modes[::-1] + (1,) * (3 - self._dim)
_n_modes = self.n_modes[::-1] + (1,) * (3 - self.dim)
_n_modes = (c_int64 * 3)(*_n_modes)

ier = self._make_plan(self._type,
self._dim,
ier = self._make_plan(self.type,
self.dim,
_n_modes,
self._isign,
self._n_trans,
self.n_trans,
self._eps,
byref(self._plan),
self._opts)
Expand Down Expand Up @@ -245,16 +245,16 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
_y = _ensure_array_type(y, "y", self._real_dtype)
_z = _ensure_array_type(z, "z", self._real_dtype)

_x, _y, _z = _ensure_valid_pts(_x, _y, _z, self._dim)
_x, _y, _z = _ensure_valid_pts(_x, _y, _z, self.dim)

M = _compat.get_array_size(_x)

if self._type == 3:
if self.type == 3:
_s = _ensure_array_type(s, "s", self._real_dtype)
_t = _ensure_array_type(t, "t", self._real_dtype)
_u = _ensure_array_type(u, "u", self._real_dtype)

_s, _t, _u = _ensure_valid_pts(_s, _t, _u, self._dim)
_s, _t, _u = _ensure_valid_pts(_s, _t, _u, self.dim)

N = _compat.get_array_size(_s)
else:
Expand All @@ -274,22 +274,22 @@ def setpts(self, x, y=None, z=None, s=None, t=None, u=None):
# We will also store references to these arrays.
# This keeps python from prematurely cleaning them up.
self._references.append(_x)
if self._dim >= 2:
if self.dim >= 2:
fpts_axes.insert(0, _compat.get_array_ptr(_y))
self._references.append(_y)
if self._dim >= 3:
if self.dim >= 3:
fpts_axes.insert(0, _compat.get_array_ptr(_z))
self._references.append(_z)

# Do the same for type 3
if self._type == 3:
if self.type == 3:
fpts_axes_t3 = [_compat.get_array_ptr(_s), None, None]
self._references.append(_s)
if self._dim >= 2:
if self.dim >= 2:
fpts_axes_t3.insert(0, _compat.get_array_ptr(_t))
self._references.append(_t)

if self._dim >= 3:
if self.dim >= 3:
fpts_axes_t3.insert(0, _compat.get_array_ptr(_u))
self._references.append(_u)
else:
Expand Down Expand Up @@ -329,37 +329,37 @@ def execute(self, data, out=None):
The output array of the transform(s).
"""

_data = _ensure_array_type(data, "data", self._dtype)
_out = _ensure_array_type(out, "out", self._dtype, output=True)
_data = _ensure_array_type(data, "data", self.dtype)
_out = _ensure_array_type(out, "out", self.dtype, output=True)

if self._type == 1:
req_data_shape = (self._n_trans, self._nj)
req_out_shape = self._n_modes
elif self._type == 2:
req_data_shape = (self._n_trans, *self._n_modes)
if self.type == 1:
req_data_shape = (self.n_trans, self._nj)
req_out_shape = self.n_modes
elif self.type == 2:
req_data_shape = (self.n_trans, *self.n_modes)
req_out_shape = (self._nj,)
elif self._type == 3:
req_data_shape = (self._n_trans, self._nj)
elif self.type == 3:
req_data_shape = (self.n_trans, self._nj)
req_out_shape = (self._nk,)

_data, data_shape = _ensure_array_shape(_data, "data", req_data_shape,
allow_reshape=True)
if self._type == 1:
if self.type == 1:
batch_shape = data_shape[:-1]
else:
batch_shape = data_shape[:-self._dim]
batch_shape = data_shape[:-self.dim]

req_out_shape = batch_shape + req_out_shape

if out is None:
_out = _compat.array_empty_like(_data, req_out_shape, dtype=self._dtype)
_out = _compat.array_empty_like(_data, req_out_shape, dtype=self.dtype)
else:
_out = _ensure_array_shape(_out, "out", req_out_shape)

if self._type in [1, 3]:
if self.type in [1, 3]:
ier = self._exec_plan(self._plan, _compat.get_array_ptr(_data),
_compat.get_array_ptr(_out))
elif self._type == 2:
elif self.type == 2:
ier = self._exec_plan(self._plan, _compat.get_array_ptr(_out),
_compat.get_array_ptr(_data))

Expand Down

0 comments on commit 67e11ee

Please sign in to comment.