Skip to content

Commit

Permalink
Add method for checking consistency between GridVariable objects, and…
Browse files Browse the repository at this point in the history
… convenience method for generating the contents of GridVariable on the fly.

This is CL 2 of ... in the GridVariable refactor.

PiperOrigin-RevId: 399961866
  • Loading branch information
pnorgaard authored and JAX-CFD authors committed Sep 30, 2021
1 parent a4713d4 commit 9f815ed
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 5 deletions.
38 changes: 34 additions & 4 deletions jax_cfd/base/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,21 @@ def __post_init__(self):
'Incompatible dimension between grid and bc, array dimension = '
f'{self.array.grid.ndim}, bc dimension = {len(self.bc.boundaries)}')

@classmethod
def create(
cls,
data: Array,
offset: Tuple[float, ...],
grid: Grid,
boundaries: Union[str, Tuple[str, ...]],
) -> GridVariable:
"""Create the enclosed GridArray and BoundaryConditions on the fly."""
array = GridArray(data, offset, grid)
if isinstance(boundaries, str):
boundaries = (boundaries,) * grid.ndim
bc = BoundaryConditions(boundaries)
return cls(array, bc)

def tree_flatten(self):
"""Returns flattening recipe for GridVariable JAX pytree."""
children = (self.array,)
Expand Down Expand Up @@ -388,8 +403,9 @@ class InconsistentOffsetError(Exception):
"""Raised for cases of inconsistent offset in GridArrays."""


def consistent_offset(*arrays: GridArray) -> Tuple[float, ...]:
"""Returns the single unique offset, or raises InconsistentOffsetError."""
def consistent_offset(
*arrays: Union[GridArray, GridVariable]) -> Tuple[float, ...]:
"""Returns the unique offset, or raises InconsistentOffsetError."""
offsets = {array.offset for array in arrays}
if len(offsets) != 1:
raise InconsistentOffsetError(
Expand All @@ -402,15 +418,29 @@ class InconsistentGridError(Exception):
"""Raised for cases of inconsistent grids between GridArrays."""


def consistent_grid(*arrays: GridArray) -> Grid:
"""Returns the single unique grid, or raises InconsistentGridError."""
def consistent_grid(*arrays: Union[GridArray, GridVariable]) -> Grid:
"""Returns the unique grid, or raises InconsistentGridError."""
grids = {array.grid for array in arrays}
if len(grids) != 1:
raise InconsistentGridError(f'arrays do not have a unique grid: {grids}')
grid, = grids
return grid


class InconsistentBoundaryConditionError(Exception):
"""Raised for cases of inconsistent bc between GridVariables."""


def consistent_boundary_conditions(*arrays: GridVariable) -> BoundaryConditions:
"""Returns the unique BCs, or raises InconsistentBoundaryConditionError."""
bcs = {array.bc for array in arrays}
if len(bcs) != 1:
raise InconsistentBoundaryConditionError(
f'arrays do not have a unique bc: {bcs}')
bc, = bcs
return bc


@dataclasses.dataclass(init=False, frozen=True)
class Grid:
"""Describes the size, shape and boundary conditions for an Arakawa C-Grid.
Expand Down
24 changes: 24 additions & 0 deletions jax_cfd/base/grids_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,30 @@ def test_shift_pad_trim(self, shape, boundaries, padding, axis):
self.assertArrayEqual(
u.trim(padding, axis), grid.trim(array, padding, axis))

with self.subTest('raises exception'):
with self.assertRaisesRegex(
ValueError, 'Incompatible dimension between grid and bc'):
grid = grids.Grid((10,))
data = np.zeros((10,))
array = grids.GridArray(data, offset=(0.5,), grid=grid) # 1D
bc = grids.BoundaryConditions((grids.PERIODIC, grids.PERIODIC)) # 2D
grids.GridVariable(array, bc)

def test_construction_with_create(self):
grid = grids.Grid((10, 10))
data = np.zeros((10, 10))
offset = (0.5, 0.5)
array = grids.GridArray(data, offset, grid)
boundaries = (grids.PERIODIC, grids.PERIODIC)
bc = grids.BoundaryConditions(boundaries)
variable_1 = grids.GridVariable(array, bc)
variable_2 = grids.GridVariable.create(data, offset, grid, boundaries)
self.assertArrayEqual(variable_1, variable_2)

with self.subTest('str boundaries arg'):
variable_3 = grids.GridVariable.create(data, offset, grid, 'periodic')
self.assertArrayEqual(variable_1, variable_3)


class TensorTest(test_util.TestCase):

Expand Down
34 changes: 33 additions & 1 deletion jax_cfd/base/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,20 @@ class TestCase(parameterized.TestCase):
"""TestCase with assertions for arrays and grids.AlignedArray."""

def _check_and_remove_alignment_and_grid(self, *arrays):
"""If arrays are aligned, verify their offsets and grids match."""
"""Check that array-like data values and other attributes match.
If args type is GridArray, verify their offsets and grids match.
If args type is GridVariable, verify their offsets, grids, and bc match.
Args:
*arrays: one or more Array, GridArray or GridVariable, but they all be the
same type.
Returns:
The data-only arrays, with other attributes removed.
"""
is_gridarray = [isinstance(array, grids.GridArray) for array in arrays]
# GridArray
if any(is_gridarray):
self.assertTrue(
all(is_gridarray), msg=f'arrays have mixed types: {arrays}')
Expand All @@ -41,6 +53,26 @@ def _check_and_remove_alignment_and_grid(self, *arrays):
except grids.InconsistentGridError as e:
raise AssertionError(str(e)) from None
arrays = tuple(array.data for array in arrays)
# GridVariable
is_gridvariable = [
isinstance(array, grids.GridVariable) for array in arrays
]
if any(is_gridvariable):
self.assertTrue(
all(is_gridvariable), msg=f'arrays have mixed types: {arrays}')
try:
grids.consistent_offset(*arrays)
except grids.InconsistentOffsetError as e:
raise AssertionError(str(e)) from None
try:
grids.consistent_grid(*arrays)
except grids.InconsistentGridError as e:
raise AssertionError(str(e)) from None
try:
grids.consistent_boundary_conditions(*arrays)
except grids.InconsistentBoundaryConditionError as e:
raise AssertionError(str(e)) from None
arrays = tuple(array.array.data for array in arrays)
return arrays

# pylint: disable=unbalanced-tuple-unpacking
Expand Down

0 comments on commit 9f815ed

Please sign in to comment.