Skip to content

Commit

Permalink
Merge pull request #57 from nebulabroadcast/chore/more_strict_type_va…
Browse files Browse the repository at this point in the history
…lidation

More strict type validation
  • Loading branch information
martastain authored Mar 24, 2024
2 parents fdb6a98 + bcf099f commit c47d8b1
Show file tree
Hide file tree
Showing 42 changed files with 291 additions and 250 deletions.
18 changes: 12 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
IMAGE_NAME=nebulabroadcast/nebula-server:dev
VERSION=$(shell cd backend && poetry run python -c 'import nebula' --version)

check: check_version
cd frontend && yarn format
check:
cd frontend && \
yarn format

cd backend && \
poetry version $(VERSION) && \
poetry run ruff format . && \
poetry run ruff check --fix . && \
poetry run mypy .

check_version:
cd backend && poetry version $(VERSION)

build:
build: check
docker build -t $(IMAGE_NAME) .

dist: build
docker push $(IMAGE_NAME)

setup-hooks:
@echo "Setting up Git hooks..."
@mkdir -p .git/hooks
@echo '#!/bin/sh\n\n# Navigate to the repository root directory\ncd "$$(git rev-parse --show-toplevel)"\n\n# Execute the linting command from the Makefile\nmake check\n\n# Check the return code of the make command\nif [ $$? -ne 0 ]; then\n echo "Linting failed. Commit aborted."\n exit 1\nfi\n\n# If everything is fine, allow the commit\nexit 0' > .git/hooks/pre-commit
@chmod +x .git/hooks/pre-commit
@echo "Git hooks set up successfully."
15 changes: 7 additions & 8 deletions backend/api/init/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal

import fastapi
from pydantic import Field
Expand All @@ -16,14 +16,14 @@


class InitResponseModel(ResponseModel):
installed: bool = Field(
installed: Literal[True] | None = Field(
True,
title="Installed",
description="Is Nebula installed?",
)

motd: str = Field(
"",
motd: str | None = Field(
None,
title="Message of the day",
description="Server welcome string (displayed on login page)",
)
Expand Down Expand Up @@ -51,7 +51,6 @@ class InitResponseModel(ResponseModel):
default_factory=list,
title="OAuth2 options",
)
something: str | None = Field(None)


class Request(APIRequest):
Expand Down Expand Up @@ -81,11 +80,11 @@ async def handle(
if not nebula.settings.installed:
await load_settings()
if not nebula.settings.installed:
return InitResponseModel(installed=False) # type: ignore
return InitResponseModel(installed=False)

# Not logged in. Only return motd and oauth2 options.
if user is None:
return InitResponseModel(motd=motd) # type: ignore
return InitResponseModel(motd=motd)

# TODO: get preferred user language
lang: LanguageCode = user.language
Expand All @@ -102,4 +101,4 @@ async def handle(
settings=client_settings,
frontend_plugins=plugins,
scoped_endpoints=server_context.scoped_endpoints,
) # type: ignore
)
3 changes: 1 addition & 2 deletions backend/api/init/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import nebula
from nebula.enum import ContentType
from nebula.filetypes import FileTypes
from nebula.settings.common import LanguageCode
from nebula.settings.common import LanguageCode, SettingsModel
from nebula.settings.models import (
BasePlayoutChannelSettings,
BaseSystemSettings,
CSItemRole,
FolderSettings,
SettingsModel,
ViewSettings,
)

Expand Down
4 changes: 2 additions & 2 deletions backend/api/playout/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ class PlayoutPluginSlot(ResponseModel):
value: Any = None

@property
def title(self):
self.name.replace("_", " ").title()
def title(self) -> str:
return self.name.replace("_", " ").title()


class PlayoutPluginManifest(ResponseModel):
Expand Down
2 changes: 1 addition & 1 deletion backend/api/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def get_bytes_range(file_name: str, start: int, end: int) -> bytes:


def _get_range_header(range_header: str, file_size: int) -> tuple[int, int]:
def _invalid_range():
def _invalid_range() -> HTTPException:
return HTTPException(
status.HTTP_416_REQUESTED_RANGE_NOT_SATISFIABLE,
detail=f"Invalid request range (Range:{range_header!r})",
Expand Down
2 changes: 1 addition & 1 deletion backend/api/rundown/rundown.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def get_rundown(request: RundownRequestModel) -> RundownResponseModel:
id_bin=id_bin,
id_event=id_event,
meta=emeta,
) # type: ignore
)

ts_scheduled = row.scheduled_time
if last_event is None:
Expand Down
4 changes: 3 additions & 1 deletion backend/api/scheduler/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from pydantic import Field

from server.models import RequestModel, ResponseModel
Expand Down Expand Up @@ -95,7 +97,7 @@ class SchedulerResponseModel(ResponseModel):
examples=[[134, 135, 136]],
)

