Skip to content

Commit

Permalink
Support lazy-loaded EK80 broadband-complex data (#1311)
Browse files Browse the repository at this point in the history
* fiddles to get cal code work with dask array

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert accidentally pushed test change

* compute channels together

* add rechunk message

* remove drop beam

* fix comment

* fix comment

* incorporate wu-jung's suggestions

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ctuguinay <[email protected]>
  • Loading branch information
3 people authored Jun 11, 2024
1 parent 66fda23 commit 46b6296
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 45 deletions.
6 changes: 3 additions & 3 deletions echopype/calibrate/calibrate_ek.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _get_chan_dict(beam: xr.Dataset) -> Dict:
# assume transmit_type identical for all pings in a channel
first_ping_transmit_type = (
beam["transmit_type"].isel(ping_time=0).drop_vars("ping_time")
) # noqa
).compute() # noqa
return {
# For BB: Keep only non-CW channels (LFM or FMD) based on transmit_type
"BB": first_ping_transmit_type.where(
Expand Down Expand Up @@ -544,9 +544,9 @@ def _cal_complex_samples(self, cal_type: str) -> xr.Dataset:
ping_time=beam["ping_time"],
)
# Use pulse_duration in place of tau_effective for GPT channels
# below assumesthat all transmit parameters are identical
# TODO: below assumes that all transmit parameters are identical
# and needs to be changed when allowing transmit parameters to vary by ping
ch_GPT = vend["transceiver_type"] == "GPT"
ch_GPT = (vend["transceiver_type"] == "GPT").compute()
tau_effective[ch_GPT] = beam["transmit_duration_nominal"][ch_GPT].isel(ping_time=0)

# equivalent_beam_angle
Expand Down
106 changes: 68 additions & 38 deletions echopype/calibrate/ek80_complex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
from functools import partial
from typing import Dict, Literal, Optional, Union

import numpy as np
Expand Down Expand Up @@ -258,6 +259,37 @@ def get_transmit_signal(
return y_all, y_time_all


def _convolve_per_channel(backscatter_subset: np.ndarray, replica_dict: dict, channels: dict):
"""
Convolve `backscatter_subset` array along range sample dimension for each channel.
The `backscatter_subset` array is a numpy array and has implicit dimensions
`('range_sample', 'channel')`.
When the `backscatter_subset` array is all 0s, we return it since the resulting
convolution will be all 0s, irrespective of what the corresponding transmit
signal is.
When this function is used in `compress_pulse`, the array that is being sent
as backscatter subset corresponds to a specific `ping_time` and `beam`, from
the backscatter array.
"""
# Return if all 0s
if np.all(backscatter_subset == 0.0 + 0.0j):
return backscatter_subset
else:
# Create zeros like array from `backscatter_subset`
convolved = np.zeros_like(backscatter_subset, dtype=np.complex64)
# Iterate over channels
for ch_seq, channel in enumerate(channels):
# Extract replica values
replica = replica_dict[str(channel.values)]
# Convolve backscatter and chirp replica
convolved[:, ch_seq] = signal.convolve(
backscatter_subset[:, ch_seq], replica, mode="full"
)[replica.size - 1 :]
return convolved


def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray:
"""Perform pulse compression on the backscatter data.
Expand All @@ -273,47 +305,45 @@ def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray:
xr.DataArray
A data array containing pulse compression output.
"""
pc_all = []

for chan in backscatter["channel"]:
# Select channel `chan` and drop the specific beam dimension if all of the values are nan.
backscatter_chan = backscatter.sel(channel=chan).dropna(dim="beam", how="all")

# Create NaN mask
# If `backscatter_chan` is lazy loaded, then `nan_mask` too will be lazy loaded.
nan_mask = np.isnan(backscatter_chan)

# Zero out backscatter NaN values
# If `nan_mask` is lazy loaded, then resulting `backscatter_chan` will be lazy loaded.
backscatter_chan = xr.where(nan_mask, 0.0 + 0j, backscatter_chan)

# Extract transmit values
tx = chirp[str(chan.values)]

# Compute complex conjugate of transmit values and reverse order of elements
replica = np.flipud(np.conj(tx))

# Apply convolve on backscatter (along range sample dimension) and replica
pc = xr.apply_ufunc(
lambda m: (signal.convolve(m, replica, mode="full")[replica.size - 1 :]),
backscatter_chan,
input_core_dims=[["range_sample"]],
output_core_dims=[["range_sample"]],
dask="parallelized",
vectorize=True,
output_dtypes=[np.complex64],
).compute()

# Restore NaN values in the resulting array.
# Computing of `nan_mask` here is necessary in the case when `nan_mask` is lazy loaded
# or else the resulting `pc` will also be lazy loaded.
pc = xr.where(nan_mask.compute(), np.nan, pc)
# Calculate the transmit signal values from the chirp dictionary
replica_dict = {
# Compute conjugate and flip for each channel's transmit signal
str(channel.values): np.flipud(np.conj(chirp[str(channel.values)]))
for channel in backscatter["channel"]
}

# Zero out backscatter NaN values
nan_mask = np.isnan(backscatter)
backscatter_with_zeroed_nans = xr.where(nan_mask, 0.0 + 0.0j, backscatter)

# Create a partial function of the convolve function to pass in chirp and channels
_convolve_per_channel_partial = partial(
_convolve_per_channel,
replica_dict=replica_dict,
channels=backscatter_with_zeroed_nans["channel"],
)

pc_all.append(pc)
# Apply convolve on backscatter and replica (along range sample and channel dimension):
# To enable parallelized computation with `dask='parallelized'`, we rechunk to ensure that
# the data is chunked with only one chunk along the core dimensions.
if backscatter_with_zeroed_nans.chunks is not None:
backscatter_with_zeroed_nans = backscatter_with_zeroed_nans.chunk(
{"range_sample": -1, "channel": -1}
)
pc = xr.apply_ufunc(
_convolve_per_channel_partial,
backscatter_with_zeroed_nans,
input_core_dims=[["range_sample", "channel"]],
output_core_dims=[["range_sample", "channel"]],
dask="parallelized",
vectorize=True,
output_dtypes=[np.complex64],
)

pc_all = xr.concat(pc_all, dim="channel")
# Restore NaN values in the pulse compressed array
pc = xr.where(nan_mask, np.nan, pc)

return pc_all
return pc


def get_norm_fac(chirp: Dict) -> xr.DataArray:
Expand Down
8 changes: 4 additions & 4 deletions echopype/tests/calibrate/test_calibrate_ek80.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def test_ek80_BB_power_from_complex(
tx, _ = ep.calibrate.ek80_complex.get_transmit_signal(beam, filter_coeff, waveform_mode, fs)

# Get power from complex samples
prx = cal_obj._get_power_from_complex(beam=beam, chirp=tx, z_et=z_et, z_er=z_er)
prx = cal_obj._get_power_from_complex(beam=beam, chirp=tx, z_et=z_et, z_er=z_er).compute()

ch_sel = "WBT 714590-15 ES70-7C"

Expand Down Expand Up @@ -307,8 +307,8 @@ def test_ek80_BB_power_compute_Sv(
)
pyel_vals = pyel_BB_p_data["sv_data"]
if dask_array:
ep_vals = ds_Sv["Sv"].sel(channel=ch_sel).squeeze().data.compute()
else:
ep_vals = ds_Sv["Sv"].sel(channel=ch_sel).squeeze().data.compute()
else:
ep_vals = ds_Sv["Sv"].sel(channel=ch_sel).squeeze().data

assert pyel_vals.shape == ep_vals.shape
Expand Down Expand Up @@ -353,7 +353,7 @@ def test_ek80_BB_power_echoview(ek80_path):
pc = ep.calibrate.ek80_complex.compress_pulse(
backscatter=beam["backscatter_r"] + 1j * beam["backscatter_i"],
chirp=chirp,
)
).compute()
pc = pc / ep.calibrate.ek80_complex.get_norm_fac(chirp) # normalization for each channel
pc_mean = pc.sel(channel="WBT 549762-15 ES70-7C").mean(dim="beam").dropna("range_sample")

Expand Down

0 comments on commit 46b6296

Please sign in to comment.