Skip to content

Commit

Permalink
simplify get_bitinformation
Browse files Browse the repository at this point in the history
  • Loading branch information
observingClouds committed Feb 9, 2024
1 parent 6fd035b commit 1073c65
Showing 1 changed file with 97 additions and 77 deletions.
174 changes: 97 additions & 77 deletions xbitinfo/xbitinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,28 @@ def dict_to_dataset(info_per_bit):
return dsb


def get_bitinformation( # noqa: C901
def _check_bitinfo_kwargs(implementation=None, axis=None, dim=None, kwargs=None):
if kwargs is None:
kwargs = {}
# check keywords
if implementation == "julia" and not julia_installed:
raise ImportError('Please install julia or use implementation="python".')
if axis is not None and dim is not None:
raise ValueError("Please provide either `axis` or `dim` but not both.")
if axis:
if not isinstance(axis, int):
raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.")
if dim:
if not isinstance(dim, str):
raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.")
if "mask" in kwargs:
raise ValueError(
"`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
)
return


def get_bitinformation(
ds,
dim=None,
axis=None,
Expand Down Expand Up @@ -182,81 +203,48 @@ def get_bitinformation( # noqa: C901
xbitinfo_version: ...
BitInformation.jl_version: ...
"""
if implementation == "julia" and not julia_installed:
raise ImportError('Please install julia or use implementation="python".')
if dim is None and axis is None:
# gather bitinformation on all axis
return _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
)
if isinstance(dim, list) and axis is None:
# gather bitinformation on dims specified
return _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
)
if overwrite is False and label is not None:
try:
info_per_bit = load_bitinformation(label)
except FileNotFoundError:
logging.info(
f"No bitinformation could be found for {label}. Please set `overwrite=True` for recalculation..."
)
else:
# gather bitinformation along one axis
if overwrite is False and label is not None:
try:
info_per_bit = load_bitinformation(label)
return info_per_bit
except FileNotFoundError:
logging.info(
f"No bitinformation could be found for {label}. Recalculating..."
)

# check keywords
if axis is not None and dim is not None:
raise ValueError("Please provide either `axis` or `dim` but not both.")
if axis:
if not isinstance(axis, int):
raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.")
if dim:
if not isinstance(dim, str):
raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.")
if "mask" in kwargs:
raise ValueError(
"`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead."
_check_bitinfo_kwargs(implementation, axis, dim, kwargs)
if dim is None and axis is None:
# gather bitinformation on all axis
info_per_bit, label = _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
)
elif isinstance(dim, list) and axis is None:
# gather bitinformation on dims specified
info_per_bit, label = _get_bitinformation_along_dims(
ds,
dim=dim,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
)
else:
# gather bitinformation along one axis
info_per_bit = _get_bitinformation_along_axis(
ds, implementation, axis, dim, kwargs
)

info_per_bit = {}
pbar = tqdm(ds.data_vars)
for var in pbar:
pbar.set_description(f"Processing var: {var} for dim: {dim}")
if implementation == "julia":
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
elif implementation == "python":
info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
else:
raise ValueError(
f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one."
)
if label is not None:
with open(label + ".json", "w") as f:
logging.debug(f"Save bitinformation to {label + '.json'}")
json.dump(info_per_bit, f, cls=JsonCustomEncoder)
info_per_bit = dict_to_dataset(info_per_bit)
for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix
for a in ds[var].attrs.keys():
info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a]
out_fn = label + ".json"
if not os.path.exists(out_fn) or overwrite:
save_bitinformation(info_per_bit, out_fn)
info_per_bit = dict_to_dataset(info_per_bit)
for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix
for a in ds[var].attrs.keys():
info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a]
return info_per_bit


Expand Down Expand Up @@ -328,7 +316,6 @@ def _get_bitinformation_along_dims(
ds,
dim=None,
label=None,
overwrite=False,
implementation="julia",
**kwargs,
):
Expand All @@ -345,16 +332,41 @@ def _get_bitinformation_along_dims(
logging.info(f"Get bitinformation along dimension {d}")
if label is not None:
label = "_".join([label, d])
info_per_bit_per_dim[d] = get_bitinformation(
info_per_bit_per_dim[d] = _get_bitinformation_along_axis(
ds,
dim=d,
axis=None,
label=label,
overwrite=overwrite,
implementation=implementation,
**kwargs,
).expand_dims("dim", axis=0)
info_per_bit = xr.merge(info_per_bit_per_dim.values()).squeeze()
return info_per_bit, label


def _get_bitinformation_along_axis(ds, implementation, axis, dim, kwargs):
"""
Helper function for :py:func:`xbitinfo.xbitinfo.get_bitinformation` to handle analysis along one axis.
"""
info_per_bit = {}
pbar = tqdm(ds.data_vars)
for var in pbar:
pbar.set_description(f"Processing var: {var} for dim: {dim}")
if implementation == "julia":
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
elif implementation == "python":
info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
continue
else:
info_per_bit[var] = info_per_bit_var
else:
raise ValueError(
f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one."
)
return info_per_bit


Expand Down Expand Up @@ -385,6 +397,14 @@ def load_bitinformation(label):
raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}")


def save_bitinformation(info_per_bit, out_fn, overwrite=False):
"""Save bitinformation to JSON file"""
with open(out_fn, "w") as f:
logging.debug(f"Save bitinformation to {out_fn}")
json.dump(info_per_bit, f, cls=JsonCustomEncoder)
return


def get_keepbits(info_per_bit, inflevel=0.99):
"""Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`.
Expand Down

0 comments on commit 1073c65

Please sign in to comment.