Skip to content

Commit

Permalink
Finalize models, type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
creatorcary committed Jun 28, 2024
1 parent e411ddb commit c901452
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/middlewared/middlewared/api/base/excluded.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
__all__ = ["Excluded", "excluded_field"]


class ExcludedField(Any):
class ExcludedField:
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
Expand Down
51 changes: 27 additions & 24 deletions src/middlewared/middlewared/api/v25_04_0/api_key.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,48 @@
from datetime import datetime
from typing import Literal
from typing import Literal, TypeAlias
from typing_extensions import Annotated

from middlewared.api.base import BaseModel, NonEmptyString
from pydantic import ConfigDict, StringConstraints

__all__ = [
"ApiKeyCreateArgs",
"ApiKeyCreateResult",
"ApiKeyUpdateArgs",
"ApiKeyUpdateResult",
"ApiKeyDeleteArgs",
"ApiKeyDeleteResult",
]
from middlewared.api.base import BaseModel, Excluded, excluded_field, NonEmptyString, Private


HttpVerb: TypeAlias = Literal["GET", "POST", "PUT", "DELETE", "CALL", "SUBSCRIBE", "*"]


class AllowListItem(BaseModel):
method: Literal["GET", "POST", "PUT", "DELETE", "CALL", "SUBSCRIBE", "*"]
method: HttpVerb
resource: NonEmptyString


class ApiKeyCreate(BaseModel):
name: NonEmptyString
class ApiKeyEntry(BaseModel):
"""Represents a record in the account.api_key table."""
#: This allows the model to be created from an instance of plugins/api_key.APIKeyModel
model_config = ConfigDict(from_attributes=True)

id: int
name: Annotated[NonEmptyString, StringConstraints(max_length=200)]
key: Private[str]
created_at: datetime
allowlist: list[AllowListItem]


class ApiKeyCreate(ApiKeyEntry):
id: Excluded = excluded_field()
key: Excluded = excluded_field()
created_at: Excluded = excluded_field()


class ApiKeyCreateArgs(BaseModel):
api_key_create: ApiKeyCreate


class ApiKeyCreateResult(ApiKeyCreate):
"""Represents a record in the account.api_key table."""

id: int
key: str
created_at: datetime
class ApiKeyCreateResult(BaseModel):
result: ApiKeyEntry


class ApiKeyUpdate(ApiKeyCreate):
reset: bool
update: bool = True


class ApiKeyUpdateArgs(BaseModel):
Expand All @@ -46,14 +51,12 @@ class ApiKeyUpdateArgs(BaseModel):


class ApiKeyUpdateResult(BaseModel):
# Needs implemented
pass
result: ApiKeyEntry


class ApiKeyDeleteArgs(BaseModel):
id: int


class ApiKeyDeleteResult(BaseModel):
# Needs implemented
pass
result: Literal[True]
29 changes: 13 additions & 16 deletions src/middlewared/middlewared/plugins/api_key.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from datetime import datetime
import random
import string
from typing import Any, TYPE_CHECKING
from typing import Literal, TYPE_CHECKING

from passlib.hash import pbkdf2_sha256

from middlewared.api import api_method
from middlewared.api.current import (
ApiKeyCreateArgs, ApiKeyCreateResult, ApiKeyUpdateArgs,
ApiKeyUpdateResult, ApiKeyDeleteArgs, ApiKeyDeleteResult
ApiKeyUpdateResult, ApiKeyDeleteArgs, ApiKeyDeleteResult,
ApiKeyCreate, ApiKeyEntry, ApiKeyUpdate, HttpVerb,
NonEmptyString
)
from middlewared.service import CRUDService, private, ValidationErrors
import middlewared.sqlalchemy as sa
Expand All @@ -28,31 +30,26 @@ class APIKeyModel(sa.Model):


class ApiKey:
def __init__(self, api_key: dict[str, list[dict[str, str]]]):
def __init__(self, api_key: ApiKeyEntry):
self.api_key = api_key
self.allowlist = Allowlist(self.api_key["allowlist"])

def authorize(self, method: str, resource: str):
def authorize(self, method: HttpVerb, resource: NonEmptyString) -> bool:
return self.allowlist.authorize(method, resource)


