Skip to content

Commit

Permalink
Merge pull request #286 from JoelJaeschke/add-quantization-warning
Browse files Browse the repository at this point in the history
Add warning for quantized variables

fixes #202
  • Loading branch information
observingClouds authored Aug 15, 2024
2 parents 8d820af + b7c5790 commit 1cdac23
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CHANGELOG
X.X.X (unreleased)
------------------

* Add warning for quantized variables (:pr:`286`, :issue:`202`) `Joel Jaeschke`_.
* Update BitInformation.jl version to v0.6.3 (:pr:`292`) `Hauke Schulz`_
* Improve test/docs environment separation (:pr:`275`, :issue:`267`) `Aryan Bakliwal`_.
* Set default masked value to None for integers (:pr:`289`) `Hauke Schulz`_.
Expand Down
15 changes: 15 additions & 0 deletions tests/test_get_bitinformation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for `xbitinfo` package."""

import os
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -267,3 +268,17 @@ def test_implementations_agree(ds, dim, axis, request):
masked_value=None,
)
bitinfo_assert_allclose(bi_python, bi_julia, rtol=1e-4)


@pytest.mark.parametrize("implementation", ["python", "julia"])
@pytest.mark.parametrize("dataset_name", ["air_temperature", "eraint_uvz"])
def test_warn_on_quantized_variables(dataset_name, implementation):
ds_quantized = xr.tutorial.load_dataset(dataset_name)
ds_raw = xr.tutorial.load_dataset(dataset_name, mask_and_scale=False)

with pytest.warns(UserWarning):
_ = xb.get_bitinformation(ds_quantized, implementation=implementation)

with warnings.catch_warnings():
warnings.simplefilter("error")
_ = xb.get_bitinformation(ds_raw, implementation=implementation)
30 changes: 30 additions & 0 deletions xbitinfo/xbitinfo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import os
import warnings

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -233,6 +234,15 @@ def get_bitinformation( # noqa: C901
pbar = tqdm(ds.data_vars)
for var in pbar:
pbar.set_description(f"Processing var: {var} for dim: {dim}")

if _quantized_variable_is_scaled(ds, var):
loaded_dtype = ds[var].dtype
quantized_storage_dtype = ds[var].encoding["dtype"]
warnings.warn(
f"Variable {var} is quantized as {quantized_storage_dtype}, but loaded as {loaded_dtype}. Consider reopening using `mask_and_scale=False` to get sensible results",
category=UserWarning,
)

if implementation == "julia":
info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs)
if info_per_bit_var is None:
Expand Down Expand Up @@ -260,6 +270,26 @@ def get_bitinformation( # noqa: C901
return info_per_bit


def _quantized_variable_is_scaled(ds: xr.DataArray, var: str) -> bool:
has_scale_or_offset = any(
["add_offset" in ds[var].encoding, "scale_factor" in ds[var].encoding]
)

if not has_scale_or_offset:
return False

loaded_dtype = ds[var].dtype
storage_dtype = ds[var].encoding.get("dtype", None)
assert (
storage_dtype is not None
), f"Variable {var} is likely quantized, but does not have a storage dtype"

if loaded_dtype == storage_dtype:
return False

return True


def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}):
X = ds[var].values
Main.X = X
Expand Down

0 comments on commit 1cdac23

Please sign in to comment.