Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays. #58

Merged
merged 1 commit into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jax_cfd/base/advection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _square_concentration(grid):

def _unit_velocity(grid, velocity_sign=1.):
ndim = grid.ndim
offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2.
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
return tuple(
grids.GridArray(velocity_sign * jnp.ones(grid.shape) if ax == 0
else jnp.zeros(grid.shape), tuple(offset), grid)
Expand All @@ -52,7 +52,7 @@ def _unit_velocity(grid, velocity_sign=1.):

def _cos_velocity(grid):
ndim = grid.ndim
offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2.
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
mesh = grid.mesh()
v = tuple(grids.GridArray(jnp.cos(mesh[i] * 2. * np.pi), tuple(offset), grid)
for i, offset in enumerate(offsets))
Expand Down
3 changes: 2 additions & 1 deletion jax_cfd/base/forcings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@
from jax_cfd.base import forcings
from jax_cfd.base import grids
from jax_cfd.base import test_util
import numpy as np


def _make_zero_velocity_field(grid):
ndim = grid.ndim
offsets = (jnp.eye(ndim) + jnp.ones([ndim, ndim])) / 2.
offsets = (np.eye(ndim) + np.ones([ndim, ndim])) / 2.
return tuple(
grids.GridArray(jnp.zeros(grid.shape), tuple(offset), grid)
for ax, offset in enumerate(offsets))
Expand Down