Skip to content

Commit

Permalink
Merge pull request #33 from martindurant/combine_api
Browse files Browse the repository at this point in the history
rewrite MultiZarrToZarr to always use xr.concat
  • Loading branch information
martindurant authored Jun 28, 2021
2 parents c2e95b3 + 1112247 commit 67ccf71
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 27 deletions.
63 changes: 38 additions & 25 deletions fsspec_reference_maker/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@
import numcodecs
import xarray as xr
import zarr
logging = logging.getLogger('reference-combine')
logger = logging.getLogger('reference-combine')


class MultiZarrToZarr:

def __init__(self, path, remote_protocol,
remote_options=None, xarray_kwargs=None, storage_options=None,
with_mf=True):
remote_options=None, xarray_open_kwargs=None, xarray_concat_args=None,
preprocess=None, storage_options=None):
"""
:param path: a URL containing multiple JSONs
:param xarray_kwargs:
:param storage_options:
"""
xarray_kwargs = xarray_kwargs or {}
self.path = path
self.with_mf = with_mf
self.xr_kwargs = xarray_kwargs
self.xr_kwargs = xarray_open_kwargs or {}
self.concat_kwargs = xarray_concat_args or {}
self.storage_options = storage_options or {}
self.preprocess = preprocess
self.remote_protocol = remote_protocol
self.remote_options = remote_options or {}

Expand All @@ -36,6 +36,7 @@ def translate(self, outpath):
self.output = self._consolidate(out)

self._write(self.output, outpath)
# TODO: return new zarr dataset?

@staticmethod
def _write(refs, outpath, filetype=None):
Expand Down Expand Up @@ -86,7 +87,7 @@ def _write(refs, outpath, filetype=None):
compression="ZSTD"
)

def _consolidate(self, mapping, inline_threashold=100, template_count=5):
def _consolidate(self, mapping, inline_threshold=100, template_count=5):
counts = Counter(v[0] for v in mapping.values() if isinstance(v, list))

def letter_sets():
Expand All @@ -104,7 +105,7 @@ def letter_sets():

out = {}
for k, v in mapping.items():
if isinstance(v, list) and v[2] < inline_threashold:
if isinstance(v, list) and v[2] < inline_threshold:
v = self.fs.cat_file(v[0], start=v[1], end=v[1] + v[2])
if isinstance(v, bytes):
try:
Expand Down Expand Up @@ -158,15 +159,17 @@ def _determine_dims(self):
self.fs = fss[0].fs
mappers = [fs.get_mapper("") for fs in fss]

if self.with_mf is True:
ds = xr.open_mfdataset(mappers, engine="zarr", chunks={}, **self.xr_kwargs)
ds0 = xr.open_mfdataset(mappers[:1], engine="zarr", chunks={}, **self.xr_kwargs)
else:
dss = [xr.open_dataset(m, engine="zarr", chunks={}, **self.xr_kwargs) for m in mappers]
ds = xr.concat(dss, dim=self.with_mf)
ds0 = dss[0]
dss = [xr.open_dataset(m, engine="zarr", chunks={}, **self.xr_kwargs)
for m in mappers]
if self.preprocess:
dss = [self.preprocess(d) for d in dss]
ds = xr.concat(dss, **self.concat_kwargs)
ds0 = dss[0]
self.extra_dims = set(ds.dims) - set(ds0.dims)
self.concat_dims = set(k for k, v in ds.dims.items() if k in ds0.dims and v / ds0.dims[k] == len(mappers))
self.concat_dims = set(
k for k, v in ds.dims.items()
if k in ds0.dims and v / ds0.dims[k] == len(mappers)
)
self.same_dims = set(ds.dims) - self.extra_dims - self.concat_dims
return ds, ds0, fss

Expand All @@ -180,19 +183,29 @@ def drop_coords(ds):
ds = ds.drop(['reference_time', 'crs'])
return ds.reset_coords(drop=True)

xarray_open_kwargs = {
"decode_cf": False,
"mask_and_scale": False,
"decode_times": False,
"decode_timedelta": False,
"use_cftime": False,
"decode_coords": False
}
concat_kwargs = {
"data_vars": "minimal",
"coords": "minimal",
"compat": "override",
"join": "override",
"combine_attrs": "override",
"dim": "time"
}
mzz = MultiZarrToZarr(
"zip://*.json::out.zip",
remote_protocol="s3",
remote_options={'anon': True},
xarray_kwargs={
"preprocess": drop_coords,
"decode_cf": False,
"mask_and_scale": False,
"decode_times": False,
"decode_timedelta": False,
"use_cftime": False,
"decode_coords": False
},
preprocess=drop_coords,
xarray_open_kwargs=xarray_open_kwargs,
xarray_concat_args=concat_kwargs
)
mzz.translate("output.zarr")

3 changes: 2 additions & 1 deletion fsspec_reference_maker/grib2.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,5 +197,6 @@ def example_multi(filter={'typeOfLevel': 'heightAboveGround', 'level': 2}):
# 'hrrr.t04z.wrfsfcf01.json',
# 'hrrr.t05z.wrfsfcf01.json',
# 'hrrr.t06z.wrfsfcf01.json']
# mzz = MultiZarrToZarr(files, remote_protocol="s3", remote_options={"anon": True}, with_mf='time')
# mzz = MultiZarrToZarr(files, remote_protocol="s3", remote_options={"anon": True}
# concat_kwargs={"dim": 'time'})
# mzz.translate("hrrr.total.json")
2 changes: 1 addition & 1 deletion fsspec_reference_maker/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def example_single():
)
fsspec.utils.setup_logging(logger=lggr)
with fsspec.open(url, **so) as f:
h5chunks = SingleHdf5ToZarr(f, url, xarray=True)
h5chunks = SingleHdf5ToZarr(f, url)
return h5chunks.translate()


Expand Down

0 comments on commit 67ccf71

Please sign in to comment.