From 56e60d89026346409b697ed3f20250c7fa56306d Mon Sep 17 00:00:00 2001 From: Tom Aldcroft Date: Tue, 5 Nov 2024 11:40:13 -0500 Subject: [PATCH] Improve command states with better transition infrastructure (#340) * Improved transitions * Add default_value for HRC transitions * Change Transition constraint to callable function And update some docs / comments * Add notebook for comprehensive ~5 year states regression test --- kadi/commands/states.py | 395 +++++++++++------- kadi/commands/tests/test_commands.py | 63 +++ kadi/commands/tests/test_states.py | 4 +- .../states-regression/states-regression.ipynb | 211 ++++++++++ 4 files changed, 530 insertions(+), 143 deletions(-) create mode 100644 validate/states-regression/states-regression.ipynb diff --git a/kadi/commands/states.py b/kadi/commands/states.py index 6d308990..1d92e75f 100644 --- a/kadi/commands/states.py +++ b/kadi/commands/states.py @@ -6,10 +6,11 @@ import collections import contextlib -import functools +import dataclasses import inspect import itertools import re +from typing import Callable import astropy.units as u import chandra_maneuver @@ -22,15 +23,23 @@ from kadi import commands -# Registry of Transition classes with state transition name as key. A state +# Registry of Transition classes that can updated a state key. A given state key # transition may be generated by several different transition classes, hence the -# dict value is a list +# dict value is a list. Note that some state keys are coupled. For example for OTG +# states ["letg", "hetg", "grating"] are always all updated together, so if you ask for +# "letg" the code computes all three. PCAD_STATE_KEYS are another example. +# +# >>> kadi.commands.states.TRANSITIONS["hetg"] +# [kadi.commands.states.HETG_INSR_Transition, # directly impacts hetg +# kadi.commands.states.HETG_RETR_Transition, +# kadi.commands.states.LETG_INSR_Transition, # coupled state +# kadi.commands.states.LETG_RETR_Transition] TRANSITIONS = collections.defaultdict(list) # Set of all Transition classes TRANSITION_CLASSES = set() -# Ordered list of all state keys +# Ordered list of all available state keys STATE_KEYS = [] # Quaternion componenent names @@ -78,6 +87,7 @@ "vid_board", ) +# {TLMSID: SI_MODE} mapping for SI modes that have special handling in state transitions NIL_SIMODES = { "WT00DAA014": "H2C_0002", "WT00D96014": "H2C_0001", @@ -102,6 +112,44 @@ def disable_grating_move_duration(): MechMove.apply_move_duration = apply_move_duration +class Transition(dict): + """ + Dict of transitions at a given date. + + This is a dict of {state_key: state_val} where state_val is either a value + or a TransitionCallback object. + + When creating a TransitionCallback object, you can optionally supply two additional + positional arguments: the date (str) and the constraint (function). + + The constraint function must take a single argument, the state dict, and return + True if the transition is allowed and False otherwise. For example:: + + def constraint(state): + return state["pcad_mode"] == "NMAN" + + transition = Transition("2010:001:00:00:00", constraint=constraint) + """ + + date: str | None = None + constraint: Callable | None = None + + def __init__(self, *args, **kwargs): + if args and isinstance(args[0], str): + self.date = args[0] + args = args[1:] + if args and callable(args[0]): + self.constraint = args[0] + args = args[1:] + super().__init__(*args, **kwargs) + + def __repr__(self): + return f"{self.date}: {super().__repr__()}" + + def __eq__(self, other): + return self.date == other.date and super().__eq__(other) + + class NoTransitionsError(ValueError): """No transitions found within commands""" @@ -136,22 +184,29 @@ def copy(self): return StateDict(self) +@dataclasses.dataclass +class TransitionCallback: + """ + Callback function for a transition. + + This is used to store the callback function and any additional keyword arguments. + """ + + callback: callable + kwargs: dict = dataclasses.field(default_factory=dict) + + ################################################################### # Transition base classes ################################################################### -class TransitionMeta(type): +class BaseTransition: """ - Metaclass that adds the class to the TRANSITIONS dict. - - This is keyed by state_keys from the TRANSITIONS_CLASSES set, and makes the - complete list of STATE_KEYS. + Base transition class from which all actual transition classes are derived. """ - def __new__(mcls, name, bases, members): # noqa: N804 - cls = super().__new__(mcls, name, bases, members) - + def __init_subclass__(cls) -> None: # Register transition classes that have a `state_keys` (base classes do # not have this attribute set). if hasattr(cls, "state_keys"): @@ -164,14 +219,6 @@ def __new__(mcls, name, bases, members): # noqa: N804 cls._auto_update_docstring() - return cls - - -class BaseTransition(metaclass=TransitionMeta): - """ - Base transition class from which all actual transition classes are derived. - """ - @classmethod def get_state_changing_commands(cls, cmds): """ @@ -290,14 +337,14 @@ class FixedTransition(BaseTransition): """ @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -319,9 +366,7 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): attrs = [attrs] for cmd in state_cmds: - date = cmd["date"] - for val, attr in zip(vals, attrs): - transitions_dict[date][attr] = val + transitions_list.append(Transition(cmd["date"], zip(attrs, vals))) class ParamTransition(BaseTransition): @@ -337,14 +382,14 @@ class ParamTransition(BaseTransition): """ @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -375,8 +420,10 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): else: params = dict(rev_pars_dict[cmd["idx"]]) - for name, param_key in zip(names, param_keys): - transitions_dict[date][name] = params[param_key] + transition = Transition( + date, zip(names, [params[key] for key in param_keys]) + ) + transitions_list.append(Transition(date, transition)) ################################################################### @@ -502,14 +549,14 @@ class MechMove(FixedTransition): apply_move_duration = True @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -534,17 +581,25 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): for cmd in state_cmds: date_start = CxoTime(cmd["date"]) date_stop = date_start + move_duration + transition_start = Transition(date_start.date) + transition_stop = Transition(date_stop.date) + for val, attr in zip(vals, attrs): if attr == "grating": - transitions_dict[date_start.date][attr] = val + transition_start[attr] = val else: # noqa: PLR5501 # 'letg' or 'hetg' insert/retract status, include the move # interval here if cls.apply_move_duration: - transitions_dict[date_start.date][attr] = val + "_MOVE" - transitions_dict[date_stop.date][attr] = val + transition_start[attr] = val + "_MOVE" + transition_stop[attr] = val else: - transitions_dict[date_start.date][attr] = val + transition_start[attr] = val + + if transition_start: + transitions_list.append(transition_start) + if transition_stop: + transitions_list.append(transition_stop) class HETG_INSR_Transition(MechMove): @@ -626,6 +681,7 @@ class Hrc15vOn_SCS134_Transition(FixedTransition): state_keys = ["hrc_15v"] transition_key = "hrc_15v" transition_val = "ON" + default_value = "ON" class Hrc15vOff_Transition(FixedTransition): @@ -644,6 +700,7 @@ class Hrc24vOn_Transition(FixedTransition): state_keys = ["hrc_24v"] transition_key = "hrc_24v" transition_val = "ON" + default_value = "ON" class Hrc24vOff_Transition(FixedTransition): @@ -673,6 +730,7 @@ class HrcIOff_Transition(FixedTransition): state_keys = ["hrc_i"] transition_key = "hrc_i" transition_val = "OFF" + default_value = "OFF" class HrcSOn_Transition(FixedTransition): @@ -693,6 +751,7 @@ class HrcSOff_Transition(FixedTransition): state_keys = ["hrc_s"] transition_key = "hrc_s" transition_val = "OFF" + default_value = "OFF" ################################################################### @@ -796,14 +855,14 @@ class SPMEclipseEnableTransition(BaseTransition): default_value = False @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -819,15 +878,16 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) for cmd in state_cmds: - transitions_dict[cmd["date"]]["sun_pos_mon"] = cls.callback + transitions_list.append( + Transition(cmd["date"], sun_pos_mon=TransitionCallback(cls.callback)) + ) @classmethod def callback(cls, date, transitions, state, idx): if state["eclipse_enable_spm"]: - transition = { - "date": secs2date(date2secs(date) + 11 * 60), - "sun_pos_mon": "ENAB", - } + transition = Transition( + secs2date(date2secs(date) + 11 * 60), {"sun_pos_mon": "ENAB"} + ) add_transition(transitions, idx, transition) @@ -850,14 +910,14 @@ class EclipseEnableSPM(BaseTransition): BATTERY_CONNECT_MAX_DT = 135 # seconds @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -873,7 +933,11 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) for cmd in state_cmds: - transitions_dict[cmd["date"]]["eclipse_enable_spm"] = cls.callback + transitions_list.append( + Transition( + cmd["date"], eclipse_enable_spm=TransitionCallback(cls.callback) + ) + ) @classmethod def callback(cls, date, transitions, state, idx): @@ -891,7 +955,7 @@ def callback(cls, date, transitions, state, idx): enable_spm = ( eclipse_entry_time - battery_connect_time < cls.BATTERY_CONNECT_MAX_DT ) - transition = {"date": date, "eclipse_enable_spm": enable_spm} + transition = Transition(date, eclipse_enable_spm=enable_spm) add_transition(transitions, idx, transition) @@ -904,14 +968,14 @@ class BatteryConnect(BaseTransition): default_value = "1999:001:00:00:00.000" @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -926,7 +990,9 @@ def set_transitions(cls, transitions, cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) for cmd in state_cmds: - transitions[cmd["date"]]["battery_connect"] = cmd["date"] + transitions_list.append( + Transition(cmd["date"], battery_connect=cmd["date"]) + ) class SCS84EnableTransition(FixedTransition): @@ -1033,14 +1099,14 @@ class EphemerisUpdateTransition(BaseTransition): state_keys = ["ephem_update"] @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -1056,7 +1122,7 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): for cmd in state_cmds: date = cmd["date"] - transitions_dict[date]["ephem_update"] = date[:8] + transitions_list.append(Transition(date, ephem_update=date[:8])) ################################################################### @@ -1074,14 +1140,14 @@ class SunVectorTransition(BaseTransition): state_keys = PCAD_STATE_KEYS @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -1105,7 +1171,12 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): # Now with the dates, finally make all the transition dicts which will # call `update_pitch_state` during state processing. for date in dates: - transitions_dict[date]["update_sun_vector"] = cls.update_sun_vector_state + transitions_list.append( + Transition( + date, + update_sun_vector=TransitionCallback(cls.update_sun_vector_state), + ) + ) @classmethod def update_sun_vector_state(cls, date, transitions, state, idx): @@ -1171,14 +1242,14 @@ class DitherParamsTransition(BaseTransition): ] @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -1201,7 +1272,7 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): "dither_period_pitch": 2 * np.pi / cmd["ratep"], "dither_period_yaw": 2 * np.pi / cmd["ratey"], } - transitions_dict[cmd["date"]].update(dither) + transitions_list.append(Transition(cmd["date"], dither)) class NMM_Transition(FixedTransition): @@ -1247,14 +1318,14 @@ class TargQuatTransition(BaseTransition): state_keys = PCAD_STATE_KEYS @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -1269,9 +1340,10 @@ def set_transitions(cls, transitions, cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) for cmd in state_cmds: - transition = transitions[cmd["date"]] - for qc in ("q1", "q2", "q3", "q4"): - transition["targ_" + qc] = cmd[qc] + transition = Transition( + cmd["date"], {f"targ_{qc}": cmd[qc] for qc in QUAT_COMPS} + ) + transitions_list.append(transition) class ManeuverTransition(BaseTransition): @@ -1286,13 +1358,18 @@ class ManeuverTransition(BaseTransition): command_attributes = {"tlmsid": "AOMANUVR"} state_keys = PCAD_STATE_KEYS + pcad_mode = "NMAN" @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) for cmd in state_cmds: - transitions[cmd["date"]]["maneuver_transition"] = cls.callback + transitions_list.append( + Transition( + cmd["date"], maneuver_transition=TransitionCallback(cls.callback) + ) + ) @classmethod def callback(cls, date, transitions, state, idx): @@ -1309,10 +1386,17 @@ def callback(cls, date, transitions, state, idx): if end_manvr_date is None: return - # If auto-transition to NPM after manvr is enabled (this is - # normally the case) then back to NPNT at end of maneuver + # If auto-transition to NPM after manvr is enabled (this is normally the case) + # then back to NPNT at end of maneuver if state["auto_npnt"] == "ENAB": - transition = {"date": end_manvr_date, "pcad_mode": "NPNT"} + # Auto-transition to NPNT only works from NMAN. The lambda constraint will + # prevent transition to NPNT if NSM or Safe mode transition occurs during + # maneuver. + transition = Transition( + end_manvr_date, + lambda state_: state_["pcad_mode"] == "NMAN", + pcad_mode="NPNT", + ) add_transition(transitions, idx, transition) @classmethod @@ -1366,7 +1450,12 @@ def add_manvr_transitions(cls, date, transitions, state, idx): for att, date_att, pitch, off_nom_roll in zip( atts, dates, pitches, off_nom_rolls ): - transition = {"date": date_att} + # Check pcad_mode at time of each maneuver transition is same as at start. + # This cuts off maneuver for a mode change like NSM or Safe mode. Note that + # `state_` is the future state and `state` is the current state. + transition = Transition( + date_att, lambda state_: state_["pcad_mode"] == state["pcad_mode"] + ) att_q = np.array([att[x] for x in QUAT_COMPS]) for qc, q_i in zip(QUAT_COMPS, att_q): transition[qc] = q_i @@ -1396,7 +1485,7 @@ class NormalSunTransition(ManeuverTransition): state_keys = PCAD_STATE_KEYS @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) # Set the maneuver transition for each state-changing command. The AONSMSAF or @@ -1404,17 +1493,19 @@ def set_transitions(cls, transitions, cmds, start, stop): # value of the params for the NSM or Safe Mode command events. If not provided # then use 90 degrees. - # The functools.partial is a little odd in that it passes the class directly. - # Ideally we would use functools.partial of a @classmethod but that doesn't seem - # to work. for cmd in state_cmds: pitch = cmd["params"].get("pitch", 90) - transitions[cmd["date"]]["maneuver_transition"] = functools.partial( - cls.callback, cls, pitch + transitions_list.append( + Transition( + cmd["date"], + maneuver_transition=TransitionCallback( + cls.callback, {"pitch": pitch} + ), + ) ) - @staticmethod - def callback(cls, pitch, date, transitions, state, idx): # noqa: PLW0211 + @classmethod + def callback(cls, date, transitions, state, idx, *, pitch): """ This is a transition function callback. @@ -1466,16 +1557,21 @@ class ManeuverSunPitchTransition(ManeuverTransition): state_keys = PCAD_STATE_KEYS @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) for cmd in state_cmds: - transitions[cmd["date"]]["maneuver_transition"] = functools.partial( - cls.callback, cmd["params"]["pitch"] + transitions_list.append( + Transition( + cmd["date"], + maneuver_transition=TransitionCallback( + cls.callback, {"pitch": cmd["params"]["pitch"]} + ), + ) ) - @staticmethod - def callback(pitch, date, transitions, state, idx): + @classmethod + def callback(cls, date, transitions, state, idx, *, pitch): # Setup for maneuver to sun-pointed attitude from current att curr_att = [state[qc] for qc in QUAT_COMPS] @@ -1511,16 +1607,21 @@ class ManeuverSunRaslTransition(ManeuverTransition): state_keys = PCAD_STATE_KEYS @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): state_cmds = cls.get_state_changing_commands(cmds) for cmd in state_cmds: - transitions[cmd["date"]]["maneuver_transition"] = functools.partial( - cls.callback, cmd["params"]["rasl"] + transitions_list.append( + Transition( + cmd["date"], + maneuver_transition=TransitionCallback( + cls.callback, {"rasl": cmd["params"]["rasl"]} + ), + ) ) - @staticmethod - def callback(rasl, date, transitions, state, idx): + @classmethod + def callback(cls, date, transitions, state, idx, *, rasl): # Setup for maneuver to sun-pointed attitude from current att curr_att = [state[qc] for qc in QUAT_COMPS] @@ -1625,14 +1726,14 @@ class ACISTransition(BaseTransition): ] @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -1648,10 +1749,11 @@ def set_transitions(cls, transitions, cmds, start, stop): for cmd in state_cmds: tlmsid = cmd["tlmsid"] date = cmd["date"] + transition = Transition(date) if tlmsid.startswith("WSPOW"): pwr = decode_power(tlmsid) - transitions[date].update( + transition.update( fep_count=pwr["fep_count"], ccd_count=pwr["ccd_count"], vid_board=pwr["vid_board"], @@ -1660,24 +1762,26 @@ def set_transitions(cls, transitions, cmds, start, stop): ) elif tlmsid in ("XCZ0000005", "XTZ0000005"): - transitions[date].update(clocking=1, power_cmd=tlmsid) + transition.update(clocking=1, power_cmd=tlmsid) elif tlmsid == "WSVIDALLDN": - transitions[date].update(vid_board=0, ccd_count=0, power_cmd=tlmsid) + transition.update(vid_board=0, ccd_count=0, power_cmd=tlmsid) elif tlmsid == "AA00000000": - transitions[date].update(clocking=0, power_cmd=tlmsid) + transition.update(clocking=0, power_cmd=tlmsid) elif tlmsid == "WSFEPALLUP": - transitions[date].update(fep_count=6, power_cmd=tlmsid) + transition.update(fep_count=6, power_cmd=tlmsid) elif tlmsid[:2] in ("WT", "WC"): - transitions[cmd["date"]]["si_mode"] = functools.partial( - ACISTransition.simode_callback, tlmsid + transition["si_mode"] = TransitionCallback( + cls.simode_callback, {"tlmsid": tlmsid} ) - @staticmethod - def simode_callback(tlmsid, date, transitions, state, idx): + transitions_list.append(transition) + + @classmethod + def simode_callback(cls, date, transitions, state, idx, *, tlmsid): # Other SIMODEs than the ones caught here exist in the # ACIS tables, and may be used in execptional circumstances # such as anomalies or special tests. The ones that are @@ -1743,14 +1847,14 @@ class ACISFP_SetPointTransition(BaseTransition): default_value = -121.0 @classmethod - def set_transitions(cls, transitions, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -1771,7 +1875,9 @@ def set_transitions(cls, transitions, cmds, start, stop): match = re.search(r"(\d+)$", tlmsid) if not match: raise ValueError(f"unable to parse command {tlmsid}") - transitions[date].update(acisfp_setpoint=-float(match.group(1))) + transitions_list.append( + Transition(date, acisfp_setpoint=-float(match.group(1))) + ) class FidsTransition(BaseTransition): @@ -1832,14 +1938,14 @@ class FidsTransition(BaseTransition): state_keys = ["fids"] @classmethod - def set_transitions(cls, transitions_dict, cmds, start, stop): + def set_transitions(cls, transitions_list: list[Transition], cmds, start, stop): """ Set transitions for a Table of commands ``cmds``. Parameters ---------- - transitions_dict - global dict of transitions (updated in-place) + transitions_list + list of transitions (updated in-place) cmds commands (CmdList) start @@ -1856,12 +1962,15 @@ def set_transitions(cls, transitions_dict, cmds, start, stop): for cmd in state_cmds: msid = cmd["msid"] if msid == "AFLCRSET" or re.match(r"AFLC \d\d D \d $", msid, re.VERBOSE): - transitions_dict[cmd["date"]]["fids"] = functools.partial( - FidsTransition.fids_callback, msid + transitions_list.append( + Transition( + cmd["date"], + fids=TransitionCallback(cls.fids_callback, {"msid": msid}), + ) ) - @staticmethod - def fids_callback(msid, date, transitions, state, idx): + @classmethod + def fids_callback(cls, date, transitions, state, idx, *, msid): """Update ``state`` for the given fid light command ``msid``.""" if msid == "AFLCRSET": state["fids"] = set() @@ -1901,7 +2010,9 @@ def get_transition_classes(state_keys=None): return trans_classes -def get_transitions_list(cmds, state_keys, start, stop, continuity=None): +def get_transitions_list( + cmds, state_keys, start, stop, continuity=None +) -> list[Transition]: """ Get transitions for given set of ``cmds`` and ``state_keys``. @@ -1922,11 +2033,8 @@ def get_transitions_list(cmds, state_keys, start, stop, continuity=None): ------- list of dict (transitions), set of transition classes """ - # To start, collect transitions in a dict keyed by date. This auto-initializes - # a dict whenever a new date is used, allowing (e.g.) a single step of:: - # - # transitions_dict['2017:002:01:02:03.456']['obsid'] = 23456. - transitions_dict = collections.defaultdict(dict) + # To start, collect a list transitions of dated transitions. + transitions_list: list[Transition] = [] # If an initial list of transitions is provided in the continuity dict # then apply those. These would be transitions that occur after the the @@ -1934,29 +2042,25 @@ def get_transitions_list(cmds, state_keys, start, stop, continuity=None): # where we need the remaining attitude and pcad_mode transitions. if continuity is not None and "__transitions__" in continuity: for transition in continuity["__transitions__"]: - transitions_dict[transition["date"]].update(transition) + transitions_list.append(transition) # Iterate through Transition classes which depend on or affect ``state_keys`` - # and ask each one to update ``transitions_dict`` in-place to include + # and ask each one to update ``transitions_list`` in-place to include # transitions from that class. for transition_class in get_transition_classes(state_keys): - transition_class.set_transitions(transitions_dict, cmds, start, stop) + transition_class.set_transitions(transitions_list, cmds, start, stop) # Convert the dict of transitions (keyed by date) into an ordered list of # transitions sorted by date. A *list* of transitions is needed to allow a # transition to dynamically generate additional (later) transitions, e.g. in # the case of a maneuver. - transitions_list = [] - for date in sorted(transitions_dict): - transition = transitions_dict[date] - transition["date"] = date - transitions_list.append(transition) + transitions_list = sorted(transitions_list, key=lambda x: x.date) # In the rest of this module ``transitions`` is always this *list* of transitions. return transitions_list -def add_transition(transitions, idx, transition): +def add_transition(transitions: list[Transition], idx: int, transition: Transition): """ Add ``transition`` to the ``transitions`` list. @@ -1981,8 +2085,8 @@ def add_transition(transitions, idx, transition): """ # Prevent adding command before current command since the command # interpreter is a one-pass process. - date = transition["date"] - if date < transitions[idx]["date"]: + date = transition.date + if date < transitions[idx].date: raise ValueError("cannot insert transition prior to current command") # Insert transition at first place where new transition date is strictly @@ -1990,7 +2094,7 @@ def add_transition(transitions, idx, transition): # could be improved, though in practice transitions are often inserted # close to the original. for ii in range(idx + 1, len(transitions)): - if date < transitions[ii]["date"]: + if date < transitions[ii].date: transitions.insert(ii, transition) break else: @@ -2132,7 +2236,13 @@ def get_states( continuity_transitions = [] for idx, transition in enumerate(transitions): - date = transition["date"] + # If there are state constraints then check that the state satisfies the + # constraints. If not then skip this transition. Canonical example is + # ManeuverTransition which requires that the state be in NMAN. + if transition.constraint is not None and not transition.constraint(state): + continue + + date = transition.date # Some transition classes (e.g. Maneuver) might put in transitions that # extend past the stop time. Add to a list for potential use in continuity. @@ -2151,13 +2261,14 @@ def get_states( # Process the transition. for key, value in transition.items(): - if callable(value): + if isinstance(value, TransitionCallback): # Special case of a functional transition that calls a function # instead of directly updating the state. The function might itself # update the state or it might generate downstream transitions. - value(date, transitions, state, idx) - elif key != "date": + value.callback(date, transitions, state, idx, **value.kwargs) + else: # Normal case of just updating current state + state[key] = value # Make into an astropy Table and set up datestart/stop columns diff --git a/kadi/commands/tests/test_commands.py b/kadi/commands/tests/test_commands.py index 6a424de7..21000290 100644 --- a/kadi/commands/tests/test_commands.py +++ b/kadi/commands/tests/test_commands.py @@ -332,6 +332,69 @@ def stop_date_fixture(monkeypatch): # 2021:297 0300z just after recovery maneuver following 2021:296 NSM stop_date_2021_10_24 = stop_date_fixture_factory("2021-10-24T03:00:00") stop_date_2020_12_03 = stop_date_fixture_factory("2020-12-03") +stop_date_2023_203 = stop_date_fixture_factory("2023:203") + + +@pytest.mark.skipif(not HAS_INTERNET, reason="No internet connection") +def test_nsm_safe_mode_pitch_offsets_state_constraints(stop_date_2023_203): + """Test NSM, Safe mode transitions with pitch offsets along with state constraints. + + State constraints testing means that an NMAN maneuver with auto-NPNT transition + is properly interrupted by a NSM or Safe Mode. In this case the NMAN is interrupted + and NPNT never happens. + """ + scenario = "test-nsm-safe-mode" + cmd_events_path = kadi.paths.CMD_EVENTS_PATH(scenario) + cmd_events_path.parent.mkdir(parents=True, exist_ok=True) + # Maneuver attitude: ska_sun.get_att_for_sun_pitch_yaw(pitch=170, yaw=0, time="2023:199") + text = """ + State,Date,Event,Params,Author,Reviewer,Comment + Definitive,2023:199:00:00:00.000,Safe mode,120,,, + Definitive,2023:199:02:00:00.000,NSM,,,, + Definitive,2023:199:03:00:00.000,Maneuver,-0.84752928 0.52176697 0.08279618 0.05097206,,, + Definitive,2023:199:03:17:00.000,NSM,50,,, + Definitive,2023:199:03:30:00.000,Safe mode,,,, + Definitive,2023:199:04:30:00.000,NSM,100,,, + """ + cmd_events_path.write_text(text) + states = kcs.get_states( + "2023:198:23:00:00", + "2023:202:00:00:00", + state_keys=["pitch", "pcad_mode"], + scenario=scenario, + ) + states["pitch"].info.format = ".1f" + out = states["datestart", "pitch", "pcad_mode"].pformat_all() + exp = [ + " datestart pitch pcad_mode", + "--------------------- ----- ---------", + "2023:198:23:00:00.000 144.6 NPNT", + "2023:199:00:00:00.000 142.1 STBY", # Safe mode to 120 + "2023:199:00:05:16.679 132.3 STBY", + "2023:199:00:10:33.359 122.6 STBY", + "2023:199:00:15:50.038 120.0 STBY", + "2023:199:02:00:00.000 116.9 NSUN", # NSM to default 90 + "2023:199:02:05:47.530 105.0 NSUN", + "2023:199:02:11:35.060 93.2 NSUN", + "2023:199:02:17:22.590 90.0 NSUN", + "2023:199:03:00:00.000 90.0 NMAN", # Maneuver to 170 + "2023:199:03:00:10.250 90.9 NMAN", + "2023:199:03:05:20.245 95.4 NMAN", + "2023:199:03:10:30.240 105.2 NMAN", + "2023:199:03:15:40.235 118.7 NMAN", + "2023:199:03:17:00.000 109.2 NSUN", # NMAN interrupted by NSM to 50 + "2023:199:03:22:05.282 99.2 NSUN", + "2023:199:03:27:10.564 80.6 NSUN", + "2023:199:03:30:00.000 91.3 STBY", # NSM to 50 interrupted by Safe mode + "2023:199:03:34:18.814 90.7 STBY", # to default 90 + "2023:199:03:38:37.628 90.2 STBY", + "2023:199:03:42:56.441 90.0 STBY", + "2023:199:04:30:00.000 92.5 NSUN", # NSM to 100 + "2023:199:04:35:13.882 97.5 NSUN", + "2023:199:04:40:27.763 100.0 NSUN", + ] + + assert out == exp @pytest.mark.skipif(not HAS_INTERNET, reason="No internet connection") diff --git a/kadi/commands/tests/test_states.py b/kadi/commands/tests/test_states.py index 593d6aa3..761059a3 100644 --- a/kadi/commands/tests/test_states.py +++ b/kadi/commands/tests/test_states.py @@ -1499,7 +1499,9 @@ def test_continuity_with_transitions_SPM(): # noqa: N802 cont = states.get_continuity(start, state_keys=["sun_pos_mon"]) assert cont == { "__dates__": {"sun_pos_mon": "2017:087:07:44:55.838"}, - "__transitions__": [{"date": "2017:087:08:21:35.838", "sun_pos_mon": "ENAB"}], + "__transitions__": [ + states.Transition("2017:087:08:21:35.838", {"sun_pos_mon": "ENAB"}) + ], "sun_pos_mon": "DISA", } diff --git a/validate/states-regression/states-regression.ipynb b/validate/states-regression/states-regression.ipynb new file mode 100644 index 00000000..e36f8c31 --- /dev/null +++ b/validate/states-regression/states-regression.ipynb @@ -0,0 +1,211 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Commanded States regression testing 2020-2024.\n", + "\n", + "This notebook generates a regression states file in a branch (e.g. master or a test\n", + "branch) for each of the available state keys.\n", + "\n", + "The intended usage is to run this notebook with `master` checked out, then with the \n", + "`test-branch` checked out, and then compare.\n", + "\n", + "Outputs are written to `validation/states-regression/`. \n", + "\n", + "You can compare all outputs using:\n", + "```\n", + "diff -r validation/states-regression/{master,test-branch}\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.home() / \"git\" / \"kadi\"))\n", + "\n", + "# Prior to ska3 2024.11.\n", + "sys.path.insert(0, str(Path.home() / \"git\" / \"ska_sun\"))\n", + "sys.path.insert(0, str(Path.home() / \"git\" / \"ska_helpers\"))\n", + "\n", + "import subprocess\n", + "\n", + "import numpy as np\n", + "\n", + "import kadi.commands.states as kcs" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "git_branch = subprocess.getoutput([\"git branch --show-current\"])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "START = \"2020:001:00:00:00\"\n", + "STOP = \"2024:300:00:00:00\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# REMOVE this after PR #340 is merged. This PR fixes an issue with these state keys\n", + "# prior to the new HRC ops con implemented in 2023.\n", + "continuity = {\"hrc_24v\": \"ON\", \"hrc_15v\": \"ON\"}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "states = kcs.get_states(START, STOP, state_keys=kcs.STATE_KEYS, continuity=continuity)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "outdir = Path(git_branch)\n", + "outdir.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Writing format\n", + "Writing subformat\n", + "Writing letg\n", + "Writing hetg\n", + "Writing grating\n", + "Writing simpos\n", + "Writing simfa_pos\n", + "Writing hrc_15v\n", + "Writing hrc_24v\n", + "Writing hrc_i\n", + "Writing hrc_s\n", + "Writing obsid\n", + "Writing eclipse_timer\n", + "Writing eclipse\n", + "Writing sun_pos_mon\n", + "Writing battery_connect\n", + "Writing eclipse_enable_spm\n", + "Writing scs84\n", + "Writing scs98\n", + "Writing radmon\n", + "Writing orbit_point\n", + "Writing aoephem1\n", + "Writing aoephem2\n", + "Writing aoratio\n", + "Writing aoargper\n", + "Writing aoeccent\n", + "Writing ao1minus\n", + "Writing ao1plus\n", + "Writing aomotion\n", + "Writing aoiterat\n", + "Writing aoorbang\n", + "Writing aoperige\n", + "Writing aoascend\n", + "Writing aosini\n", + "Writing aoslr\n", + "Writing aosqrtmu\n", + "Writing ephem_update\n", + "Writing q1\n", + "Writing q2\n", + "Writing q3\n", + "Writing q4\n", + "Writing targ_q1\n", + "Writing targ_q2\n", + "Writing targ_q3\n", + "Writing targ_q4\n", + "Writing ra\n", + "Writing dec\n", + "Writing roll\n", + "Writing auto_npnt\n", + "Writing pcad_mode\n", + "Writing pitch\n", + "Writing off_nom_roll\n", + "Writing dither\n", + "Writing dither_phase_pitch\n", + "Writing dither_phase_yaw\n", + "Writing dither_ampl_pitch\n", + "Writing dither_ampl_yaw\n", + "Writing dither_period_pitch\n", + "Writing dither_period_yaw\n", + "Writing clocking\n", + "Writing power_cmd\n", + "Writing vid_board\n", + "Writing fep_count\n", + "Writing si_mode\n", + "Writing ccd_count\n", + "Writing acisfp_setpoint\n", + "Writing fids\n" + ] + } + ], + "source": [ + "for state_key in kcs.STATE_KEYS:\n", + " print(f\"Writing {state_key}\")\n", + " states_for_key = kcs.reduce_states(\n", + " states, state_keys=[state_key], merge_identical=True\n", + " )\n", + " cols = [\"datestart\", state_key]\n", + " if states_for_key[state_key].dtype.kind == \"O\":\n", + " states_for_key[state_key] = [str(value) for value in states_for_key[state_key]]\n", + " if states_for_key[state_key].dtype.kind == \"f\":\n", + " states_for_key[state_key] = np.round(states_for_key[state_key], 6)\n", + " states_for_key[cols].write(\n", + " outdir / f\"{state_key}.dat\", format=\"ascii\", overwrite=True\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}