diff --git a/xbitinfo/xbitinfo.py b/xbitinfo/xbitinfo.py index 0415147d..01573710 100644 --- a/xbitinfo/xbitinfo.py +++ b/xbitinfo/xbitinfo.py @@ -105,7 +105,28 @@ def dict_to_dataset(info_per_bit): return dsb -def get_bitinformation( # noqa: C901 +def _check_bitinfo_kwargs(implementation=None, axis=None, dim=None, kwargs=None): + if kwargs is None: + kwargs = {} + # check keywords + if implementation == "julia" and not julia_installed: + raise ImportError('Please install julia or use implementation="python".') + if axis is not None and dim is not None: + raise ValueError("Please provide either `axis` or `dim` but not both.") + if axis: + if not isinstance(axis, int): + raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.") + if dim: + if not isinstance(dim, str): + raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.") + if "mask" in kwargs: + raise ValueError( + "`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead." + ) + return + + +def get_bitinformation( ds, dim=None, axis=None, @@ -182,81 +203,48 @@ def get_bitinformation( # noqa: C901 xbitinfo_version: ... BitInformation.jl_version: ... """ - if implementation == "julia" and not julia_installed: - raise ImportError('Please install julia or use implementation="python".') - if dim is None and axis is None: - # gather bitinformation on all axis - return _get_bitinformation_along_dims( - ds, - dim=dim, - label=label, - overwrite=overwrite, - implementation=implementation, - **kwargs, - ) - if isinstance(dim, list) and axis is None: - # gather bitinformation on dims specified - return _get_bitinformation_along_dims( - ds, - dim=dim, - label=label, - overwrite=overwrite, - implementation=implementation, - **kwargs, - ) + if overwrite is False and label is not None: + try: + info_per_bit = load_bitinformation(label) + except FileNotFoundError: + logging.info( + f"No bitinformation could be found for {label}. Please set `overwrite=True` for recalculation..." + ) else: - # gather bitinformation along one axis - if overwrite is False and label is not None: - try: - info_per_bit = load_bitinformation(label) - return info_per_bit - except FileNotFoundError: - logging.info( - f"No bitinformation could be found for {label}. Recalculating..." - ) - - # check keywords - if axis is not None and dim is not None: - raise ValueError("Please provide either `axis` or `dim` but not both.") - if axis: - if not isinstance(axis, int): - raise ValueError(f"Please provide `axis` as `int`, found {type(axis)}.") - if dim: - if not isinstance(dim, str): - raise ValueError(f"Please provide `dim` as `str`, found {type(dim)}.") - if "mask" in kwargs: - raise ValueError( - "`xbitinfo` does not wrap the mask argument. Mask your xr.Dataset with NaNs instead." + _check_bitinfo_kwargs(implementation, axis, dim, kwargs) + if dim is None and axis is None: + # gather bitinformation on all axis + info_per_bit, label = _get_bitinformation_along_dims( + ds, + dim=dim, + label=label, + overwrite=overwrite, + implementation=implementation, + **kwargs, + ) + elif isinstance(dim, list) and axis is None: + # gather bitinformation on dims specified + info_per_bit, label = _get_bitinformation_along_dims( + ds, + dim=dim, + label=label, + overwrite=overwrite, + implementation=implementation, + **kwargs, + ) + else: + # gather bitinformation along one axis + info_per_bit = _get_bitinformation_along_axis( + ds, implementation, axis, dim, kwargs ) - - info_per_bit = {} - pbar = tqdm(ds.data_vars) - for var in pbar: - pbar.set_description(f"Processing var: {var} for dim: {dim}") - if implementation == "julia": - info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs) - if info_per_bit_var is None: - continue - else: - info_per_bit[var] = info_per_bit_var - elif implementation == "python": - info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs) - if info_per_bit_var is None: - continue - else: - info_per_bit[var] = info_per_bit_var - else: - raise ValueError( - f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one." - ) if label is not None: - with open(label + ".json", "w") as f: - logging.debug(f"Save bitinformation to {label + '.json'}") - json.dump(info_per_bit, f, cls=JsonCustomEncoder) - info_per_bit = dict_to_dataset(info_per_bit) - for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix - for a in ds[var].attrs.keys(): - info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a] + out_fn = label + ".json" + if not os.path.exists(out_fn) or overwrite: + save_bitinformation(info_per_bit, out_fn) + info_per_bit = dict_to_dataset(info_per_bit) + for var in info_per_bit.data_vars: # keep attrs from input with source_ prefix + for a in ds[var].attrs.keys(): + info_per_bit[var].attrs["source_" + a] = ds[var].attrs[a] return info_per_bit @@ -328,7 +316,6 @@ def _get_bitinformation_along_dims( ds, dim=None, label=None, - overwrite=False, implementation="julia", **kwargs, ): @@ -345,16 +332,41 @@ def _get_bitinformation_along_dims( logging.info(f"Get bitinformation along dimension {d}") if label is not None: label = "_".join([label, d]) - info_per_bit_per_dim[d] = get_bitinformation( + info_per_bit_per_dim[d] = _get_bitinformation_along_axis( ds, dim=d, axis=None, - label=label, - overwrite=overwrite, implementation=implementation, **kwargs, ).expand_dims("dim", axis=0) info_per_bit = xr.merge(info_per_bit_per_dim.values()).squeeze() + return info_per_bit, label + + +def _get_bitinformation_along_axis(ds, implementation, axis, dim, kwargs): + """ + Helper function for :py:func:`xbitinfo.xbitinfo.get_bitinformation` to handle analysis along one axis. + """ + info_per_bit = {} + pbar = tqdm(ds.data_vars) + for var in pbar: + pbar.set_description(f"Processing var: {var} for dim: {dim}") + if implementation == "julia": + info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs) + if info_per_bit_var is None: + continue + else: + info_per_bit[var] = info_per_bit_var + elif implementation == "python": + info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs) + if info_per_bit_var is None: + continue + else: + info_per_bit[var] = info_per_bit_var + else: + raise ValueError( + f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one." + ) return info_per_bit @@ -385,6 +397,14 @@ def load_bitinformation(label): raise FileNotFoundError(f"No bitinformation could be found at {label+'.json'}") +def save_bitinformation(info_per_bit, out_fn, overwrite=False): + """Save bitinformation to JSON file""" + with open(out_fn, "w") as f: + logging.debug(f"Save bitinformation to {out_fn}") + json.dump(info_per_bit, f, cls=JsonCustomEncoder) + return + + def get_keepbits(info_per_bit, inflevel=0.99): """Get the number of mantissa bits to keep. To be used in :py:func:`xbitinfo.bitround.xr_bitround` and :py:func:`xbitinfo.bitround.jl_bitround`.