Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up time interpolation+renaming code #309

Merged
merged 4 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 5 additions & 22 deletions src/xradio/measurement_set/_utils/_msv2/create_antenna_xds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
table_exists,
)
from xradio._utils.schema import convert_generic_xds_to_xradio_schema
from xradio.measurement_set._utils._msv2.msv4_sub_xdss import interpolate_to_time
from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
rename_and_interpolate_to_time,
)

from xradio._utils.list_and_array import (
check_if_consistent,
Expand Down Expand Up @@ -509,27 +511,8 @@ def create_phase_calibration_xds(
phase_cal_xds.time_phase_cal.astype("float64").astype("float64") / 10**9
)

phase_cal_xds = interpolate_to_time(
phase_cal_xds,
phase_cal_interp_time,
"antenna_xds",
time_name="time_phase_cal",
phase_cal_xds = rename_and_interpolate_to_time(
phase_cal_xds, "time_phase_cal", phase_cal_interp_time, "phase_cal_xds"
)

time_coord_attrs = {
"type": "time",
"units": ["s"],
"scale": "utc",
"format": "unix",
}

# If we interpolate rename the time_phase_cal axis to time.
if phase_cal_interp_time is not None:
time_coord = {"time": ("time_phase_cal", phase_cal_interp_time.data)}
phase_cal_xds = phase_cal_xds.assign_coords(time_coord)
phase_cal_xds.coords["time"].attrs.update(time_coord_attrs)
phase_cal_xds = phase_cal_xds.swap_dims({"time_phase_cal": "time"}).drop_vars(
"time_phase_cal"
)

return phase_cal_xds
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import xarray as xr

import toolviper.utils.logger as logger
from xradio.measurement_set._utils._msv2.msv4_sub_xdss import interpolate_to_time
from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
rename_and_interpolate_to_time,
)
from xradio.measurement_set._utils._msv2.subtables import subt_rename_ids
from xradio.measurement_set._utils._msv2._tables.read import (
convert_casacore_time_to_mjd,
Expand Down Expand Up @@ -363,20 +365,13 @@ def extract_ephemeris_info(
}
temp_xds["time_ephemeris"].attrs.update(time_coord_attrs)

# Convert to si units and interpolate if ephemeris_interpolate=True:
# Convert to si units
temp_xds = convert_to_si_units(temp_xds)
temp_xds = interpolate_to_time(
temp_xds, interp_time, "field_and_source_xds", time_name="time_ephemeris"
)

# If we interpolate rename the time_ephemeris axis to time.
if interp_time is not None:
time_coord = {"time": ("time_ephemeris", interp_time.data)}
temp_xds = temp_xds.assign_coords(time_coord)
temp_xds.coords["time"].attrs.update(time_coord_attrs)
temp_xds = temp_xds.swap_dims({"time_ephemeris": "time"}).drop_vars(
"time_ephemeris"
)
# interpolate if ephemeris_interpolate/interp_time=True, and rename time_ephemeris=>time
temp_xds = rename_and_interpolate_to_time(
temp_xds, "time_ephemeris", interp_time, "field_and_source_xds"
)

xds = xr.merge([xds, temp_xds])

Expand Down
102 changes: 79 additions & 23 deletions src/xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,74 @@
)


standard_time_coord_attrs = {
"type": "time",
"units": ["s"],
"scale": "utc",
"format": "unix",
}


def rename_and_interpolate_to_time(
xds: xr.Dataset,
time_initial_name: str,
interp_time: Union[xr.DataArray, None],
message_prefix: str,
) -> xr.Dataset:
"""
This function interpolates the time dimension and renames it:

- interpolates a time_* dimension to values given in interp_time (presumably the time
axis of the main xds)
- rename/replace that time_* dimension to "time", where time_* is a (sub)xds specific
time axis
(for example "time_pointing", "time_ephemeris", "time_syscal", "time_phase_cal").

If interp_time is None this will simply return the input xds without modificaitons.
Uses interpolate_to_time() for interpolation.
...

Parameters:
----------
xds : xr.Dataset
Xarray dataset to interpolate (presumably a pointing_xds or an xds of
ephemeris variables)
time_initial_name: str = None
Name of time to be renamed+interpolated. Expected an existing time_* coordinate in the
dataset
interp_time:
Time axis to interpolate the dataset to (usually main MSv4 time)
message_prefix:
A prefix for info/debug/etc. messages about the specific xds being interpolated/
time-renamed

Returns:
-------
renamed_interpolated_xds : xr.Dataset
xarray dataset with time axis renamed to "time" (from time_name, for example
"time_ephemeris") and interpolated to interp_time.
"""
if interp_time is None:
return xds

interpolated_xds = interpolate_to_time(
xds,
interp_time,
message_prefix,
time_name=time_initial_name,
)

# rename the time_* axis to time.
time_coord = {"time": (time_initial_name, interp_time.data)}
renamed_time_xds = interpolated_xds.assign_coords(time_coord)
renamed_time_xds.coords["time"].attrs.update(standard_time_coord_attrs)
renamed_time_xds = renamed_time_xds.swap_dims({time_initial_name: "time"})
if time_initial_name != "time":
renamed_time_xds = renamed_time_xds.drop_vars(time_initial_name)

return renamed_time_xds


def interpolate_to_time(
xds: xr.Dataset,
interp_time: Union[xr.DataArray, None],
Expand Down Expand Up @@ -56,7 +124,9 @@ def interpolate_to_time(
method = "linear"
else:
method = "nearest"
xds = xds.interp({time_name: interp_time}, method=method, assume_sorted=True)
xds = xds.interp(
{time_name: interp_time.data}, method=method, assume_sorted=True
)
# scan_number sneaks in as a coordinate of the main time axis, drop it
if "scan_number" in xds.coords:
xds = xds.drop_vars("scan_number")
Expand Down Expand Up @@ -309,7 +379,7 @@ def create_pointing_xds(
elif size == 0:
generic_pointing_xds = generic_pointing_xds.drop_dims("n_polynomial")

time_ant_dims = ["time", "antenna_name"]
time_ant_dims = ["time_pointing", "antenna_name"]
time_ant_dir_dims = time_ant_dims + ["local_sky_dir_label"]
to_new_data_variables = {
"DIRECTION": ["POINTING_BEAM", time_ant_dir_dims],
Expand All @@ -318,7 +388,7 @@ def create_pointing_xds(
}

to_new_coords = {
"TIME": ["time", ["time"]],
"TIME": ["time_pointing", ["time_pointing"]],
"dim_2": ["local_sky_dir_label", ["local_sky_dir_label"]],
}

Expand All @@ -337,7 +407,9 @@ def create_pointing_xds(
generic_pointing_xds, pointing_xds, to_new_data_variables, to_new_coords
)

pointing_xds = interpolate_to_time(pointing_xds, interp_time, "pointing_xds")
pointing_xds = rename_and_interpolate_to_time(
pointing_xds, "time_pointing", interp_time, "pointing_xds"
)

logger.debug(f"create_pointing_xds() execution time {time.time() - start:0.2f} s")

Expand Down Expand Up @@ -522,25 +594,9 @@ def create_system_calibration_xds(
}
sys_cal_xds.coords["frequency_cal"].attrs.update(frequency_measure)

if sys_cal_interp_time is not None:
sys_cal_xds = interpolate_to_time(
sys_cal_xds,
sys_cal_interp_time,
"system_calibration_xds",
time_name="time_cal",
)

time_coord_attrs = {
"type": "time",
"units": ["s"],
"scale": "utc",
"format": "unix",
}
# If interpolating time, rename time_cal => time
time_coord = {"time": ("time_cal", sys_cal_interp_time.data)}
sys_cal_xds = sys_cal_xds.assign_coords(time_coord)
sys_cal_xds.coords["time"].attrs.update(time_coord_attrs)
sys_cal_xds = sys_cal_xds.swap_dims({"time_cal": "time"}).drop_vars("time_cal")
sys_cal_xds = rename_and_interpolate_to_time(
sys_cal_xds, "time_cal", sys_cal_interp_time, "system_calibration_xds"
)

# correct expected types
for data_var in sys_cal_xds:
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/measurement_set/_utils/_msv2/test_msv4_sub_xdss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,31 @@
import pytest


def test_rename_and_interpoalte_to_time_bogus(ddi_xds_min, main_xds_min):
from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
rename_and_interpolate_to_time,
)

input_time = main_xds_min.time
with pytest.raises(KeyError, match="No variable named 'time_bogus'."):
out_xds = rename_and_interpolate_to_time(
ddi_xds_min, "time_bogus", input_time, message_prefix="test_call"
)


def test_rename_and_interpoalte_to_time_main(main_xds_min):
from xradio.measurement_set._utils._msv2.msv4_sub_xdss import (
rename_and_interpolate_to_time,
)

input_time = main_xds_min.time
out_xds = rename_and_interpolate_to_time(
main_xds_min, "time", input_time, message_prefix="test_call"
)

xr.testing.assert_equal(out_xds.time, input_time)


def test_interpolate_to_time_bogus(ddi_xds_min, main_xds_min):
from xradio.measurement_set._utils._msv2.msv4_sub_xdss import interpolate_to_time

Expand Down
Loading