Skip to content

Commit

Permalink
Use threads instead of processes in test_pocs.py (#468)
Browse files Browse the repository at this point in the history
Run pocs in a thread instead of in a separate process to make it easier
to share resources between the test thread and pocs. This was motivated
by PanMemoryDB, but I think it will help in general.

Add a Timeout class to pocs.utils and use in test_pocs to force the
ending of long running tests.

Shutdown pocs even if message not received; use a try-finally block
to ensure we send the shutdown message.

Use hardware module to get the list of simulator names, minus 'weather'.

Wait for pocs thread to stop even if an assertion has failed, to avoid
having two running at once (i.e. if the next test starts while the
previous test's thread is still running).
  • Loading branch information
jamessynge authored Feb 11, 2018
1 parent 82d3c9c commit 42d62a5
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 34 deletions.
83 changes: 50 additions & 33 deletions pocs/tests/test_pocs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,37 @@
import os
import pytest
import time

from multiprocessing import Process
import threading

from astropy import units as u

from pocs import hardware
from pocs.core import POCS
from pocs.observatory import Observatory
from pocs.utils import Timeout
from pocs.utils.messaging import PanMessaging


def wait_for_running(sub, max_duration=90):
"""Given a message subscriber, wait for a RUNNING message."""
timeout = Timeout(max_duration)
while not timeout.expired():
msg_type, msg_obj = sub.receive_message()
if msg_obj and 'RUNNING' == msg_obj.get('message'):
return True
return False


def wait_for_state(sub, state, max_duration=90):
"""Given a message subscriber, wait for the specified state."""
timeout = Timeout(max_duration)
while not timeout.expired():
msg_type, msg_obj = sub.receive_message()
if msg_type == 'STATUS' and msg_obj and msg_obj.get('state') == state:
return True
return False


@pytest.fixture(scope='function')
def observatory(config, db_type):
observatory = Observatory(
Expand Down Expand Up @@ -199,7 +219,9 @@ def test_run_wait_until_safe(observatory):
observatory.db.clear_current('weather')

def start_pocs():
observatory.config['simulator'] = ['camera', 'mount', 'night']
observatory.logger.info('start_pocs ENTER')
# Remove weather simulator, else it would always be safe.
observatory.config['simulator'] = hardware.get_all_names(without=['weather'])

pocs = POCS(observatory,
messaging=True, safe_delay=5)
Expand All @@ -220,32 +242,28 @@ def start_pocs():
pocs.run(run_once=True, exit_when_done=True)
assert pocs.is_weather_safe() is True
pocs.power_down()
observatory.logger.info('start_pocs EXIT')

pub = PanMessaging.create_publisher(6500)
sub = PanMessaging.create_subscriber(6511)

pocs_process = Process(target=start_pocs)
pocs_process.start()
pocs_thread = threading.Thread(target=start_pocs)
pocs_thread.start()

# Wait for the running message
while True:
msg_type, msg_obj = sub.receive_message()
if msg_obj is None:
continue
try:
# Wait for the RUNNING message,
assert wait_for_running(sub)

if msg_obj.get('message', '') == 'RUNNING':
time.sleep(2)
# Insert a dummy weather record to break wait
observatory.db.insert_current('weather', {'safe': True})
time.sleep(2)
# Insert a dummy weather record to break wait
observatory.db.insert_current('weather', {'safe': True})

if msg_type == 'STATUS':
current_state = msg_obj.get('state', {})
if current_state == 'pointing':
pub.send_message('POCS-CMD', 'shutdown')
break
assert wait_for_state(sub, 'scheduling')
finally:
pub.send_message('POCS-CMD', 'shutdown')
pocs_thread.join(timeout=30)

pocs_process.join()
assert pocs_process.is_alive() is False
assert pocs_thread.is_alive() is False


def test_unsafe_park(pocs):
Expand Down Expand Up @@ -336,6 +354,7 @@ def test_run_complete(pocs):

def test_run_power_down_interrupt(observatory):
def start_pocs():
observatory.logger.info('start_pocs ENTER')
pocs = POCS(observatory, messaging=True)
pocs.initialize()
pocs.observatory.scheduler.clear_available_observations()
Expand All @@ -349,23 +368,21 @@ def start_pocs():
pocs.logger.info('Starting observatory run')
pocs.run()
pocs.power_down()
observatory.logger.info('start_pocs EXIT')

pocs_process = Process(target=start_pocs)
pocs_process.start()
pocs_thread = threading.Thread(target=start_pocs)
pocs_thread.start()

pub = PanMessaging.create_publisher(6500)
sub = PanMessaging.create_subscriber(6511)

while True:
msg_type, msg_obj = sub.receive_message()
if msg_type == 'STATUS':
current_state = msg_obj.get('state', {})
if current_state == 'pointing':
pub.send_message('POCS-CMD', 'shutdown')
break

pocs_process.join()
assert pocs_process.is_alive() is False
try:
assert wait_for_state(sub, 'scheduling')
finally:
pub.send_message('POCS-CMD', 'shutdown')
pocs_thread.join(timeout=30)

assert pocs_thread.is_alive() is False


def test_pocs_park_to_ready_with_observations(pocs):
Expand Down
38 changes: 37 additions & 1 deletion pocs/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
import shutil
import subprocess

import time

from astropy import units as u
from astropy.coordinates import AltAz
Expand Down Expand Up @@ -48,6 +48,42 @@ def flatten_time(t):
return t.isot.replace('-', '').replace(':', '').split('.')[0]


# This is a streamlined variant of PySerial's serialutil.Timeout.
class Timeout(object):
"""Simple timer object for tracking whether a time duration has elapsed.
Attribute `is_non_blocking` is true IFF the duration is zero.
"""

def __init__(self, duration):
"""Initialize a timeout with given duration (seconds)."""
assert duration >= 0
self.is_non_blocking = (duration == 0)
self.duration = duration
self.restart()

def expired(self):
"""Return a boolean, telling if the timeout has expired."""
return self.time_left() <= 0

def time_left(self):
"""Return how many seconds are left until the timeout expires."""
if self.is_non_blocking:
return 0
else:
delta = self.target_time - time.monotonic()
if delta > self.duration:
# clock jumped, recalculate
self.restart()
return self.duration
else:
return max(0, delta)

def restart(self):
"""Restart the timed duration."""
self.target_time = time.monotonic() + self.duration


def listify(obj):
""" Given an object, return a list
Expand Down

0 comments on commit 42d62a5

Please sign in to comment.