Skip to content

Commit

Permalink
Apply Mask Changes: Multichannel, Allow Depth, Simplify Fill Value (#…
Browse files Browse the repository at this point in the history
…1230)

* allow depth dimension into coordinates for apply mask and incorporate this change into tests

* add channel dimension consistency check

* modification so that fill value cannot be of type np ndarray

* fix fill value incorrect tests

* initial logic for broadcast then np logical and reduce; commented out test that currently needs to be refactored

* modify apply mask logic

* add keep_unmasked_channel logic and tests

* make more strict as to input dimensions

* use 'a list' instead of 'list'

* add 3 cases and user mask subset

* add wu-jung's suggestion for large docstring

Co-authored-by: Wu-Jung Lee <[email protected]>

* add wu-jung's suggestion for channel and dimension wording

Co-authored-by: Wu-Jung Lee <[email protected]>

* add wu-jung's suggestion for making matching channel dimension wording more precise

Co-authored-by: Wu-Jung Lee <[email protected]>

---------

Co-authored-by: ctuguinay <[email protected]>
Co-authored-by: Wu-Jung Lee <[email protected]>
  • Loading branch information
3 people authored Apr 8, 2024
1 parent 63267dd commit 646759e
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 104 deletions.
130 changes: 81 additions & 49 deletions echopype/mask/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,26 @@ def _validate_and_collect_mask_input(
# the coordinate sequence matters, so fix the tuple form
allowed_dims = [
("ping_time", "range_sample"),
("ping_time", "depth"),
("channel", "ping_time", "range_sample"),
("channel", "ping_time", "depth"),
]
if mask[mask_ind].dims not in allowed_dims:
raise ValueError("All masks must have dimensions ('ping_time', 'range_sample')!")
raise ValueError(
"Masks must have one of the following dimensions: "
"('ping_time', 'range_sample'), ('ping_time', 'depth'), "
"('channel', 'ping_time', 'range_sample'), "
"('channel', 'ping_time', 'depth')"
)

# Check for the channel dimension consistency
channel_dim_shapes = set()
for mask_indiv in mask:
if "channel" in mask_indiv.dims:
for mask_chan_ind in range(len(mask_indiv["channel"])):
channel_dim_shapes.add(mask_indiv.isel(channel=mask_chan_ind).shape)
if len(channel_dim_shapes) > 1:
raise ValueError("All masks must have the same shape in the 'channel' dimension.")

else:
if not isinstance(storage_options_mask, dict):
Expand All @@ -126,7 +142,7 @@ def _validate_and_collect_mask_input(


def _check_var_name_fill_value(
source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, np.ndarray, xr.DataArray]
source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, xr.DataArray]
) -> Union[int, float, np.ndarray, xr.DataArray]:
"""
Ensures that the inputs ``var_name`` and ``fill_value`` for the function
Expand All @@ -138,12 +154,12 @@ def _check_var_name_fill_value(
A Dataset that contains the variable ``var_name``
var_name: str
The variable name in ``source_ds`` that the mask should be applied to
fill_value: int or float or np.ndarray or xr.DataArray
fill_value: int, float, or xr.DataArray
Specifies the value(s) at false indices
Returns
-------
fill_value: int or float or np.ndarray or xr.DataArray
fill_value: int, float, or xr.DataArray
fill_value with sanitized dimensions
Raises
Expand All @@ -165,17 +181,12 @@ def _check_var_name_fill_value(
raise ValueError("The Dataset source_ds does not contain the variable var_name!")

# check the type of fill_value
if not isinstance(fill_value, (int, float, np.ndarray, xr.DataArray)):
raise TypeError(
"The input fill_value must be of type int or " "float or np.ndarray or xr.DataArray!"
)
if not isinstance(fill_value, (int, float, xr.DataArray)):
raise TypeError("The input fill_value must be of type int, float, or xr.DataArray!")

# make sure that fill_values is the same shape as var_name
if isinstance(fill_value, (np.ndarray, xr.DataArray)):
if isinstance(fill_value, xr.DataArray):
fill_value = fill_value.data.squeeze() # squeeze out length=1 channel dimension
elif isinstance(fill_value, np.ndarray):
fill_value = fill_value.squeeze() # squeeze out length=1 channel dimension
if isinstance(fill_value, xr.DataArray):
fill_value = fill_value.data.squeeze() # squeeze out length=1 channel dimension

source_ds_shape = (
source_ds[var_name].isel(channel=0).shape
Expand Down Expand Up @@ -246,37 +257,53 @@ def apply_mask(
source_ds: Union[xr.Dataset, str, pathlib.Path],
mask: Union[xr.DataArray, str, pathlib.Path, List[Union[xr.DataArray, str, pathlib.Path]]],
var_name: str = "Sv",
fill_value: Union[int, float, np.ndarray, xr.DataArray] = np.nan,
fill_value: Union[int, float, xr.DataArray] = np.nan,
storage_options_ds: dict = {},
storage_options_mask: Union[dict, List[dict]] = {},
) -> xr.Dataset:
"""
Applies the provided mask(s) to the Sv variable ``var_name``
in the provided Dataset ``source_ds``.
The code allows for these 3 cases of `source_ds` and `mask` dimensions:
1) No channel in both `source_ds` and `mask`,
but they have matching `ping_time` and
`depth` (or `range_sample`) dimensions.
2) `source_ds` and `mask` both have matching `channel`,
`ping_time`, and `depth` (or `range_sample`) dimensions.
3) `source_ds` has the channel dimension and `mask` doesn't,
but they have matching
`ping_time` and `depth` (or `range_sample`) dimensions.
If a user only wants to apply masks to a subset of the channels in `source_ds`,
they could put 1s to allow all data entries in the other channels.
Parameters
----------
source_ds: xr.Dataset, str, or pathlib.Path
Points to a Dataset that contains the variable the mask should be applied to
mask: xr.DataArray, str, pathlib.Path, or a list of these datatypes
The mask(s) to be applied.
Can be a single input or list that corresponds to a DataArray or a path.
Each entry in the list must have dimensions ``('ping_time', 'range_sample')``.
Multi-channel masks are not currently supported.
Can be a individual input or a list that corresponds to a DataArray or a path.
Each individual input or entry in the list must contain dimensions
``('ping_time', 'range_sample')`` or dimensions ``('ping_time', 'depth')``.
The mask can also contain the dimension ``channel``.
If a path is provided this should point to a zarr or netcdf file with only
one data variable in it.
If the input ``mask`` is a list, a logical AND will be used to produce the final
mask that will be applied to ``var_name``.
var_name: str, default="Sv"
The Sv variable name in ``source_ds`` that the mask should be applied to.
This variable needs to have coordinates ``ping_time`` and ``range_sample``,
and can optionally also have coordinate ``channel``.
This variable needs to have coordinates ``('ping_time', 'range_sample')`` or
coordinates ``('ping_time', 'depth')``, and can optionally also have coordinate
``channel``.
In the case of a multi-channel Sv data variable, the ``mask`` will be broadcast
to all channels.
fill_value: int, float, np.ndarray, or xr.DataArray, default=np.nan
fill_value: int, float, or xr.DataArray, default=np.nan
Value(s) at masked indices.
If ``fill_value`` is of type ``np.ndarray`` or ``xr.DataArray``,
it must have the same shape as each entry of ``mask``.
If ``fill_value`` is of type ``xr.DataArray`` it must have the same shape as each
entry of ``mask``.
storage_options_ds: dict, default={}
Any additional parameters for the storage backend, corresponding to the
path provided for ``source_ds``
Expand All @@ -303,43 +330,49 @@ def apply_mask(

# Obtain final mask to be applied to var_name
if isinstance(mask, list):
# perform a logical AND element-wise operation across the masks
final_mask = np.logical_and.reduce(mask)
# Broadcast all input masks together before combining them
broadcasted_masks = xr.broadcast(*mask)

# Perform a logical AND element-wise operation across the masks
final_mask = np.logical_and.reduce(broadcasted_masks)

# xr.where has issues with attrs when final_mask is an array, so we make it a DataArray
final_mask = xr.DataArray(final_mask, coords=mask[0].coords)
final_mask = xr.DataArray(final_mask, coords=broadcasted_masks[0].coords)
else:
final_mask = mask

# Sanity check: final_mask should be of the same shape as source_ds[var_name]
# along the ping_time and range_sample dimensions
def get_ch_shape(da):
return da.isel(channel=0).shape if "channel" in da.dims else da.shape

# Below operate on the actual data array to be masked
# Operate on the actual data array to be masked
source_da = source_ds[var_name]

source_da_shape = get_ch_shape(source_da)
final_mask_shape = get_ch_shape(final_mask)

if final_mask_shape != source_da_shape:
# The final_mask should be of the same shape as source_ds[var_name]
# along the ping_time and range_sample dimensions.
source_da_chan_shape = (
source_da.isel(channel=0).shape if "channel" in source_da.dims else source_da.shape
)
final_mask_chan_shape = (
final_mask.isel(channel=0).shape if "channel" in final_mask.dims else final_mask.shape
)
if final_mask_chan_shape != source_da_chan_shape:
raise ValueError(
f"The final constructed mask is not of the same shape as source_ds[{var_name}] "
"along the ping_time and range_sample dimensions!"
"along the ping_time, and range_sample dimensions!"
)

# final_mask is always an xr.DataArray with at most length=1 channel dimension
if "channel" in final_mask.dims:
final_mask = final_mask.isel(channel=0)

# Make sure fill_value and final_mask are expanded in dimensions
if "channel" in source_da.dims:
if isinstance(fill_value, np.ndarray):
fill_value = np.array([fill_value] * source_da["channel"].size)
final_mask = np.array([final_mask.data] * source_da["channel"].size)
# If final_mask has dim channel then source_da must have dim channel
if "channel" in final_mask.dims and "channel" not in source_da.dims:
raise ValueError(
"The final constructed mask has the channel dimension, "
f"so source_ds[{var_name}] must also have the channel dimension."
)
# If final_mask and source_da both have channel dimension, then they must
# have the same number of channels.
elif "channel" in final_mask.dims and "channel" in source_da.dims:
if len(final_mask["channel"]) != len(source_da["channel"]):
raise ValueError(
f"If both the final constructed mask and source_ds[{var_name}] "
"have the channel dimension, that dimension should match between the two."
)

# Apply the mask to var_name
# Somehow keep_attrs=True errors out here, so will attach later
var_name_masked = xr.where(final_mask, x=source_da, y=fill_value)

# Obtain a shallow copy of source_ds
Expand All @@ -354,12 +387,11 @@ def get_ch_shape(da):
_variable_prov_attrs(output_ds[var_name], mask)
)

# Attribute handling
process_type = "mask"
prov_dict = echopype_prov_attrs(process_type=process_type)
prov_dict[f"{process_type}_function"] = "mask.apply_mask"

output_ds = output_ds.assign_attrs(prov_dict)

output_ds = insert_input_processing_level(output_ds, input_ds=source_ds)

return output_ds
Expand Down
Loading

0 comments on commit 646759e

Please sign in to comment.