Skip to content

Commit

Permalink
deal with nans in bb complex multiplex
Browse files Browse the repository at this point in the history
  • Loading branch information
ctuguinay committed Apr 13, 2024
1 parent 46515f2 commit 43f606a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
53 changes: 47 additions & 6 deletions echopype/calibrate/ek80_complex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
from functools import partial
from typing import Dict, Literal, Optional, Union

import numpy as np
Expand Down Expand Up @@ -226,11 +227,10 @@ def get_transmit_signal(
# but keeping this here for use as standalone function
if waveform_mode == "BB" and np.all(beam["transmit_type"] == "CW"):
raise TypeError("File does not contain BB mode complex samples!")

# Generate all transmit replica
y_all = {}
y_time_all = {}
# TODO: expand to deal with the case with varying tx param across ping_time
# TODO: expand to deal with the case with varying non-NaN tx param across ping_time
tx_param_names = [
"transmit_duration_nominal",
"slope",
Expand All @@ -240,23 +240,57 @@ def get_transmit_signal(
for ch in beam["channel"].values:
tx_params = {}
for p in tx_param_names:
tx_params[p] = np.unique(beam[p].sel(channel=ch))
# Extract beam values and filter out NaNs
beam_values = np.unique(beam[p].sel(channel=ch))
# Filter out NaN values
beam_values_without_nan = beam_values[~np.isnan(beam_values)]
tx_params[p] = beam_values_without_nan
if tx_params[p].size != 1:
raise TypeError("File contains changing %s!" % p)
fs_chan = fs.sel(channel=ch).data if isinstance(fs, xr.DataArray) else fs
tx_params["fs"] = fs_chan
y_ch, _ = tapered_chirp(**tx_params)

# Filter and decimate chirp template
y_ch, y_tmp_time = filter_decimate_chirp(coeff_ch=coeff[ch], y_ch=y_ch, fs=fs_chan)

# Fill into output dict
y_all[ch] = y_ch
y_time_all[ch] = y_tmp_time

return y_all, y_time_all


def _nan_check_convolve(m, replica):
"""
Convolve two arrays while handling NaN values efficiently.
Parameters
----------
m : array-like
Input array.
replica : array-like
Array for convolution.
Returns
-------
array-like
Convolved array.
Notes
-----
If all elements in `m` are NaN, the function returns `m`.
If any element in `m` is NaN, direct convolution is performed.
Otherwise, FFT convolution is used.
Direct convolution is slower than FFT but works when NaN values are present.
"""
if np.all(np.isnan(m)):
return m
elif np.any(np.isnan(m)):
return signal.convolve(m, replica, mode="full", method="direct")[replica.size - 1 :]
else:
return signal.convolve(m, replica, mode="full", method="fft")[replica.size - 1 :]


def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray:
"""Perform pulse compression on the backscatter data.
Expand All @@ -278,11 +312,18 @@ def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray:
# Select channel `chan` and drop the specific beam dimension if all of the values are nan.
backscatter_chan = backscatter.sel(channel=chan).dropna(dim="beam", how="all")

# Extract transmit values
tx = chirp[str(chan.values)]

# Compute complex conjugate of transmit values and reverse order of elements
replica = np.flipud(np.conj(tx))

# Create partial of nan_check_convolve with passed in replica
_nan_check_convolve_partial = partial(_nan_check_convolve, replica=replica)

# Apply convolve on backscatter (along range sample dimension) and replica
pc = xr.apply_ufunc(
lambda m: (signal.convolve(m, replica, mode="full")[replica.size - 1 :]),
_nan_check_convolve_partial,
backscatter_chan,
input_core_dims=[["range_sample"]],
output_core_dims=[["range_sample"]],
Expand Down
9 changes: 6 additions & 3 deletions echopype/calibrate/env_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ def harmonize_env_param_time(
if "time1" not in p.coords:
return p

# If there's only 1 time1 value,
# or if after dropping NaN there's only 1 time1 value
if p["time1"].size == 1 or p.dropna(dim="time1").size == 1:
# If there's only 1 time1 value:
if p["time1"].size == 1:
return p.squeeze(dim="time1").drop("time1")

# If after dropping NaN there's only 1 time1 value:
if p.dropna(dim="time1").size == 1:
return p.dropna(dim="time1").squeeze(dim="time1").drop("time1")

if ping_time is None:
Expand Down

0 comments on commit 43f606a

Please sign in to comment.