diff --git a/echopype/calibrate/ek80_complex.py b/echopype/calibrate/ek80_complex.py index 5c93118b1..8b2dc03dd 100644 --- a/echopype/calibrate/ek80_complex.py +++ b/echopype/calibrate/ek80_complex.py @@ -1,4 +1,5 @@ from collections import defaultdict +from functools import partial from typing import Dict, Literal, Optional, Union import numpy as np @@ -226,11 +227,10 @@ def get_transmit_signal( # but keeping this here for use as standalone function if waveform_mode == "BB" and np.all(beam["transmit_type"] == "CW"): raise TypeError("File does not contain BB mode complex samples!") - # Generate all transmit replica y_all = {} y_time_all = {} - # TODO: expand to deal with the case with varying tx param across ping_time + # TODO: expand to deal with the case with varying non-NaN tx param across ping_time tx_param_names = [ "transmit_duration_nominal", "slope", @@ -240,16 +240,18 @@ def get_transmit_signal( for ch in beam["channel"].values: tx_params = {} for p in tx_param_names: - tx_params[p] = np.unique(beam[p].sel(channel=ch)) + # Extract beam values and filter out NaNs + beam_values = np.unique(beam[p].sel(channel=ch)) + # Filter out NaN values + beam_values_without_nan = beam_values[~np.isnan(beam_values)] + tx_params[p] = beam_values_without_nan if tx_params[p].size != 1: raise TypeError("File contains changing %s!" % p) fs_chan = fs.sel(channel=ch).data if isinstance(fs, xr.DataArray) else fs tx_params["fs"] = fs_chan y_ch, _ = tapered_chirp(**tx_params) - # Filter and decimate chirp template y_ch, y_tmp_time = filter_decimate_chirp(coeff_ch=coeff[ch], y_ch=y_ch, fs=fs_chan) - # Fill into output dict y_all[ch] = y_ch y_time_all[ch] = y_tmp_time @@ -257,6 +259,38 @@ def get_transmit_signal( return y_all, y_time_all +def _nan_check_convolve(m, replica): + """ + Convolve two arrays while handling NaN values efficiently. + + Parameters + ---------- + m : array-like + Input array. + replica : array-like + Array for convolution. + + Returns + ------- + array-like + Convolved array. + + Notes + ----- + If all elements in `m` are NaN, the function returns `m`. + If any element in `m` is NaN, direct convolution is performed. + Otherwise, FFT convolution is used. + + Direct convolution is slower than FFT but works when NaN values are present. + """ + if np.all(np.isnan(m)): + return m + elif np.any(np.isnan(m)): + return signal.convolve(m, replica, mode="full", method="direct")[replica.size - 1 :] + else: + return signal.convolve(m, replica, mode="full", method="fft")[replica.size - 1 :] + + def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray: """Perform pulse compression on the backscatter data. @@ -278,11 +312,18 @@ def compress_pulse(backscatter: xr.DataArray, chirp: Dict) -> xr.DataArray: # Select channel `chan` and drop the specific beam dimension if all of the values are nan. backscatter_chan = backscatter.sel(channel=chan).dropna(dim="beam", how="all") + # Extract transmit values tx = chirp[str(chan.values)] + + # Compute complex conjugate of transmit values and reverse order of elements replica = np.flipud(np.conj(tx)) + # Create partial of nan_check_convolve with passed in replica + _nan_check_convolve_partial = partial(_nan_check_convolve, replica=replica) + + # Apply convolve on backscatter (along range sample dimension) and replica pc = xr.apply_ufunc( - lambda m: (signal.convolve(m, replica, mode="full")[replica.size - 1 :]), + _nan_check_convolve_partial, backscatter_chan, input_core_dims=[["range_sample"]], output_core_dims=[["range_sample"]], diff --git a/echopype/calibrate/env_params.py b/echopype/calibrate/env_params.py index 49be4d12c..1bab6cabb 100644 --- a/echopype/calibrate/env_params.py +++ b/echopype/calibrate/env_params.py @@ -50,9 +50,12 @@ def harmonize_env_param_time( if "time1" not in p.coords: return p - # If there's only 1 time1 value, - # or if after dropping NaN there's only 1 time1 value - if p["time1"].size == 1 or p.dropna(dim="time1").size == 1: + # If there's only 1 time1 value: + if p["time1"].size == 1: + return p.squeeze(dim="time1").drop("time1") + + # If after dropping NaN there's only 1 time1 value: + if p.dropna(dim="time1").size == 1: return p.dropna(dim="time1").squeeze(dim="time1").drop("time1") if ping_time is None: