Skip to content

Commit

Permalink
Optimize get vend cal params power (#1285)
Browse files Browse the repository at this point in the history
* Optimize get vend cal params power (#2)

* Refactor function to avoid expand_dims and sortby if required

* test: increase co-ordinates to 1000 for ping_time dimension

* style: upper-casing variables

* style: upper-casing variables

* style: revert

---------

Co-authored-by: Anant Mittal <[email protected]>

* test: update test_get_interp_da() to test on 200 time_coordinates (#3)

* test: refactor test data (#4)

* test: refactor test data (#5)

* test: refactor test data

* test: refactor test data

---------

Co-authored-by: Anant Mittal <[email protected]>
  • Loading branch information
anujsinha3 and anantmittal authored Mar 26, 2024
1 parent 6dc196a commit 8600a59
Show file tree
Hide file tree
Showing 2 changed files with 398 additions and 168 deletions.
14 changes: 7 additions & 7 deletions echopype/calibrate/cal_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def get_vend_cal_params_power(beam: xr.Dataset, vend: xr.Dataset, param: str) ->
raise ValueError(f"{param} does not exist in the Vendor_specific group!")

# Find idx to select the corresponding param value
# by matching beam["transmit_duration_nominal"] with ds_vend["pulse_length"]
# by matching beam["transmit_duration_nominal"] with vend["pulse_length"]
transmit_isnull = beam["transmit_duration_nominal"].isnull()
idxmin = np.abs(beam["transmit_duration_nominal"] - vend["pulse_length"]).idxmin(
dim="pulse_length_bin"
Expand All @@ -297,13 +297,13 @@ def get_vend_cal_params_power(beam: xr.Dataset, vend: xr.Dataset, param: str) ->
idxmin = idxmin.where(~transmit_isnull, 0).astype(int)

# Get param dataarray into correct shape
da_param = (
vend[param]
.expand_dims(dim={"ping_time": idxmin["ping_time"]}) # expand dims for direct indexing
.sortby(idxmin.channel) # sortby in case channel sequence differs in vend and beam
)
da_param = vend[param].transpose("pulse_length_bin", "channel")

if not np.array_equal(da_param.channel.data, idxmin.channel.data):
da_param = da_param.sortby(
da_param.channel, ascending=False
) # sortby because channel sequence differs in vend and beam

# Select corresponding index and clean up the original nan elements
da_param = da_param.sel(pulse_length_bin=idxmin, drop=True)

# Set the nan elements back to nan.
Expand Down
Loading

0 comments on commit 8600a59

Please sign in to comment.