Skip to content
This repository has been archived by the owner on Sep 2, 2024. It is now read-only.

MXDGA-3724: Added the MVP for the fast grid scan motion #2

Merged
merged 7 commits into from
Jan 14, 2022
21 changes: 21 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
},
{
"name": "Debug Unit Test",
"type": "python",
"request": "test",
"justMyCode": false,
},
]
}
8 changes: 8 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"python.testing.pytestArgs": [
"."
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.formatting.provider": "black"
}
Empty file added src/__init__.py
Empty file.
Empty file added src/artemis/devices/__init__.py
Empty file.
144 changes: 144 additions & 0 deletions src/artemis/devices/fast_grid_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import threading
import time
from typing import List
from ophyd import Component, Device, EpicsSignal, EpicsSignalRO, EpicsSignalWithRBV
from ophyd.status import DeviceStatus, StatusBase, SubscriptionStatus


class GridScanCompleteStatus(DeviceStatus):
"""
A Status for the grid scan completion
A special status object that notifies watchers (progress bars)
based on comparing device.expected_images to device.position_counter.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.start_ts = time.time()

self.device.position_counter.subscribe(self._notify_watchers)
self.device.status.subscribe(self._running_changed)

self._name = self.device.name
self._target_count = self.device.expected_images

def _notify_watchers(self, value, *args, **kwargs):
if not self._watchers:
return
time_elapsed = time.time() - self.start_ts
try:
fraction = value / self._target_count
callumforrester marked this conversation as resolved.
Show resolved Hide resolved
except ZeroDivisionError:
fraction = 1
time_remaining = 0
except Exception:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is some unknown exception I think we should pass it upwards with self.set_exception().

fraction = None
time_remaining = None
else:
time_remaining = time_elapsed / fraction
for watcher in self._watchers:
watcher(
name=self._name,
current=value,
initial=0,
target=self._target_count,
unit="images",
precision=0,
fraction=fraction,
time_elapsed=time_elapsed,
time_remaining=time_remaining,
)

def _running_changed(self, value=None, old_value=None, **kwargs):
if (old_value == 1) and (value == 0):
# Stopped running
number_of_images = self.device.position_counter.get()
if number_of_images != self._target_count:
self.set_exception(
Exception(
f"Grid scan finished without collecting expected number of images. Expected {self._target_count} got {number_of_images}."
)
)
else:
self.set_finished()
self.clean_up()

def clean_up(self):
self.device.position_counter.clear_sub(self._notify_watchers)
self.device.status.clear_sub(self._running_changed)


class FastGridScan(Device):

x_steps: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "X_NUM_STEPS")
y_steps: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Y_NUM_STEPS")
z_steps: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Z_NUM_STEPS")

x_step_size: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "X_STEP_SIZE")
y_step_size: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Y_STEP_SIZE")
z_step_size: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Z_STEP_SIZE")

dwell_time: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "DWELL_TIME")

x_start: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "X_START")
y1_start: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Y_START")
y2_start: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Y2_START")
z1_start: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Z_START")
z2_start: EpicsSignalWithRBV = Component(EpicsSignalWithRBV, "Z2_START")

position_counter: EpicsSignal = Component(
EpicsSignal, "POS_COUNTER", write_pv="POS_COUNTER_WRITE"
)
x_counter: EpicsSignalRO = Component(EpicsSignalRO, "X_COUNTER")
y_counter: EpicsSignalRO = Component(EpicsSignalRO, "Y_COUNTER")
scan_invalid: EpicsSignalRO = Component(EpicsSignalRO, "SCAN_INVALID")

run_cmd: EpicsSignal = Component(EpicsSignal, "RUN.PROC")
stop_cmd: EpicsSignal = Component(EpicsSignal, "STOP.PROC")
status: EpicsSignalRO = Component(EpicsSignalRO, "SCAN_STATUS")

# Kickoff timeout in seconds
KICKOFF_TIMEOUT: float = 5.0

def set_program_data(self, nx, ny, width, height, exptime, startx, starty, startz):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda feel like more structured API here would help? set_program_data call in the test just looks like pushing a bunch of random numbers

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yh, I left this in as a bit of a placeholder until I had a better feel for how we would actually be using the Device upstream but I'll think about improving it now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I borrowed some of the work of @callumforrester for this, which also included some checking against motor limits.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't spot this chain, I think this needs updating to reflect this thread: #2 (comment)

self.x_steps.put(nx)
self.y_steps.put(ny)
self.x_step_size.put(float(width))
self.y_step_size.put(float(height))
self.dwell_time.put(float(exptime))
self.x_start.put(float(startx))
self.y1_start.put(float(starty))
self.z1_start.put(float(startz))
self.expected_images = nx * ny

def is_invalid(self):
callumforrester marked this conversation as resolved.
Show resolved Hide resolved
if "GONP" in self.scan_invalid.pvname:
return False
return self.scan_invalid.get()

def kickoff(self) -> StatusBase:
# Check running already here?
st = DeviceStatus(device=self, timeout=self.KICKOFF_TIMEOUT)

def check_valid_and_scan():
try:
self.log.info("Waiting on position counter reset and valid settings")
while self.is_invalid() or not self.position_counter.get() == 0:
time.sleep(0.1)
self.log.debug("Running scan")
running = SubscriptionStatus(self.status, lambda value: value == 1)
run_requested = self.run_cmd.set(1)
(run_requested and running).wait()
st.set_finished()
except Exception as e:
st.set_exception(e)

threading.Thread(target=check_valid_and_scan, daemon=True).start()
return st

def stage(self) -> List[object]:
self.position_counter.put(0)
return super().stage()

def complete(self) -> DeviceStatus:
return GridScanCompleteStatus(self)
Empty file.
16 changes: 16 additions & 0 deletions src/artemis/devices/system_tests/test_gridscan_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from src.artemis.devices.fast_grid_scan import FastGridScan


@pytest.fixture()
def fast_grid_scan():
fast_grid_scan = FastGridScan(name="fast_grid_scan", prefix="BL03S-MO-SGON-01:FGS:")
yield fast_grid_scan


@pytest.mark.s03
def test_set_program_data_and_kickoff(fast_grid_scan: FastGridScan):
fast_grid_scan.set_program_data(2, 2, 0.1, 0.1, 1, 0, 0, 0)
callumforrester marked this conversation as resolved.
Show resolved Hide resolved
kickoff_status = fast_grid_scan.kickoff()
kickoff_status.wait()
Empty file.
134 changes: 134 additions & 0 deletions src/artemis/devices/unit_tests/test_gridscan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from ophyd.sim import make_fake_device
from src.artemis.devices.fast_grid_scan import FastGridScan, time

from mockito import *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I prefer to import what you use rather than from thing import * - so for example when I wonder what when() is below, I can see where it comes from

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I wouldn't use import * in the actual production code so I shouldn't in the test code.

from mockito.matchers import *
import pytest


@pytest.fixture
def fast_grid_scan():
FakeFastGridScan = make_fake_device(FastGridScan)
fast_grid_scan: FastGridScan = FakeFastGridScan(name="test")
fast_grid_scan.scan_invalid.pvname = ""

# A bit of a hack to assume that if we are waiting on something then we will timeout
when(time).sleep(ANY).thenRaise(TimeoutError())
return fast_grid_scan


def test_given_invalid_scan_when_kickoff_then_timeout(fast_grid_scan: FastGridScan):
when(fast_grid_scan.scan_invalid).get().thenReturn(True)
when(fast_grid_scan.position_counter).get().thenReturn(0)

status = fast_grid_scan.kickoff()

with pytest.raises(TimeoutError):
status.wait()


def test_given_image_counter_not_reset_when_kickoff_then_timeout(
fast_grid_scan: FastGridScan,
):
when(fast_grid_scan.scan_invalid).get().thenReturn(False)
when(fast_grid_scan.position_counter).get().thenReturn(10)

status = fast_grid_scan.kickoff()

with pytest.raises(TimeoutError):
status.wait()


def test_given_settings_valid_when_kickoff_then_run_started(
fast_grid_scan: FastGridScan,
):
when(fast_grid_scan.scan_invalid).get().thenReturn(False)
when(fast_grid_scan.position_counter).get().thenReturn(0)

mock_run_set_status = mock()
when(fast_grid_scan.run_cmd).set(ANY).thenReturn(mock_run_set_status)
fast_grid_scan.status.subscribe = lambda func, **kwargs: func(1)

status = fast_grid_scan.kickoff()

status.wait()

verify(fast_grid_scan.run_cmd).set(1)
assert status.exception() == None


def run_test_on_complete_watcher(fast_grid_scan, num_pos_1d, put_value, expected_frac):
fast_grid_scan.set_program_data(
num_pos_1d, num_pos_1d, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
)

complete_status = fast_grid_scan.complete()
watcher = mock()
complete_status.watch(watcher)

fast_grid_scan.position_counter.sim_put(put_value)
verify(watcher).__call__(
*ARGS,
current=put_value,
target=num_pos_1d ** 2,
fraction=expected_frac,
**KWARGS
)


def test_when_new_image_then_complete_watcher_notified(fast_grid_scan: FastGridScan):
run_test_on_complete_watcher(fast_grid_scan, 2, 1, 1 / 4)


def test_given_0_expected_images_then_complete_watcher_correct(
fast_grid_scan: FastGridScan,
):
run_test_on_complete_watcher(fast_grid_scan, 0, 1, 1)


def test_given_invalid_image_number_then_complete_watcher_correct(
fast_grid_scan: FastGridScan,
):
run_test_on_complete_watcher(fast_grid_scan, 1, "BAD", None)


def test_running_finished_with_not_all_images_done_then_complete_status_in_error(
fast_grid_scan: FastGridScan,
):
num_pos_1d = 2
fast_grid_scan.set_program_data(
num_pos_1d, num_pos_1d, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
)

fast_grid_scan.status.sim_put(1)

complete_status = fast_grid_scan.complete()
assert not complete_status.done
fast_grid_scan.status.sim_put(0)

with pytest.raises(Exception):
complete_status.wait()

assert complete_status.done
assert complete_status.exception() != None


def test_running_finished_with_all_images_done_then_complete_status_finishes_not_in_error(
fast_grid_scan: FastGridScan,
):
num_pos_1d = 2
fast_grid_scan.set_program_data(
num_pos_1d, num_pos_1d, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
)

fast_grid_scan.status.sim_put(1)

complete_status = fast_grid_scan.complete()
assert not complete_status.done
fast_grid_scan.position_counter.sim_put(num_pos_1d ** 2)
fast_grid_scan.status.sim_put(0)

complete_status.wait()

assert complete_status.done
assert complete_status.exception() == None