From fc40a21523ed18d21c347bd447b369d8bbd549e3 Mon Sep 17 00:00:00 2001 From: Max Marrone Date: Thu, 9 Jun 2022 15:02:56 -0400 Subject: [PATCH] fix(update-server): Keep name deconflicted with other devices on the network (#10559) Fixes #10126. --- update-server/Pipfile | 4 +- update-server/Pipfile.lock | 10 +- update-server/mypy.ini | 1 + update-server/otupdate/buildroot/__init__.py | 58 +-- update-server/otupdate/buildroot/__main__.py | 45 +- update-server/otupdate/buildroot/constants.py | 3 - update-server/otupdate/common/constants.py | 3 - update-server/otupdate/common/control.py | 5 +- .../common/name_management/__init__.py | 32 +- .../otupdate/common/name_management/avahi.py | 400 +++++++++++++----- .../name_management/name_synchronizer.py | 151 +++++++ .../common/name_management/pretty_hostname.py | 2 +- .../otupdate/common/run_application.py | 37 ++ .../otupdate/openembedded/__init__.py | 56 +-- .../otupdate/openembedded/__main__.py | 37 +- update-server/pytest.ini | 2 +- update-server/tests/buildroot/conftest.py | 11 +- update-server/tests/buildroot/test_control.py | 13 +- .../tests/buildroot/test_name_endpoints.py | 51 --- update-server/tests/common/conftest.py | 7 +- .../common/name_management/test_avahi.py | 64 +++ .../name_management/test_http_endpoints.py | 35 ++ .../name_management/test_name_synchronizer.py | 228 ++++++++++ update-server/tests/conftest.py | 10 + update-server/tests/openembedded/conftest.py | 7 +- .../tests/openembedded/test_control.py | 13 +- 26 files changed, 985 insertions(+), 300 deletions(-) create mode 100644 update-server/otupdate/common/name_management/name_synchronizer.py create mode 100644 update-server/otupdate/common/run_application.py delete mode 100644 update-server/tests/buildroot/test_name_endpoints.py create mode 100644 update-server/tests/common/name_management/test_avahi.py create mode 100644 update-server/tests/common/name_management/test_http_endpoints.py create mode 100644 update-server/tests/common/name_management/test_name_synchronizer.py diff --git a/update-server/Pipfile b/update-server/Pipfile index e7d01f0fc7c..50ff797de6c 100644 --- a/update-server/Pipfile +++ b/update-server/Pipfile @@ -18,12 +18,14 @@ pytest-watch = "~=4.2.0" pytest-cov = "==2.10.1" pytest-aiohttp = "==0.3.0" coverage = "==5.1" -# pytest dependencies on windows, spec'd here to force lockfile inclusion +# atomicwrites and colorama are pytest dependencies on windows, +# spec'd here to force lockfile inclusion # https://github.com/pypa/pipenv/issues/4408#issuecomment-668324177 atomicwrites = {version="==1.4.0", sys_platform="== 'win32'"} colorama = {version="==0.4.4", sys_platform="== 'win32'"} mypy = "==0.940" black = "==22.3.0" +decoy = "~=1.10" [requires] python_version = "3.7" diff --git a/update-server/Pipfile.lock b/update-server/Pipfile.lock index 26b52e105aa..cced41f3d19 100644 --- a/update-server/Pipfile.lock +++ b/update-server/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "c1f36f5bd77121139aa590708ed72923fd199df3b0c32569a9870e0d8c8162ca" + "sha256": "62e7113800829936d95ee7f0b54803dfb95e2040c6f1b6b84a4c0b3876d2d6c8" }, "pipfile-spec": 6, "requires": { @@ -328,6 +328,14 @@ "index": "pypi", "version": "==5.1" }, + "decoy": { + "hashes": [ + "sha256:6d2088d95a4c020cc546e36f551145aa765bf54c73779d818b9e9cc9e970723e", + "sha256:8c5a960f35976f1365aa859355214b2cfb7a7fbc2064c7d32a5156d051e8f2f1" + ], + "index": "pypi", + "version": "==1.11.0" + }, "docopt": { "hashes": [ "sha256:49b3a825280bd66b3aa83585ef59c4a8c82f2c8a522dbe754a8bc8d08c85c491" diff --git a/update-server/mypy.ini b/update-server/mypy.ini index 4dc4bd7f78e..33f0a4990f3 100644 --- a/update-server/mypy.ini +++ b/update-server/mypy.ini @@ -1,6 +1,7 @@ [mypy] strict = True show_error_codes = True +plugins = decoy.mypy # The dbus and systemd packages will not be installed in non-Linux dev environments. diff --git a/update-server/otupdate/buildroot/__init__.py b/update-server/otupdate/buildroot/__init__.py index 3e30681f4e7..b562d4e4ff0 100644 --- a/update-server/otupdate/buildroot/__init__.py +++ b/update-server/otupdate/buildroot/__init__.py @@ -3,14 +3,15 @@ import logging import json from typing import Any, Mapping, Optional + from aiohttp import web from otupdate.common import ( config, + constants, control, - ssh_key_management, name_management, - constants, + ssh_key_management, update, ) from . import update_actions @@ -38,11 +39,11 @@ def get_version(version_file: str) -> Mapping[str, str]: def get_app( + name_synchronizer: name_management.NameSynchronizer, system_version_file: Optional[str] = None, config_file_override: Optional[str] = None, name_override: Optional[str] = None, boot_id_override: Optional[str] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, ) -> web.Application: """Build and return the aiohttp.web.Application that runs the server @@ -52,40 +53,17 @@ def get_app( system_version_file = BR_BUILTIN_VERSION_FILE version = get_version(system_version_file) - name = name_override or name_management.get_pretty_hostname() boot_id = boot_id_override or control.get_boot_id() config_obj = config.load(config_file_override) - LOG.info( - "Setup: " - + "\n\t".join( - [ - f"Device name: {name}", - "Buildroot version: " - f'{version.get("buildroot_version", "unknown")}', - "\t(from git sha " f'{version.get("buildroot_sha", "unknown")}', - "API version: " - f'{version.get("opentrons_api_version", "unknown")}', - "\t(from git sha " - f'{version.get("opentrons_api_sha", "unknown")}', - "Update server version: " - f'{version.get("update_server_version", "unknown")}', - "\t(from git sha " - f'{version.get("update_server_sha", "unknown")}', - "Smoothie firmware version: TODO", - ] - ) - ) - - if not loop: - loop = asyncio.get_event_loop() - app = web.Application(middlewares=[log_error_middleware]) + app[config.CONFIG_VARNAME] = config_obj app[constants.RESTART_LOCK_NAME] = asyncio.Lock() app[constants.DEVICE_BOOT_ID_NAME] = boot_id - app[constants.DEVICE_NAME_VARNAME] = name update_actions.OT2UpdateActions.build_and_insert(app) + name_management.install_name_synchronizer(name_synchronizer, app) + app.router.add_routes( [ web.get( @@ -106,6 +84,28 @@ def get_app( web.get("/server/name", name_management.get_name_endpoint), ] ) + + LOG.info( + "Setup: " + + "\n\t".join( + [ + f"Device name: {name_synchronizer.get_name()}", + "Buildroot version: " + f'{version.get("buildroot_version", "unknown")}', + "\t(from git sha " f'{version.get("buildroot_sha", "unknown")}', + "API version: " + f'{version.get("opentrons_api_version", "unknown")}', + "\t(from git sha " + f'{version.get("opentrons_api_sha", "unknown")}', + "Update server version: " + f'{version.get("update_server_version", "unknown")}', + "\t(from git sha " + f'{version.get("update_server_sha", "unknown")}', + "Smoothie firmware version: TODO", + ] + ) + ) + return app diff --git a/update-server/otupdate/buildroot/__main__.py b/update-server/otupdate/buildroot/__main__.py index 1b4e2d6784b..8d763cf3538 100644 --- a/update-server/otupdate/buildroot/__main__.py +++ b/update-server/otupdate/buildroot/__main__.py @@ -3,45 +3,40 @@ """ import asyncio import logging +from typing import NoReturn -from . import get_app, constants +from . import get_app from otupdate.common import name_management, cli, systemd -from aiohttp import web +from otupdate.common.run_application import run_and_notify_up + LOG = logging.getLogger(__name__) -def main() -> None: +async def main() -> NoReturn: parser = cli.build_root_parser() args = parser.parse_args() - loop = asyncio.get_event_loop() + systemd.configure_logging(getattr(logging, args.log_level.upper())) + # Because this involves restarting Avahi, this must happen early, + # before the NameSynchronizer starts up and connects to Avahi. LOG.info("Setting static hostname") - hostname = loop.run_until_complete(name_management.set_up_static_hostname()) - LOG.info(f"Set static hostname to {hostname}") - - LOG.info("Building buildroot update server") - app = get_app(args.version_file, args.config_file) - - name = app[constants.DEVICE_NAME_VARNAME] # Set by get_app(). - - LOG.info(f"Setting Avahi service name to {name}") - try: - loop.run_until_complete(name_management.set_avahi_service_name(name)) - except Exception: - LOG.exception( - "Error setting the Avahi service name." - " mDNS + DNS-SD advertisement may not work." + static_hostname = await name_management.set_up_static_hostname() + LOG.info(f"Set static hostname to {static_hostname}") + + async with name_management.NameSynchronizer.start() as name_synchronizer: + LOG.info("Building buildroot update server") + app = get_app( + system_version_file=args.version_file, + config_file_override=args.config_file, + name_synchronizer=name_synchronizer, ) - LOG.info(f"Notifying {systemd.SOURCE} that service is up") - systemd.notify_up() - - LOG.info(f"Starting buildroot update server on http://{args.host}:{args.port}") - web.run_app(app, host=args.host, port=args.port) # type: ignore[no-untyped-call] + LOG.info(f"Starting buildroot update server on http://{args.host}:{args.port}") + await run_and_notify_up(app=app, host=args.host, port=args.port) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/update-server/otupdate/buildroot/constants.py b/update-server/otupdate/buildroot/constants.py index ab6bf6b6c88..33783d32769 100644 --- a/update-server/otupdate/buildroot/constants.py +++ b/update-server/otupdate/buildroot/constants.py @@ -6,9 +6,6 @@ RESTART_LOCK_NAME = APP_VARIABLE_PREFIX + "restartlock" #: Name for the asyncio lock in the application dictlike -DEVICE_NAME_VARNAME = APP_VARIABLE_PREFIX + "name" -#: Name for the device - DEVICE_BOOT_ID_NAME = APP_VARIABLE_PREFIX + "boot_id" #: A random string that changes every time the device boots. #: diff --git a/update-server/otupdate/common/constants.py b/update-server/otupdate/common/constants.py index cb5a27ee7d0..8718ae54f72 100644 --- a/update-server/otupdate/common/constants.py +++ b/update-server/otupdate/common/constants.py @@ -6,9 +6,6 @@ RESTART_LOCK_NAME = APP_VARIABLE_PREFIX + "restartlock" #: Name for the asyncio lock in the application dictlike -DEVICE_NAME_VARNAME = APP_VARIABLE_PREFIX + "name" -#: Name for the device - DEVICE_BOOT_ID_NAME = APP_VARIABLE_PREFIX + "boot_id" #: A random string that changes every time the device boots. #: diff --git a/update-server/otupdate/common/control.py b/update-server/otupdate/common/control.py index 29faf7fd8a1..5b31221e722 100644 --- a/update-server/otupdate/common/control.py +++ b/update-server/otupdate/common/control.py @@ -12,7 +12,8 @@ from aiohttp import web -from .constants import RESTART_LOCK_NAME, DEVICE_BOOT_ID_NAME, DEVICE_NAME_VARNAME +from .constants import RESTART_LOCK_NAME, DEVICE_BOOT_ID_NAME +from .name_management import get_name_synchronizer LOG = logging.getLogger(__name__) @@ -41,7 +42,7 @@ async def health(request: web.Request) -> web.Response: { **health_response, **{ - "name": request.app[DEVICE_NAME_VARNAME], + "name": get_name_synchronizer(request).get_name(), "serialNumber": get_serial(), "bootId": request.app[DEVICE_BOOT_ID_NAME], }, diff --git a/update-server/otupdate/common/name_management/__init__.py b/update-server/otupdate/common/name_management/__init__.py index 60e1d7b14bc..0daecc0d573 100644 --- a/update-server/otupdate/common/name_management/__init__.py +++ b/update-server/otupdate/common/name_management/__init__.py @@ -41,30 +41,20 @@ See `set_name_endpoint()`. """ +from __future__ import annotations import json from aiohttp import web -from ..constants import DEVICE_NAME_VARNAME - -from .avahi import set_avahi_service_name -from .pretty_hostname import ( - get_pretty_hostname, - persist_pretty_hostname, +from .name_synchronizer import ( + NameSynchronizer, + install_name_synchronizer, + get_name_synchronizer, ) from .static_hostname import set_up_static_hostname -async def set_name(app: web.Application, new_name: str) -> str: - """See `set_name_endpoint()`.""" - await set_avahi_service_name(new_name) - # Setting the Avahi service name can fail if Avahi doesn't like the new name. - # Persist only after it succeeds, so we don't persist something invalid. - persisted_pretty_hostname = await persist_pretty_hostname(new_name) - return persisted_pretty_hostname - - async def set_name_endpoint(request: web.Request) -> web.Response: """Set the robot's name. @@ -104,9 +94,9 @@ def build_400(msg: str) -> web.Response: if not isinstance(name_to_set, str): return build_400('"name" key is not a string"') - new_name = await set_name(app=request.app, new_name=name_to_set) + name_synchronizer = get_name_synchronizer(request) + new_name = await name_synchronizer.set_name(new_name=name_to_set) - request.app[DEVICE_NAME_VARNAME] = new_name return web.json_response( # type: ignore[no-untyped-call,no-any-return] data={"name": new_name}, status=200 ) @@ -120,15 +110,17 @@ async def get_name_endpoint(request: web.Request) -> web.Response: GET /server/name -> 200 OK, {'name': robot name} """ + name_synchronizer = get_name_synchronizer(request) return web.json_response( # type: ignore[no-untyped-call,no-any-return] - data={"name": request.app[DEVICE_NAME_VARNAME]}, status=200 + data={"name": name_synchronizer.get_name()}, status=200 ) __all__ = [ - "get_pretty_hostname", + "NameSynchronizer", + "install_name_synchronizer", + "get_name_synchronizer", "set_up_static_hostname", - "set_avahi_service_name", "get_name_endpoint", "set_name_endpoint", ] diff --git a/update-server/otupdate/common/name_management/avahi.py b/update-server/otupdate/common/name_management/avahi.py index 677c87ebc66..b2afdd8fd41 100644 --- a/update-server/otupdate/common/name_management/avahi.py +++ b/update-server/otupdate/common/name_management/avahi.py @@ -1,127 +1,40 @@ """Control the Avahi daemon.""" -import asyncio -from logging import getLogger -from typing import Optional - - -_log = getLogger(__name__) +from __future__ import annotations +import asyncio +import contextlib +import logging +import re +from typing import AsyncGenerator, Awaitable, Callable, cast try: import dbus - class DBusState: - """Bundle of state for dbus""" - - def __init__( - self, - bus: dbus.SystemBus, - server: dbus.Interface, - entrygroup: dbus.Interface, - ) -> None: - """ - Build the state bundle. - - :param bus: The system bus instance - :param server: An org.freedesktop.Avahi.Server interface - :param entrygroup: An org.freedesktop.Avahi.EntryGroup interface - """ - self.bus = bus - #: The system bus - self.entrygroup = entrygroup - #: The entry group interface - self.server = server - #: The avahi server interface - - _BUS_STATE: Optional[DBusState] = None - - def _set_avahi_service_name_sync(new_service_name: str) -> None: - """The synchronous implementation of setting the Avahi service name. - - The dbus module doesn't natively support async/await. - """ - # For semantics of the methods we're calling, see Avahi's API docs. - # For example: https://www.avahi.org/doxygen/html/index.html#good_publish - # It's mostly in terms of the C API, but the semantics should be the same. - # - # For exact method names and argument types, see Avahi's D-Bus bindings, - # which they specify across several machine-readable files. For example: - # https://github.com/lathiat/avahi/blob/v0.7/avahi-daemon/org.freedesktop.Avahi.EntryGroup.xml - global _BUS_STATE - if not _BUS_STATE: - bus = dbus.SystemBus() - server_obj = bus.get_object("org.freedesktop.Avahi", "/") - server_if = dbus.Interface(server_obj, "org.freedesktop.Avahi.Server") - entrygroup_path = server_if.EntryGroupNew() - entrygroup_obj = bus.get_object("org.freedesktop.Avahi", entrygroup_path) - entrygroup_if = dbus.Interface( - entrygroup_obj, "org.freedesktop.Avahi.EntryGroup" - ) - _BUS_STATE = DBusState(bus, server_if, entrygroup_if) - - _BUS_STATE.entrygroup.Reset() - - hostname = _BUS_STATE.server.GetHostName() - domainname = _BUS_STATE.server.GetDomainName() - - # TODO(mm, 2022-05-06): This isn't exception-safe. - # Since we've already reset the entrygroup, if this fails - # (for example because Avahi doesn't like the new name), - # we'll be left with no entrygroup and Avahi will stop advertising the machine. - _BUS_STATE.entrygroup.AddService( - dbus.Int32(-1), # avahi.IF_UNSPEC - dbus.Int32(-1), # avahi.PROTO_UNSPEC - dbus.UInt32(0), # flags - new_service_name, # sname - "_http._tcp", # stype - domainname, # sdomain (.local) - f"{hostname}.{domainname}", # shost (hostname.local) - dbus.UInt16(31950), # port - dbus.Array([], signature="ay"), - ) - _BUS_STATE.entrygroup.Commit() - - # TODO(mm, 2022-05-04): Recover from AVAHI_ENTRY_GROUP_COLLISION in case - # this name collides with another device on the network. - # https://github.com/Opentrons/opentrons/issues/10126 - + _DBUS_AVAILABLE = True except ImportError: - _log.exception("Couldn't import dbus, name setting will be nonfunctional") + _DBUS_AVAILABLE = False - def _set_avahi_service_name_sync(new_service_name: str) -> None: - _log.warning("Not setting name, dbus could not be imported") +_COLLISION_POLL_INTERVAL = 5 -_BUS_LOCK = asyncio.Lock() +SERVICE_NAME_MAXIMUM_OCTETS = 63 +"""The maximum length of an Avahi service name. +Measured in UTF-8 octets (bytes) -- not code points or characters! -async def set_avahi_service_name(new_service_name: str) -> None: - """Set the Avahi service name. +This comes from the DNS-SD specification on instance names, +which Avahi service names correspond to, under the hood. +https://datatracker.ietf.org/doc/html/rfc6763#section-4.1.1 +""" - See the `name_management` package docstring for background on the service name - and how it's distinct from other names on the machine. - The new service name will only apply to the current boot. - - Avahi requires a service name. - It will not advertise the system over mDNS + DNS-SD until one is set. - - Since the Avahi service name corresponds to the DNS-SD instance name, - it's a human-readable string of mostly arbitrary Unicode, - at most 63 octets (not 63 code points or 63 characters!) long. - (See: https://datatracker.ietf.org/doc/html/rfc6763#section-4.1.1) - Avahi will raise an exception through this function if it thinks - the new service name is invalid. - """ - async with _BUS_LOCK: - await asyncio.get_event_loop().run_in_executor( - None, _set_avahi_service_name_sync, new_service_name - ) +_log = logging.getLogger(__name__) async def restart_daemon() -> None: + """Restart the system's Avahi daemon and wait for it to come back up.""" proc = await asyncio.create_subprocess_exec( "systemctl", "restart", @@ -137,3 +50,280 @@ async def restart_daemon() -> None: f"stdout: {stdout!r} stderr: {stderr!r}" ) raise RuntimeError("Error restarting avahi") + + +class AvahiClient: + def __init__(self, sync_client: _SyncClient) -> None: + """For internal use by this class only. Use `connect()` instead.""" + self._sync_client = sync_client + self._lock = asyncio.Lock() + + @classmethod + async def connect(cls) -> AvahiClient: + sync_client = await asyncio.get_running_loop().run_in_executor( + executor=None, + func=_SyncClient.connect, + ) + return cls(sync_client=sync_client) + + async def start_advertising(self, service_name: str) -> None: + """Start advertising the machine over mDNS + DNS-SD. + + Since the Avahi service name corresponds to the DNS-SD instance name, + it's a human-readable string of mostly arbitrary Unicode, + at most SERVICE_NAME_MAXIMUM_OCTETS long. + + Avahi will raise an exception through this method if it thinks + the new service name is invalid. + + Avahi will stop advertising the machine when either of these happen: + + * This process dies. + * We have a name collision with another device on the network. + See `listen_for_collisions()`. + + It's safe to call this more than once. If we're already advertising, + the existing service name will be replaced with the new one. + """ + async with self._lock: + await asyncio.get_running_loop().run_in_executor( + None, self._sync_client.start_advertising, service_name + ) + + @contextlib.asynccontextmanager + async def listen_for_collisions( + self, callback: CollisionCallback + ) -> AsyncGenerator[None, None]: + """Be informed of name collisions. + + Background: + The Avahi service name and the static hostname that this client advertises + must be unique on the network. When Avahi detects a collision, it will stop + advertising until we fix the conflict by giving it a different name. + + While this context manager is entered, `callback()` will be called + whenever a collision happens. You should then call `start_advertising()` + with a new name to resume advertising. + + If `callback()` raises an exception, it's logged but otherwise ignored. + """ + # Ideally, instead of polling, we'd listen to the EntryGroup.StateChanged + # signal. But receiving signals through the dbus library requires a 3rd-party + # event loop like GLib's. + background_task = asyncio.create_task( + self._poll_infinitely_for_collision(callback=callback) + ) + + try: + yield + finally: + background_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await background_task + + async def _poll_infinitely_for_collision(self, callback: CollisionCallback) -> None: + while True: + is_collided = await self._is_collided() + + if is_collided: + try: + await callback() + except asyncio.CancelledError: + # If it raises CancelledError, this task is being cancelled. + # Let the CancelledError propagate and break the loop + # so we don't accidentally suppress the cancellation. + raise + except Exception: + # If it raises any other Exception, it's unexpected. + # Log it and keep polling. + _log.exception("Exception while handling Avahi name collision.") + + await asyncio.sleep(_COLLISION_POLL_INTERVAL) + + async def _is_collided(self) -> bool: + async with self._lock: + return await asyncio.get_running_loop().run_in_executor( + None, self._sync_client.is_collided + ) + + +CollisionCallback = Callable[[], Awaitable[None]] + + +# TODO(mm, 2022-06-08): Delegate to Avahi instead of implementing this ourselves +# when it becomes safe for our names to have spaces and number signs. +# See code history on https://github.com/Opentrons/opentrons/pull/10559. +def alternative_service_name(current_service_name: str) -> str: + """Return an alternative to the given Avahi service name. + + This is useful for fixing name collisions with other things on the network. + For example: + + alternative_service_name("Foo") == "FooNum2" + alternative_service_name("FooNum2") == "FooNum3" + + Appending incrementing integers like this, instead of using something like + the machine's serial number, follows a recommendation in the DNS-SD spec: + https://datatracker.ietf.org/doc/html/rfc6763#appendix-D + + We use "Num" as a separator because: + + * Having some separator is good to avoid accidentally "incrementing" names that are + serial numbers; we don't want OT2CEP20200827B10 to become OT2CEP20200827B11. + + * The Opentrons App is temporarily limiting user-input names to alphanumeric + characters while other server-side bugs are being resolved. + So, the names that we generate here should also be alphanumeric. + https://github.com/Opentrons/opentrons/issues/10214 + + * The number sign (#) character in particular breaks the Opentrons App, + so we especially can't use that. + https://github.com/Opentrons/opentrons/issues/10672 + """ + separator = "Num" + + pattern = f"(?P.*){separator}(?P[0-9]+)" + match = re.fullmatch(pattern=pattern, string=current_service_name, flags=re.DOTALL) + if match is None: + base_name = current_service_name + next_sequence_number = 2 + else: + base_name = match["base_name"] + next_sequence_number = int(match["sequence_number"]) + 1 + + new_suffix = f"{separator}{next_sequence_number}" + + octets_needed_for_suffix = len(new_suffix.encode("utf-8")) + octets_available_for_base_name = ( + SERVICE_NAME_MAXIMUM_OCTETS - octets_needed_for_suffix + ) + + if octets_available_for_base_name >= 0: + truncated_base_name = _truncate_by_utf8_octets( + s=base_name, maximum_utf8_octets=octets_available_for_base_name + ) + result = f"{truncated_base_name}{new_suffix}" + else: + # Edge case: The sequence number was already so high, incrementing it added a + # new digit that we don't have room for. Roll back to zero instead. + result = f"{separator}0" + + assert len(result.encode("utf-8")) <= SERVICE_NAME_MAXIMUM_OCTETS + return result + + +def _truncate_by_utf8_octets(s: str, maximum_utf8_octets: int) -> str: + """Truncate a string so it's no longer than the given number of octets when + encoded as UTF-8. + + This takes care to avoid truncating in the middle of a multi-byte code point. + So the output is always valid Unicode. + + However, it may split up multi-code-point sequences + that a human would inituitively see as a single character. + For example, it's possible to encode "é" as two code points: + + 1. U+0061 Latin Small Letter E + 2. U+0306 Combining Acute Accent + + And this may truncate in between them, leaving an unaccented "e" in the output. + (If we care to fix this, we should look into preserving grapheme clusters: + https://unicode.org/reports/tr29/) + """ + encoded = s.encode("utf-8")[:maximum_utf8_octets] + return encoded.decode( + "utf-8", + # Silently discard any trailing junk left over + # from truncating in the middle of a code point. + errors="ignore", + ) + + +class _SyncClient: + """A non-async wrapper around Avahi's D-Bus API. + + Since methods of this class do blocking I/O, they should be offloaded to a thread + and not called directly inside an async event loop. But they're not safe to call + concurrently from multiple threads, so the calls should be serialized with a lock. + """ + + # For general semantics of the methods we're calling on dbus proxies, + # see Avahi's API docs. For example: + # https://www.avahi.org/doxygen/html/index.html#good_publish + # It's mostly in terms of the C API, but the semantics should be the same. + + # For exact method names and argument types, see Avahi's D-Bus bindings, + # which they specify across several XML files. For example: + # https://github.com/lathiat/avahi/blob/v0.7/avahi-daemon/org.freedesktop.Avahi.EntryGroup.xml + + def __init__( + self, + bus: dbus.SystemBus, + server: dbus.Interface, + entry_group: dbus.Interface, + ) -> None: + """For internal use by this class only. Use `connect()` instead. + + Args: + bus: The system bus instance. + server: An org.freedesktop.Avahi.Server interface. + entry_group: An org.freedesktop.Avahi.EntryGroup interface. + """ + self._bus = bus + self._server = server + self._entry_group = entry_group + + @classmethod + def connect(cls) -> _SyncClient: + bus = dbus.SystemBus() + server_obj = bus.get_object("org.freedesktop.Avahi", "/") + server_if = dbus.Interface(server_obj, "org.freedesktop.Avahi.Server") + entry_group_path = server_if.EntryGroupNew() + entry_group_obj = bus.get_object("org.freedesktop.Avahi", entry_group_path) + entry_group_if = dbus.Interface( + entry_group_obj, "org.freedesktop.Avahi.EntryGroup" + ) + return cls( + bus=bus, + server=server_if, + entry_group=entry_group_if, + ) + + def start_advertising(self, service_name: str) -> None: + # TODO(mm, 2022-05-26): Can we leave these as the empty string? + # The avahi_entry_group_add_service() C API recommends passing NULL + # to let the daemon decide these values. + hostname = self._server.GetHostName() + domainname = self._server.GetDomainName() + + self._entry_group.Reset() + + # TODO(mm, 2022-05-06): Can this be made exception-safe? + # We've already reset the entry group, so if this fails + # (for example because Avahi doesn't like the new name), + # we'll be left with no entry group, + # and Avahi will stop advertising the machine. + self._entry_group.AddService( + dbus.Int32(-1), # avahi.IF_UNSPEC + dbus.Int32(-1), # avahi.PROTO_UNSPEC + dbus.UInt32(0), # flags + service_name, # sname + "_http._tcp", # stype + domainname, # sdomain (.local) + f"{hostname}.{domainname}", # shost (hostname.local) + dbus.UInt16(31950), # port + dbus.Array([], signature="ay"), + ) + + self._entry_group.Commit() + + def is_collided(self) -> bool: + state = self._entry_group.GetState() + + # "3" comes from AVAHI_ENTRY_GROUP_COLLISION being index 3 in this enum: + # https://github.com/lathiat/avahi/blob/v0.8/avahi-common/defs.h#L234 + avahi_entry_group_collision = dbus.Int32(3) + + # Cast for when type-checking is run on dev machines without dbus, + # where the type-checker will see this expression as `Any == Any`. + return cast(bool, state == avahi_entry_group_collision) diff --git a/update-server/otupdate/common/name_management/name_synchronizer.py b/update-server/otupdate/common/name_management/name_synchronizer.py new file mode 100644 index 00000000000..c2e061c4a40 --- /dev/null +++ b/update-server/otupdate/common/name_management/name_synchronizer.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from logging import getLogger +from typing import AsyncGenerator, Optional + +from aiohttp import web + +from otupdate.common.constants import APP_VARIABLE_PREFIX +from .avahi import AvahiClient, alternative_service_name +from .pretty_hostname import get_pretty_hostname, persist_pretty_hostname + + +_NAME_SYNCHRONIZER_VARNAME = APP_VARIABLE_PREFIX + "name_synchronizer" +_log = getLogger(__name__) + + +class NameSynchronizer: + """Keep the machine's human-readable names in sync with each other. + + This ties the pretty hostname and the Avahi service name together, + so they always have the same value. + + The `set_name()` and `get_name()` methods are intended for use by HTTP + endpoints, which makes for a total of three names tied together, + if you also count the name available over HTTP. + + See the `name_management` package docstring for an overview of these various names. + + We tie all of these names together because: + + * It's important to avoid confusing the client-side discovery client, + at least at the time of writing. + https://github.com/Opentrons/opentrons/issues/10199 + + * It helps maintain a conceptually simple interface. + There is one name accessible in three separate ways, + rather than three separate names. + + * It implements the DNS-SD spec's recommendation to make the DNS-SD instance name + configurable. https://datatracker.ietf.org/doc/html/rfc6763#section-4.1.1 + """ + + def __init__(self, avahi_client: AvahiClient) -> None: + """For internal use by this class only. Use `start()` instead.""" + self._avahi_client = avahi_client + + @classmethod + @asynccontextmanager + async def start( + cls, avahi_client: Optional[AvahiClient] = None + ) -> AsyncGenerator[NameSynchronizer, None]: + """Build a NameSynchronizer and keep it running in the background. + + Avahi advertisements will start as soon as this context manager is entered. + The pretty hostname will be used as the Avahi service name. + + While this context manager remains entered, Avahi will be monitored in the + background to see if this device's name ever collides with another device on + the network. If that ever happens, a new name will be chosen automatically, + which will be visible through `get_name()`. + + Collision monitoring will stop when this context manager exits. + + Args: + avahi_client: The interface for communicating with Avahi. + Changeable for testing this class; should normally be left as + the default. + """ + if avahi_client is None: + avahi_client = await AvahiClient.connect() + + name_synchronizer = cls(avahi_client) + async with avahi_client.listen_for_collisions( + callback=name_synchronizer._on_avahi_collision + ): + await avahi_client.start_advertising( + service_name=name_synchronizer.get_name() + ) + yield name_synchronizer + + async def set_name(self, new_name: str) -> str: + """Set the machine's human-readable name. + + This first sets the Avahi service name, and then persists it + as the pretty hostname. + + Returns the new name. This is normally the same as the requested name, + but it it might be different if it had to be truncated, sanitized, etc. + """ + await self._avahi_client.start_advertising(service_name=new_name) + # Setting the Avahi service name can fail if Avahi doesn't like the new name. + # Persist only after it succeeds, so we don't persist something invalid. + persisted_pretty_hostname = persist_pretty_hostname(new_name) + _log.info( + f"Changed name to {repr(new_name)}" + f" (persisted {repr(persisted_pretty_hostname)})." + ) + return persisted_pretty_hostname + + def get_name(self) -> str: + """Return the machine's current human-readable name. + + Note that this can change even if you haven't called `set_name()`, + if it was necessary to avoid conflicts with other devices on the network. + """ + return get_pretty_hostname() + + async def _on_avahi_collision(self) -> None: + current_name = self.get_name() + + # Assume that the service name was the thing that collided. + # Theoretically it also could have been the static hostname, + # but our static hostnames are unique in practice, so that's unlikely. + alternative_name = alternative_service_name(current_name) + _log.info( + f"Name collision detected by Avahi." + f" Changing name from {repr(current_name)} to {repr(alternative_name)}." + ) + + # Setting the new name includes persisting it for the next boot. + # + # Persisting the new name is recommended in the mDNS spec + # (https://datatracker.ietf.org/doc/html/rfc6762#section-9). + # It prevents two machines with the same name from flipping + # which one is #1 and which one is #2 every time they reboot. + await self.set_name(new_name=alternative_name) + + +def install_name_synchronizer( + name_synchronizer: NameSynchronizer, app: web.Application +) -> None: + """Install a NameSynchronizer on `app` for later retrieval + via get_name_synchronizer(). + + This should be done as part of server startup. + """ + app[_NAME_SYNCHRONIZER_VARNAME] = name_synchronizer + + +def get_name_synchronizer(request: web.Request) -> NameSynchronizer: + """Return the server's singleton NameSynchronizer from a request. + + The singleton NameSynchronizer is expected to have been installed on the + aiohttp.Application already via install_name_synchronizer(). + """ + name_synchronizer = request.app.get(_NAME_SYNCHRONIZER_VARNAME, None) + assert isinstance( + name_synchronizer, NameSynchronizer + ), f"Unexpected type {type(name_synchronizer)}. Incorrect Application setup?" + return name_synchronizer diff --git a/update-server/otupdate/common/name_management/pretty_hostname.py b/update-server/otupdate/common/name_management/pretty_hostname.py index 1be093af9c1..1a958a25e72 100644 --- a/update-server/otupdate/common/name_management/pretty_hostname.py +++ b/update-server/otupdate/common/name_management/pretty_hostname.py @@ -31,7 +31,7 @@ def get_pretty_hostname(default: str = "no name set") -> str: return default -async def persist_pretty_hostname(name: str) -> str: +def persist_pretty_hostname(name: str) -> str: """Change the robot's pretty hostname. Writes the new name to /etc/machine-info so it persists across reboots. diff --git a/update-server/otupdate/common/run_application.py b/update-server/otupdate/common/run_application.py new file mode 100644 index 00000000000..393372d6b84 --- /dev/null +++ b/update-server/otupdate/common/run_application.py @@ -0,0 +1,37 @@ +import asyncio +from typing import NoReturn + +from aiohttp import web + +from . import systemd + + +async def run_and_notify_up(app: web.Application, host: str, port: int) -> NoReturn: + """Run an aiohttp application. + + Once the application is up and running and serving requests, + notify systemd that this service has completed its startup. + + This method will only return once the task has been canceled, + such as if this process has been signaled to stop. + """ + runner = web.AppRunner(app) # type: ignore[no-untyped-call] + + await runner.setup() # type: ignore[no-untyped-call] + + try: + site = web.TCPSite( # type: ignore[no-untyped-call] + runner=runner, host=host, port=port + ) + await site.start() # type: ignore[no-untyped-call] + + systemd.notify_up() + + while True: + # It seems like there should be a better way of doing this, + # but this is what the docs recommend. + # https://docs.aiohttp.org/en/stable/web_advanced.html#application-runners + await asyncio.sleep(3600) + + finally: + await runner.cleanup() # type: ignore[no-untyped-call] diff --git a/update-server/otupdate/openembedded/__init__.py b/update-server/otupdate/openembedded/__init__.py index 7700783fa1b..89dcf0cacd9 100644 --- a/update-server/otupdate/openembedded/__init__.py +++ b/update-server/otupdate/openembedded/__init__.py @@ -9,8 +9,8 @@ config, constants, control, - ssh_key_management, name_management, + ssh_key_management, update, ) @@ -44,54 +44,33 @@ def get_version_dict(version_file: Optional[str]) -> Mapping[str, str]: def get_app( + name_synchronizer: name_management.NameSynchronizer, system_version_file: Optional[str] = None, config_file_override: Optional[str] = None, name_override: Optional[str] = None, boot_id_override: Optional[str] = None, - loop: Optional[asyncio.AbstractEventLoop] = None, ) -> web.Application: """Build and return the aiohttp.web.Application that runs the server""" if not system_version_file: system_version_file = OE_BUILTIN_VERSION_FILE version = get_version_dict(system_version_file) - - if not loop: - loop = asyncio.get_event_loop() - + boot_id = boot_id_override or control.get_boot_id() config_obj = config.load(config_file_override) app = web.Application(middlewares=[log_error_middleware]) - name = name_override or name_management.get_pretty_hostname() - boot_id = boot_id_override or control.get_boot_id() + app[config.CONFIG_VARNAME] = config_obj app[constants.RESTART_LOCK_NAME] = asyncio.Lock() app[constants.DEVICE_BOOT_ID_NAME] = boot_id - app[constants.DEVICE_NAME_VARNAME] = name - LOG.info( - "Setup: " - + "\n\t".join( - [ - f"Device name: {name}", - "Buildroot version: " - f'{version.get("buildroot_version", "unknown")}', - "\t(from git sha " f'{version.get("buildroot_sha", "unknown")}', - "API version: " - f'{version.get("opentrons_api_version", "unknown")}', - "\t(from git sha " - f'{version.get("opentrons_api_sha", "unknown")}', - "Update server version: " - f'{version.get("update_server_version", "unknown")}', - "\t(from git sha " - f'{version.get("update_server_sha", "unknown")}', - ] - ) - ) rfs = RootFSInterface() part_mgr = PartitionManager() updater = Updater(rfs, part_mgr) app[FILE_ACTIONS_VARNAME] = updater + + name_management.install_name_synchronizer(name_synchronizer, app) + app.router.add_routes( [ web.get( @@ -112,6 +91,27 @@ def get_app( web.get("/server/name", name_management.get_name_endpoint), ] ) + + LOG.info( + "Setup: " + + "\n\t".join( + [ + f"Device name: {name_synchronizer.get_name()}", + "Buildroot version: " + f'{version.get("buildroot_version", "unknown")}', + "\t(from git sha " f'{version.get("buildroot_sha", "unknown")}', + "API version: " + f'{version.get("opentrons_api_version", "unknown")}', + "\t(from git sha " + f'{version.get("opentrons_api_sha", "unknown")}', + "Update server version: " + f'{version.get("update_server_version", "unknown")}', + "\t(from git sha " + f'{version.get("update_server_sha", "unknown")}', + ] + ) + ) + return app diff --git a/update-server/otupdate/openembedded/__main__.py b/update-server/otupdate/openembedded/__main__.py index 3a34637c58c..1c912805b77 100644 --- a/update-server/otupdate/openembedded/__main__.py +++ b/update-server/otupdate/openembedded/__main__.py @@ -3,33 +3,42 @@ """ import asyncio import logging -import logging.config +from typing import NoReturn + from . import get_app + from otupdate.common import name_management, cli, systemd -from aiohttp import web +from otupdate.common.run_application import run_and_notify_up + LOG = logging.getLogger(__name__) -def main() -> None: +async def main() -> NoReturn: parser = cli.build_root_parser() args = parser.parse_args() - loop = asyncio.get_event_loop() systemd.configure_logging(getattr(logging, args.log_level.upper())) - LOG.info("Setting hostname") - hostname = loop.run_until_complete(name_management.set_up_static_hostname()) - LOG.info(f"Set hostname to {hostname}") - - LOG.info("Building openembedded update server") - app = get_app(args.version_file, args.config_file) + # Because this involves restarting Avahi, this must happen early, + # before the NameSynchronizer starts up and connects to Avahi. + LOG.info("Setting static hostname") + static_hostname = await name_management.set_up_static_hostname() + LOG.info(f"Set static hostname to {static_hostname}") - systemd.notify_up() + async with name_management.NameSynchronizer.start() as name_synchronizer: + LOG.info("Building openembedded update server") + app = get_app( + system_version_file=args.version_file, + config_file_override=args.config_file, + name_synchronizer=name_synchronizer, + ) - LOG.info(f"Starting openembedded update server on http://{args.host}:{args.port}") - web.run_app(app, host=args.host, port=args.port) # type: ignore[no-untyped-call] + LOG.info( + f"Starting openembedded update server on http://{args.host}:{args.port}" + ) + await run_and_notify_up(app=app, host=args.host, port=args.port) if __name__ == "__main__": - main() + asyncio.run(main()) diff --git a/update-server/pytest.ini b/update-server/pytest.ini index dd19983da9a..073cb34a9a4 100644 --- a/update-server/pytest.ini +++ b/update-server/pytest.ini @@ -11,4 +11,4 @@ markers = exclude_rootfs_ext4: check if theres no rootfs bad_cert_path: bad local cert for code signing no_cert_path: no local cert for code signing - bad_rootfs_ext4_hash: bad rootfs hash \ No newline at end of file + bad_rootfs_ext4_hash: bad rootfs hash diff --git a/update-server/tests/buildroot/conftest.py b/update-server/tests/buildroot/conftest.py index cfbbb0535b9..d31e699ef76 100644 --- a/update-server/tests/buildroot/conftest.py +++ b/update-server/tests/buildroot/conftest.py @@ -11,19 +11,22 @@ @pytest.fixture async def test_cli( - aiohttp_client, loop, otupdate_config, monkeypatch, version_file_path + aiohttp_client, + otupdate_config, + monkeypatch, + version_file_path, + mock_name_synchronizer, ): """ Build an app using dummy versions, then build a test client and return it """ app = buildroot.get_app( + name_synchronizer=mock_name_synchronizer, system_version_file=version_file_path, config_file_override=otupdate_config, - name_override="opentrons-test", boot_id_override="dummy-boot-id-abc123", - loop=loop, ) - client = await loop.create_task(aiohttp_client(app)) + client = await aiohttp_client(app) return client diff --git a/update-server/tests/buildroot/test_control.py b/update-server/tests/buildroot/test_control.py index b56c3149588..15e2afbf46b 100644 --- a/update-server/tests/buildroot/test_control.py +++ b/update-server/tests/buildroot/test_control.py @@ -2,14 +2,23 @@ from typing import Dict from aiohttp.test_utils import TestClient +from decoy import Decoy +from otupdate.common.name_management import NameSynchronizer -async def test_health(test_cli: TestClient, version_dict: Dict[str, str]): + +async def test_health( + test_cli: TestClient, + version_dict: Dict[str, str], + mock_name_synchronizer: NameSynchronizer, + decoy: Decoy, +): + decoy.when(mock_name_synchronizer.get_name()).then_return("test name") resp = await test_cli.get("/server/update/health") assert resp.status == 200 body = await resp.json() assert body == { - "name": "opentrons-test", + "name": "test name", "updateServerVersion": version_dict["update_server_version"], "apiServerVersion": version_dict["opentrons_api_version"], "smoothieVersion": "unimplemented", diff --git a/update-server/tests/buildroot/test_name_endpoints.py b/update-server/tests/buildroot/test_name_endpoints.py deleted file mode 100644 index f845ad3cd9a..00000000000 --- a/update-server/tests/buildroot/test_name_endpoints.py +++ /dev/null @@ -1,51 +0,0 @@ -from otupdate.common import name_management - - -async def test_name_endpoint(test_cli, monkeypatch): - async def fake_set_name(app, new_name): - return new_name + new_name - - monkeypatch.setattr(name_management, "set_name", fake_set_name) - - # Valid: - to_set = "check out this cool name" - resp = await test_cli.post("/server/name", json={"name": to_set}) - assert resp.status == 200 - body = await resp.json() - assert body["name"] == to_set + to_set - health = await test_cli.get("/server/update/health") - health_body = await health.json() - assert health_body["name"] == to_set + to_set - get_name = await test_cli.get("/server/name") - name_body = await get_name.json() - assert name_body["name"] == to_set + to_set - - # Invalid, field is not a str: - resp = await test_cli.post("/server/name", json={"name": 2}) - assert resp.status == 400 - body = await resp.json() - assert "message" in body - health = await test_cli.get("/server/update/health") - health_body = await health.json() - assert health_body["name"] == to_set + to_set - get_name = await test_cli.get("/server/name") - name_body = await get_name.json() - assert name_body["name"] == to_set + to_set - - # Invalid, field is missing: - resp = await test_cli.post("/server/name", json={}) - assert resp.status == 400 - body = await resp.json() - assert "message" in body - health = await test_cli.get("/server/update/health") - health_body = await health.json() - assert health_body["name"] == to_set + to_set - - # Invalid, body is not JSON: - resp = await test_cli.post("/server/name", data="bada bing bada boom") - assert resp.status == 400 - body = await resp.json() - assert "message" in body - health = await test_cli.get("/server/update/health") - health_body = await health.json() - assert health_body["name"] == to_set + to_set diff --git a/update-server/tests/common/conftest.py b/update-server/tests/common/conftest.py index 97a2edb8de0..1a3b1795cbd 100644 --- a/update-server/tests/common/conftest.py +++ b/update-server/tests/common/conftest.py @@ -23,20 +23,19 @@ @pytest.fixture(params=[openembedded, buildroot]) async def test_cli( - aiohttp_client, loop, otupdate_config, request, version_file_path + aiohttp_client, otupdate_config, request, version_file_path, mock_name_synchronizer ) -> Tuple[TestClient, str]: """ Build an app using dummy versions, then build a test client and return it """ cli_client_pkg = request.param app = cli_client_pkg.get_app( + name_synchronizer=mock_name_synchronizer, system_version_file=version_file_path, config_file_override=otupdate_config, - name_override="opentrons-test", boot_id_override="dummy-boot-id-abc123", - loop=loop, ) - client = await loop.create_task(aiohttp_client(app)) + client = await aiohttp_client(app) return client, cli_client_pkg.__name__ diff --git a/update-server/tests/common/name_management/test_avahi.py b/update-server/tests/common/name_management/test_avahi.py new file mode 100644 index 00000000000..c5f3fe16e6f --- /dev/null +++ b/update-server/tests/common/name_management/test_avahi.py @@ -0,0 +1,64 @@ +import pytest +from otupdate.common.name_management.avahi import ( + alternative_service_name, + SERVICE_NAME_MAXIMUM_OCTETS, +) + + +def test_alternative_service_name_sequencing() -> None: + """It should sequence through incrementing integers. + + "BaseName" -> "BaseNameNum2" -> "BaseNameNum3" etc. + """ + base_name = "BaseName" + current_name = base_name + for sequence_number in range(2, 150): + current_name = alternative_service_name(current_name) + assert current_name == f"{base_name}Num{sequence_number}" + + +@pytest.mark.parametrize( + ("current_name", "expected_alternative_name"), + [ + ( + # 63 ASCII bytes, exactly at the maximum. + "123456789012345678901234567890123456789012345678901234567890123", + # 59 bytes from the first 59 characters + # + 4 bytes for the new "Num2" suffix + # = 63 bytes total, exactly at the maximum. + "12345678901234567890123456789012345678901234567890123456789Num2", + ), + ( + # 21 kanji + # * 3 UTF-8 bytes per kanji + # = 63 UTF-8 bytes, exactly at the maximum. + "名名名名名名名名名名名名名名名名名名名名名", + # 57 bytes from the first 19 kanji + # + 4 bytes for the new "Num2" suffix + # = 61 bytes total, just below the maximum of 63. + # + # It's correct for it to be below the maximum because an additional 3-byte + # kanji code point would put us over the limit, and we shouldn't truncate + # in the middle of a code point. + "名名名名名名名名名名名名名名名名名名名Num2", + ), + ( + # 63 ASCII bytes, exactly at the maximum, + # and no room left to increment the sequence number. + # Pathological, but maybe someone "clever" set the name to this manually. + "Num999999999999999999999999999999999999999999999999999999999999", + # Rolled back to 0. + # This is an arbitrary solution and it can change according to whatever's + # easiest to implement. + "Num0", + ), + ], +) +def test_alternative_service_name_truncation( + current_name: str, + expected_alternative_name: str, +) -> None: + """When appending a sequence number, it should not let the name become too long.""" + alternative_name = alternative_service_name(current_name) + assert alternative_name == expected_alternative_name + assert len(alternative_name.encode("utf-8")) <= SERVICE_NAME_MAXIMUM_OCTETS diff --git a/update-server/tests/common/name_management/test_http_endpoints.py b/update-server/tests/common/name_management/test_http_endpoints.py new file mode 100644 index 00000000000..497078b5ecd --- /dev/null +++ b/update-server/tests/common/name_management/test_http_endpoints.py @@ -0,0 +1,35 @@ +async def test_get_name(test_cli, mock_name_synchronizer, decoy) -> None: + decoy.when(mock_name_synchronizer.get_name()).then_return("the returned name") + + response = await (test_cli[0].get("/server/name")) + assert response.status == 200 + + body = await response.json() + assert body["name"] == "the returned name" + + +async def test_set_name_valid(test_cli, mock_name_synchronizer, decoy) -> None: + decoy.when(await mock_name_synchronizer.set_name("the input name")).then_return( + "the returned name" + ) + + response = await test_cli[0].post("/server/name", json={"name": "the input name"}) + assert response.status == 200 + + body = await response.json() + assert body["name"] == "the returned name" + + +async def test_set_name_not_json(test_cli) -> None: + response = await test_cli[0].post("/server/name", data="bada bing bada boom") + assert response.status == 400 + + +async def test_set_name_field_missing(test_cli) -> None: + response = await test_cli[0].post("/server/name", json={}) + assert response.status == 400 + + +async def test_set_name_field_not_a_str(test_cli) -> None: + response = await test_cli[0].post("/server/name", json={"name": 2}) + assert response.status == 400 diff --git a/update-server/tests/common/name_management/test_name_synchronizer.py b/update-server/tests/common/name_management/test_name_synchronizer.py new file mode 100644 index 00000000000..ba185ca584d --- /dev/null +++ b/update-server/tests/common/name_management/test_name_synchronizer.py @@ -0,0 +1,228 @@ +import asyncio +from typing import Any, AsyncGenerator, Callable + +import pytest +from decoy import Decoy, matchers + +from otupdate.common.name_management import name_synchronizer +from otupdate.common.name_management.name_synchronizer import NameSynchronizer +from otupdate.common.name_management.avahi import ( + AvahiClient, + alternative_service_name as real_alternative_service_name, +) +from otupdate.common.name_management.pretty_hostname import ( + get_pretty_hostname as real_get_pretty_hostname, + persist_pretty_hostname as real_persist_pretty_hostname, +) + + +_GET_PRETTY_HOSTNAME_SIGNATURE = Callable[[], str] +_PERSIST_PRETTY_HOSTNAME_SIGNATURE = Callable[[str], str] +_ALTERNATIVE_SERVICE_NAME_SIGNATURE = Callable[[str], str] + + +@pytest.fixture +def mock_get_pretty_hostname(decoy: Decoy) -> _GET_PRETTY_HOSTNAME_SIGNATURE: + """Return a mock in the shape of the get_pretty_hostname() function.""" + return decoy.mock(func=real_get_pretty_hostname) + + +@pytest.fixture +def monkeypatch_get_pretty_hostname( + monkeypatch: Any, mock_get_pretty_hostname: _GET_PRETTY_HOSTNAME_SIGNATURE +) -> None: + """Replace the real get_pretty_hostname() with our mock.""" + monkeypatch.setattr( + name_synchronizer, "get_pretty_hostname", mock_get_pretty_hostname + ) + + +@pytest.fixture +def mock_persist_pretty_hostname(decoy: Decoy) -> _PERSIST_PRETTY_HOSTNAME_SIGNATURE: + """Return a mock in the shape of the persist_pretty_hostname() function.""" + return decoy.mock(func=real_persist_pretty_hostname) + + +@pytest.fixture +def monkeypatch_persist_pretty_hostname( + monkeypatch: Any, mock_persist_pretty_hostname: _PERSIST_PRETTY_HOSTNAME_SIGNATURE +) -> None: + """Replace the real persist_pretty_hostname with our mock.""" + monkeypatch.setattr( + name_synchronizer, "persist_pretty_hostname", mock_persist_pretty_hostname + ) + + +@pytest.fixture +def mock_alternative_service_name(decoy: Decoy) -> _ALTERNATIVE_SERVICE_NAME_SIGNATURE: + """Return a mock in the shape of the alternative_service_name() function.""" + return decoy.mock(func=real_alternative_service_name) + + +@pytest.fixture +def monkeypatch_alternative_service_name( + monkeypatch: Any, mock_alternative_service_name: _ALTERNATIVE_SERVICE_NAME_SIGNATURE +) -> None: + """Replace the real alternative_service_name() with our mock.""" + monkeypatch.setattr( + name_synchronizer, "alternative_service_name", mock_alternative_service_name + ) + + +@pytest.fixture +def mock_avahi_client(decoy: Decoy) -> AvahiClient: + """Return a mock in the shape of an AvahiClient.""" + return decoy.mock(cls=AvahiClient) + + +@pytest.fixture +async def started_up_subject( + monkeypatch_get_pretty_hostname: None, + monkeypatch_persist_pretty_hostname: None, + mock_avahi_client: AvahiClient, + decoy: Decoy, + loop: asyncio.AbstractEventLoop, # Required by aiohttp for async fixtures. +) -> AsyncGenerator[NameSynchronizer, None]: + """Return a subject NameSynchronizer that's set up with mock dependencies, + and that's already started up and running. + + Tests that need to operate in the subject's startup phase itself + should build the subject manually instead of using this fixture. + """ + # NameSynchronizer.start() will call mock_avahi_client.listen_for_collisions() + # and expect to be able to enter its result as a context manager. + decoy.when( + mock_avahi_client.listen_for_collisions(matchers.Anything()) + ).then_enter_with(None) + + async with NameSynchronizer.start(avahi_client=mock_avahi_client) as subject: + yield subject + + +async def test_set( + started_up_subject: NameSynchronizer, + mock_persist_pretty_hostname: _PERSIST_PRETTY_HOSTNAME_SIGNATURE, + mock_avahi_client: AvahiClient, + decoy: Decoy, +) -> None: + """It should set the new name as the Avahi service name and then as the pretty + hostname. + """ + await started_up_subject.set_name("new name") + + decoy.verify( + await mock_avahi_client.start_advertising("new name"), + mock_persist_pretty_hostname("new name"), + ) + + +async def test_set_does_not_persist_invalid_avahi_service_name( + started_up_subject: NameSynchronizer, + mock_persist_pretty_hostname: _PERSIST_PRETTY_HOSTNAME_SIGNATURE, + mock_avahi_client: AvahiClient, + decoy: Decoy, +) -> None: + """It should not persist any name that's not valid as an Avahi service name. + + Covers this bug: + https://github.com/Opentrons/opentrons/issues/9960 + """ + decoy.when(await mock_avahi_client.start_advertising("danger!")).then_raise( + Exception("oh the humanity") + ) + + with pytest.raises(Exception, match="oh the humanity"): + await started_up_subject.set_name("danger!") + + decoy.verify(mock_persist_pretty_hostname(matchers.Anything()), times=0) + + +async def test_get( + started_up_subject: NameSynchronizer, + mock_get_pretty_hostname: _GET_PRETTY_HOSTNAME_SIGNATURE, + decoy: Decoy, +) -> None: + decoy.when(mock_get_pretty_hostname()).then_return("the current name") + assert started_up_subject.get_name() == "the current name" + + +async def test_advertises_initial_name( + started_up_subject: NameSynchronizer, + mock_get_pretty_hostname: _GET_PRETTY_HOSTNAME_SIGNATURE, + monkeypatch_get_pretty_hostname: None, + mock_persist_pretty_hostname: _PERSIST_PRETTY_HOSTNAME_SIGNATURE, + monkeypatch_persist_pretty_hostname: None, + mock_avahi_client: AvahiClient, + decoy: Decoy, +) -> None: + """It should immediately start advertising the existing pretty hostname + as the Avahi service name, when it's started up. + """ + + decoy.when(mock_get_pretty_hostname()).then_return("initial name") + mock_collision_subscription_context_manager = decoy.mock() + decoy.when( + mock_avahi_client.listen_for_collisions(matchers.Anything()) + ).then_return(mock_collision_subscription_context_manager) + + async with NameSynchronizer.start(avahi_client=mock_avahi_client): + decoy.verify( + # It should only start advertising after subscribing to collisions. + await mock_collision_subscription_context_manager.__aenter__(), + await mock_avahi_client.start_advertising("initial name"), + ) + + +async def test_collision_handling( + mock_get_pretty_hostname: _GET_PRETTY_HOSTNAME_SIGNATURE, + monkeypatch_get_pretty_hostname: None, + mock_persist_pretty_hostname: _PERSIST_PRETTY_HOSTNAME_SIGNATURE, + monkeypatch_persist_pretty_hostname: None, + mock_alternative_service_name: _ALTERNATIVE_SERVICE_NAME_SIGNATURE, + monkeypatch_alternative_service_name: None, + mock_avahi_client: AvahiClient, + decoy: Decoy, +) -> None: + """It should resolve naming collisions reported by the AvahiClient. + + When notified of a collision, it should: + + 1. Get a new service name. + 2. Start advertising with it. + 3. Persist it as the pretty hostname. + """ + decoy.when(mock_get_pretty_hostname()).then_return("initial name") + decoy.when(mock_alternative_service_name("initial name")).then_return( + "alternative name" + ) + + # Expect the subject to do: + # + # with avahi_client.listen_for_collisions(some_callback_func): + # ... + # + # When it does, save the function that it provided as `some_callback_func` + # into `collision_callback_captor.value`. + mock_listen_context_manager = decoy.mock() + collision_callback_captor = matchers.Captor() + decoy.when( + mock_avahi_client.listen_for_collisions(collision_callback_captor) + ).then_return(mock_listen_context_manager) + + async with NameSynchronizer.start(avahi_client=mock_avahi_client): + captured_collision_callback = collision_callback_captor.value + # Prompt the subject to run its collision-handling logic. + await captured_collision_callback() + # Exit this context manager to wait for the subject to wrap up cleanly, + # ensuring its collision handling has run to completion before we assert stuff. + + decoy.verify( + await mock_listen_context_manager.__aenter__(), + await mock_avahi_client.start_advertising("initial name"), + # The subject should only persist the alternative name *after* + # the Avahi client accepts it for advertisement, + # just in case the alternative name turns out to be invalid in some way. + # https://github.com/Opentrons/opentrons/issues/9960 + await mock_avahi_client.start_advertising("alternative name"), + mock_persist_pretty_hostname("alternative name"), + ) diff --git a/update-server/tests/conftest.py b/update-server/tests/conftest.py index 9c1e76a982d..319d608b252 100644 --- a/update-server/tests/conftest.py +++ b/update-server/tests/conftest.py @@ -8,6 +8,10 @@ from typing import Dict import pytest +from decoy import Decoy + +from otupdate.common.name_management import NameSynchronizer + HERE = os.path.abspath(os.path.dirname(__file__)) @@ -143,3 +147,9 @@ def otupdate_config(request, tmpdir, testing_cert): conf.update({"update_cert_path": testing_cert}) json.dump(conf, open(path, "w")) return path + + +@pytest.fixture +def mock_name_synchronizer(decoy: Decoy) -> NameSynchronizer: + """Return a mock in the shape of a NameSynchronizer.""" + return decoy.mock(cls=NameSynchronizer) diff --git a/update-server/tests/openembedded/conftest.py b/update-server/tests/openembedded/conftest.py index 81bd12dbf5c..b2db3b9c23c 100644 --- a/update-server/tests/openembedded/conftest.py +++ b/update-server/tests/openembedded/conftest.py @@ -27,19 +27,18 @@ def mock_update_actions_interface( @pytest.fixture async def test_cli( - aiohttp_client, loop, otupdate_config, version_file_path + aiohttp_client, otupdate_config, version_file_path, mock_name_synchronizer ) -> TestClient: """ Build an app using dummy versions, then build a test client and return it """ app = openembedded.get_app( + name_synchronizer=mock_name_synchronizer, system_version_file=version_file_path, config_file_override=otupdate_config, - name_override="opentrons-test", boot_id_override="dummy-boot-id-abc123", - loop=loop, ) - client = await loop.create_task(aiohttp_client(app)) + client = await aiohttp_client(app) return client diff --git a/update-server/tests/openembedded/test_control.py b/update-server/tests/openembedded/test_control.py index e76afc29f49..4112371a7d6 100644 --- a/update-server/tests/openembedded/test_control.py +++ b/update-server/tests/openembedded/test_control.py @@ -3,14 +3,23 @@ from typing import Dict from aiohttp.test_utils import TestClient +from decoy import Decoy +from otupdate.common.name_management import NameSynchronizer -async def test_health(test_cli: TestClient, version_dict: Dict[str, str]): + +async def test_health( + test_cli: TestClient, + version_dict: Dict[str, str], + mock_name_synchronizer: NameSynchronizer, + decoy: Decoy, +): + decoy.when(mock_name_synchronizer.get_name()).then_return("test name") resp = await test_cli.get("/server/update/health") assert resp.status == 200 body = await resp.json() assert body == { - "name": "opentrons-test", + "name": "test name", "updateServerVersion": version_dict["update_server_version"], "apiServerVersion": version_dict["opentrons_api_version"], "systemVersion": version_dict["openembedded_version"],