diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index ac4b859cf..ed2a47f6b 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -42,7 +42,7 @@ jobs: run: | echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_ENV - name: Set up Python - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python-version }} - name: Upgrade pip diff --git a/.github/workflows/packit.yaml b/.github/workflows/packit.yaml index e6ed0b9c6..4f032690b 100644 --- a/.github/workflows/packit.yaml +++ b/.github/workflows/packit.yaml @@ -20,7 +20,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.1.0 with: python-version: 3.9 @@ -52,7 +52,7 @@ jobs: needs: build-artifact runs-on: ubuntu-20.04 steps: - - uses: actions/setup-python@v5.0.0 + - uses: actions/setup-python@v5.1.0 name: Install Python with: python-version: 3.9 diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 0e94524cf..868779862 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -35,7 +35,7 @@ jobs: with: fetch-depth: 0 # Fetch all history for all branches and tags. - name: Set up Python - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python-version }} - name: Upgrade pip diff --git a/.github/workflows/pypi.yaml b/.github/workflows/pypi.yaml index 6ebab0acf..feb69c488 100644 --- a/.github/workflows/pypi.yaml +++ b/.github/workflows/pypi.yaml @@ -24,7 +24,7 @@ jobs: fetch-depth: 0 - name: Set up Python - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.1.0 with: python-version: 3.9 @@ -56,7 +56,7 @@ jobs: needs: build-artifact runs-on: ubuntu-20.04 steps: - - uses: actions/setup-python@v5.0.0 + - uses: actions/setup-python@v5.1.0 name: Install Python with: python-version: 3.9 diff --git a/.github/workflows/windows.yaml b/.github/workflows/windows.yaml index b77576831..bc5e1bce9 100644 --- a/.github/workflows/windows.yaml +++ b/.github/workflows/windows.yaml @@ -46,7 +46,7 @@ jobs: # Check data endpoint curl http://localhost:8080/data/ - name: Setup Python - uses: actions/setup-python@v5.0.0 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python-version }} architecture: x64 diff --git a/echopype/mask/api.py b/echopype/mask/api.py index 71b618a1a..e8cf10fa6 100644 --- a/echopype/mask/api.py +++ b/echopype/mask/api.py @@ -105,10 +105,26 @@ def _validate_and_collect_mask_input( # the coordinate sequence matters, so fix the tuple form allowed_dims = [ ("ping_time", "range_sample"), + ("ping_time", "depth"), ("channel", "ping_time", "range_sample"), + ("channel", "ping_time", "depth"), ] if mask[mask_ind].dims not in allowed_dims: - raise ValueError("All masks must have dimensions ('ping_time', 'range_sample')!") + raise ValueError( + "Masks must have one of the following dimensions: " + "('ping_time', 'range_sample'), ('ping_time', 'depth'), " + "('channel', 'ping_time', 'range_sample'), " + "('channel', 'ping_time', 'depth')" + ) + + # Check for the channel dimension consistency + channel_dim_shapes = set() + for mask_indiv in mask: + if "channel" in mask_indiv.dims: + for mask_chan_ind in range(len(mask_indiv["channel"])): + channel_dim_shapes.add(mask_indiv.isel(channel=mask_chan_ind).shape) + if len(channel_dim_shapes) > 1: + raise ValueError("All masks must have the same shape in the 'channel' dimension.") else: if not isinstance(storage_options_mask, dict): @@ -128,7 +144,7 @@ def _validate_and_collect_mask_input( def _check_var_name_fill_value( - source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, np.ndarray, xr.DataArray] + source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, xr.DataArray] ) -> Union[int, float, np.ndarray, xr.DataArray]: """ Ensures that the inputs ``var_name`` and ``fill_value`` for the function @@ -140,12 +156,12 @@ def _check_var_name_fill_value( A Dataset that contains the variable ``var_name`` var_name: str The variable name in ``source_ds`` that the mask should be applied to - fill_value: int or float or np.ndarray or xr.DataArray + fill_value: int, float, or xr.DataArray Specifies the value(s) at false indices Returns ------- - fill_value: int or float or np.ndarray or xr.DataArray + fill_value: int, float, or xr.DataArray fill_value with sanitized dimensions Raises @@ -167,17 +183,12 @@ def _check_var_name_fill_value( raise ValueError("The Dataset source_ds does not contain the variable var_name!") # check the type of fill_value - if not isinstance(fill_value, (int, float, np.ndarray, xr.DataArray)): - raise TypeError( - "The input fill_value must be of type int or " "float or np.ndarray or xr.DataArray!" - ) + if not isinstance(fill_value, (int, float, xr.DataArray)): + raise TypeError("The input fill_value must be of type int, float, or xr.DataArray!") # make sure that fill_values is the same shape as var_name - if isinstance(fill_value, (np.ndarray, xr.DataArray)): - if isinstance(fill_value, xr.DataArray): - fill_value = fill_value.data.squeeze() # squeeze out length=1 channel dimension - elif isinstance(fill_value, np.ndarray): - fill_value = fill_value.squeeze() # squeeze out length=1 channel dimension + if isinstance(fill_value, xr.DataArray): + fill_value = fill_value.data.squeeze() # squeeze out length=1 channel dimension source_ds_shape = ( source_ds[var_name].isel(channel=0).shape @@ -248,7 +259,7 @@ def apply_mask( source_ds: Union[xr.Dataset, str, pathlib.Path], mask: Union[xr.DataArray, str, pathlib.Path, List[Union[xr.DataArray, str, pathlib.Path]]], var_name: str = "Sv", - fill_value: Union[int, float, np.ndarray, xr.DataArray] = np.nan, + fill_value: Union[int, float, xr.DataArray] = np.nan, storage_options_ds: dict = {}, storage_options_mask: Union[dict, List[dict]] = {}, ) -> xr.Dataset: @@ -256,29 +267,45 @@ def apply_mask( Applies the provided mask(s) to the Sv variable ``var_name`` in the provided Dataset ``source_ds``. + The code allows for these 3 cases of `source_ds` and `mask` dimensions: + + 1) No channel in both `source_ds` and `mask`, + but they have matching `ping_time` and + `depth` (or `range_sample`) dimensions. + 2) `source_ds` and `mask` both have matching `channel`, + `ping_time`, and `depth` (or `range_sample`) dimensions. + 3) `source_ds` has the channel dimension and `mask` doesn't, + but they have matching + `ping_time` and `depth` (or `range_sample`) dimensions. + + If a user only wants to apply masks to a subset of the channels in `source_ds`, + they could put 1s to allow all data entries in the other channels. + Parameters ---------- source_ds: xr.Dataset, str, or pathlib.Path Points to a Dataset that contains the variable the mask should be applied to mask: xr.DataArray, str, pathlib.Path, or a list of these datatypes The mask(s) to be applied. - Can be a single input or list that corresponds to a DataArray or a path. - Each entry in the list must have dimensions ``('ping_time', 'range_sample')``. - Multi-channel masks are not currently supported. + Can be a individual input or a list that corresponds to a DataArray or a path. + Each individual input or entry in the list must contain dimensions + ``('ping_time', 'range_sample')`` or dimensions ``('ping_time', 'depth')``. + The mask can also contain the dimension ``channel``. If a path is provided this should point to a zarr or netcdf file with only one data variable in it. If the input ``mask`` is a list, a logical AND will be used to produce the final mask that will be applied to ``var_name``. var_name: str, default="Sv" The Sv variable name in ``source_ds`` that the mask should be applied to. - This variable needs to have coordinates ``ping_time`` and ``range_sample``, - and can optionally also have coordinate ``channel``. + This variable needs to have coordinates ``('ping_time', 'range_sample')`` or + coordinates ``('ping_time', 'depth')``, and can optionally also have coordinate + ``channel``. In the case of a multi-channel Sv data variable, the ``mask`` will be broadcast to all channels. - fill_value: int, float, np.ndarray, or xr.DataArray, default=np.nan + fill_value: int, float, or xr.DataArray, default=np.nan Value(s) at masked indices. - If ``fill_value`` is of type ``np.ndarray`` or ``xr.DataArray``, - it must have the same shape as each entry of ``mask``. + If ``fill_value`` is of type ``xr.DataArray`` it must have the same shape as each + entry of ``mask``. storage_options_ds: dict, default={} Any additional parameters for the storage backend, corresponding to the path provided for ``source_ds`` @@ -305,43 +332,49 @@ def apply_mask( # Obtain final mask to be applied to var_name if isinstance(mask, list): - # perform a logical AND element-wise operation across the masks - final_mask = np.logical_and.reduce(mask) + # Broadcast all input masks together before combining them + broadcasted_masks = xr.broadcast(*mask) + + # Perform a logical AND element-wise operation across the masks + final_mask = np.logical_and.reduce(broadcasted_masks) # xr.where has issues with attrs when final_mask is an array, so we make it a DataArray - final_mask = xr.DataArray(final_mask, coords=mask[0].coords) + final_mask = xr.DataArray(final_mask, coords=broadcasted_masks[0].coords) else: final_mask = mask - # Sanity check: final_mask should be of the same shape as source_ds[var_name] - # along the ping_time and range_sample dimensions - def get_ch_shape(da): - return da.isel(channel=0).shape if "channel" in da.dims else da.shape - - # Below operate on the actual data array to be masked + # Operate on the actual data array to be masked source_da = source_ds[var_name] - source_da_shape = get_ch_shape(source_da) - final_mask_shape = get_ch_shape(final_mask) - - if final_mask_shape != source_da_shape: + # The final_mask should be of the same shape as source_ds[var_name] + # along the ping_time and range_sample dimensions. + source_da_chan_shape = ( + source_da.isel(channel=0).shape if "channel" in source_da.dims else source_da.shape + ) + final_mask_chan_shape = ( + final_mask.isel(channel=0).shape if "channel" in final_mask.dims else final_mask.shape + ) + if final_mask_chan_shape != source_da_chan_shape: raise ValueError( f"The final constructed mask is not of the same shape as source_ds[{var_name}] " - "along the ping_time and range_sample dimensions!" + "along the ping_time, and range_sample dimensions!" ) - - # final_mask is always an xr.DataArray with at most length=1 channel dimension - if "channel" in final_mask.dims: - final_mask = final_mask.isel(channel=0) - - # Make sure fill_value and final_mask are expanded in dimensions - if "channel" in source_da.dims: - if isinstance(fill_value, np.ndarray): - fill_value = np.array([fill_value] * source_da["channel"].size) - final_mask = np.array([final_mask.data] * source_da["channel"].size) + # If final_mask has dim channel then source_da must have dim channel + if "channel" in final_mask.dims and "channel" not in source_da.dims: + raise ValueError( + "The final constructed mask has the channel dimension, " + f"so source_ds[{var_name}] must also have the channel dimension." + ) + # If final_mask and source_da both have channel dimension, then they must + # have the same number of channels. + elif "channel" in final_mask.dims and "channel" in source_da.dims: + if len(final_mask["channel"]) != len(source_da["channel"]): + raise ValueError( + f"If both the final constructed mask and source_ds[{var_name}] " + "have the channel dimension, that dimension should match between the two." + ) # Apply the mask to var_name - # Somehow keep_attrs=True errors out here, so will attach later var_name_masked = xr.where(final_mask, x=source_da, y=fill_value) # Obtain a shallow copy of source_ds @@ -356,12 +389,11 @@ def get_ch_shape(da): _variable_prov_attrs(output_ds[var_name], mask) ) + # Attribute handling process_type = "mask" prov_dict = echopype_prov_attrs(process_type=process_type) prov_dict[f"{process_type}_function"] = "mask.apply_mask" - output_ds = output_ds.assign_attrs(prov_dict) - output_ds = insert_input_processing_level(output_ds, input_ds=source_ds) return output_ds diff --git a/echopype/tests/mask/test_mask.py b/echopype/tests/mask/test_mask.py index d056cb3b3..82fe75932 100644 --- a/echopype/tests/mask/test_mask.py +++ b/echopype/tests/mask/test_mask.py @@ -134,6 +134,7 @@ def get_mock_source_ds_apply_mask(n: int, n_chan: int, is_delayed: bool) -> xr.D ------- xr.Dataset A Dataset containing data variables ``var1, var2`` with coordinates + ``('channel', 'ping_time', 'depth')`` and ``('channel', 'ping_time', 'range_sample')``. The variables are square matrices of ones for each ``channel``. """ @@ -148,24 +149,15 @@ def get_mock_source_ds_apply_mask(n: int, n_chan: int, is_delayed: bool) -> xr.D mock_var_data = [np.ones((n, n)) for i in range(n_chan)] # create mock var1 and var2 DataArrays - mock_var1_da = xr.DataArray( - data=np.stack(mock_var_data), - coords={ - "channel": ("channel", chan_vals, {"long_name": "channel name"}), - "ping_time": np.arange(n), - "range_sample": np.arange(n), - }, - attrs={"long_name": "variable 1"}, - ) - mock_var2_da = xr.DataArray( - data=np.stack(mock_var_data), - coords={ - "channel": ("channel", chan_vals, {"long_name": "channel name"}), - "ping_time": np.arange(n), - "range_sample": np.arange(n), - }, - attrs={"long_name": "variable 2"}, - ) + mock_var1_da = xr.DataArray(data=np.stack(mock_var_data), + coords={"channel": ("channel", chan_vals, {"long_name": "channel name"}), + "ping_time": np.arange(n), "depth": np.arange(n)}, + attrs={"long_name": "variable 1"}) + mock_var2_da = xr.DataArray(data=np.stack(mock_var_data), + coords={"channel": ("channel", chan_vals, {"long_name": "channel name"}), + "ping_time": np.arange(n), + "range_sample": np.arange(n)}, + attrs={"long_name": "variable 2"}) # create mock Dataset mock_ds = xr.Dataset(data_vars={"var1": mock_var1_da, "var2": mock_var2_da}) @@ -805,82 +797,79 @@ def test_validate_and_collect_mask_input( assert mask_out.identical(mask_da) +@pytest.mark.parametrize( + ("mask_list"), + [ + pytest.param( + [xr.DataArray([np.identity(4)], dims=['channel', 'ping_time', 'depth'], + coords={'channel': ['channel_0']})] + ), + pytest.param( + [xr.DataArray([np.identity(4), np.identity(4)], dims=['channel', 'ping_time', 'depth'], + coords={'channel': ['channel_0', 'channel_1']})] + ), + pytest.param( + [xr.DataArray([np.identity(4), np.identity(4)], dims=['channel', 'ping_time', 'depth'], + coords={'channel': ['channel_0', 'channel_1']}), + xr.DataArray([np.identity(4), np.identity(4)], dims=['channel', 'ping_time', 'depth'], + coords={'channel': ['channel_0', 'channel_1']})] + ), + pytest.param( + [xr.DataArray([np.identity(3), np.identity(3)], dims=['channel', 'ping_time', 'depth'], + coords={'channel': ['channel_0', 'channel_1']}), + xr.DataArray([np.identity(4), np.identity(4)], dims=['channel', 'ping_time', 'depth'], + coords={'channel': ['channel_0', 'channel_1']})], + marks=pytest.mark.xfail( + strict=True, + reason="This should fail because the channel dims are not uniform." + )) + ], + ids=["single_channel_mask", "double_channel", "double_channel_double_masks", + "inconsistent_channels_across_two_masks"] +) +def test_multi_mask_validate_and_collect_mask(mask_list: List[xr.DataArray]): + """ + Tests the allowable types and dimensions for multimask input. + + Parameters + ---------- + mask_list: List[xr.DataArray] + Multimask input to be tested in validate and collect mask input. + """ + + _validate_and_collect_mask_input(mask=mask_list, storage_options_mask={}) + + @pytest.mark.parametrize( ("n", "n_chan", "var_name", "fill_value"), [ - pytest.param( - 4, - 2, - 2.0, - np.nan, - marks=pytest.mark.xfail( - strict=True, reason="This should fail because the var_name is not a string." - ), - ), - pytest.param( - 4, - 2, - "var3", - np.nan, - marks=pytest.mark.xfail( - strict=True, - reason="This should fail because mock_ds will " "not have var_name=var3 in it.", - ), - ), - pytest.param( - 4, - 2, - "var1", - "1.0", - marks=pytest.mark.xfail( - strict=True, reason="This should fail because fill_value is an incorrect type." - ), - ), + pytest.param(4, 2, 2.0, np.nan, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because the var_name is not a string.")), + pytest.param(4, 2, "var3", np.nan, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because mock_ds will " + "not have var_name=var3 in it.")), + pytest.param(4, 2, "var2", "1.0", + marks=pytest.mark.xfail(strict=True, + reason="This should fail because fill_value is an incorrect type.")), (4, 2, "var1", 1), (4, 2, "var1", 1.0), - (2, 1, "var1", np.identity(2)[None, :]), - ( - 2, - 1, - "var1", - xr.DataArray( - data=np.array([[[1.0, 0], [0, 1]]]), - coords={"channel": ["chan1"], "ping_time": [0, 1], "range_sample": [0, 1]}, - ), - ), - pytest.param( - 4, - 2, - "var1", - np.identity(2), - marks=pytest.mark.xfail( - strict=True, reason="This should fail because fill_value is not the right shape." - ), - ), - pytest.param( - 4, - 2, - "var1", - xr.DataArray( - data=np.array([[1.0, 0], [0, 1]]), - coords={"ping_time": [0, 1], "range_sample": [0, 1]}, - ), - marks=pytest.mark.xfail( - strict=True, reason="This should fail because fill_value is not the right shape." - ), - ), - ], - ids=[ - "wrong_var_name_type", - "no_var_name_ds", - "wrong_fill_value_type", - "fill_value_int", - "fill_value_float", - "fill_value_np_array", - "fill_value_DataArray", - "fill_value_np_array_wrong_shape", - "fill_value_DataArray_wrong_shape", + pytest.param(2, 1, "var1", np.identity(2)[None, :], + marks=pytest.mark.xfail(strict=True, + reason="This should fail because fill_value is an incorrect type.")), + (2, 1, "var1", xr.DataArray(data=np.array([[[1.0, 0], [0, 1]]]), + coords={"channel": ["chan1"], "ping_time": [0, 1], "depth": [0, 1]}) + ), + pytest.param(4, 2, "var2", + xr.DataArray(data=np.array([[1.0, 0], [0, 1]]), + coords={"ping_time": [0, 1], "range_sample": [0, 1]}), + marks=pytest.mark.xfail(strict=True, + reason="This should fail because fill_value is not the right shape.")), ], + ids=["wrong_var_name_type", "no_var_name_ds", "wrong_fill_value_type", "fill_value_int", + "fill_value_float", "fill_value_np_array", "fill_value_DataArray", + "fill_value_DataArray_wrong_shape"] ) def test_check_var_name_fill_value( n: int, n_chan: int, var_name: str, fill_value: Union[int, float, np.ndarray, xr.DataArray] @@ -947,17 +936,11 @@ def test_check_var_name_fill_value( # single_mask_float_fill (2, 1, "var1", np.identity(2), None, 2.0, False, np.array([[1, 2.0], [2.0, 1]]), False), # single_mask_np_array_fill - ( - 2, - 1, - "var1", - np.identity(2), - None, - np.array([[[np.nan, np.nan], [np.nan, np.nan]]]), - False, - np.array([[1, np.nan], [np.nan, 1]]), - False, - ), + pytest.param( + 2, 1, "var1", np.identity(2), None, np.array([[[np.nan, np.nan], [np.nan, np.nan]]]), + False, np.array([[1, np.nan], [np.nan, 1]]), False, + marks=pytest.mark.xfail(strict=True, + reason="This should fail because fill_value is an incorrect type.")), # single_mask_DataArray_fill ( 2, @@ -1078,7 +1061,7 @@ def test_apply_mask( mock_ds = get_mock_source_ds_apply_mask(n, n_chan, is_delayed) # create input mask and obtain temporary directory, if it was created - mask, temp_dir = create_input_mask(mask, mask_file, mock_ds.coords) + mask, temp_dir = create_input_mask(mask, mask_file, mock_ds[var_name].coords) # create DataArray form of the known truth value var_masked_truth = xr.DataArray( @@ -1115,61 +1098,146 @@ def test_apply_mask( temp_dir.cleanup() +@pytest.mark.integration @pytest.mark.parametrize( - ("source_has_ch", "mask_has_ch"), - [ - (True, True), - (False, True), - (True, False), - (False, False), + ("source_has_ch", "mask", "truth_da"), + [ + # source_with_ch_mask_list_with_ch + (True, [ + xr.DataArray( + np.array([np.identity(2), np.identity(2)]), + coords={"channel": ["chan1", "chan2"], "ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_with_channel"}, + ), + xr.DataArray( + np.array([np.zeros_like(np.identity(2))]), + coords={"channel": ["chan3"], "ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_with_channel"}, + ), + ], + xr.DataArray( + np.array([[[1, np.nan], [np.nan, 1]], + [[1, np.nan], [np.nan, 1]], + [[np.nan, np.nan], [np.nan, np.nan]]]), + coords={"channel": ["chan1", "chan2", "chan3"], + "ping_time": np.arange(2), "depth": np.arange(2)}, + )), + + # source_with_ch_mask_list_with_ch_fail_different_channel_lengths + (True, [ + xr.DataArray( + np.array([np.identity(2)]), + coords={"channel": ["chan1"], "ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_with_channel"}, + ), + xr.DataArray( + np.array([np.zeros_like(np.identity(2))]), + coords={"channel": ["chan3"], "ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_with_channel"}, + ), + ], + None), + + # source_with_ch_mask_with_ch + (True, + xr.DataArray( + np.array([np.identity(2), np.identity(2), np.ones_like(np.identity(2))]), + coords={"channel": ["chan1", "chan2", "chan3"], "ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_with_channel"}, + ), + xr.DataArray( + np.array([[[1, np.nan], [np.nan, 1]], + [[1, np.nan], [np.nan, 1]], + [[1, 1], [1, 1]]]), + coords={"channel": ["chan1", "chan2", "chan3"], + "ping_time": np.arange(2), "depth": np.arange(2)}, + )), + + # source_with_ch_mask_no_ch + (True, xr.DataArray( + np.identity(2), + coords={"ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_no_channel"}, + ), + xr.DataArray( + np.array([[[1, 1, 1], [np.nan, np.nan, np.nan]], + [[np.nan, np.nan, np.nan], [1, 1, 1]]]), + coords={"ping_time": np.arange(2), "depth": np.arange(2), + "channel": ["chan1", "chan2", "chan3"]} + )), + + # source_no_ch_mask_with_ch_fail + (False, xr.DataArray( + np.array([np.identity(2)]), + coords={"channel": ["chan1"], "ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_with_channel"}, + ), + None), + + # source_no_ch_mask_no_ch + (False, xr.DataArray( + np.identity(2), + coords={"ping_time": np.arange(2), "depth": np.arange(2)}, + attrs={"long_name": "mask_no_channel"}, + ), + xr.DataArray( + np.array([[1, np.nan], [np.nan, 1]]), + coords={"ping_time": np.arange(2), "depth": np.arange(2)} + )), + + # source_no_ch_mask_no_ch_fail_different_ping_time_depth_shape + (False, xr.DataArray( + np.zeros((3, 1)), + coords={"ping_time": np.arange(3), "depth": np.arange(1)}, + attrs={"long_name": "mask_no_channel"}, + ), + None), ], ids=[ + "source_with_ch_mask_list_with_ch", + "source_with_ch_mask_list_with_ch_fail_different_channel_lengths", "source_with_ch_mask_with_ch", - "source_no_ch_mask_with_ch", "source_with_ch_mask_no_ch", + "source_no_ch_mask_with_ch_fail", "source_no_ch_mask_no_ch", - ], + "source_no_ch_mask_no_ch_fail_different_ping_time_depth_shape", + ] ) -def test_apply_mask_channel_variation(source_has_ch, mask_has_ch): +def test_apply_mask_channel_variation(source_has_ch, mask, truth_da): + + # Create source dataset source_ds = get_mock_source_ds_apply_mask(2, 3, False) var_name = "var1" - if mask_has_ch: - mask = xr.DataArray( - np.array([np.identity(2)]), - coords={"channel": ["chA"], "ping_time": np.arange(2), "range_sample": np.arange(2)}, - attrs={"long_name": "mask_with_channel"}, - ) - else: - mask = xr.DataArray( - np.identity(2), - coords={"ping_time": np.arange(2), "range_sample": np.arange(2)}, - attrs={"long_name": "mask_no_channel"}, - ) - - if source_has_ch: - masked_ds = echopype.mask.apply_mask(source_ds, mask, var_name) - else: - source_ds[f"{var_name}_ch0"] = source_ds[var_name].isel(channel=0).squeeze() - var_name = f"{var_name}_ch0" - masked_ds = echopype.mask.apply_mask(source_ds, mask, var_name) - - # Output dimension will be the same as source - if source_has_ch: - truth_da = xr.DataArray( - np.array([[[1, np.nan], [np.nan, 1]]] * 3), - coords={ - "channel": ["chan1", "chan2", "chan3"], - "ping_time": np.arange(2), - "range_sample": np.arange(2), - }, - attrs=source_ds[var_name].attrs, - ) + if truth_da is None: + # Attempt to apply mask w/ 'bad' shapes and check for raised ValueError + with pytest.raises(ValueError): + if source_has_ch: + masked_ds = echopype.mask.apply_mask(source_ds, + mask, + var_name + ) + else: + source_ds[f"{var_name}_ch0"] = source_ds[var_name].isel(channel=0).squeeze() + var_name = f"{var_name}_ch0" + masked_ds = echopype.mask.apply_mask(source_ds, + mask, + var_name + ) else: - truth_da = xr.DataArray( - [[1, np.nan], [np.nan, 1]], - coords={"ping_time": np.arange(2), "range_sample": np.arange(2)}, - attrs=source_ds[var_name].attrs, - ) - - assert masked_ds[var_name].equals(truth_da) + # Apply mask and check matching truth_da + if source_has_ch: + masked_ds = echopype.mask.apply_mask(source_ds, + mask, + var_name + ) + else: + source_ds[f"{var_name}_ch0"] = source_ds[var_name].isel(channel=0).squeeze() + var_name = f"{var_name}_ch0" + masked_ds = echopype.mask.apply_mask(source_ds, + mask, + var_name + ) + + # Check mask to match truth + assert masked_ds[var_name].equals(truth_da)