Skip to content

Commit

Permalink
catch common exceptions in routes and raise web errors instead of 500…
Browse files Browse the repository at this point in the history
…: server got itself in trouble

Signed-off-by: sklump <[email protected]>
  • Loading branch information
sklump committed Jun 2, 2020
1 parent d95c9de commit 76a06f8
Show file tree
Hide file tree
Showing 23 changed files with 657 additions and 251 deletions.
24 changes: 14 additions & 10 deletions aries_cloudagent/holder/indy.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,17 @@ async def get_credential(self, credential_id: str) -> str:
credential_json = await indy.anoncreds.prover_get_credential(
self.wallet.handle, credential_id
)
except IndyError as e:
if e.error_code == ErrorCode.WalletItemNotFound:
except IndyError as err:
if err.error_code == ErrorCode.WalletItemNotFound:
raise WalletNotFoundError(
"Credential not found in the wallet: {}".format(credential_id)
"Credential {} not found in wallet {}".format(
credential_id, self.wallet.name
)
)
else:
raise IndyErrorHandler.wrap_error(
e, "Error when fetching credential", HolderError
) from e
err, f"Error when fetching credential {credential_id}", HolderError
) from err

return credential_json

Expand All @@ -277,15 +279,17 @@ async def delete_credential(self, credential_id: str):
await indy.anoncreds.prover_delete_credential(
self.wallet.handle, credential_id
)
except IndyError as e:
if e.error_code == ErrorCode.WalletItemNotFound:
except IndyError as err:
if err.error_code == ErrorCode.WalletItemNotFound:
raise WalletNotFoundError(
"Credential not found in the wallet: {}".format(credential_id)
"Credential {} not found in wallet {}".format(
credential_id, self.wallet.name
)
)
else:
raise IndyErrorHandler.wrap_error(
e, "Error when deleting credential", HolderError
) from e
err, "Error when deleting credential", HolderError
) from err

async def get_mime_type(
self, credential_id: str, attr: str = None
Expand Down
12 changes: 6 additions & 6 deletions aries_cloudagent/holder/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ async def credentials_get(request: web.BaseRequest):
holder: BaseHolder = await context.inject(BaseHolder)
try:
credential = await holder.get_credential(credential_id)
except WalletNotFoundError:
raise web.HTTPNotFound()
except WalletNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err

credential_json = json.loads(credential)
return web.json_response(credential_json)
Expand All @@ -140,8 +140,8 @@ async def credentials_remove(request: web.BaseRequest):
holder: BaseHolder = await context.inject(BaseHolder)
try:
await holder.delete_credential(credential_id)
except WalletNotFoundError:
raise web.HTTPNotFound()
except WalletNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err

return web.json_response({})

Expand Down Expand Up @@ -178,8 +178,8 @@ async def credentials_list(request: web.BaseRequest):
holder: BaseHolder = await context.inject(BaseHolder)
try:
credentials = await holder.get_credentials(start, count, wql)
except HolderError as x_holder:
raise web.HTTPBadRequest(reason=x_holder.message)
except HolderError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({"results": credentials})

Expand Down
54 changes: 38 additions & 16 deletions aries_cloudagent/ledger/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from marshmallow import fields, Schema, validate

from ..messaging.valid import INDY_DID, INDY_RAW_PUBLIC_KEY
from ..storage.error import StorageError
from ..wallet.error import WalletError
from .base import BaseLedger
from .error import BadLedgerRequestError, LedgerTransactionError
from .indy import Role
from .error import BadLedgerRequestError, LedgerError, LedgerTransactionError


class AMLRecordSchema(Schema):
Expand Down Expand Up @@ -69,7 +71,7 @@ class RegisterLedgerNymQueryStringSchema(Schema):
description="Role",
required=False,
validate=validate.OneOf(
["TRUSTEE", "STEWARD", "ENDORSER", "NETWORK_MONITOR", "reset"]
[r.name for r in Role if isinstance(r.value[0], int)] + ["reset"]
),
)

Expand Down Expand Up @@ -102,7 +104,9 @@ async def register_ledger_nym(request: web.BaseRequest):
did = request.query.get("did")
verkey = request.query.get("verkey")
if not did or not verkey:
raise web.HTTPBadRequest()
raise web.HTTPBadRequest(
reason="Request query must include both did and verkey"
)

alias = request.query.get("alias")
role = request.query.get("role")
Expand All @@ -114,8 +118,8 @@ async def register_ledger_nym(request: web.BaseRequest):
try:
await ledger.register_nym(did, verkey, alias, role)
success = True
except LedgerTransactionError as e:
raise web.HTTPForbidden(text=e.message)
except LedgerTransactionError as err:
raise web.HTTPForbidden(reason=err.roll_up)
return web.json_response({"success": success})


