diff --git a/echopype/calibrate/ek80_complex.py b/echopype/calibrate/ek80_complex.py index beaa71e78..2a8f5df23 100644 --- a/echopype/calibrate/ek80_complex.py +++ b/echopype/calibrate/ek80_complex.py @@ -1,4 +1,5 @@ from collections import defaultdict +from functools import partial from typing import Dict, Literal, Optional, Union import numpy as np @@ -258,6 +259,17 @@ def get_transmit_signal( return y_all, y_time_all +def _convolve_per_channel(m, replica_dict, channels): + convolved = np.zeros_like(m, dtype=np.complex64) + # Iterate over channels + for i, channel in enumerate(channels): + # Extract replica values + replica = replica_dict[str(channel.values)] + # Convolve backscatter and chirp replica + convolved[:, i] = signal.convolve(m[:, i], replica, mode="full")[replica.size - 1 :] + return convolved + + def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray: """Perform pulse compression on the backscatter data. @@ -273,54 +285,53 @@ def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray: xr.DataArray A data array containing pulse compression output. """ - pc_all = [] + # Select channel `chan` and drop the specific beam dimension if all of the values are nan. + # Additionally, in the same for loop, compute the replica dictionary values from the chirp. + backscatter_NaN_beam_drop_all = [] + replica_dict = {} + for channel in backscatter["channel"]: + # TODO: Once `dropna` allows for dropping along multiple dimensions, put this outside of the + # loop and remove the concatenate. + backscatter_NaN_beam_drop = backscatter.sel(channel=channel).dropna(dim="beam", how="all") + backscatter_NaN_beam_drop_all.append(backscatter_NaN_beam_drop) + + # Extract tx + tx = chirp[str(channel.values)] + # Compute complex conjugate of transmit values and reverse order of elements + replica_dict[str(channel.values)] = np.flipud(np.conj(tx)) + # Concatenate backscatter channels with dropped NaN beam dimensions. + backscatter_NaN_beam_drop_all = xr.concat(backscatter_NaN_beam_drop_all, dim="channel") - for chan in backscatter["channel"]: - # Select channel `chan` and drop the specific beam dimension if all of the values are nan. - backscatter_chan = backscatter.sel(channel=chan).dropna(dim="beam", how="all") + # Create NaN mask + nan_mask = np.isnan(backscatter_NaN_beam_drop_all) - # Create NaN mask - # If `backscatter_chan` is lazy loaded, then `nan_mask` too will be lazy loaded. - nan_mask = np.isnan(backscatter_chan) + # Zero out backscatter NaN values + backscatter_NaN_beam_drop_all = xr.where(nan_mask, 0.0 + 0j, backscatter_NaN_beam_drop_all) - # Zero out backscatter NaN values - # If `nan_mask` is lazy loaded, then resulting `backscatter_chan` will be lazy loaded. - backscatter_chan = xr.where(nan_mask, 0.0 + 0j, backscatter_chan) + # Create a partial function of the convolve function to pass in chirp and channels + _convolve_per_channel_partial = partial( + _convolve_per_channel, + replica_dict=replica_dict, + channels=backscatter_NaN_beam_drop_all["channel"], + ) - # Extract transmit values - tx = chirp[str(chan.values)] + # Apply convolve on backscatter and replica (along range sample and channel dimension): + # To enable parallelized computation with `dask='parallelized'`, we rechunk to ensure that + # the data is chunked with only one chunk along the core dimensions. + pc = xr.apply_ufunc( + _convolve_per_channel_partial, + backscatter_NaN_beam_drop_all.chunk({"range_sample": -1, "channel": -1}), + input_core_dims=[["range_sample", "channel"]], + output_core_dims=[["range_sample", "channel"]], + dask="parallelized", + vectorize=True, + output_dtypes=[np.complex64], + ) - # Compute complex conjugate of transmit values and reverse order of elements - replica = np.flipud(np.conj(tx)) - - # Apply convolve on backscatter (along range sample dimension) and replica - # Rechunking backscatter_chan is needed to avoid the following ValueError: - # ValueError: dimension range_sample on 0th function argument to apply_ufunc - # with dask='parallelized' consists of multiple chunks, but is also a core dimension. - # To fix, either rechunk into a single array chunk along this dimension, - # i.e., ``.chunk(dict(range_sample=-1))``, - # or pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` - # but beware that this may significantly increase memory usage. - pc = xr.apply_ufunc( - lambda m: (signal.convolve(m, replica, mode="full")[replica.size - 1 :]), - backscatter_chan.chunk({"range_sample": -1}), - input_core_dims=[["range_sample"]], - output_core_dims=[["range_sample"]], - dask="parallelized", - vectorize=True, - output_dtypes=[np.complex64], - ) # .compute() - - # Restore NaN values in the resulting array. - # Computing of `nan_mask` here is necessary in the case when `nan_mask` is lazy loaded - # or else the resulting `pc` will also be lazy loaded. - pc = xr.where(nan_mask, np.nan, pc) - - pc_all.append(pc) - - pc_all = xr.concat(pc_all, dim="channel") - - return pc_all + # Restore NaN values in the resulting array. + pc = xr.where(nan_mask, np.nan, pc) + + return pc def get_norm_fac(chirp: Dict) -> xr.DataArray: diff --git a/echopype/tests/calibrate/test_calibrate_ek80.py b/echopype/tests/calibrate/test_calibrate_ek80.py index a4b4011d7..ec09af476 100644 --- a/echopype/tests/calibrate/test_calibrate_ek80.py +++ b/echopype/tests/calibrate/test_calibrate_ek80.py @@ -234,7 +234,7 @@ def test_ek80_BB_power_from_complex( tx, _ = ep.calibrate.ek80_complex.get_transmit_signal(beam, filter_coeff, waveform_mode, fs) # Get power from complex samples - prx = cal_obj._get_power_from_complex(beam=beam, chirp=tx, z_et=z_et, z_er=z_er) + prx = cal_obj._get_power_from_complex(beam=beam, chirp=tx, z_et=z_et, z_er=z_er).compute() ch_sel = "WBT 714590-15 ES70-7C" @@ -306,10 +306,7 @@ def test_ek80_BB_power_compute_Sv( encode_mode=encode_mode, ) pyel_vals = pyel_BB_p_data["sv_data"] - if dask_array: - ep_vals = ds_Sv["Sv"].sel(channel=ch_sel).squeeze().data.compute() - else: - ep_vals = ds_Sv["Sv"].sel(channel=ch_sel).squeeze().data + ep_vals = ds_Sv["Sv"].sel(channel=ch_sel).squeeze().data.compute() assert pyel_vals.shape == ep_vals.shape idx_to_cmp = ~( @@ -353,7 +350,7 @@ def test_ek80_BB_power_echoview(ek80_path): pc = ep.calibrate.ek80_complex.compress_pulse( backscatter=beam["backscatter_r"] + 1j * beam["backscatter_i"], chirp=chirp, - ) + ).compute() pc = pc / ep.calibrate.ek80_complex.get_norm_fac(chirp) # normalization for each channel pc_mean = pc.sel(channel="WBT 549762-15 ES70-7C").mean(dim="beam").dropna("range_sample")