Skip to content

Commit

Permalink
refactor(update-server): Type update-server more strictly (#10458)
Browse files Browse the repository at this point in the history
SyntaxColoring authored May 26, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 929ef3f commit faacde2
Showing 19 changed files with 262 additions and 166 deletions.
2 changes: 1 addition & 1 deletion update-server/Makefile
Original file line number Diff line number Diff line change
@@ -53,9 +53,9 @@ test:

.PHONY: lint
lint:
$(python) -m mypy otupdate
$(python) -m black --check ./otupdate ./tests
$(python) -m flake8 otupdate tests
$(python) -m mypy otupdate

.PHONY: format
format:
2 changes: 1 addition & 1 deletion update-server/Pipfile
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@ coverage = "==5.1"
# https://github.com/pypa/pipenv/issues/4408#issuecomment-668324177
atomicwrites = {version="==1.4.0", sys_platform="== 'win32'"}
colorama = {version="==0.4.4", sys_platform="== 'win32'"}
mypy = "==0.910"
mypy = "==0.940"
black = "==22.3.0"

[requires]
175 changes: 85 additions & 90 deletions update-server/Pipfile.lock

Large diffs are not rendered by default.

53 changes: 52 additions & 1 deletion update-server/mypy.ini
Original file line number Diff line number Diff line change
@@ -1,3 +1,54 @@
[mypy]
ignore_missing_imports = True
strict = True
show_error_codes = True


# The dbus and systemd packages will not be installed in non-Linux dev environments.
# Permit mypy to find them missing wherever we try importing them.
[mypy-dbus.*]
ignore_missing_imports=True
[mypy-systemd.*]
ignore_missing_imports=True


# TODO(mm, 2022-05-25): Resolve the typing errors in these files
# and remove these overrides when able.

# ~6 errors
[mypy-otupdate.common.control]
disallow_untyped_defs= False
disallow_untyped_calls = False
warn_return_any = False

# ~8 errors
[mypy-otupdate.common.session]
disallow_untyped_defs= False
disallow_untyped_calls = False
warn_return_any = False

# ~28 errors
[mypy-otupdate.common.update]
disallow_untyped_defs= False
disallow_untyped_calls = False
warn_return_any = False

# ~ 17 errors
[mypy-otupdate.buildroot]
disallow_untyped_defs= False
disallow_untyped_calls = False
warn_return_any = False

# ~5 errors
[mypy-otupdate.buildroot.update_actions]
disallow_untyped_defs= False
disallow_untyped_calls = False

# ~16 errors
[mypy-otupdate.openembedded]
disallow_untyped_defs= False
disallow_untyped_calls = False

# ~5 errors
[mypy-otupdate.openembedded.updater]
disallow_untyped_defs= False
disallow_untyped_calls = False
12 changes: 6 additions & 6 deletions update-server/otupdate/buildroot/__init__.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import asyncio
import logging
import json
from typing import Mapping, Any
from typing import Any, Mapping, Optional
from aiohttp import web

from otupdate.common import (
@@ -38,11 +38,11 @@ def get_version(version_file: str) -> Mapping[str, str]:


def get_app(
system_version_file: str = None,
config_file_override: str = None,
name_override: str = None,
boot_id_override: str = None,
loop: asyncio.AbstractEventLoop = None,
system_version_file: Optional[str] = None,
config_file_override: Optional[str] = None,
name_override: Optional[str] = None,
boot_id_override: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> web.Application:
"""Build and return the aiohttp.web.Application that runs the server
2 changes: 1 addition & 1 deletion update-server/otupdate/buildroot/__main__.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,7 @@ def main() -> None:
systemd.notify_up()

LOG.info(f"Starting buildroot update server on http://{args.host}:{args.port}")
web.run_app(app, host=args.host, port=args.port)
web.run_app(app, host=args.host, port=args.port) # type: ignore[no-untyped-call]


if __name__ == "__main__":
12 changes: 6 additions & 6 deletions update-server/otupdate/buildroot/update_actions.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
import re
import subprocess
import tempfile
from typing import Callable, Optional
from typing import Callable, Generator, Optional

from otupdate.common.file_actions import (
unzip_update,
@@ -95,7 +95,7 @@ def write_update(
rootfs_filepath: str,
progress_callback: Callable[[float], None],
chunk_size: int = 1024,
file_size: int = None,
file_size: Optional[int] = None,
) -> Partition:
"""
Write the new rootfs to the next root partition
@@ -121,7 +121,7 @@ def write_update(
return unused.value

@contextlib.contextmanager
def mount_update(self):
def mount_update(self) -> Generator[str, None, None]:
"""Mount the freshly-written partition r/w (to update machine-id).
Should be used as a context manager, and the yielded value is the path
@@ -150,7 +150,7 @@ def commit_update(self) -> None:
else:
LOG.info(f"commit_update: committed to booting {new}")

def write_machine_id(self, current_root: str, new_root: str):
def write_machine_id(self, current_root: str, new_root: str) -> None:
"""Update the machine id in target rootfs"""
mid = open(os.path.join(current_root, "etc", "machine-id")).read()
with open(os.path.join(new_root, "etc", "machine-id"), "w") as new_mid:
@@ -169,8 +169,8 @@ def write_file(
outfile: str,
progress_callback: Callable[[float], None],
chunk_size: int = 1024,
file_size: int = None,
):
file_size: Optional[int] = None,
) -> None:
"""Write a file to another file with progress callbacks.
:param infile: The input filepath
6 changes: 3 additions & 3 deletions update-server/otupdate/common/config.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import os
import logging
import json
from typing import Any, Dict, Mapping, NamedTuple, Optional, Tuple
from typing import Any, Dict, Mapping, NamedTuple, Optional, Tuple, cast

from aiohttp.web import Request

@@ -38,7 +38,7 @@ class Config(NamedTuple):


def config_from_request(req: Request) -> Config:
return req.app[CONFIG_VARNAME]
return cast(Config, req.app[CONFIG_VARNAME])


def _ensure_load(path: str) -> Optional[Mapping[str, Any]]:
@@ -116,7 +116,7 @@ def _get_path(args_path: Optional[str]) -> str:
return DEFAULT_PATH


def load(args_path: str = None) -> Config:
def load(args_path: Optional[str] = None) -> Config:
"""
Load the config file, selecting the appropriate path from many sources
"""
20 changes: 10 additions & 10 deletions update-server/otupdate/common/file_actions.py
Original file line number Diff line number Diff line change
@@ -16,38 +16,38 @@


class FileMissing(ValueError):
def __init__(self, message):
def __init__(self, message: str) -> None:
self.message = message
self.short = "File Missing"

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.message}>"

def __str__(self):
def __str__(self) -> str:
return self.message


class SignatureMismatch(ValueError):
def __init__(self, message):
def __init__(self, message: str) -> None:
self.message = message
self.short = "Signature Mismatch"

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.message}>"

def __str__(self):
def __str__(self) -> str:
return self.message


class HashMismatch(ValueError):
def __init__(self, message):
def __init__(self, message: str) -> None:
self.message = message
self.short = "Hash Mismatch"

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.message}>"

def __str__(self):
def __str__(self) -> str:
return self.message


@@ -144,7 +144,7 @@ def hash_file(
path: str,
progress_callback: Callable[[float], None],
chunk_size: int = 1024,
file_size: int = None,
file_size: Optional[int] = None,
algo: str = "sha256",
) -> bytes:
"""
12 changes: 12 additions & 0 deletions update-server/otupdate/common/handler_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing_extensions import Protocol
from aiohttp import web


class Handler(Protocol):
"""The type signature of an aiohttp request handler function.
Useful for typing function decorators that operate on aiohttp request handlers.
"""

async def __call__(self, request: web.Request) -> web.Response:
...
18 changes: 11 additions & 7 deletions update-server/otupdate/common/name_management.py
Original file line number Diff line number Diff line change
@@ -83,7 +83,7 @@ def __init__(

_BUS_STATE: Optional[DBusState] = None

def _set_avahi_service_name_sync(new_service_name: str):
def _set_avahi_service_name_sync(new_service_name: str) -> None:
"""The synchronous implementation of setting the Avahi service name.
The dbus module doesn't natively support async/await.
@@ -136,7 +136,7 @@ def _set_avahi_service_name_sync(new_service_name: str):
except ImportError:
LOG.exception("Couldn't import dbus, name setting will be nonfunctional")

def _set_avahi_service_name_sync(new_service_name: str):
def _set_avahi_service_name_sync(new_service_name: str) -> None:
LOG.warning("Not setting name, dbus could not be imported")


@@ -249,7 +249,7 @@ async def set_up_static_hostname() -> str:
return hostname


def _rewrite_machine_info(new_pretty_hostname: str):
def _rewrite_machine_info(new_pretty_hostname: str) -> None:
"""Write a new value for the pretty hostname.
:raises OSError: If the new value could not be written.
@@ -287,7 +287,7 @@ def _rewrite_machine_info_str(
return new_contents


def get_pretty_hostname(default: str = "no name set"):
def get_pretty_hostname(default: str = "no name set") -> str:
"""Get the currently-configured pretty hostname"""
try:
with open("/etc/machine-info") as emi:
@@ -356,7 +356,9 @@ async def set_name_endpoint(request: web.Request) -> web.Response:
"""

def build_400(msg: str) -> web.Response:
return web.json_response(data={"message": msg}, status=400)
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={"message": msg}, status=400
)

