Skip to content

Commit

Permalink
Avoid complex->real casts via jax.numpy.astype
Browse files Browse the repository at this point in the history
This currently issues a warning about implicitly discarding the imaginary part, and it will issue an error in the future.

PiperOrigin-RevId: 657780666
  • Loading branch information
Jake VanderPlas authored and JAX-CFD authors committed Jul 31, 2024
1 parent d5b6d75 commit 2d43ad9
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions jax_cfd/base/fast_diagonalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def apply(rhs: Array) -> Array:
return apply


def _cast(x, dtype):
if (np.issubdtype(x.dtype, np.complexfloating)
and not np.issubdtype(dtype, np.complexfloating)):
x = x.real
return x.astype(dtype)


def _circulant_fft_transform(
func: Callable[[Array], Array],
operators: Sequence[np.ndarray],
Expand All @@ -184,7 +191,7 @@ def _circulant_fft_transform(
def apply(rhs: Array) -> Array:
if rhs.shape != shape:
raise ValueError(f'rhs.shape={rhs.shape} does not match shape={shape}')
return jnp.fft.ifftn(diagonals * jnp.fft.fftn(rhs)).astype(dtype)
return _cast(jnp.fft.ifftn(diagonals * jnp.fft.fftn(rhs)), dtype)

return apply

Expand Down Expand Up @@ -213,7 +220,7 @@ def _circulant_rfft_transform(
def apply(rhs: Array) -> Array:
if rhs.dtype != dtype:
raise ValueError(f'rhs.dtype={rhs.dtype} does not match dtype={dtype}')
return jnp.fft.irfftn(diagonals * jnp.fft.rfftn(rhs)).astype(dtype)
return _cast(jnp.fft.irfftn(diagonals * jnp.fft.rfftn(rhs)), dtype)

return apply

Expand Down

0 comments on commit 2d43ad9

Please sign in to comment.