Expand All @@ -130,12 +134,15 @@ async def rotate_public_did_keypair(request: web.BaseRequest):
context = request.app["request_context"]
ledger = await context.inject(BaseLedger, required=False)
if not ledger:
raise web.HTTPForbidden()
reason = "No ledger available"
if not context.settings.get_value("wallet.type"):
reason += ": missing wallet-type?"
raise web.HTTPForbidden(reason=reason)
async with ledger:
try:
await ledger.rotate_public_did_keypair() # do not take seed over the wire
except (WalletError, BadLedgerRequestError):
raise web.HTTPBadRequest()
except (WalletError, BadLedgerRequestError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({})

Expand All @@ -161,11 +168,15 @@ async def get_did_verkey(request: web.BaseRequest):

did = request.query.get("did")
if not did:
raise web.HTTPBadRequest()
raise web.HTTPBadRequest(reason="Request query must include DID")

async with ledger:
r = await ledger.get_key_for_did(did)
return web.json_response({"verkey": r})
try:
result = await ledger.get_key_for_did(did)
except LedgerError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({"verkey": result})


@docs(
Expand All @@ -189,10 +200,14 @@ async def get_did_endpoint(request: web.BaseRequest):

did = request.query.get("did")
if not did:
raise web.HTTPBadRequest()
raise web.HTTPBadRequest(reason="Request query must include DID")

async with ledger:
r = await ledger.get_endpoint_for_did(did)
try:
r = await ledger.get_endpoint_for_did(did)
except LedgerError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({"endpoint": r})


Expand Down Expand Up @@ -247,18 +262,25 @@ async def ledger_accept_taa(request: web.BaseRequest):
context = request.app["request_context"]
ledger: BaseLedger = await context.inject(BaseLedger, required=False)
if not ledger or ledger.LEDGER_TYPE != "indy":
raise web.HTTPForbidden()
reason = "No indy ledger available"
if not context.settings.get_value("wallet.type"):
reason += ": missing wallet-type?"
raise web.HTTPForbidden(reason=reason)

accept_input = await request.json()
taa_info = await ledger.get_txn_author_agreement()
if not taa_info["taa_required"]:
raise web.HTTPBadRequest()
raise web.HTTPBadRequest(reason=f"Ledger {ledger.pool_name} TAA not available")
taa_record = {
"version": accept_input["version"],
"text": accept_input["text"],
"digest": ledger.taa_digest(accept_input["version"], accept_input["text"]),
}
await ledger.accept_txn_author_agreement(taa_record, accept_input["mechanism"])
try:
await ledger.accept_txn_author_agreement(taa_record, accept_input["mechanism"])
except StorageError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({})


Expand Down
57 changes: 44 additions & 13 deletions aries_cloudagent/ledger/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from asynctest import mock as async_mock
import pytest

from aiohttp.web import HTTPBadRequest, HTTPForbidden

from ...config.injection_context import InjectionContext
from ...ledger.base import BaseLedger

Expand All @@ -14,6 +12,7 @@ class TestLedgerRoutes(AsyncTestCase):
def setUp(self):
self.context = InjectionContext(enforce_typing=False)
self.ledger = async_mock.create_autospec(BaseLedger)
self.ledger.pool_name = "pool.0"
self.context.injector.bind_instance(BaseLedger, self.ledger)
self.app = {
"outbound_message_router": async_mock.CoroutineMock(),
Expand All @@ -27,16 +26,16 @@ async def test_missing_ledger(self):
request = async_mock.MagicMock(app=self.app,)
self.context.injector.clear_binding(BaseLedger)

with self.assertRaises(HTTPForbidden):
with self.assertRaises(test_module.web.HTTPForbidden):
await test_module.register_ledger_nym(request)

with self.assertRaises(HTTPForbidden):
with self.assertRaises(test_module.web.HTTPForbidden):
await test_module.rotate_public_did_keypair(request)

with self.assertRaises(HTTPForbidden):
with self.assertRaises(test_module.web.HTTPForbidden):
await test_module.get_did_verkey(request)

with self.assertRaises(HTTPForbidden):
with self.assertRaises(test_module.web.HTTPForbidden):
await test_module.get_did_endpoint(request)

async def test_get_verkey(self):
Expand All @@ -57,7 +56,15 @@ async def test_get_verkey_no_did(self):
request = async_mock.MagicMock()
request.app = self.app
request.query = {"no": "did"}
with self.assertRaises(HTTPBadRequest):
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.get_did_verkey(request)

async def test_get_verkey_x(self):
request = async_mock.MagicMock()
request.app = self.app
request.query = {"did": self.test_did}
self.ledger.get_key_for_did.side_effect = test_module.LedgerError()
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.get_did_verkey(request)

async def test_get_endpoint(self):
Expand All @@ -78,9 +85,17 @@ async def test_get_endpoint_no_did(self):
request = async_mock.MagicMock()
request.app = self.app
request.query = {"no": "did"}
with self.assertRaises(HTTPBadRequest):
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.get_did_endpoint(request)

async def test_get_endpoint_x(self):
request = async_mock.MagicMock()
request.app = self.app
request.query = {"did": self.test_did}
self.ledger.get_endpoint_for_did.side_effect = test_module.LedgerError()
with self.assertRaises(test_module.web.HTTPBadRequest):
result = await test_module.get_did_endpoint(request)

async def test_register_nym(self):
request = async_mock.MagicMock(
app=self.app,
Expand All @@ -100,7 +115,7 @@ async def test_register_nym_bad_request(self):
request = async_mock.MagicMock()
request.app = self.app
request.query = {"no": "did"}
with self.assertRaises(HTTPBadRequest):
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.register_ledger_nym(request)

async def test_register_nym_ledger_txn_error(self):
Expand All @@ -110,7 +125,7 @@ async def test_register_nym_ledger_txn_error(self):
self.ledger.register_nym.side_effect = test_module.LedgerTransactionError(
"Error"
)
with self.assertRaises(HTTPForbidden):
with self.assertRaises(test_module.web.HTTPForbidden):
await test_module.register_ledger_nym(request)

async def test_rotate_public_did_keypair(self):
Expand Down Expand Up @@ -143,7 +158,7 @@ async def test_taa_forbidden(self):
request = async_mock.MagicMock()
request.app = self.app

with self.assertRaises(HTTPForbidden):
with self.assertRaises(test_module.web.HTTPForbidden):
await test_module.ledger_get_taa(request)

async def test_get_taa(self):
Expand Down Expand Up @@ -192,7 +207,7 @@ async def test_taa_accept_not_required(self):
}
)

with self.assertRaises(HTTPBadRequest):
with self.assertRaises(test_module.web.HTTPBadRequest):
self.ledger.LEDGER_TYPE = "indy"
self.ledger.get_txn_author_agreement.return_value = {"taa_required": False}
await test_module.ledger_accept_taa(request)
Expand Down Expand Up @@ -230,7 +245,23 @@ async def test_accept_taa_bad_ledger(self):
request.app = self.app

self.ledger.LEDGER_TYPE = "not-indy"
with self.assertRaises(HTTPForbidden):
with self.assertRaises(test_module.web.HTTPForbidden):
await test_module.ledger_accept_taa(request)

async def test_accept_taa_x(self):
request = async_mock.MagicMock()
request.app = self.app
request.json = async_mock.CoroutineMock(
return_value={
"version": "version",
"text": "text",
"mechanism": "mechanism",
}
)
self.ledger.LEDGER_TYPE = "indy"
self.ledger.get_txn_author_agreement.return_value = {"taa_required": True}
self.ledger.accept_txn_author_agreement.side_effect = test_module.StorageError()
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.ledger_accept_taa(request)

async def test_register(self):
Expand Down
16 changes: 13 additions & 3 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(
"""Initialize a new BaseRecord."""
if not self.RECORD_TYPE:
raise TypeError(
"Can't instantiate abstract class {} with no RECORD_TYPE".format(
"Cannot instantiate abstract class {} with no RECORD_TYPE".format(
self.__class__.__name__
)
)
Expand Down Expand Up @@ -241,10 +241,20 @@ async def retrieve_by_tag_filter(
vals = json.loads(record.value)
if match_post_filter(vals, post_filter):
if found:
raise StorageDuplicateError("Multiple records located")
raise StorageDuplicateError(
"Multiple {} records located for {}{}".format(
cls.__name__,
tag_filter,
f", {post_filter}" if post_filter else "",
)
)
found = cls.from_storage(record.id, vals)
if not found:
raise StorageNotFoundError("Record not found")
raise StorageNotFoundError(
"{} record not found for {}{}".format(
cls.__name__, tag_filter, f", {post_filter}" if post_filter else ""
)
)
return found

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions aries_cloudagent/messaging/schemas/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def schemas_send_schema(request: web.BaseRequest):
)
)
except (IssuerError, LedgerError) as err:
raise web.HTTPBadRequest(reason=err.roll_up)
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({"schema_id": schema_id, "schema": schema_def})

Expand Down Expand Up @@ -182,7 +182,7 @@ async def schemas_get_schema(request: web.BaseRequest):
try:
schema = await ledger.get_schema(schema_id)
except LedgerError as err:
raise web.HTTPBadRequest(reason=err.roll_up)
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({"schema": schema})

Expand Down
Loading

0 comments on commit 76a06f8

Please sign in to comment.