Skip to content

Commit

Permalink
Merge pull request #28 from zmoon/lat-lon-attrs
Browse files Browse the repository at this point in the history
Ensure lat/lon attrs in the `xr.Dataset`
  • Loading branch information
zmoon authored Feb 12, 2024
2 parents 7332c02 + 60078ac commit 786d41c
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 15 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ ds = uscrn.to_xarray(df) # xarray.Dataset, with soil depth dimension if applica

Both `df` (pandas) and `ds` (xarray) include dataset and variable metadata.
For `df`, these are in `df.attrs` and can be preserved by
writing to Parquet with the PyArrow engine with pandas v2.1+.
writing to Parquet with the PyArrow engine[^d] with
[pandas v2.1+](https://pandas.pydata.org/docs/whatsnew/v2.1.0.html#other-enhancements).

```python
df.to_parquet("uscrn_2019_hourly.parquet", engine="pyarrow")
Expand All @@ -40,3 +41,4 @@ pip install --no-deps uscrn
[^a]: Use `uscrn.load_meta()` to load the site metadata table.
[^b]: Not counting the `import` statement...
[^c]: `uscrn` is not yet on conda-forge.
[^d]: Or the fastparquet engine with [fastparquet v2024.2.0+](https://github.com/dask/fastparquet/commit/9d7ee90e38103fef3dd1bd2f5eb0654b8bd3fdff).
62 changes: 49 additions & 13 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def test_read(which, url):

ds = to_xarray(df)

assert {"long_name", "units"} <= ds.latitude.attrs.keys()
assert {"long_name", "units"} <= ds.longitude.attrs.keys()

if which == "subhourly":
assert "depth" not in ds.dims
elif which == "hourly":
Expand Down Expand Up @@ -189,22 +192,55 @@ def test_to_xarray_no_which_attr():
to_xarray(pd.DataFrame())


@pytest.mark.parametrize("engine", ["pyarrow", "fastparquet"])
def test_df_parquet_roundtrip(tmp_path, engine):
def test_df_parquet_roundtrip(tmp_path):
import fastparquet

df = get_data(2019, which="daily", n_jobs=N, cat=True)
assert df.attrs != {}

fp = tmp_path / "test.parquet"
df.to_parquet(fp, index=False)
df2 = pd.read_parquet(fp, engine=engine)
# Write with both engines
p_pa = tmp_path / "test_pa.parquet"
df.to_parquet(p_pa, index=False, engine="pyarrow")
p_fp = tmp_path / "test_fp.parquet"
df.to_parquet(p_fp, index=False, engine="fastparquet")

# Read with both engines
cases = {
"pa-pa": pd.read_parquet(p_pa, engine="pyarrow"),
"pa-fp": pd.read_parquet(p_pa, engine="fastparquet"),
"fp-pa": pd.read_parquet(p_fp, engine="pyarrow"),
"fp-fp": pd.read_parquet(p_fp, engine="fastparquet"),
}

assert df.equals(df2), "data same"
# For all cases, the data should be the same
for case, df_ in cases.items():
assert df_.index.equals(df.index), f"index same for {case}"
assert df_.columns.equals(df.columns), f"columns same for {case}"
if case == "fp-pa":
# Categorical type not preserved, but data same
assert not df_.equals(df), f"equals fails for {case}"
diff = df.compare(df_)
assert diff.empty, f"data same for {case}"
assert not isinstance(
df_.sur_temp_daily_type.dtype, pd.CategoricalDtype
), f"cat dtype not rt for {case}"
else:
assert df_.equals(df), f"data same for {case}"
assert isinstance(
df_.sur_temp_daily_type.dtype, pd.CategoricalDtype
), f"cat dtype rt for {case}"

if Version(pd.__version__) < Version("2.1"):
assert df2.attrs == {}, "no preservation before pandas 2.1"
else:
assert df.attrs is not df2.attrs
if engine == "fastparquet":
assert df2.attrs == {}
else:
assert df.attrs == df2.attrs
for case, df_ in cases.items():
assert df_.attrs == {}, f"no preservation before pandas 2.1, case {case}"
else: # pandas 2.1+
for case, df_ in cases.items():
if case == "pa-pa":
assert df.attrs == df_.attrs, f"attrs roundtrip for {case}"
else: # fastparquet involved
if Version(fastparquet.__version__) < Version("2024.2.0"):
assert (
df_.attrs == {}
), f"no attrs roundtrip before fastparquet 2024.2.0, case {case}"
else:
assert df_.attrs == df.attrs, f"attrs roundtrip for {case}"
2 changes: 1 addition & 1 deletion uscrn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def func(x):
i = inds[0]
return x[i]

return xr.apply_ufunc(func, da, input_core_dims=[["time"]], vectorize=True)
return xr.apply_ufunc(func, da, input_core_dims=[["time"]], vectorize=True, keep_attrs=True)

lat0 = first(ds["latitude"])
lon0 = first(ds["longitude"])
Expand Down

0 comments on commit 786d41c

Please sign in to comment.