diff --git a/src/xradio/measurement_set/_utils/_msv2/create_antenna_xds.py b/src/xradio/measurement_set/_utils/_msv2/create_antenna_xds.py index 7571b5c7..974c577d 100644 --- a/src/xradio/measurement_set/_utils/_msv2/create_antenna_xds.py +++ b/src/xradio/measurement_set/_utils/_msv2/create_antenna_xds.py @@ -15,7 +15,9 @@ table_exists, ) from xradio._utils.schema import convert_generic_xds_to_xradio_schema -from xradio.measurement_set._utils._msv2.msv4_sub_xdss import interpolate_to_time +from xradio.measurement_set._utils._msv2.msv4_sub_xdss import ( + rename_and_interpolate_to_time, +) from xradio._utils.list_and_array import ( check_if_consistent, @@ -509,27 +511,8 @@ def create_phase_calibration_xds( phase_cal_xds.time_phase_cal.astype("float64").astype("float64") / 10**9 ) - phase_cal_xds = interpolate_to_time( - phase_cal_xds, - phase_cal_interp_time, - "antenna_xds", - time_name="time_phase_cal", + phase_cal_xds = rename_and_interpolate_to_time( + phase_cal_xds, "time_phase_cal", phase_cal_interp_time, "phase_cal_xds" ) - time_coord_attrs = { - "type": "time", - "units": ["s"], - "scale": "utc", - "format": "unix", - } - - # If we interpolate rename the time_phase_cal axis to time. - if phase_cal_interp_time is not None: - time_coord = {"time": ("time_phase_cal", phase_cal_interp_time.data)} - phase_cal_xds = phase_cal_xds.assign_coords(time_coord) - phase_cal_xds.coords["time"].attrs.update(time_coord_attrs) - phase_cal_xds = phase_cal_xds.swap_dims({"time_phase_cal": "time"}).drop_vars( - "time_phase_cal" - ) - return phase_cal_xds diff --git a/src/xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py b/src/xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py index 3408cd7b..98ddf2b7 100644 --- a/src/xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py +++ b/src/xradio/measurement_set/_utils/_msv2/create_field_and_source_xds.py @@ -6,7 +6,9 @@ import xarray as xr import toolviper.utils.logger as logger -from xradio.measurement_set._utils._msv2.msv4_sub_xdss import interpolate_to_time +from xradio.measurement_set._utils._msv2.msv4_sub_xdss import ( + rename_and_interpolate_to_time, +) from xradio.measurement_set._utils._msv2.subtables import subt_rename_ids from xradio.measurement_set._utils._msv2._tables.read import ( convert_casacore_time_to_mjd, @@ -363,20 +365,13 @@ def extract_ephemeris_info( } temp_xds["time_ephemeris"].attrs.update(time_coord_attrs) - # Convert to si units and interpolate if ephemeris_interpolate=True: + # Convert to si units temp_xds = convert_to_si_units(temp_xds) - temp_xds = interpolate_to_time( - temp_xds, interp_time, "field_and_source_xds", time_name="time_ephemeris" - ) - # If we interpolate rename the time_ephemeris axis to time. - if interp_time is not None: - time_coord = {"time": ("time_ephemeris", interp_time.data)} - temp_xds = temp_xds.assign_coords(time_coord) - temp_xds.coords["time"].attrs.update(time_coord_attrs) - temp_xds = temp_xds.swap_dims({"time_ephemeris": "time"}).drop_vars( - "time_ephemeris" - ) + # interpolate if ephemeris_interpolate/interp_time=True, and rename time_ephemeris=>time + temp_xds = rename_and_interpolate_to_time( + temp_xds, "time_ephemeris", interp_time, "field_and_source_xds" + ) xds = xr.merge([xds, temp_xds]) diff --git a/src/xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py b/src/xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py index 961fac03..1ca2c534 100644 --- a/src/xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py +++ b/src/xradio/measurement_set/_utils/_msv2/msv4_sub_xdss.py @@ -20,6 +20,74 @@ ) +standard_time_coord_attrs = { + "type": "time", + "units": ["s"], + "scale": "utc", + "format": "unix", +} + + +def rename_and_interpolate_to_time( + xds: xr.Dataset, + time_initial_name: str, + interp_time: Union[xr.DataArray, None], + message_prefix: str, +) -> xr.Dataset: + """ + This function interpolates the time dimension and renames it: + + - interpolates a time_* dimension to values given in interp_time (presumably the time + axis of the main xds) + - rename/replace that time_* dimension to "time", where time_* is a (sub)xds specific + time axis + (for example "time_pointing", "time_ephemeris", "time_syscal", "time_phase_cal"). + + If interp_time is None this will simply return the input xds without modificaitons. + Uses interpolate_to_time() for interpolation. + ... + + Parameters: + ---------- + xds : xr.Dataset + Xarray dataset to interpolate (presumably a pointing_xds or an xds of + ephemeris variables) + time_initial_name: str = None + Name of time to be renamed+interpolated. Expected an existing time_* coordinate in the + dataset + interp_time: + Time axis to interpolate the dataset to (usually main MSv4 time) + message_prefix: + A prefix for info/debug/etc. messages about the specific xds being interpolated/ + time-renamed + + Returns: + ------- + renamed_interpolated_xds : xr.Dataset + xarray dataset with time axis renamed to "time" (from time_name, for example + "time_ephemeris") and interpolated to interp_time. + """ + if interp_time is None: + return xds + + interpolated_xds = interpolate_to_time( + xds, + interp_time, + message_prefix, + time_name=time_initial_name, + ) + + # rename the time_* axis to time. + time_coord = {"time": (time_initial_name, interp_time.data)} + renamed_time_xds = interpolated_xds.assign_coords(time_coord) + renamed_time_xds.coords["time"].attrs.update(standard_time_coord_attrs) + renamed_time_xds = renamed_time_xds.swap_dims({time_initial_name: "time"}) + if time_initial_name != "time": + renamed_time_xds = renamed_time_xds.drop_vars(time_initial_name) + + return renamed_time_xds + + def interpolate_to_time( xds: xr.Dataset, interp_time: Union[xr.DataArray, None], @@ -56,7 +124,9 @@ def interpolate_to_time( method = "linear" else: method = "nearest" - xds = xds.interp({time_name: interp_time}, method=method, assume_sorted=True) + xds = xds.interp( + {time_name: interp_time.data}, method=method, assume_sorted=True + ) # scan_number sneaks in as a coordinate of the main time axis, drop it if "scan_number" in xds.coords: xds = xds.drop_vars("scan_number") @@ -309,7 +379,7 @@ def create_pointing_xds( elif size == 0: generic_pointing_xds = generic_pointing_xds.drop_dims("n_polynomial") - time_ant_dims = ["time", "antenna_name"] + time_ant_dims = ["time_pointing", "antenna_name"] time_ant_dir_dims = time_ant_dims + ["local_sky_dir_label"] to_new_data_variables = { "DIRECTION": ["POINTING_BEAM", time_ant_dir_dims], @@ -318,7 +388,7 @@ def create_pointing_xds( } to_new_coords = { - "TIME": ["time", ["time"]], + "TIME": ["time_pointing", ["time_pointing"]], "dim_2": ["local_sky_dir_label", ["local_sky_dir_label"]], } @@ -337,7 +407,9 @@ def create_pointing_xds( generic_pointing_xds, pointing_xds, to_new_data_variables, to_new_coords ) - pointing_xds = interpolate_to_time(pointing_xds, interp_time, "pointing_xds") + pointing_xds = rename_and_interpolate_to_time( + pointing_xds, "time_pointing", interp_time, "pointing_xds" + ) logger.debug(f"create_pointing_xds() execution time {time.time() - start:0.2f} s") @@ -522,25 +594,9 @@ def create_system_calibration_xds( } sys_cal_xds.coords["frequency_cal"].attrs.update(frequency_measure) - if sys_cal_interp_time is not None: - sys_cal_xds = interpolate_to_time( - sys_cal_xds, - sys_cal_interp_time, - "system_calibration_xds", - time_name="time_cal", - ) - - time_coord_attrs = { - "type": "time", - "units": ["s"], - "scale": "utc", - "format": "unix", - } - # If interpolating time, rename time_cal => time - time_coord = {"time": ("time_cal", sys_cal_interp_time.data)} - sys_cal_xds = sys_cal_xds.assign_coords(time_coord) - sys_cal_xds.coords["time"].attrs.update(time_coord_attrs) - sys_cal_xds = sys_cal_xds.swap_dims({"time_cal": "time"}).drop_vars("time_cal") + sys_cal_xds = rename_and_interpolate_to_time( + sys_cal_xds, "time_cal", sys_cal_interp_time, "system_calibration_xds" + ) # correct expected types for data_var in sys_cal_xds: diff --git a/tests/unit/measurement_set/_utils/_msv2/test_msv4_sub_xdss.py b/tests/unit/measurement_set/_utils/_msv2/test_msv4_sub_xdss.py index 5c1729a9..365a4ddc 100644 --- a/tests/unit/measurement_set/_utils/_msv2/test_msv4_sub_xdss.py +++ b/tests/unit/measurement_set/_utils/_msv2/test_msv4_sub_xdss.py @@ -2,6 +2,31 @@ import pytest +def test_rename_and_interpoalte_to_time_bogus(ddi_xds_min, main_xds_min): + from xradio.measurement_set._utils._msv2.msv4_sub_xdss import ( + rename_and_interpolate_to_time, + ) + + input_time = main_xds_min.time + with pytest.raises(KeyError, match="No variable named 'time_bogus'."): + out_xds = rename_and_interpolate_to_time( + ddi_xds_min, "time_bogus", input_time, message_prefix="test_call" + ) + + +def test_rename_and_interpoalte_to_time_main(main_xds_min): + from xradio.measurement_set._utils._msv2.msv4_sub_xdss import ( + rename_and_interpolate_to_time, + ) + + input_time = main_xds_min.time + out_xds = rename_and_interpolate_to_time( + main_xds_min, "time", input_time, message_prefix="test_call" + ) + + xr.testing.assert_equal(out_xds.time, input_time) + + def test_interpolate_to_time_bogus(ddi_xds_min, main_xds_min): from xradio.measurement_set._utils._msv2.msv4_sub_xdss import interpolate_to_time