Skip to content

Commit

Permalink
Merge pull request #304 from sot/improve-off-nom-roll-state-accuracy
Browse files Browse the repository at this point in the history
Update code and tests for use with ska_sun accurate position
  • Loading branch information
taldcroft authored Dec 5, 2023
2 parents 682d658 + 0a0ca2e commit e99d43e
Show file tree
Hide file tree
Showing 9 changed files with 6,295 additions and 50 deletions.
18 changes: 11 additions & 7 deletions kadi/commands/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from astropy.table import Column, Table
from chandra_time import DateTime, date2secs, secs2date
from cxotime import CxoTime
from Quaternion import Quat, quat_to_equatorial
from Quaternion import quat_to_equatorial

from . import commands
from kadi import commands

# Registry of Transition classes with state transition name as key. A state
# transition may be generated by several different transition classes, hence the
Expand Down Expand Up @@ -1085,9 +1085,9 @@ def set_transitions(cls, transitions_dict, cmds, start, stop):
@classmethod
def update_sun_vector_state(cls, date, transitions, state, idx):
"""
Transition callback method for ``pitch`` / ``off_nominal_roll`` states.
Transition callback method for ``pitch`` / ``off_nom_roll`` states.
This will potentially update the ``pitch`` and ``off_nominal`` states if
This will potentially update the ``pitch`` and ``off_nom_roll`` states if
pcad_mode is NPNT.
Parameters
Expand All @@ -1102,9 +1102,13 @@ def update_sun_vector_state(cls, date, transitions, state, idx):
current index into transitions
"""
if state["pcad_mode"] == "NPNT":
q_att = Quat([state[qc] for qc in QUAT_COMPS])
state["pitch"] = ska_sun.pitch(q_att.ra, q_att.dec, date)
state["off_nom_roll"] = ska_sun.off_nominal_roll(q_att, date)
ra, dec, roll = state["ra"], state["dec"], state["roll"]
time = date2secs(date)
sun_ra, sun_dec = ska_sun.position(time)
state["pitch"] = ska_sun.pitch(ra, dec, sun_ra=sun_ra, sun_dec=sun_dec)
state["off_nom_roll"] = ska_sun.off_nominal_roll(
[ra, dec, roll], time, sun_ra=sun_ra, sun_dec=sun_dec
)


class DitherEnableTransition(FixedTransition):
Expand Down
7 changes: 7 additions & 0 deletions kadi/commands/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import pytest
import ska_sun


