From e0dcc5bbb4261783cc67d51b3634feeaa91f2d03 Mon Sep 17 00:00:00 2001 From: Anant Mittal Date: Thu, 7 Dec 2023 10:41:57 -0800 Subject: [PATCH] Add regrid sv functionality to commongrid --- echopype/commongrid/api.py | 28 +++++++++-- .../tests/commongrid/test_commongrid_api.py | 47 ++++++++++++++++--- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/echopype/commongrid/api.py b/echopype/commongrid/api.py index 396a943be..532b2e5bc 100644 --- a/echopype/commongrid/api.py +++ b/echopype/commongrid/api.py @@ -2,7 +2,7 @@ Functions for enhancing the spatial and temporal coherence of data. """ import logging -from typing import Literal +from typing import Literal, Optional import numpy as np import pandas as pd @@ -405,5 +405,27 @@ def compute_NASC( return ds_NASC -def regrid(): - return 1 +def regrid( + ds_Sv: xr.Dataset, + range_wanted: xr.DataArray, + ping_time_wanted: Optional[np.ndarray] = None, +) -> xr.Dataset: + """ + Regrid Sv data to a regular grid based on range_wanted and ping_time_wanted. + """ + + if ping_time_wanted is None: + # https://tutorial.xarray.dev/advanced/apply_ufunc/automatic-vectorizing-numpy.html#try-nd-input + ds_Sv["Sv"] = xr.apply_ufunc( + np.interp, + range_wanted.data, + ds_Sv["Sv"].range_sample, + ds_Sv["Sv"], + input_core_dims=[["range_sample"], ["range_sample"], ["range_sample"]], + output_core_dims=[["range_sample"]], + # exclude_dims={"range_sample"}, + keep_attrs="identical", + vectorize=True, + ) + # TODO: Logic for regrid when ping_time_wanted is not None. + return ds_Sv diff --git a/echopype/tests/commongrid/test_commongrid_api.py b/echopype/tests/commongrid/test_commongrid_api.py index 6e3a84385..b4f303907 100644 --- a/echopype/tests/commongrid/test_commongrid_api.py +++ b/echopype/tests/commongrid/test_commongrid_api.py @@ -9,7 +9,7 @@ _parse_x_bin, _groupby_x_along_channels, get_distance_from_latlon, - compute_raw_NASC + compute_raw_NASC, ) from echopype.tests.commongrid.conftest import get_NASC_echoview @@ -45,9 +45,7 @@ def test__parse_x_bin(x_bin, x_label, expected_result): @pytest.mark.unit -@pytest.mark.parametrize( - ["range_var", "lat_lon"], [("depth", False), ("echo_range", False)] -) +@pytest.mark.parametrize(["range_var", "lat_lon"], [("depth", False), ("echo_range", False)]) def test__groupby_x_along_channels(request, range_var, lat_lon): """Testing the underlying function of compute_MVBS and compute_NASC""" range_bin = 20 @@ -74,7 +72,7 @@ def test__groupby_x_along_channels(request, range_var, lat_lon): .indexes["ping_time"] ) ping_interval = d_index.union([d_index[-1] + pd.Timedelta(ping_time_bin)]) - + sv_mean = _groupby_x_along_channels( ds_Sv, range_interval, @@ -82,7 +80,7 @@ def test__groupby_x_along_channels(request, range_var, lat_lon): x_var="ping_time", range_var=range_var, method=method, - **flox_kwargs + **flox_kwargs, ) # Check that the range_var is in the dimension @@ -446,3 +444,40 @@ def test_compute_NASC_values(request, er_type): assert np.allclose( ds_NASC.NASC.values, expected_nasc.values, atol=1e-10, rtol=1e-10, equal_nan=True ) + + +def test_regrid(request): + """Test regrid function on irregular Sv data.""" + mock_Sv_dataset_irregular = request.getfixturevalue("mock_Sv_dataset_irregular") + + # Confirm that the echo_range values for different ping_times are not equal. + assert not np.array_equal( + mock_Sv_dataset_irregular["echo_range"].isel(channel=0, ping_time=0).values, + mock_Sv_dataset_irregular["echo_range"].isel(channel=0, ping_time=1).values, + ) + + interpolated_Sv_data = np.interp( + mock_Sv_dataset_irregular["echo_range"].isel(channel=0, ping_time=0).data, + mock_Sv_dataset_irregular["Sv"].range_sample, + mock_Sv_dataset_irregular["Sv"].isel(channel=0, ping_time=0).data, + ) + + ds_Sv_out = ep.commongrid.api.regrid( + ds_Sv=mock_Sv_dataset_irregular, + range_wanted=mock_Sv_dataset_irregular["echo_range"].isel(channel=0, ping_time=0), + ping_time_wanted=None, + ) + + assert np.array_equal( + interpolated_Sv_data, ds_Sv_out["Sv"].isel(channel=0, ping_time=0).data, equal_nan=True + ) + + # assert nan values in range_wanted are nans in ds_Sv_out. + nan_indices = np.argwhere( + np.isnan(mock_Sv_dataset_irregular["echo_range"].isel(channel=0, ping_time=0).data) + )[0] + for c in ds_Sv_out["Sv"].channel.data: + for p in ds_Sv_out["Sv"].ping_time.data: + assert np.isnan( + np.take(ds_Sv_out["Sv"].sel(channel=c, ping_time=p).data, nan_indices) + ).all()