diff --git a/scico/linop/xray/_xray.py b/scico/linop/xray/_xray.py index ff17e58c..0656a43b 100644 --- a/scico/linop/xray/_xray.py +++ b/scico/linop/xray/_xray.py @@ -142,23 +142,22 @@ def fbp(self, y: ArrayLike) -> snp.Array: """ N = y.shape[1] - nvec = snp.arange(N) - (N - 1) // 2 - dx = snp.sqrt(self.dx[0] * self.dx[1]) # type: ignore - h = XRayTransform2D._ramp_filter(nvec, 1.0 / dx) + nvec = jnp.arange(N) - (N - 1) // 2 + h = XRayTransform2D._ramp_filter(nvec, 1.0).reshape(1, -1) # Apply ramp filter in the frequency domain, padding to avoid # boundary effects - hf = snp.fft.fft(h.reshape(1, -1), n=2 * N - 1, axis=1) - yf = snp.fft.fft(y, n=2 * N - 1, axis=1) - hy = snp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ + hf = jnp.fft.fft(h, n=2 * N - 1, axis=1) + yf = jnp.fft.fft(y, n=2 * N - 1, axis=1) + hy = jnp.fft.ifft(hf * yf, n=2 * N - 1, axis=1)[ :, (N - 1) // 2 : -(N - 1) // 2 - ].real.astype(snp.float32) + ].real.astype(jnp.float32) - x = (snp.pi / y.shape[0]) * self.back_project(hy) + x = (jnp.pi * self.dx[0] * self.dx[1] / y.shape[0]) * self.back_project(hy) # Mask out the invalid region of the reconstruction - gi, gj = snp.mgrid[: x.shape[0], : x.shape[1]] - x = snp.where( - snp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2, + gi, gj = jnp.mgrid[: x.shape[0], : x.shape[1]] + x = jnp.where( + jnp.sqrt((gi - x.shape[0] / 2) ** 2 + (gj - x.shape[1] / 2) ** 2) < min(x.shape) / 2, x, 0.0, ) @@ -182,10 +181,10 @@ def _ramp_filter(x: ArrayLike, tau: float) -> snp.Array: # is included to avoid division by zero warnings when x == 1 # since np.where evaluates all values for both True and False # branches. - return snp.where( + return jnp.where( x == 0, 1.0 / (4.0 * tau**2), - snp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), + jnp.where(x % 2, -1.0 / (x**2 * np.pi**2 * tau**2 + (x == 0)), 0), ) @staticmethod