Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply Mask Changes: Multichannel, Allow Depth, Simplify Fill Value #1230

Merged
merged 13 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 89 additions & 48 deletions echopype/mask/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,26 @@ def _validate_and_collect_mask_input(
# the coordinate sequence matters, so fix the tuple form
allowed_dims = [
("ping_time", "range_sample"),
("ping_time", "depth"),
("channel", "ping_time", "range_sample"),
("channel", "ping_time", "depth"),
]
if mask[mask_ind].dims not in allowed_dims:
raise ValueError("All masks must have dimensions ('ping_time', 'range_sample')!")
raise ValueError(
"Masks must have one of the following dimensions: "
"('ping_time', 'range_sample'), ('ping_time', 'depth'), "
"('channel', 'ping_time', 'range_sample'), "
"('channel', 'ping_time', 'depth')"
)

# Check for the channel dimension consistency
channel_dim_shapes = set()
for mask_indiv in mask:
if "channel" in mask_indiv.dims:
for mask_chan_ind in range(len(mask_indiv["channel"])):
channel_dim_shapes.add(mask_indiv.isel(channel=mask_chan_ind).shape)
if len(channel_dim_shapes) > 1:
raise ValueError("All masks must have the same shape in the 'channel' dimension.")

else:
if not isinstance(storage_options_mask, dict):
Expand All @@ -128,7 +144,7 @@ def _validate_and_collect_mask_input(


def _check_var_name_fill_value(
source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, np.ndarray, xr.DataArray]
source_ds: xr.Dataset, var_name: str, fill_value: Union[int, float, xr.DataArray]
) -> Union[int, float, np.ndarray, xr.DataArray]:
"""
Ensures that the inputs ``var_name`` and ``fill_value`` for the function
Expand All @@ -140,12 +156,12 @@ def _check_var_name_fill_value(
A Dataset that contains the variable ``var_name``
var_name: str
The variable name in ``source_ds`` that the mask should be applied to
fill_value: int or float or np.ndarray or xr.DataArray
fill_value: int, float, or xr.DataArray
Specifies the value(s) at false indices

Returns
-------
fill_value: int or float or np.ndarray or xr.DataArray
fill_value: int, float, or xr.DataArray
fill_value with sanitized dimensions

Raises
Expand All @@ -167,17 +183,12 @@ def _check_var_name_fill_value(
raise ValueError("The Dataset source_ds does not contain the variable var_name!")

# check the type of fill_value
if not isinstance(fill_value, (int, float, np.ndarray, xr.DataArray)):
raise TypeError(
"The input fill_value must be of type int or " "float or np.ndarray or xr.DataArray!"
)
if not isinstance(fill_value, (int, float, xr.DataArray)):
raise TypeError("The input fill_value must be of type int, float, or xr.DataArray!")

# make sure that fill_values is the same shape as var_name
if isinstance(fill_value, (np.ndarray, xr.DataArray)):
if isinstance(fill_value, xr.DataArray):
fill_value = fill_value.data.squeeze() # squeeze out length=1 channel dimension
elif isinstance(fill_value, np.ndarray):
fill_value = fill_value.squeeze() # squeeze out length=1 channel dimension
if isinstance(fill_value, xr.DataArray):
fill_value = fill_value.data.squeeze() # squeeze out length=1 channel dimension

source_ds_shape = (
source_ds[var_name].isel(channel=0).shape
Expand Down Expand Up @@ -248,7 +259,8 @@ def apply_mask(
source_ds: Union[xr.Dataset, str, pathlib.Path],
mask: Union[xr.DataArray, str, pathlib.Path, List[Union[xr.DataArray, str, pathlib.Path]]],
var_name: str = "Sv",
fill_value: Union[int, float, np.ndarray, xr.DataArray] = np.nan,
fill_value: Union[int, float, xr.DataArray] = np.nan,
keep_unmasked_channel: bool = True,
storage_options_ds: dict = {},
storage_options_mask: Union[dict, List[dict]] = {},
) -> xr.Dataset:
Expand All @@ -262,23 +274,28 @@ def apply_mask(
Points to a Dataset that contains the variable the mask should be applied to
mask: xr.DataArray, str, pathlib.Path, or a list of these datatypes
The mask(s) to be applied.
Can be a single input or list that corresponds to a DataArray or a path.
Each entry in the list must have dimensions ``('ping_time', 'range_sample')``.
Multi-channel masks are not currently supported.
Can be a individual input or list that corresponds to a DataArray or a path.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small fix

Suggested change
Can be a individual input or list that corresponds to a DataArray or a path.
Can be a individual input or a list that corresponds to a DataArray or a path.

Each individual input or entry in the list must contain dimensions
``('ping_time', 'range_sample')`` or dimensions ``('ping_time', 'depth')``.
The mask can also contain the dimension ``channel``.
If a path is provided this should point to a zarr or netcdf file with only
one data variable in it.
If the input ``mask`` is a list, a logical AND will be used to produce the final
mask that will be applied to ``var_name``.
var_name: str, default="Sv"
The Sv variable name in ``source_ds`` that the mask should be applied to.
This variable needs to have coordinates ``ping_time`` and ``range_sample``,
and can optionally also have coordinate ``channel``.
This variable needs to have coordinates ``('ping_time', 'range_sample')`` or
coordinates ``('ping_time', 'depth')``, and can optionally also have coordinate
``channel``.
In the case of a multi-channel Sv data variable, the ``mask`` will be broadcast
to all channels.
leewujung marked this conversation as resolved.
Show resolved Hide resolved
fill_value: int, float, np.ndarray, or xr.DataArray, default=np.nan
fill_value: int, float, or xr.DataArray, default=np.nan
Value(s) at masked indices.
If ``fill_value`` is of type ``np.ndarray`` or ``xr.DataArray``,
it must have the same shape as each entry of ``mask``.
If ``fill_value`` is of type ``xr.DataArray`` it must have the same shape as each
entry of ``mask``.
keep_unmasked_channel: bool, default=True
When True: Channels that are not in mask will be left as is.
When False: Channels that are not in mask will be masked.
storage_options_ds: dict, default={}
Any additional parameters for the storage backend, corresponding to the
path provided for ``source_ds``
Expand All @@ -305,44 +322,69 @@ def apply_mask(

# Obtain final mask to be applied to var_name
if isinstance(mask, list):
# perform a logical AND element-wise operation across the masks
final_mask = np.logical_and.reduce(mask)
# Broadcast all input masks together before combining them
broadcasted_masks = xr.broadcast(*mask)

# Perform a logical AND element-wise operation across the masks
final_mask = np.logical_and.reduce(broadcasted_masks)

# xr.where has issues with attrs when final_mask is an array, so we make it a DataArray
final_mask = xr.DataArray(final_mask, coords=mask[0].coords)
final_mask = xr.DataArray(final_mask, coords=broadcasted_masks[0].coords)
else:
final_mask = mask

# Sanity check: final_mask should be of the same shape as source_ds[var_name]
# along the ping_time and range_sample dimensions
def get_ch_shape(da):
return da.isel(channel=0).shape if "channel" in da.dims else da.shape

# Below operate on the actual data array to be masked
# Operate on the actual data array to be masked
source_da = source_ds[var_name]

source_da_shape = get_ch_shape(source_da)
final_mask_shape = get_ch_shape(final_mask)

if final_mask_shape != source_da_shape:
# Sanity check: final_mask should be of the same shape as source_ds[var_name]
# along the ping_time and range_sample dimensions.
source_da_chan_shape = (
source_da.isel(channel=0).shape if "channel" in source_da.dims else source_da.shape
)
final_mask_chan_shape = (
final_mask.isel(channel=0).shape if "channel" in final_mask.dims else final_mask.shape
)
if final_mask_chan_shape != source_da_chan_shape:
raise ValueError(
f"The final constructed mask is not of the same shape as source_ds[{var_name}] "
"along the ping_time and range_sample dimensions!"
)

# final_mask is always an xr.DataArray with at most length=1 channel dimension
if "channel" in final_mask.dims:
final_mask = final_mask.isel(channel=0)
# Apply the mask to var_name
if "channel" in final_mask.dims and "channel" in source_da.dims:
# Identify common channels
common_channels = set(final_mask.coords["channel"].values).intersection(
set(source_da.coords["channel"].values)
)

# Make sure fill_value and final_mask are expanded in dimensions
if "channel" in source_da.dims:
if isinstance(fill_value, np.ndarray):
fill_value = np.array([fill_value] * source_da["channel"].size)
final_mask = np.array([final_mask.data] * source_da["channel"].size)
# Convert common channels back to a sorted list to maintain order
common_channels = sorted(list(common_channels))

# Apply the mask to var_name
# Somehow keep_attrs=True errors out here, so will attach later
var_name_masked = xr.where(final_mask, x=source_da, y=fill_value)
# Select common channels for operation
final_mask_common = final_mask.sel(channel=common_channels)
source_da_common = source_da.sel(channel=common_channels)

# Perform operation on common channels
masked_common = xr.where(final_mask_common, x=source_da_common, y=fill_value)

# Identify remaining channels
all_channels = set(source_da.coords["channel"].values)
remaining_channels = sorted(list(all_channels - set(common_channels)))

# Select remaining channels
source_da_remaining = source_da.sel(channel=remaining_channels)

if not keep_unmasked_channel:
# Replace unmasked channel values with fill value
source_da_remaining = xr.full_like(source_da_remaining, fill_value=fill_value)

# Combine modified common channels with remaining channels
var_name_masked = xr.concat([masked_common, source_da_remaining], dim="channel")
else:
if "channel" in final_mask.dims and "channel" not in source_da.dims:
# Select first channel if final mask has channel dim and source da does not
final_mask = final_mask.isel(channel=0)
var_name_masked = xr.where(final_mask, x=source_da, y=fill_value)

# Obtain a shallow copy of source_ds
output_ds = source_ds.copy(deep=False)
Expand All @@ -356,12 +398,11 @@ def get_ch_shape(da):
_variable_prov_attrs(output_ds[var_name], mask)
)

# Attribute handling
process_type = "mask"
prov_dict = echopype_prov_attrs(process_type=process_type)
prov_dict[f"{process_type}_function"] = "mask.apply_mask"

output_ds = output_ds.assign_attrs(prov_dict)

output_ds = insert_input_processing_level(output_ds, input_ds=source_ds)

return output_ds
Expand Down
Loading