From 81ccc61ea27bdbcf01e010cbcd6b6269ce73892d Mon Sep 17 00:00:00 2001 From: Panu Lahtinen Date: Fri, 25 Oct 2024 11:03:55 +0300 Subject: [PATCH] Fix cira stretch upcasting the data --- satpy/enhancements/__init__.py | 5 +++-- satpy/tests/enhancement_tests/test_enhancements.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index a44ca590cf..86efe1ffba 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -219,11 +219,12 @@ def cira_stretch(img, **kwargs): @exclude_alpha def _cira_stretch(band_data): - log_root = np.log10(0.0223) + dtype = band_data.dtype + log_root = np.log10(0.0223, dtype=dtype) denom = (1.0 - log_root) * 0.75 band_data *= 0.01 band_data = band_data.clip(np.finfo(float).eps) - band_data = np.log10(band_data) + band_data = np.log10(band_data, dtype=dtype) band_data -= log_root band_data /= denom return band_data diff --git a/satpy/tests/enhancement_tests/test_enhancements.py b/satpy/tests/enhancement_tests/test_enhancements.py index 96176fda34..fd471ae792 100644 --- a/satpy/tests/enhancement_tests/test_enhancements.py +++ b/satpy/tests/enhancement_tests/test_enhancements.py @@ -32,7 +32,7 @@ # - tmp_path -def run_and_check_enhancement(func, data, expected, **kwargs): +def run_and_check_enhancement(func, data, expected, match_dtype=False, **kwargs): """Perform basic checks that apply to multiple tests.""" from trollimage.xrimage import XRImage @@ -58,6 +58,9 @@ def run_and_check_enhancement(func, data, expected, **kwargs): res_data = res_data_arr.data.compute() # mimics what xrimage geotiff writing does assert not isinstance(res_data, da.Array) np.testing.assert_allclose(res_data, expected, atol=1.e-6, rtol=0) + if match_dtype: + assert res_data_arr.dtype == data.dtype + assert res_data.dtype == data.dtype def identical_decorator(func): @@ -109,14 +112,15 @@ def _calc_func(data): exp_data = exp_data[np.newaxis, :, :] run_and_check_enhancement(_enh_func, in_data, exp_data) - def test_cira_stretch(self): + @pytest.mark.parametrize("dtype", [np.float32, np.float64]) + def test_cira_stretch(self, dtype): """Test applying the cira_stretch.""" from satpy.enhancements import cira_stretch expected = np.array([[ [np.nan, -7.04045974, -7.04045974, 0.79630132, 0.95947296], - [1.05181359, 1.11651012, 1.16635571, 1.20691137, 1.24110186]]]) - run_and_check_enhancement(cira_stretch, self.ch1, expected) + [1.05181359, 1.11651012, 1.16635571, 1.20691137, 1.24110186]]], dtype=dtype) + run_and_check_enhancement(cira_stretch, self.ch1.astype(dtype), expected, match_dtype=True) def test_reinhard(self): """Test the reinhard algorithm."""