diff --git a/jax_cfd/base/advection_test.py b/jax_cfd/base/advection_test.py index d27c7f1..3fb1dd5 100644 --- a/jax_cfd/base/advection_test.py +++ b/jax_cfd/base/advection_test.py @@ -43,7 +43,7 @@ def _square_concentration(grid): def _unit_velocity(grid, velocity_sign=1.): ndim = grid.ndim - offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2. + offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2. return tuple( grids.GridArray(velocity_sign * jnp.ones(grid.shape) if ax == 0 else jnp.zeros(grid.shape), tuple(offset), grid) @@ -52,7 +52,7 @@ def _unit_velocity(grid, velocity_sign=1.): def _cos_velocity(grid): ndim = grid.ndim - offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2. + offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2. mesh = grid.mesh() v = tuple(grids.GridArray(jnp.cos(mesh[i] * 2. * np.pi), tuple(offset), grid) for i, offset in enumerate(offsets)) diff --git a/jax_cfd/base/forcings_test.py b/jax_cfd/base/forcings_test.py index 476bd8d..df688c0 100644 --- a/jax_cfd/base/forcings_test.py +++ b/jax_cfd/base/forcings_test.py @@ -22,11 +22,12 @@ from jax_cfd.base import forcings from jax_cfd.base import grids from jax_cfd.base import test_util +import numpy as np def _make_zero_velocity_field(grid): ndim = grid.ndim - offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2. + offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2. return tuple( grids.GridArray(jnp.zeros(grid.shape), tuple(offset), grid) for ax, offset in enumerate(offsets))