Skip to content

Commit

Permalink
Search for coord_names in separate_coords (#191)
Browse files Browse the repository at this point in the history
* find coord_names in vars

* resolve merge conflict

* add 2d coords test

* add kerchunk dep and add 1d coord test

---------

Co-authored-by: Tom Nicholas <[email protected]>
  • Loading branch information
ayushnag and TomNicholas authored Nov 5, 2024
1 parent ab23caa commit 3fa5cff
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def netcdf4_file(tmpdir):
return filepath


@pytest.fixture
def netcdf4_file_with_2d_coords(tmpdir):
ds = xr.tutorial.open_dataset("ROMS_example")
filepath = f"{tmpdir}/ROMS_example.nc"
ds.to_netcdf(filepath, format="NETCDF4")
ds.close()
return filepath


@pytest.fixture
def netcdf4_virtual_dataset(netcdf4_file):
from virtualizarr import open_virtual_dataset
Expand Down
7 changes: 6 additions & 1 deletion virtualizarr/readers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,13 @@ def separate_coords(
coord_vars: dict[
str, tuple[Hashable, Any, dict[Any, Any], dict[Any, Any]] | Variable
] = {}
found_coord_names: set[str] = set()
# Search through variable attributes for coordinate names
for var in vars.values():
if "coordinates" in var.attrs:
found_coord_names.update(var.attrs["coordinates"].split(" "))
for name, var in vars.items():
if name in coord_names or var.dims == (name,):
if name in coord_names or var.dims == (name,) or name in found_coord_names:
# use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263
if len(var.dims) == 1:
dim1d, *_ = var.dims
Expand Down
22 changes: 22 additions & 0 deletions virtualizarr/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,28 @@ def test_coordinate_variable_attrs_preserved(self, netcdf4_file):
}


@requires_kerchunk
class TestDetermineCoords:
def test_infer_one_dimensional_coords(self, netcdf4_file):
vds = open_virtual_dataset(netcdf4_file, indexes={})
assert set(vds.coords) == {"time", "lat", "lon"}

def test_var_attr_coords(self, netcdf4_file_with_2d_coords):
vds = open_virtual_dataset(netcdf4_file_with_2d_coords, indexes={})

expected_dimension_coords = ["ocean_time", "s_rho"]
expected_2d_coords = ["lon_rho", "lat_rho", "h"]
expected_1d_non_dimension_coords = ["Cs_r"]
expected_scalar_coords = ["hc", "Vtransform"]
expected_coords = (
expected_dimension_coords
+ expected_2d_coords
+ expected_1d_non_dimension_coords
+ expected_scalar_coords
)
assert set(vds.coords) == set(expected_coords)


@network
@requires_s3fs
class TestReadFromS3:
Expand Down

0 comments on commit 3fa5cff

Please sign in to comment.