From 9f31885aa8c667e8e2e286452261dab626ad969b Mon Sep 17 00:00:00 2001 From: Landung 'Don' Setiawan Date: Mon, 1 Apr 2024 12:15:31 -0700 Subject: [PATCH] feat: Update regrid_Sv to handle complex samples --- echopype/commongrid/regrid.py | 4 ++ .../commongrid/test_commongrid_regrid.py | 37 ++++++++++++++----- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/echopype/commongrid/regrid.py b/echopype/commongrid/regrid.py index 7fdbc654d..35473df25 100644 --- a/echopype/commongrid/regrid.py +++ b/echopype/commongrid/regrid.py @@ -54,6 +54,10 @@ def regrid_Sv( ds_list = [] for chan in input_ds[CHANNEL]: channel_Sv = input_ds.sel(channel=chan) + # Ensure no nans in range_sample + data_range = channel_Sv[range_var].dropna(RANGE_SAMPLE)[RANGE_SAMPLE] + channel_Sv = channel_Sv.sel({RANGE_SAMPLE: data_range}) + original_dims = _get_iris_dims(channel_Sv, range_var) regrid_ds = _regrid_data(channel_Sv[Sv_var].data, original_dims, target_dims) ds_list.append(regrid_ds) diff --git a/echopype/tests/commongrid/test_commongrid_regrid.py b/echopype/tests/commongrid/test_commongrid_regrid.py index 48eca63dd..3cea2fb91 100644 --- a/echopype/tests/commongrid/test_commongrid_regrid.py +++ b/echopype/tests/commongrid/test_commongrid_regrid.py @@ -3,6 +3,7 @@ import numpy as np import xarray as xr + @pytest.mark.integration def test_regrid_Sv(test_data_samples): """ @@ -26,17 +27,29 @@ def test_regrid_Sv(test_data_samples): if "azfp_cal_type" in range_kwargs: range_kwargs.pop("azfp_cal_type") Sv = ep.calibrate.compute_Sv(ed, **range_kwargs) - + # Setup output grid - channel_Sv = Sv.isel(channel=0) + channel_Sv = Sv.isel(channel=1) depth_data = channel_Sv.echo_range.isel(ping_time=0).data - time_data = channel_Sv.ping_time.data.astype('float64') + time_data = channel_Sv.ping_time.data.astype("float64") + # If there are NaNs in the depth data, remove them + if np.isnan(depth_data).any(): + depth_data = depth_data[~np.isnan(depth_data)] + # Evenly spaced grid - target_grid = xr.Dataset({ - "ping_time": (["ping_time"], np.linspace(time_data[0], time_data[-1], 300).astype('datetime64[ns]')), - "echo_range": (["echo_range"], np.linspace(depth_data[0], depth_data[-1], 300)), - }) - + target_grid = xr.Dataset( + { + "ping_time": ( + ["ping_time"], + np.linspace(np.min(time_data), np.max(time_data), 300).astype("datetime64[ns]"), + ), + "echo_range": ( + ["echo_range"], + np.linspace(np.min(depth_data), np.max(depth_data), 300), + ), + } + ) + regridded_Sv = ep.commongrid.regrid_Sv(Sv, target_grid=target_grid) assert regridded_Sv is not None @@ -44,4 +57,10 @@ def test_regrid_Sv(test_data_samples): for channel in regridded_Sv.channel: original_vals = Sv.sel(channel=channel).Sv.values regridded_vals = regridded_Sv.sel(channel=channel).Sv.values - assert np.allclose(np.nanmean(original_vals), np.nanmean(regridded_vals), atol=1.0, rtol=1.0) + assert np.allclose( + np.nanmean(original_vals), + np.nanmean(regridded_vals), + atol=1.0, + rtol=1.0, + equal_nan=True, + )