try:
body = await request.json()
@@ -376,7 +378,9 @@ def build_400(msg: str) -> web.Response:
new_name = await set_name(app=request.app, new_name=name_to_set)

request.app[DEVICE_NAME_VARNAME] = new_name
return web.json_response(data={"name": new_name}, status=200)
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={"name": new_name}, status=200
)


async def get_name_endpoint(request: web.Request) -> web.Response:
@@ -387,6 +391,6 @@ async def get_name_endpoint(request: web.Request) -> web.Response:
GET /server/name -> 200 OK, {'name': robot name}
"""
return web.json_response(
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={"name": request.app[DEVICE_NAME_VARNAME]}, status=200
)
8 changes: 4 additions & 4 deletions update-server/otupdate/common/session.py
Original file line number Diff line number Diff line change
@@ -47,23 +47,23 @@ def __init__(self, storage_path: str) -> None:
self._rootfs_file: Optional[str] = None
LOG.info(f"update session: created {self._token}")

def _setup_dl_area(self):
def _setup_dl_area(self) -> None:
if os.path.exists(self._storage_path):
shutil.rmtree(self._storage_path)
os.makedirs(self._storage_path, mode=0o700, exist_ok=True)

def __del__(self):
def __del__(self) -> None:
if hasattr(self, "_storage_path"):
shutil.rmtree(self._storage_path)
LOG.info(f"Update session: removed {getattr(self, '_token', '<unknown>')}")

def set_stage(self, stage: Stages):
def set_stage(self, stage: Stages) -> None:
"""Convenience method to set the stage and lookup message"""
assert stage in Stages
LOG.info(f"Update session: stage {self._stage.name}->{stage.name}")
self._stage = stage

def set_error(self, error_shortmsg: str, error_longmsg: str):
def set_error(self, error_shortmsg: str, error_longmsg: str) -> None:
"""Set the stage to error and add a message"""
LOG.error(
f"Update session: error in stage {self._stage.name}: "
38 changes: 26 additions & 12 deletions update-server/otupdate/common/ssh_key_management.py
Original file line number Diff line number Diff line change
@@ -7,15 +7,23 @@
import ipaddress
import logging
import os
from typing import List, Tuple
from typing import (
Any,
Generator,
IO,
List,
Tuple,
)

from aiohttp import web

from .handler_type import Handler


LOG = logging.getLogger(__name__)


def require_linklocal(handler):
def require_linklocal(handler: Handler) -> Handler:
"""Ensure the decorated is only called if the request is linklocal.
The host ip address should be in the X-Host-IP header (provided by nginx)
@@ -35,23 +43,27 @@ async def decorated(request: web.Request) -> web.Response:
),
}
if not ipaddr_str:
return web.json_response(data=invalid_req_data, status=403)
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data=invalid_req_data, status=403
)
try:
addr = ipaddress.ip_address(ipaddr_str)
except ValueError:
LOG.exception(f"Couldn't parse host ip address {ipaddr_str}")
raise