@pytest.fixture()
def fast_sun_position_method(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(ska_sun.conf, "sun_position_method_default", "fast")
6 changes: 3 additions & 3 deletions kadi/commands/tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def test_get_cmds_from_backstop_and_add_cmds(version_env):


@pytest.mark.skipif("not HAS_MPDIR")
def test_commands_create_archive_regress(tmpdir, version_env):
def test_commands_create_archive_regress(tmpdir, version_env, fast_sun_position_method):
"""Create cmds archive from scratch and test that it matches flight
This tests over an eventful month that includes IU reset/NSM, SCS-107
Expand Down Expand Up @@ -1134,7 +1134,7 @@ def test_get_cmds_from_event_all(idx):


@pytest.mark.skipif(not HAS_INTERNET, reason="No internet connection")
def test_scenario_with_rts(monkeypatch):
def test_scenario_with_rts(monkeypatch, fast_sun_position_method):
# Test a custom scenario with RTS. This is basically the same as the
# example in the documentation.
from kadi import paths
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def test_scenario_with_rts(monkeypatch):
2021:297:01:41:01.256 | COMMAND_SW | AONM2NPE | CMD_EVT | event=Maneuver, event_date=2021:297:01:41:01, msid=AONM2NPE,
2021:297:01:41:05.356 | MP_TARGQUAT | AOUPTARQ | CMD_EVT | event=Maneuver, event_date=2021:297:01:41:01, q1=7.05469070e
2021:297:01:41:11.250 | COMMAND_SW | AOMANUVR | CMD_EVT | event=Maneuver, event_date=2021:297:01:41:01, msid=AOMANUVR,
2021:297:02:05:11.042 | LOAD_EVENT | OBS | CMD_EVT | manvr_start=2021:297:01:41:11.250, prev_att=(0.2854059718219
2021:297:02:05:11.042 | LOAD_EVENT | OBS | CMD_EVT | manvr_start=2021:297:01:41:11.250, prev_att=(0.2854059718181
2021:297:02:12:42.886 | ORBPOINT | None | OCT1821A | event_type=EQF003M, scs=0
2021:297:03:40:42.886 | ORBPOINT | None | OCT1821A | event_type=EQF005M, scs=0
2021:297:03:40:42.886 | ORBPOINT | None | OCT1821A | event_type=EQF015M, scs=0
Expand Down
12 changes: 6 additions & 6 deletions kadi/commands/tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def test_acis_raw_mode():
assert "TN_000B6" in kstates["si_mode"]


def test_states_2017():
def test_states_2017(fast_sun_position_method):
"""
Test for 200 days in 2017. Includes 2017:066, 068, 090 anomalies and
2017:250-254 SCS107 + 251 CTI.
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_states_2017():
assert np.all(np.abs(tk - tc) < 0.0015)


def test_pitch_2017():
def test_pitch_2017(fast_sun_position_method):
"""
Test pitch for 100 days in 2017. Includes 2017:066, 068, 090 anomalies. This is done
by interpolating states (at 200 second intervals) because the pitch generation differs
Expand Down Expand Up @@ -417,7 +417,7 @@ def test_dither():
)


def test_get_continuity_regress():
def test_get_continuity_regress(fast_sun_position_method):
"""Regression test against values produced by get_continuity during development.
Correctness not validated for all values.
The particular time of 2018:001:12:00:00 happens during a maneuver, so this
Expand Down Expand Up @@ -670,7 +670,7 @@ def cmd_states_fetch_states(*args, **kwargs):
return cs


def test_reduce_states_cmd_states():
def test_reduce_states_cmd_states(fast_sun_position_method):
"""
Test that simple get_states() call with defaults gives the same results
as calling cmd_states.fetch_states().
Expand Down Expand Up @@ -1478,7 +1478,7 @@ def test_continuity_with_no_transitions_SPM(): # noqa: N802
}


def test_get_pitch_from_mid_maneuver():
def test_get_pitch_from_mid_maneuver(fast_sun_position_method):
"""Regression test for the fix for #125. Mostly the same as the test above, but for
the Maneuver transition class.
Expand Down Expand Up @@ -1551,7 +1551,7 @@ def test_get_states_start_between_aouptarg_aomanuvr_cmds():
assert cont["__dates__"]["q1"] == "2021:032:12:49:45.458"


def test_get_continuity_and_pitch_from_mid_maneuver():
def test_get_continuity_and_pitch_from_mid_maneuver(fast_sun_position_method):
"""Test for bug in continuity first noted at:
https://github.com/acisops/acis_thermal_check/pull/30#issuecomment-665240053
Expand Down
16 changes: 14 additions & 2 deletions kadi/commands/tests/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@

import numpy as np
import pytest
import ska_sun

from kadi.commands.utils import compress_time_series
from kadi.commands.validate import Validate
from kadi.commands.validate import Validate, ValidateRoll

# Regression testing for this 5-day period covering a safe mode with plenty of things
# happening. There are a number of violations in this period and a couple of excluded
Expand Down Expand Up @@ -108,7 +109,7 @@ def test_validate_subclasses():

@pytest.mark.parametrize("cls", Validate.subclasses)
@pytest.mark.parametrize("no_exclude", [False, True])
def test_validate_regression(cls, no_exclude):
def test_validate_regression(cls, no_exclude, fast_sun_position_method):
"""Test that validator data matches regression data
This is likely to be fragile. In the future we may need helper function to output
Expand All @@ -134,6 +135,17 @@ def test_validate_regression(cls, no_exclude):
assert np.all(data_obs["violations"] == data_exp["violations"])


def test_off_nominal_roll_violations():
"""Test off_nominal_roll violations over a time range with tail sun observations"""
# Default sun position method is "accurate".
off_nom_roll_val = ValidateRoll(stop="2023:327:00:00:00", days=1)
assert len(off_nom_roll_val.violations) == 0

with ska_sun.conf.set_temp("sun_position_method_default", "fast"):
off_nom_roll2 = ValidateRoll(stop="2023:327:00:00:00", days=1)
assert len(off_nom_roll2.violations) == 3


if __name__ == "__main__":
write_regression_data(REGRESSION_STOP, REGRESSION_DAYS, no_exclude=False)
write_regression_data(REGRESSION_STOP, REGRESSION_DAYS, no_exclude=True)
78 changes: 49 additions & 29 deletions kadi/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,23 @@

@dataclass
class PlotAttrs:
"""Plot attributes for a Validate subclass.
:param title: (str): Plot title.
:param ylabel: (str): Y-axis label.
:param range: (list): Y-axis range (optional).
:param max_delta_time: (float): Maximum time delta before new data point is plotted.
:param max_delta_val: (float): Maximum value delta before new data point is plotted.
:param max_gap_time: (float): Maximum gap in time before plot gap is inserted.
"""
Plot attributes for a Validate subclass.
Parameters
----------
title : str
Plot title.
ylabel : str
Y-axis label.
range : list, optional
Y-axis range.
max_delta_time : float, optional
Maximum time delta before a new data point is plotted.
max_delta_val : float, default 0
Maximum value delta before a new data point is plotted.
max_gap_time : float, default 300
Maximum gap in time before a plot gap is inserted.
"""

title: str
Expand All @@ -89,15 +98,36 @@ class PlotAttrs:
class Validate(ABC):
"""Validate kadi command states against telemetry base class.
:param state_name: (str): Name of state to validate.
:param stop: (CxoTime): Stop time.
:param days: (float): Number of days to validate.
:param state_keys_extra: (list): Extra state keys needed for validation.
:param plot_attrs: (PlotAttrs): Attributes for plot.
:param msids: (list): MSIDs to fetch for telemetry.
:param max_delta_val: (float): Maximum value delta to signal a violation.
:param max_gap: (float): Maximum gap in telemetry before breaking an interval (sec).
:param min_violation_duration: (float): Minimum duration of a violation (sec).
Class attributes are as follows:
state_name : str
Name of state to validate.
stop : CxoTime
Stop time.
days : float
Number of days to validate.
state_keys_extra : list, optional
Extra state keys needed for validation.
plot_attrs : PlotAttrs
Attributes for plot.
msids : list
MSIDs to fetch for telemetry.
max_delta_val : float
Maximum value delta to signal a violation.
max_gap : float
Maximum gap in telemetry before breaking an interval (sec).
min_violation_duration : float
Minimum duration of a violation (sec).
Parameters
----------
stop
stop time for validation
days
number of days for validation
no_exclude
if True then do not exclude any data (for testing)
"""

subclasses = []
Expand All @@ -114,17 +144,7 @@ class Validate(ABC):
min_violation_duration = 32.81

def __init__(self, stop=None, days: float = 14, no_exclude: bool = False):
"""Base class for validation.
Parameters
----------
stop
stop time for validation
days
number of days for validation
no_exclude
if True then do not exclude any data (for testing)
"""
"""Base class for validation"""
self.stop = CxoTime(stop)
self.days = days
self.start: CxoTime = self.stop - days * u.day
Expand Down Expand Up @@ -585,7 +605,7 @@ class ValidateRoll(ValidatePitchRollBase):
max_delta_val=0.5, # deg
)
max_delta_vals = {
"NPNT": 4, # deg
"NPNT": 2, # deg
"NMAN": 10.0, # deg
"NSUN": 4.0, # deg
}
Expand Down
6 changes: 3 additions & 3 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,10 @@ extend-ignore = [
"PYI056", # Calling `.append()` on `__all__` may not be supported by all type checkers
]

exclude = [
extend-exclude = [
"docs",
"utils",
"validate",
".eggs",
]

[pycodestyle]
Expand All @@ -67,4 +66,5 @@ max-line-length = 100 # E501 reports lines that exceed the length of 100.
"__init__.py" = ["E402", "F401", "F403"]
"command_sets.py" = ["ARG001"]
"**/tests/**" = ["D", "E501"]
"states.py" = ["N801", "ARG003"]
"states.py" = ["N801", "ARG003"]
"**/*.ipynb" = ["B018"]
18 changes: 18 additions & 0 deletions validate/performance_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import time

import kadi
from kadi.commands import get_cmds
from kadi.commands.states import get_states

print(f"{kadi.__version__=}")

start, stop = "2021:001", "2022:001"
cmds = get_cmds(start, stop, scenario="flight")

t0 = time.time()
states = get_states(start, stop, scenario="flight")
print(f"get_states took {time.time() - t0:.1f} sec")

t0 = time.time()
states = get_states(start, stop, scenario="flight")
print(f"2nd get_states took {time.time() - t0:.1f} sec")
Loading

0 comments on commit e99d43e

Please sign in to comment.