events: list[dict] = Field(
events: list[dict[str, Any]] = Field(
default_factory=list,
title="Events",
description="List of events",
Expand Down
4 changes: 2 additions & 2 deletions backend/api/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ async def create_new_event(
channel: PlayoutChannelSettings,
event_data: EventData,
user: nebula.User | None = None,
):
) -> None:
"""Create a new event from the given data."""

username = user.name if user else None
Expand Down Expand Up @@ -152,7 +152,7 @@ async def scheduler(
break
else:
# no primary asset found, so append it
new_item = nebula.Item(usename=username)
new_item = nebula.Item(username=username)
new_item["id_asset"] = event_data.id_asset
new_item["id_bin"] = ex_bin.id
new_item["position"] = len(ex_bin.items)
Expand Down
10 changes: 8 additions & 2 deletions backend/api/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from nebula.common import import_module
from nebula.enum import ObjectType
from nebula.helpers.scheduling import bin_refresh
from nebula.objects.base import BaseObject
from nebula.objects.utils import get_object_class_by_name
from nebula.settings import load_settings
from server.dependencies import CurrentUser
Expand Down Expand Up @@ -113,9 +114,14 @@ class OperationsResponseModel(ResponseModel):
#


async def can_modify_object(obj, user: nebula.User):
async def can_modify_object(obj: BaseObject, user: nebula.User) -> None:
"""Check if user can modify an object.
Raises ForbiddenException if user is not allowed to modify the object.
"""

if user.is_admin:
return True
return

if isinstance(obj, nebula.Asset):
acl = user.get("can/asset_edit", False)
Expand Down
2 changes: 1 addition & 1 deletion backend/api/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ async def handle(
request: Request,
asset: AssetInPath,
user: CurrentUser,
):
) -> None:
assert asset["media_type"] == MediaType.FILE, "Only file assets can be uploaded"
extension = request.headers.get("X-nebula-extension")
assert extension, "Missing X-nebula-extension header"
Expand Down
2 changes: 1 addition & 1 deletion backend/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from nebula.plugins.library import plugin_library


def main():
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("plugin")
parser.add_argument("args", nargs=argparse.REMAINDER)
Expand Down
63 changes: 0 additions & 63 deletions backend/mypy.ini

This file was deleted.

5 changes: 3 additions & 2 deletions backend/nebula/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"ValidationException",
# Plugins
"CLIPlugin",
"__version__",
]

import sys
Expand Down Expand Up @@ -69,15 +70,15 @@
log.level = LogLevel[config.log_level.upper()]


def run(entrypoint):
def run(entrypoint) -> None: # type: ignore
"""Run a coroutine in the event loop.
This function is used to run the main entrypoint of CLI scripts.
It loads the settings and starts the event loop and runs a
given entrypoint coroutine.
"""

async def run_async():
async def run_async() -> None:
await load_settings()
await entrypoint

Expand Down
2 changes: 1 addition & 1 deletion backend/nebula/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

T = TypeVar("T", bound=type)

SerializableValue = int | float | str | bool | dict | list | None
SerializableValue = int | float | str | bool | dict[str, Any] | list[Any] | None


def json_loads(data: str | bytes) -> Any:
Expand Down
26 changes: 16 additions & 10 deletions backend/nebula/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import AsyncGenerator
__all__ = ["db", "DB", "DatabaseConnection"]

from typing import Any, AsyncGenerator

import asyncpg
import asyncpg.pool
Expand All @@ -9,53 +11,55 @@


class DB:
_pool: asyncpg.pool.Pool | None = None
_pool: asyncpg.pool.Pool | None = None # type: ignore

async def init_connection(self, conn):
async def init_connection(self, conn) -> None: # type: ignore
await conn.set_type_codec(
"jsonb",
encoder=json_dumps,
decoder=json_loads,
schema="pg_catalog",
)

async def connect(self):
async def connect(self) -> None:
"""Create a Postgres connection pool."""
self._pool = await asyncpg.create_pool(
config.postgres,
init=self.init_connection,
)
assert self._pool is not None

async def pool(self) -> asyncpg.pool.Pool:
async def pool(self) -> asyncpg.pool.Pool: # type: ignore
"""Return the Postgres connection pool. If it doesn't exist, create it."""
if self._pool is None:
await self.connect()
if self._pool is None:
raise NebulaException("Unable to connect to database")
return self._pool

async def execute(self, query: str, *args) -> str:
async def execute(self, query: str, *args: Any) -> str:
"""Execute a query and return the status."""
pool = await self.pool()
return await pool.execute(query, *args)

async def executemany(self, query: str, *args) -> None:
async def executemany(self, query: str, *args: Any) -> None:
"""Execute a query multiple times and return the result."""
pool = await self.pool()
await pool.executemany(query, *args)

async def fetch(self, query: str, *args) -> list[asyncpg.Record]:
async def fetch(self, query: str, *args: Any) -> list[asyncpg.Record]:
"""Fetch a query and return the result."""
pool = await self.pool()
return await pool.fetch(query, *args)

async def fetchrow(self, query: str, *args) -> asyncpg.Record | None:
async def fetchrow(self, query: str, *args: Any) -> asyncpg.Record | None:
"""Fetch a query and return the first result."""
pool = await self.pool()
return await pool.fetchrow(query, *args)

async def iterate(self, query: str, *args) -> AsyncGenerator[asyncpg.Record, None]:
async def iterate(
self, query: str, *args: Any
) -> AsyncGenerator[asyncpg.Record, None]:
"""Iterate over a query and yield the result."""
pool = await self.pool()
async with pool.acquire() as conn, conn.transaction():
Expand All @@ -64,4 +68,6 @@ async def iterate(self, query: str, *args) -> AsyncGenerator[asyncpg.Record, Non
yield record


DatabaseConnection = asyncpg.pool.PoolConnectionProxy | DB # type: ignore

db = DB()
4 changes: 3 additions & 1 deletion backend/nebula/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from nebula.log import log as logger


Expand All @@ -17,7 +19,7 @@ def __init__(
detail: str | None = None,
log: bool | str = False,
user_name: str | None = None,
**kwargs,
**kwargs: Any,
) -> None:
self.kwargs = kwargs

Expand Down
Loading

0 comments on commit c47d8b1

Please sign in to comment.