class ApiKeyService(CRUDService):

keys: dict[int, dict[str, list[dict[str, str]]]] = {}
keys: dict[int, ApiKeyEntry] = {}

class Config:
namespace = "api_key"
datastore = "account.api_key"
datastore_extend = "api_key.item_extend"
cli_namespace = "auth.api_key"

@private
async def item_extend(self, item: dict):
item.pop("key")
return item

@api_method(ApiKeyCreateArgs, ApiKeyCreateResult)
async def do_create(self, data: dict):
async def do_create(self, data: ApiKeyCreate) -> ApiKeyEntry:
"""
Creates API Key.
Expand All @@ -76,7 +73,7 @@ async def do_create(self, data: dict):
return self._serve(data, key)

@api_method(ApiKeyUpdateArgs, ApiKeyUpdateResult)
async def do_update(self, id_, data: dict):
async def do_update(self, id_: int, data: ApiKeyUpdate) -> ApiKeyEntry:
"""
Update API Key `id`.
Expand Down Expand Up @@ -108,7 +105,7 @@ async def do_update(self, id_, data: dict):
return self._serve(await self.get_instance(id_), key)

@api_method(ApiKeyDeleteArgs, ApiKeyDeleteResult)
async def do_delete(self, id_: int):
async def do_delete(self, id_: int) -> Literal[True]:
"""
Delete API Key `id`.
"""
Expand Down Expand Up @@ -139,7 +136,7 @@ async def load_key(self, id_: int):
)

@private
async def authenticate(self, key: str):
async def authenticate(self, key: str) -> ApiKey | None:
try:
key_id, key = key.split("-", 1)
key_id = int(key_id)
Expand All @@ -156,7 +153,7 @@ async def authenticate(self, key: str):

return ApiKey(db_key)

async def _validate(self, schema_name: str, data: dict, id_=None):
async def _validate(self, schema_name: str, data: ApiKeyEntry, id_: int=None):
verrors = ValidationErrors()

await self._ensure_unique(verrors, schema_name, "name", data["name"], id_)
Expand All @@ -166,7 +163,7 @@ async def _validate(self, schema_name: str, data: dict, id_=None):
def _generate(self):
return "".join([random.SystemRandom().choice(string.ascii_letters + string.digits) for _ in range(64)])

def _serve(self, data: dict, key: Any | None):
def _serve(self, data: ApiKeyEntry, key: str | None) -> ApiKeyEntry:
if key is None:
return data

Expand Down
14 changes: 8 additions & 6 deletions src/middlewared/middlewared/utils/allowlist.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import fnmatch
import re

ALLOW_LIST_FULL_ADMIN = {'method': '*', 'resource': '*'}
from middlewared.api.current import AllowListItem, HttpVerb, NonEmptyString

ALLOW_LIST_FULL_ADMIN: AllowListItem = {'method': '*', 'resource': '*'}


class Allowlist:
def __init__(self, allowlist: list[dict[str, str]]):
self.exact: dict[str, set[str]] = {}
def __init__(self, allowlist: list[AllowListItem]):
self.exact: dict[HttpVerb, set[NonEmptyString]] = {}
self.full_admin = ALLOW_LIST_FULL_ADMIN in allowlist
self.patterns: dict[str, list[re.Pattern]] = {}
self.patterns: dict[HttpVerb, list[re.Pattern]] = {}
for entry in allowlist:
method = entry["method"]
resource = entry["resource"]
Expand All @@ -19,10 +21,10 @@ def __init__(self, allowlist: list[dict[str, str]]):
self.exact.setdefault(method, set())
self.exact[method].add(resource)

def authorize(self, method: str, resource: str):
def authorize(self, method: HttpVerb, resource: NonEmptyString):
return self._authorize_internal("*", resource) or self._authorize_internal(method, resource)

def _authorize_internal(self, method: str, resource: str):
def _authorize_internal(self, method: HttpVerb, resource: NonEmptyString):
if (exact := self.exact.get(method)) and resource in exact:
return True

Expand Down

0 comments on commit c901452

Please sign in to comment.