Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expose func arg and skip_na to compute_MVBS users #1269

Merged
21 changes: 9 additions & 12 deletions echopype/commongrid/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@ def compute_MVBS(
ds_Sv: xr.Dataset,
range_var: Literal["echo_range", "depth"] = "echo_range",
range_bin: str = "20m",
ping_time_bin: str = "20S",
ping_time_bin: str = "20s",
method="map-reduce",
func="nanmean",
skipna=True,
closed: Literal["left", "right"] = "left",
**flox_kwargs,
Expand All @@ -58,21 +57,15 @@ def compute_MVBS(
``depth`` as a data variable.
range_bin : str, default '20m'
bin size along ``echo_range`` or ``depth`` in meters.
ping_time_bin : str, default '20S'
ping_time_bin : str, default '20s'
bin size along ``ping_time``
method: str, default 'map-reduce'
The flox strategy for reduction of dask arrays only.
See flox `documentation <https://flox.readthedocs.io/en/latest/implementation.html>`_
for more details.
func: str, default 'nanmean'
The flox aggregation function used for reducing the data array.
By default, 'nanmean' is used. Other options can be found in the flox `documentation
<https://flox.readthedocs.io/en/latest/generated/flox.xarray.xarray_reduce.html>`_.
skipna: bool, default True
If true, aggregation function skips NaN values.
Else, aggregation function includes NaN values.
Note that if ``func`` is set to 'mean' and ``skipna`` is set to True, then aggregation
will have the same behavior as if func is set to 'nanmean'.
If true, mean function skips NaN values.
Else, mean function includes NaN values.
ctuguinay marked this conversation as resolved.
Show resolved Hide resolved
closed: {'left', 'right'}, default 'left'
Which side of bin interval is closed.
**flox_kwargs
Expand Down Expand Up @@ -116,7 +109,6 @@ def compute_MVBS(
ping_interval,
range_var=range_var,
method=method,
func=func,
skipna=skipna,
**flox_kwargs,
)
Expand Down Expand Up @@ -281,6 +273,7 @@ def compute_NASC(
range_bin: str = "10m",
dist_bin: str = "0.5nmi",
method: str = "map-reduce",
skipna=True,
closed: Literal["left", "right"] = "left",
**flox_kwargs,
) -> xr.Dataset:
Expand All @@ -300,6 +293,9 @@ def compute_NASC(
The flox strategy for reduction of dask arrays only.
See flox `documentation <https://flox.readthedocs.io/en/latest/implementation.html>`_
for more details.
skipna: bool, default True
If true, mean function skips NaN values.
Else, mean function includes NaN values.
ctuguinay marked this conversation as resolved.
Show resolved Hide resolved
closed: {'left', 'right'}, default 'left'
Which side of bin interval is closed.
**flox_kwargs
Expand Down Expand Up @@ -370,6 +366,7 @@ def compute_NASC(
range_interval,
dist_interval,
method=method,
skipna=skipna,
**flox_kwargs,
)

Expand Down
26 changes: 14 additions & 12 deletions echopype/commongrid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def compute_raw_MVBS(
ping_interval: Union[pd.IntervalIndex, np.ndarray],
range_var: Literal["echo_range", "depth"] = "echo_range",
method="map-reduce",
func="nanmean",
skipna=True,
**flox_kwargs,
):
Expand All @@ -47,15 +46,9 @@ def compute_raw_MVBS(
The flox strategy for reduction of dask arrays only.
See flox `documentation <https://flox.readthedocs.io/en/latest/implementation.html>`_
for more details.
func: str, default 'nanmean'
The flox aggregation function used for reducing the data array.
By default, 'nanmean' is used. Other options can be found in the flox `documentation
<https://flox.readthedocs.io/en/latest/generated/flox.xarray.xarray_reduce.html>`_.
skipna: bool, default True
If true, aggregation function skips NaN values.
Else, aggregation function includes NaN values.
Note that if ``func`` is set to 'mean' and ``skipna`` is set to True, then aggregation
will have the same behavior as if func is set to 'nanmean'.
If true, mean function skips NaN values.
Else, mean function includes NaN values.
ctuguinay marked this conversation as resolved.
Show resolved Hide resolved
**flox_kwargs
Additional keyword arguments to be passed
to flox reduction function.
Expand All @@ -76,7 +69,7 @@ def compute_raw_MVBS(
x_var=x_var,
range_var=range_var,
method=method,
func=func,
func="nanmean" if skipna else "mean",
skipna=skipna,
**flox_kwargs,
)
Expand All @@ -93,6 +86,7 @@ def compute_raw_NASC(
range_interval: Union[pd.IntervalIndex, np.ndarray],
dist_interval: Union[pd.IntervalIndex, np.ndarray],
method="map-reduce",
skipna=True,
**flox_kwargs,
):
"""
Expand All @@ -113,6 +107,9 @@ def compute_raw_NASC(
The flox strategy for reduction of dask arrays only.
See flox `documentation <https://flox.readthedocs.io/en/latest/implementation.html>`_
for more details.
skipna: bool, default True
If true, mean function skips NaN values.
Else, mean function includes NaN values.
ctuguinay marked this conversation as resolved.
Show resolved Hide resolved
**flox_kwargs
Additional keyword arguments to be passed
to flox reduction function.
Expand All @@ -139,6 +136,8 @@ def compute_raw_NASC(
x_var=x_var,
range_var=range_var,
method=method,
func="nanmean" if skipna else "mean",
skipna=skipna,
**flox_kwargs,
)

Expand All @@ -149,6 +148,7 @@ def compute_raw_NASC(
ds_Sv["ping_time"],
ds_Sv[x_var],
func="nanmean",
skipna=True,
expected_groups=(dist_interval),
isbin=True,
method=method,
Expand All @@ -167,7 +167,8 @@ def compute_raw_NASC(
h_mean_denom = xarray_reduce(
da_denom,
ds_Sv[x_var],
func="sum",
func="nansum",
skipna=True,
expected_groups=(dist_interval),
isbin=[True],
method=method,
Expand All @@ -178,7 +179,8 @@ def compute_raw_NASC(
ds_Sv["channel"],
ds_Sv[x_var],
ds_Sv[range_var].isel(**{range_dim: slice(0, -1)}),
func="sum",
func="nansum",
skipna=True,
expected_groups=(None, dist_interval, range_interval),
isbin=[False, True, True],
method=method,
Expand Down
112 changes: 55 additions & 57 deletions echopype/tests/commongrid/test_commongrid_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pandas as pd
from flox.xarray import xarray_reduce
import xarray as xr
import echopype as ep
from echopype.consolidate import add_location, add_depth
from echopype.commongrid.utils import (
Expand Down Expand Up @@ -342,23 +341,14 @@ def test_compute_MVBS_range_output(request, er_type):

@pytest.mark.integration
@pytest.mark.parametrize(
("er_type", "func", "add_nan"),
("er_type"),
[
("regular", "nanmean", "no_add_nan"),
("irregular", "nanmean", "no_add_nan"),
("regular", "mean", "no_add_nan"),
("irregular", "mean", "no_add_nan"),
("regular", "nanmean", "add_nan"),
("regular", "mean", "add_nan"),
("irregular", "nanmean", "add_nan"),
("irregular", "mean", "add_nan"),
("regular"),
("irregular"),
],
)
def test_compute_MVBS_values(request, er_type, func, add_nan):
"""
Tests for the values of compute_MVBS on regular vsirregular, func as nanmean vs mean,
and added NaN vs no added Nan data.
"""
def test_compute_MVBS_values(request, er_type):
"""Tests for the values of compute_MVBS on regular and irregular data."""

def _parse_nans(mvbs, ds_Sv) -> np.ndarray:
"""Go through and figure out nan values in result"""
Expand Down Expand Up @@ -411,48 +401,18 @@ def _parse_nans(mvbs, ds_Sv) -> np.ndarray:
ds_Sv = request.getfixturevalue("mock_Sv_dataset_irregular")
expected_mvbs = request.getfixturevalue("mock_mvbs_array_irregular")

# Check to see if MVBS matches request fixture arrays
if add_nan == "no_add_nan":
# Compute MVBS
ds_MVBS = ep.commongrid.compute_MVBS(
ds_Sv,
range_bin=range_bin,
ping_time_bin=ping_time_bin,
func=func,
skipna=False
)

# Compute expected outputs
expected_outputs = _parse_nans(ds_MVBS, ds_Sv)

assert ds_MVBS.Sv.shape == expected_mvbs.shape
# Floating digits need to check with all close not equal
# Compare the values of the MVBS array with the expected values
assert np.allclose(ds_MVBS.Sv.values, expected_mvbs, atol=1e-10, rtol=1e-10, equal_nan=True)

# Ensures that the computation of MVBS takes doesn't take into account NaN values
# that are sporadically placed in the echo_range values
assert np.array_equal(np.isnan(ds_MVBS.Sv.values), expected_outputs)

# Check appropriate aggregate function behavior when NaNs are added to first channel
elif add_nan == "add_nan":
# Add 5 NaN values to ds_Sv and compute MVBS
ds_Sv["Sv"][0, 0, 0:5] = np.nan
ds_MVBS = ep.commongrid.compute_MVBS(
ds_Sv,
range_bin=range_bin,
ping_time_bin=ping_time_bin,
func=func,
skipna=False
)
if func == "mean":
# Ensure that the 5 NaN Sv values, now projected onto the regridded dataset, turn into 2 NaN values in the case
# where func is mean.
assert np.array_equal(ds_MVBS["Sv"][0, 0, 0:2].data, np.array([np.nan, np.nan]), equal_nan=True)
assert np.sum(np.isnan(ds_MVBS["Sv"][0, 0, :].data)) == 2
elif func == "nanmean":
# Ensure that all values in regridded are non-NaN when func is nanmean
assert not np.isnan(ds_MVBS["Sv"][0, 0, :].data).any()
ds_MVBS = ep.commongrid.compute_MVBS(ds_Sv, range_bin=range_bin, ping_time_bin=ping_time_bin)

expected_outputs = _parse_nans(ds_MVBS, ds_Sv)

assert ds_MVBS.Sv.shape == expected_mvbs.shape
# Floating digits need to check with all close not equal
# Compare the values of the MVBS array with the expected values
assert np.allclose(ds_MVBS.Sv.values, expected_mvbs, atol=1e-10, rtol=1e-10, equal_nan=True)

# Ensures that the computation of MVBS takes doesn't take into account NaN values
# that are sporadically placed in the echo_range values
assert np.array_equal(np.isnan(ds_MVBS.Sv.values), expected_outputs)


@pytest.mark.integration
Expand Down Expand Up @@ -486,3 +446,41 @@ def test_compute_NASC_values(request, er_type):
assert np.allclose(
ds_NASC.NASC.values, expected_nasc.values, atol=1e-10, rtol=1e-10, equal_nan=True
)


@pytest.mark.integration
@pytest.mark.parametrize(
("operation","skipna"),
[
("MVBS", True),
("NASC", False),
],
)
def test_compute_MVBS_NASC_skipna_nan_and_non_nan_values(request, operation, skipna):
# Create subset dataset with 2 channels, 2 ping times, 2 depth values:

# Get fixture for irregular Sv
ds_Sv = request.getfixturevalue("mock_Sv_dataset_irregular")
# Already has 2 channels, so subset for only ping time and range sample
subset_ds_Sv = ds_Sv.isel(ping_time=slice(0,2), range_sample=slice(0,2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having only 2 depth values makes the mock dataset too small for the test to actually check whether the number and location of NaN elements at the output of NASC or MVBS computation is correct. I think you can select 2 of the pings, with one of them having some NaN elements and the other doesn't, but retain all elements along echo_range. This dataarray size should be manageable but would be a lot more realistic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I'll make these changes.

# Set the 1 NaN echo range value to non-NaN value
subset_ds_Sv["echo_range"][0, 1, 1] = subset_ds_Sv["echo_range"][0, 0, 1].data
# Set one of the first channel values to NaN
subset_ds_Sv["Sv"][0, 1, 1] = np.nan

# Compute MVBS / Compute NASC
if operation == "MVBS":
da = ep.commongrid.compute_MVBS(subset_ds_Sv, skipna=skipna)["Sv"]
else:
da = ep.commongrid.compute_NASC(subset_ds_Sv, skipna=skipna)["NASC"]

# Check that da is 2 values: 1 value for each channel
assert da.shape == (2, 1, 1)

# Check for appropriate NaN/non-NaN values
if skipna:
assert not np.isnan(da.isel(channel=0).squeeze().data)
assert not np.isnan(da.isel(channel=1).squeeze().data)
else:
assert np.isnan(da.isel(channel=0).squeeze().data)
assert not np.isnan(da.isel(channel=1).squeeze().data)