Skip to content

Commit

Permalink
feat: Update regrid_Sv to handle complex samples
Browse files Browse the repository at this point in the history
  • Loading branch information
lsetiawan committed Apr 1, 2024
1 parent 8ada53e commit 9f31885
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
4 changes: 4 additions & 0 deletions echopype/commongrid/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def regrid_Sv(
ds_list = []
for chan in input_ds[CHANNEL]:
channel_Sv = input_ds.sel(channel=chan)
# Ensure no nans in range_sample
data_range = channel_Sv[range_var].dropna(RANGE_SAMPLE)[RANGE_SAMPLE]
channel_Sv = channel_Sv.sel({RANGE_SAMPLE: data_range})

original_dims = _get_iris_dims(channel_Sv, range_var)
regrid_ds = _regrid_data(channel_Sv[Sv_var].data, original_dims, target_dims)
ds_list.append(regrid_ds)
Expand Down
37 changes: 28 additions & 9 deletions echopype/tests/commongrid/test_commongrid_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import xarray as xr


@pytest.mark.integration
def test_regrid_Sv(test_data_samples):
"""
Expand All @@ -26,22 +27,40 @@ def test_regrid_Sv(test_data_samples):
if "azfp_cal_type" in range_kwargs:
range_kwargs.pop("azfp_cal_type")
Sv = ep.calibrate.compute_Sv(ed, **range_kwargs)

# Setup output grid
channel_Sv = Sv.isel(channel=0)
channel_Sv = Sv.isel(channel=1)
depth_data = channel_Sv.echo_range.isel(ping_time=0).data
time_data = channel_Sv.ping_time.data.astype('float64')
time_data = channel_Sv.ping_time.data.astype("float64")
# If there are NaNs in the depth data, remove them
if np.isnan(depth_data).any():
depth_data = depth_data[~np.isnan(depth_data)]

# Evenly spaced grid
target_grid = xr.Dataset({
"ping_time": (["ping_time"], np.linspace(time_data[0], time_data[-1], 300).astype('datetime64[ns]')),
"echo_range": (["echo_range"], np.linspace(depth_data[0], depth_data[-1], 300)),
})

target_grid = xr.Dataset(
{
"ping_time": (
["ping_time"],
np.linspace(np.min(time_data), np.max(time_data), 300).astype("datetime64[ns]"),
),
"echo_range": (
["echo_range"],
np.linspace(np.min(depth_data), np.max(depth_data), 300),
),
}
)

regridded_Sv = ep.commongrid.regrid_Sv(Sv, target_grid=target_grid)
assert regridded_Sv is not None

# Test to see if values average are close
for channel in regridded_Sv.channel:
original_vals = Sv.sel(channel=channel).Sv.values
regridded_vals = regridded_Sv.sel(channel=channel).Sv.values
assert np.allclose(np.nanmean(original_vals), np.nanmean(regridded_vals), atol=1.0, rtol=1.0)
assert np.allclose(
np.nanmean(original_vals),
np.nanmean(regridded_vals),
atol=1.0,
rtol=1.0,
equal_nan=True,
)

0 comments on commit 9f31885

Please sign in to comment.