if not addr.is_link_local:
return web.json_response(data=invalid_req_data, status=403)
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data=invalid_req_data, status=403
)

return await handler(request)

return decorated


@contextlib.contextmanager
def authorized_keys(mode="r"):
def authorized_keys(mode: str = "r") -> Generator[IO[Any], None, None]:
"""Open the authorized_keys file. Separate function for mocking.
:param mode: As :py:meth:`open`
@@ -74,7 +86,7 @@ def get_keys() -> List[Tuple[str, str]]:
]


def remove_by_hash(hashval: str):
def remove_by_hash(hashval: str) -> None:
"""Remove the key whose md5 sum matches hashval.
:raises: KeyError if the hashval wasn't found
@@ -106,7 +118,7 @@ async def list_keys(request: web.Request) -> web.Response:
(or 403 if not from the link-local connection)
"""
return web.json_response(
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
{
"public_keys": [
{"key_md5": details[0], "key": details[1]} for details in get_keys()
@@ -127,7 +139,9 @@ async def add(request: web.Request) -> web.Response:
"""

def key_error(error: str, message: str) -> web.Response:
return web.json_response(data={"error": error, "message": message}, status=400)
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={"error": error, "message": message}, status=400
)

body = await request.json()

@@ -162,7 +176,7 @@ def key_error(error: str, message: str) -> web.Response:
with authorized_keys("a") as ak:
ak.write(f"{pubkey}\n")

