diff --git a/pocs/tests/test_pocs.py b/pocs/tests/test_pocs.py index 81f66376e..d603eeda0 100644 --- a/pocs/tests/test_pocs.py +++ b/pocs/tests/test_pocs.py @@ -1,17 +1,37 @@ import os import pytest import time - -from multiprocessing import Process +import threading from astropy import units as u from pocs import hardware from pocs.core import POCS from pocs.observatory import Observatory +from pocs.utils import Timeout from pocs.utils.messaging import PanMessaging +def wait_for_running(sub, max_duration=90): + """Given a message subscriber, wait for a RUNNING message.""" + timeout = Timeout(max_duration) + while not timeout.expired(): + msg_type, msg_obj = sub.receive_message() + if msg_obj and 'RUNNING' == msg_obj.get('message'): + return True + return False + + +def wait_for_state(sub, state, max_duration=90): + """Given a message subscriber, wait for the specified state.""" + timeout = Timeout(max_duration) + while not timeout.expired(): + msg_type, msg_obj = sub.receive_message() + if msg_type == 'STATUS' and msg_obj and msg_obj.get('state') == state: + return True + return False + + @pytest.fixture(scope='function') def observatory(config, db_type): observatory = Observatory( @@ -199,7 +219,9 @@ def test_run_wait_until_safe(observatory): observatory.db.clear_current('weather') def start_pocs(): - observatory.config['simulator'] = ['camera', 'mount', 'night'] + observatory.logger.info('start_pocs ENTER') + # Remove weather simulator, else it would always be safe. + observatory.config['simulator'] = hardware.get_all_names(without=['weather']) pocs = POCS(observatory, messaging=True, safe_delay=5) @@ -220,32 +242,28 @@ def start_pocs(): pocs.run(run_once=True, exit_when_done=True) assert pocs.is_weather_safe() is True pocs.power_down() + observatory.logger.info('start_pocs EXIT') pub = PanMessaging.create_publisher(6500) sub = PanMessaging.create_subscriber(6511) - pocs_process = Process(target=start_pocs) - pocs_process.start() + pocs_thread = threading.Thread(target=start_pocs) + pocs_thread.start() - # Wait for the running message - while True: - msg_type, msg_obj = sub.receive_message() - if msg_obj is None: - continue + try: + # Wait for the RUNNING message, + assert wait_for_running(sub) - if msg_obj.get('message', '') == 'RUNNING': - time.sleep(2) - # Insert a dummy weather record to break wait - observatory.db.insert_current('weather', {'safe': True}) + time.sleep(2) + # Insert a dummy weather record to break wait + observatory.db.insert_current('weather', {'safe': True}) - if msg_type == 'STATUS': - current_state = msg_obj.get('state', {}) - if current_state == 'pointing': - pub.send_message('POCS-CMD', 'shutdown') - break + assert wait_for_state(sub, 'scheduling') + finally: + pub.send_message('POCS-CMD', 'shutdown') + pocs_thread.join(timeout=30) - pocs_process.join() - assert pocs_process.is_alive() is False + assert pocs_thread.is_alive() is False def test_unsafe_park(pocs): @@ -336,6 +354,7 @@ def test_run_complete(pocs): def test_run_power_down_interrupt(observatory): def start_pocs(): + observatory.logger.info('start_pocs ENTER') pocs = POCS(observatory, messaging=True) pocs.initialize() pocs.observatory.scheduler.clear_available_observations() @@ -349,23 +368,21 @@ def start_pocs(): pocs.logger.info('Starting observatory run') pocs.run() pocs.power_down() + observatory.logger.info('start_pocs EXIT') - pocs_process = Process(target=start_pocs) - pocs_process.start() + pocs_thread = threading.Thread(target=start_pocs) + pocs_thread.start() pub = PanMessaging.create_publisher(6500) sub = PanMessaging.create_subscriber(6511) - while True: - msg_type, msg_obj = sub.receive_message() - if msg_type == 'STATUS': - current_state = msg_obj.get('state', {}) - if current_state == 'pointing': - pub.send_message('POCS-CMD', 'shutdown') - break - - pocs_process.join() - assert pocs_process.is_alive() is False + try: + assert wait_for_state(sub, 'scheduling') + finally: + pub.send_message('POCS-CMD', 'shutdown') + pocs_thread.join(timeout=30) + + assert pocs_thread.is_alive() is False def test_pocs_park_to_ready_with_observations(pocs): diff --git a/pocs/utils/__init__.py b/pocs/utils/__init__.py index 4e91b6f40..ae79757cc 100644 --- a/pocs/utils/__init__.py +++ b/pocs/utils/__init__.py @@ -2,7 +2,7 @@ import re import shutil import subprocess - +import time from astropy import units as u from astropy.coordinates import AltAz @@ -48,6 +48,42 @@ def flatten_time(t): return t.isot.replace('-', '').replace(':', '').split('.')[0] +# This is a streamlined variant of PySerial's serialutil.Timeout. +class Timeout(object): + """Simple timer object for tracking whether a time duration has elapsed. + + Attribute `is_non_blocking` is true IFF the duration is zero. + """ + + def __init__(self, duration): + """Initialize a timeout with given duration (seconds).""" + assert duration >= 0 + self.is_non_blocking = (duration == 0) + self.duration = duration + self.restart() + + def expired(self): + """Return a boolean, telling if the timeout has expired.""" + return self.time_left() <= 0 + + def time_left(self): + """Return how many seconds are left until the timeout expires.""" + if self.is_non_blocking: + return 0 + else: + delta = self.target_time - time.monotonic() + if delta > self.duration: + # clock jumped, recalculate + self.restart() + return self.duration + else: + return max(0, delta) + + def restart(self): + """Restart the timed duration.""" + self.target_time = time.monotonic() + self.duration + + def listify(obj): """ Given an object, return a list