Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error audit protocols #543

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aries_cloudagent/issuer/indy.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ async def create_credential(
f"Revocation registry {revoc_reg_id} is full: cannot create credential"
)
raise IssuerRevocationRegistryFullError(
f"Revocation registry {revoc_reg_id} full"
f"Revocation registry {revoc_reg_id} is full"
)
except IndyError as error:
raise IndyErrorHandler.wrap_error(
Expand Down
1 change: 1 addition & 0 deletions aries_cloudagent/messaging/jsonld/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def sign(request: web.BaseRequest):
try:
context = request.app["request_context"]
wallet: BaseWallet = await context.inject(BaseWallet)
wallet: BaseWallet = await context.inject(BaseWallet)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Accidental commit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch: I will fix it

if not wallet:
raise web.HTTPForbidden()

Expand Down
10 changes: 7 additions & 3 deletions aries_cloudagent/protocols/actionmenu/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aries_cloudagent.connections.models.connection_record import ConnectionRecord
from aries_cloudagent.messaging.models.base import BaseModelError
from aries_cloudagent.messaging.valid import UUIDFour
from aries_cloudagent.storage.error import StorageNotFoundError
from aries_cloudagent.storage.error import StorageError, StorageNotFoundError

from .messages.menu import Menu
from .messages.menu_request import MenuRequest
Expand Down Expand Up @@ -91,7 +91,11 @@ async def actionmenu_close(request: web.BaseRequest):
reason=f"No {MENU_RECORD_TYPE} record found for connection {connection_id}"
)

await save_connection_menu(None, connection_id, context)
try:
await save_connection_menu(None, connection_id, context)
except StorageError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({})


Expand Down Expand Up @@ -190,7 +194,7 @@ async def actionmenu_send(request: web.BaseRequest):
msg = Menu.deserialize(menu_json["menu"])
except BaseModelError as err:
LOGGER.exception("Exception deserializing menu: %s", err.roll_up)
raise
raise web.HTTPBadRequest(reason=err.roll_up) from err

try:
connection = await ConnectionRecord.retrieve_by_id(context, connection_id)
Expand Down
19 changes: 18 additions & 1 deletion aries_cloudagent/protocols/actionmenu/v1_0/tests/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,23 @@ async def test_actionmenu_close(self):
res = await test_module.actionmenu_close(mock_request)
mock_response.assert_called_once_with({})

async def test_actionmenu_close_x(self):
mock_request = async_mock.MagicMock()
mock_request.json = async_mock.CoroutineMock()

mock_request.app = {
"outbound_message_router": async_mock.CoroutineMock(),
"request_context": "context",
}

test_module.retrieve_connection_menu = async_mock.CoroutineMock()
test_module.save_connection_menu = async_mock.CoroutineMock(
side_effect=test_module.StorageError()
)

with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.actionmenu_close(mock_request)

async def test_actionmenu_close_not_found(self):
mock_request = async_mock.MagicMock()
mock_request.json = async_mock.CoroutineMock()
Expand Down Expand Up @@ -244,7 +261,7 @@ async def test_actionmenu_send_deserialize_x(self):
side_effect=test_module.BaseModelError("cannot deserialize")
)

with self.assertRaises(test_module.BaseModelError):
with self.assertRaises(test_module.web.HTTPBadRequest):
await test_module.actionmenu_send(mock_request)

async def test_actionmenu_send_no_conn_record(self):
Expand Down
155 changes: 97 additions & 58 deletions aries_cloudagent/protocols/connections/v1_0/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,21 @@

from marshmallow import fields, Schema, validate, validates_schema

