diff --git a/mini_kms.py b/mini_kms.py index 379d0fd..b865b91 100644 --- a/mini_kms.py +++ b/mini_kms.py @@ -3,9 +3,10 @@ from contextlib import asynccontextmanager import json import logging -from typing import Any, List, Optional, Tuple, cast +from typing import Any, List, Mapping, Optional, Sequence, Set, Tuple, cast +from uuid import uuid4 -from aries_askar import AskarError, AskarErrorCode, Key, KeyAlg, Store +from aries_askar import AskarError, AskarErrorCode, Entry, Key, KeyAlg, Store import base58 from fastapi import Depends, FastAPI, Header, Request, status from fastapi.responses import JSONResponse @@ -397,3 +398,242 @@ async def sign( sig = key.sign_message(req.data) return SigResp(sig=Base64UrlEncoder.encode(sig)) + + +class VCRecord(BaseModel): + """Credential storage request.""" + + contexts: Set[str] + expanded_types: Set[str] + issuer_id: str + subject_ids: Set[str] + schema_ids: Set[str] + proof_types: Set[str] + cred_value: Mapping + given_id: Optional[str] = None + cred_tags: Optional[Mapping] = None + record_id: Optional[str] = None + + +class CredStoreResult(BaseModel): + """Result of credential storage.""" + + record_id: str + + +VC_HOLDER_CAT = "vc-holder" + + +@app.post( + "/vc-holder/store", tags=["vc-holder"], response_description="Stored credential id" +) +async def store_credential( + cred: VCRecord, + profile: str = Header(default=DEFAULT_PROFILE, alias=PROFILE_HEADER), + store: Store = Depends(store), +): + """Store a credential.""" + tags = { + attr: value + for attr in ( + "contexts", + "expanded_types", + "schema_ids", + "subject_ids", + "proof_types", + "issuer_id", + "given_id", + ) + if (value := getattr(cred, attr)) + } + for tagname, tagval in (cred.cred_tags or {}).items(): + tags[f"cstm:{tagname}"] = tagval + + record_id = cred.record_id or str(uuid4()) + async with store.session(profile=profile) as txn: + await txn.insert( + category=VC_HOLDER_CAT, name=record_id, tags=tags, value_json=cred.cred_value + ) + return CredStoreResult(record_id=record_id) + + +def entry_to_vc_record(entry: Entry) -> VCRecord: + """Convert an Askar stored entry into a VC record.""" + tags = cast(dict, entry.tags) + cred_tags = {name[5:]: value for name, value in tags if name.startswith("cstm:")} + contexts = tags.get("contexts", set()) + types = tags.get("expanded_types", set()) + schema_ids = tags.get("schema_ids", set()) + subject_ids = tags.get("subject_ids", set()) + proof_types = tags.get("proof_types", set()) + issuer_id = tags.get("issuer_id") + if not isinstance(issuer_id, str): + raise ValueError("issuer_id must be str") + given_id = tags.get("given_id") + return VCRecord( + contexts=contexts, + expanded_types=types, + schema_ids=schema_ids, + issuer_id=issuer_id, + subject_ids=subject_ids, + proof_types=proof_types, + cred_value=json.loads(entry.value), + given_id=given_id, + cred_tags=cred_tags, + record_id=cast(str, entry.name), + ) + + +@app.get( + "/vc-holder/credential/record/{record_id}", + tags=["vc-holder"], + response_description="Retrieved credential", +) +async def retrieve_credential_by_id( + record_id: str, + profile: str = Header(default=DEFAULT_PROFILE, alias=PROFILE_HEADER), + store: Store = Depends(store), +) -> VCRecord: + """Retrieve a credential by id.""" + async with store.session(profile=profile) as txn: + entry = await txn.fetch(VC_HOLDER_CAT, record_id) + if not entry: + raise ProblemDetailsException.NotFound( + f"No credential record found for id {record_id}" + ) + + return entry_to_vc_record(entry) + + +@app.get( + "/vc-holder/credential/given/{record_id}", + tags=["vc-holder"], + response_description="Retrieved credential", +) +async def retrieve_credential_by_given_id( + given_id: str, + profile: str = Header(default=DEFAULT_PROFILE, alias=PROFILE_HEADER), + store: Store = Depends(store), +) -> VCRecord: + """Retrieve a credential by id.""" + async with store.session(profile=profile) as txn: + entries = await txn.fetch_all(VC_HOLDER_CAT, {"given_id": given_id}, limit=2) + if not entries: + raise ProblemDetailsException.NotFound( + f"No credential record found for given id {given_id}" + ) + + if len(entries) > 1: + raise ProblemDetailsException.BadRequest( + f"Duplicate record found for given id {given_id}" + ) + + return entry_to_vc_record(entries[0]) + + +@app.delete( + "/vc-holder/credential/record/{record_id}", + tags=["vc-holder"], + response_description="Retrieved credential", +) +async def delete_credential( + record_id: str, + profile: str = Header(default=DEFAULT_PROFILE, alias=PROFILE_HEADER), + store: Store = Depends(store), +) -> None: + """Delete a credential.""" + async with store.session(profile=profile) as txn: + # TODO error handling + await txn.remove(VC_HOLDER_CAT, record_id) + + +class VCRecords(BaseModel): + """Records from a search.""" + + records: List[VCRecord] + + +def build_type_or_schema_query(uri_list: Sequence[str]) -> dict: + """Build and return indy-specific type_or_schema_query.""" + type_or_schema_query: dict[str, Any] = {} + for uri in uri_list: + q = {"$or": [{"type": uri}, {"schema": uri}]} + if type_or_schema_query: + if "$and" not in type_or_schema_query: + type_or_schema_query = {"$and": [type_or_schema_query]} + type_or_schema_query["$and"].append(q) + else: + type_or_schema_query = q + return type_or_schema_query + + +class CredSearchReq(BaseModel): + """Credential search request body.""" + + contexts: Optional[List[str]] = None + types: Optional[List[str]] = None + schema_ids: Optional[List[str]] = None + issuer_id: Optional[str] = None + subject_ids: Optional[str] = None + proof_types: Optional[List[str]] = None + given_id: Optional[str] = None + tag_query: Optional[Mapping] = None + pd_uri_list: Optional[List[str]] = None + offset: int = 0 + limit: int = 10 + + +@app.post( + "/vc-holder/credentials", + tags=["vc-holder"], + response_description="Retrieved credentials", +) +async def search_credentials( # noqa: C901 + req: CredSearchReq, + profile: str = Header(default=DEFAULT_PROFILE, alias=PROFILE_HEADER), + store: Store = Depends(store), +) -> VCRecords: + """Search for credentials.""" + offset = req.offset or 0 + offset = 0 if offset < 0 else offset + limit = req.limit or 10 + limit = 50 if limit > 50 else limit + + def _match_any(query: list, k, vals): + if vals is None: + pass + elif len(vals) > 1: + query.append({"$or": [{k: v for v in vals}]}) + else: + query.append({k: vals[0]}) + + def _make_custom_query(query): + result = {} + for k, v in query.items(): + if isinstance(v, (list, set)) and k != "$exist": + result[k] = [_make_custom_query(cl) for cl in v] + elif k.startswith("$"): + result[k] = v + else: + result[f"cstm:{k}"] = v + return result + + query = [] + _match_any(query, "contexts", req.contexts) + _match_any(query, "expanded_types", req.types) + _match_any(query, "schema_ids", req.schema_ids) + _match_any(query, "subject_ids", req.subject_ids) + _match_any(query, "proof_types", req.proof_types) + if req.issuer_id: + query.append({"issuer_id": req.issuer_id}) + if req.given_id: + query.append({"given_id": req.given_id}) + if req.tag_query: + query.append(_make_custom_query(req.tag_query)) + if req.pd_uri_list: + query.append(build_type_or_schema_query(req.pd_uri_list)) + + query = {"$and": query} if query else {} + scan = store.scan(VC_HOLDER_CAT, query, offset=offset, limit=limit, profile=profile) + entries = await scan.fetch_all() + return VCRecords(records=[entry_to_vc_record(entry) for entry in entries])