return web.json_response(
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={"message": f"Added key {hashval}", "key_md5": hashval}, status=201
)

@@ -179,7 +193,7 @@ async def clear(request: web.Request) -> web.Response:
with authorized_keys("w") as ak:
ak.write("\n".join([]) + "\n")

return web.json_response(
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={
"message": "Keys cleared. " "Restart robot to take effect",
"restart_url": "/server/restart",
@@ -206,7 +220,7 @@ async def remove(request: web.Request) -> web.Response:
new_keys.append(key)

if not found:
return web.json_response(
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={
"error": "invalid-key-hash",
"message": f"No such key md5 {requested_hash}",
@@ -217,7 +231,7 @@ async def remove(request: web.Request) -> web.Response:
with authorized_keys("w") as ak:
ak.write("\n".join(new_keys) + "\n")

return web.json_response(
return web.json_response( # type: ignore[no-untyped-call,no-any-return]
data={
"message": f"Key {requested_hash} deleted. " "Restart robot to take effect",
"restart_url": "/server/restart",
11 changes: 6 additions & 5 deletions update-server/otupdate/common/systemd.py
Original file line number Diff line number Diff line change
@@ -3,13 +3,14 @@
"""

import logging.config
from typing import Dict, Union

try:
# systemd journal is available, we can use its handler
import systemd.journal
import systemd.daemon

def log_handler(topic_name: str, log_level: int):
def log_handler(topic_name: str, log_level: int) -> Dict[str, Union[int, str]]:
return {
"class": "systemd.journal.JournalHandler",
"formatter": "message_only",
@@ -23,28 +24,28 @@ def log_handler(topic_name: str, log_level: int):
# dependent services until we actually say we're ready. By calling this
# after we change the hostname, we make anything with an After= on us
# be guaranteed to see the correct hostname
def notify_up():
def notify_up() -> None:
systemd.daemon.notify("READY=1")

SOURCE: str = "systemd"

except ImportError:
# systemd journal isn't available, probably running tests

def log_handler(topic_name: str, log_level: int):
def log_handler(topic_name: str, log_level: int) -> Dict[str, Union[int, str]]:
return {
"class": "logging.StreamHandler",
"formatter": "basic",
"level": log_level,
}

def notify_up():
def notify_up() -> None:
pass

SOURCE = "dummy"


def configure_logging(level: int):
def configure_logging(level: int) -> None:
config = {
"version": 1,
"formatters": {
24 changes: 19 additions & 5 deletions update-server/otupdate/common/update.py
Original file line number Diff line number Diff line change
@@ -12,11 +12,13 @@
from subprocess import CalledProcessError

from typing import Optional
from typing_extensions import Protocol

from aiohttp import web, BodyPartReader

from .constants import APP_VARIABLE_PREFIX, RESTART_LOCK_NAME
from . import config, update_actions
from .constants import APP_VARIABLE_PREFIX, RESTART_LOCK_NAME
from .handler_type import Handler
from .session import UpdateSession, Stages

from otupdate.openembedded.updater import UPDATE_PKG
@@ -25,11 +27,23 @@
LOG = logging.getLogger(__name__)


class _HandlerWithSession(Protocol):
"""The type signature of an aiohttp request handler that also has a session arg.
See require_session().
"""

async def __call__(
self, request: web.Request, session: UpdateSession
) -> web.Response:
...


def session_from_request(request: web.Request) -> Optional[UpdateSession]:
return request.app.get(SESSION_VARNAME, None)


def require_session(handler):
def require_session(handler: _HandlerWithSession) -> Handler:
"""Decorator to ensure a session is properly in the request"""

@functools.wraps(handler)
@@ -77,7 +91,7 @@ async def status(request: web.Request, session: UpdateSession) -> web.Response:
return web.json_response(data=session.state, status=200)


async def _save_file(part: BodyPartReader, path: str):
async def _save_file(part: BodyPartReader, path: str) -> None:
# making sure directory exists first
Path(path).mkdir(parents=True, exist_ok=True)
with open(os.path.join(path, part.name), "wb") as write:
@@ -97,7 +111,7 @@ def _begin_write(
loop: asyncio.AbstractEventLoop,
rootfs_file_path: str,
actions: update_actions.UpdateActionsInterface,
):
) -> None:
"""Start the write process."""
session.set_progress(0)
session.set_stage(Stages.WRITING)
@@ -126,7 +140,7 @@ def _begin_validation(
loop: asyncio.AbstractEventLoop,
downloaded_update_path: str,
actions: update_actions.UpdateActionsInterface,
) -> asyncio.futures.Future:
) -> "asyncio.futures.Future[Optional[str]]":
"""Start the validation process."""
session.set_stage(Stages.VALIDATING)
cert_path = config.update_cert_path if config.signature_required else None
15 changes: 10 additions & 5 deletions update-server/otupdate/common/update_actions.py
Original file line number Diff line number Diff line change
@@ -3,9 +3,11 @@
update actions
"""

from __future__ import annotations

import abc
import contextlib
from typing import NamedTuple, Optional, Callable, Iterator
from typing import Callable, Generator, NamedTuple, Optional, cast
from aiohttp import web

from .constants import APP_VARIABLE_PREFIX
@@ -27,12 +29,15 @@ class Partition(NamedTuple):

class UpdateActionsInterface:
@staticmethod
def from_request(request: web.Request) -> Optional["UpdateActionsInterface"]:
def from_request(request: web.Request) -> Optional[UpdateActionsInterface]:
"""Get the update object from the aiohttp app store"""
return request.app.get(FILE_ACTIONS_VARNAME, None)
try:
return cast(UpdateActionsInterface, request.app[FILE_ACTIONS_VARNAME])
except KeyError:
return None

@classmethod
def build_and_insert(cls, app: web.Application):
def build_and_insert(cls, app: web.Application) -> None:
"""Build the object and put it in the app store"""
app[FILE_ACTIONS_VARNAME] = cls()

@@ -74,7 +79,7 @@ def write_update(

@abc.abstractmethod
@contextlib.contextmanager
def mount_update(self) -> Iterator:
def mount_update(self) -> Generator[str, None, None]:
"""
Mount the fs to overwrite with the update
"""
10 changes: 5 additions & 5 deletions update-server/otupdate/openembedded/__init__.py
Original file line number Diff line number Diff line change
@@ -44,11 +44,11 @@ def get_version_dict(version_file: Optional[str]) -> Mapping[str, str]:


def get_app(
system_version_file: str = None,
config_file_override: str = None,
name_override: str = None,
boot_id_override: str = None,
loop: asyncio.AbstractEventLoop = None,
system_version_file: Optional[str] = None,
config_file_override: Optional[str] = None,
name_override: Optional[str] = None,
boot_id_override: Optional[str] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> web.Application:
"""Build and return the aiohttp.web.Application that runs the server"""
if not system_version_file:
6 changes: 3 additions & 3 deletions update-server/otupdate/openembedded/__main__.py
Original file line number Diff line number Diff line change
@@ -11,15 +11,15 @@
LOG = logging.getLogger(__name__)


def main():
def main() -> None:
parser = cli.build_root_parser()
args = parser.parse_args()
loop = asyncio.get_event_loop()

systemd.configure_logging(getattr(logging, args.log_level.upper()))

LOG.info("Setting hostname")
hostname = loop.run_until_complete(name_management.setup_hostname())
hostname = loop.run_until_complete(name_management.set_up_static_hostname())
LOG.info(f"Set hostname to {hostname}")

LOG.info("Building openembedded update server")
@@ -28,7 +28,7 @@ def main():
systemd.notify_up()

LOG.info(f"Starting openembedded update server on http://{args.host}:{args.port}")
web.run_app(app, host=args.host, port=args.port)
web.run_app(app, host=args.host, port=args.port) # type: ignore[no-untyped-call]


if __name__ == "__main__":
2 changes: 1 addition & 1 deletion update-server/otupdate/openembedded/updater.py
Original file line number Diff line number Diff line change
@@ -233,7 +233,7 @@ def write_update(
rootfs_filepath: str,
progress_callback: Callable[[float], None],
chunk_size: int = 1024,
file_size: int = None,
file_size: Optional[int] = None,
) -> Partition:
self.decomp_and_write(rootfs_filepath, progress_callback)
unused_partition = self.part_mngr.find_unused_partition(

0 comments on commit faacde2

Please sign in to comment.