from aries_cloudagent.connections.models.connection_record import (
from ....connections.models.connection_record import (
ConnectionRecord,
ConnectionRecordSchema,
)
from aries_cloudagent.messaging.models.base import BaseModelError
from aries_cloudagent.messaging.valid import (
from ....messaging.models.base import BaseModelError
from ....messaging.valid import (
ENDPOINT,
INDY_DID,
INDY_RAW_PUBLIC_KEY,
UUIDFour,
)
from aries_cloudagent.storage.error import StorageNotFoundError
from ....storage.error import StorageError, StorageNotFoundError
from ....wallet.error import WalletError

from .manager import ConnectionManager
from .manager import ConnectionManager, ConnectionManagerError
from .message_types import SPEC_URI
from .messages.connection_invitation import (
ConnectionInvitation,
Expand Down Expand Up @@ -242,9 +243,12 @@ async def connections_list(request: web.BaseRequest):
):
if param_name in request.query and request.query[param_name] != "":
post_filter[param_name] = request.query[param_name]
records = await ConnectionRecord.query(context, tag_filter, post_filter)
results = [record.serialize() for record in records]
results.sort(key=connection_sort_key)
try:
records = await ConnectionRecord.query(context, tag_filter, post_filter)
results = [record.serialize() for record in records]
results.sort(key=connection_sort_key)
except (StorageError, BaseModelError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err
return web.json_response({"results": results})


Expand All @@ -264,11 +268,16 @@ async def connections_retrieve(request: web.BaseRequest):
"""
context = request.app["request_context"]
connection_id = request.match_info["conn_id"]

try:
record = await ConnectionRecord.retrieve_by_id(context, connection_id)
result = record.serialize()
except StorageNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err
return web.json_response(record.serialize())
except BaseModelError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response(result)


@docs(
Expand Down Expand Up @@ -300,14 +309,18 @@ async def connections_create_invitation(request: web.BaseRequest):
base_url = context.settings.get("invite_base_url")

connection_mgr = ConnectionManager(context)
(connection, invitation) = await connection_mgr.create_invitation(
auto_accept=auto_accept, public=public, multi_use=multi_use, alias=alias
)
result = {
"connection_id": connection and connection.connection_id,
"invitation": invitation.serialize(),
"invitation_url": invitation.to_url(base_url),
}
try:
(connection, invitation) = await connection_mgr.create_invitation(
auto_accept=auto_accept, public=public, multi_use=multi_use, alias=alias
)

result = {
"connection_id": connection and connection.connection_id,
"invitation": invitation.serialize(),
"invitation_url": invitation.to_url(base_url),
}
except (ConnectionManagerError, BaseModelError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

if connection and connection.alias:
result["alias"] = connection.alias
Expand Down Expand Up @@ -339,18 +352,19 @@ async def connections_receive_invitation(request: web.BaseRequest):
)
connection_mgr = ConnectionManager(context)
invitation_json = await request.json()

try:
invitation = ConnectionInvitation.deserialize(invitation_json)
except BaseModelError as err:
auto_accept = json.loads(request.query.get("auto_accept", "null"))
alias = request.query.get("alias")
connection = await connection_mgr.receive_invitation(
invitation, auto_accept=auto_accept, alias=alias
)
result = connection.serialize()
except (ConnectionManagerError, StorageError, BaseModelError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

auto_accept = json.loads(request.query.get("auto_accept", "null"))
alias = request.query.get("alias")

connection = await connection_mgr.receive_invitation(
invitation, auto_accept=auto_accept, alias=alias
)
return web.json_response(connection.serialize())
return web.json_response(result)


@docs(
Expand All @@ -373,16 +387,21 @@ async def connections_accept_invitation(request: web.BaseRequest):
context = request.app["request_context"]
outbound_handler = request.app["outbound_message_router"]
connection_id = request.match_info["conn_id"]

try:
connection = await ConnectionRecord.retrieve_by_id(context, connection_id)
connection_mgr = ConnectionManager(context)
my_label = request.query.get("my_label") or None
my_endpoint = request.query.get("my_endpoint") or None
request = await connection_mgr.create_request(connection, my_label, my_endpoint)
result = connection.serialize()
except StorageNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err
connection_mgr = ConnectionManager(context)
my_label = request.query.get("my_label") or None
my_endpoint = request.query.get("my_endpoint") or None
request = await connection_mgr.create_request(connection, my_label, my_endpoint)
except (StorageError, WalletError, ConnectionManagerError, BaseModelError) as err:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it just catch BaseError?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to be specific here: catching an omnibus error will probably make LGTM-bot squawk.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to see what can happen in particular.

raise web.HTTPBadRequest(reason=err.roll_up) from err

await outbound_handler(request, connection_id=connection.connection_id)
return web.json_response(connection.serialize())
return web.json_response(result)


@docs(
Expand All @@ -405,15 +424,20 @@ async def connections_accept_request(request: web.BaseRequest):
context = request.app["request_context"]
outbound_handler = request.app["outbound_message_router"]
connection_id = request.match_info["conn_id"]

try:
connection = await ConnectionRecord.retrieve_by_id(context, connection_id)
connection_mgr = ConnectionManager(context)
my_endpoint = request.query.get("my_endpoint") or None
response = await connection_mgr.create_response(connection, my_endpoint)
result = connection.serialize()
except StorageNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err
connection_mgr = ConnectionManager(context)
my_endpoint = request.query.get("my_endpoint") or None
response = await connection_mgr.create_response(connection, my_endpoint)
except (StorageError, WalletError, ConnectionManagerError, BaseModelError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

await outbound_handler(response, connection_id=connection.connection_id)
return web.json_response(connection.serialize())
return web.json_response(result)


@docs(
Expand All @@ -431,14 +455,18 @@ async def connections_establish_inbound(request: web.BaseRequest):
connection_id = request.match_info["conn_id"]
outbound_handler = request.app["outbound_message_router"]
inbound_connection_id = request.match_info["ref_id"]

try:
connection = await ConnectionRecord.retrieve_by_id(context, connection_id)
connection_mgr = ConnectionManager(context)
await connection_mgr.establish_inbound(
connection, inbound_connection_id, outbound_handler
)
except StorageNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err
connection_mgr = ConnectionManager(context)
await connection_mgr.establish_inbound(
connection, inbound_connection_id, outbound_handler
)
except (StorageError, WalletError, ConnectionManagerError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({})


Expand All @@ -453,11 +481,15 @@ async def connections_remove(request: web.BaseRequest):
"""
context = request.app["request_context"]
connection_id = request.match_info["conn_id"]

try:
connection = await ConnectionRecord.retrieve_by_id(context, connection_id)
await connection.delete_record(context)
except StorageNotFoundError as err:
raise web.HTTPNotFound(reason=err.roll_up) from err
await connection.delete_record(context)
except StorageError as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response({})


Expand All @@ -479,25 +511,32 @@ async def connections_create_static(request: web.BaseRequest):
body = await request.json()

connection_mgr = ConnectionManager(context)
(my_info, their_info, connection) = await connection_mgr.create_static_connection(
my_seed=body.get("my_seed") or None,
my_did=body.get("my_did") or None,
their_seed=body.get("their_seed") or None,
their_did=body.get("their_did") or None,
their_verkey=body.get("their_verkey") or None,
their_endpoint=body.get("their_endpoint") or None,
their_role=body.get("their_role") or None,
their_label=body.get("their_label") or None,
alias=body.get("alias") or None,
)
response = {
"my_did": my_info.did,
"my_verkey": my_info.verkey,
"my_endpoint": context.settings.get("default_endpoint"),
"their_did": their_info.did,
"their_verkey": their_info.verkey,
"record": connection.serialize(),
}
try:
(
my_info,
their_info,
connection,
) = await connection_mgr.create_static_connection(
my_seed=body.get("my_seed") or None,
my_did=body.get("my_did") or None,
their_seed=body.get("their_seed") or None,
their_did=body.get("their_did") or None,
their_verkey=body.get("their_verkey") or None,
their_endpoint=body.get("their_endpoint") or None,
their_role=body.get("their_role") or None,
their_label=body.get("their_label") or None,
alias=body.get("alias") or None,
)
response = {
"my_did": my_info.did,
"my_verkey": my_info.verkey,
"my_endpoint": context.settings.get("default_endpoint"),
"their_did": their_info.did,
"their_verkey": their_info.verkey,
"record": connection.serialize(),
}
except (WalletError, StorageError, BaseModelError) as err:
raise web.HTTPBadRequest(reason=err.roll_up) from err

return web.json_response(response)

Expand Down
Loading