diff --git a/aries_cloudagent/core/tests/test_dispatcher.py b/aries_cloudagent/core/tests/test_dispatcher.py index 998632a846..2e3bf99463 100644 --- a/aries_cloudagent/core/tests/test_dispatcher.py +++ b/aries_cloudagent/core/tests/test_dispatcher.py @@ -7,6 +7,7 @@ from ...connections.models.connection_record import ConnectionRecord from ...core.protocol_registry import ProtocolRegistry from ...messaging.agent_message import AgentMessage, AgentMessageSchema +from ...messaging.responder import MockResponder from ...messaging.util import datetime_now from ...protocols.problem_report.v1_0.message import ProblemReport @@ -301,3 +302,22 @@ async def test_create_outbound_send_webhook(self): result = await responder.create_outbound(message) assert json.loads(result.payload)["@type"] == StubAgentMessage.Meta.message_type await responder.send_webhook("topic", "payload") + + async def test_create_send_outbound(self): + message = StubAgentMessage() + responder = MockResponder() + outbound_message = await responder.create_outbound(message) + await responder.send_outbound(outbound_message) + assert len(responder.messages) == 1 + + async def test_create_enc_outbound(self): + context = make_context() + message = b"abc123xyz7890000" + responder = test_module.DispatcherResponder( + context, message, None, async_mock.CoroutineMock() + ) + with async_mock.patch.object( + responder, "send_outbound", async_mock.CoroutineMock() + ) as mock_send_outbound: + await responder.send(message) + assert mock_send_outbound.called_once() diff --git a/aries_cloudagent/messaging/agent_message.py b/aries_cloudagent/messaging/agent_message.py index 5d0225d820..b7b471d151 100644 --- a/aries_cloudagent/messaging/agent_message.py +++ b/aries_cloudagent/messaging/agent_message.py @@ -61,7 +61,7 @@ def __init__(self, _id: str = None, _decorators: BaseDecoratorSet = None): TypeError: If message type is missing on subclass Meta class """ - super(AgentMessage, self).__init__() + super().__init__() if _id: self._message_id = _id self._message_new_id = False @@ -414,13 +414,7 @@ def __init__(self, *args, **kwargs): TypeError: If Meta.model_class has not been set """ - super(AgentMessageSchema, self).__init__(*args, **kwargs) - if not self.Meta.model_class: - raise TypeError( - "Can't instantiate abstract class {} with no model_class".format( - self.__class__.__name__ - ) - ) + super().__init__(*args, **kwargs) self._decorators = DecoratorSet() self._decorators_dict = None self._signatures = {} diff --git a/aries_cloudagent/messaging/decorators/signature_decorator.py b/aries_cloudagent/messaging/decorators/signature_decorator.py index 374c9e9f49..e31a35b0eb 100644 --- a/aries_cloudagent/messaging/decorators/signature_decorator.py +++ b/aries_cloudagent/messaging/decorators/signature_decorator.py @@ -92,7 +92,7 @@ def decode(self) -> (object, int): """ msg_bin = b64_to_bytes(self.sig_data, urlsafe=True) (timestamp,) = struct.unpack_from("!Q", msg_bin, 0) - return json.loads(msg_bin[8:]), timestamp + return (json.loads(msg_bin[8:]), timestamp) async def verify(self, wallet: BaseWallet) -> bool: """ @@ -133,7 +133,9 @@ class Meta: data_key="@type", required=True, description="Signature type", - example="did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/signature/1.0/ed25519Sha512_single", + example=( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" "spec/signature/1.0/ed25519Sha512_single" + ), ) signature = fields.Str( required=True, diff --git a/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py b/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py index 8c72ca888b..3a78075933 100644 --- a/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py +++ b/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py @@ -330,8 +330,9 @@ def test_indy_dict(self): assert lynx_str == lynx_list assert lynx_str != links + assert links != DATA_LINKS # has sha256 - def test_indy_dict(self): + def test_from_aries_msg(self): deco_aries = AttachDecorator.from_aries_msg( message=INDY_CRED, ident=IDENT, description=DESCRIPTION, ) diff --git a/aries_cloudagent/messaging/decorators/tests/test_signature_decorator.py b/aries_cloudagent/messaging/decorators/tests/test_signature_decorator.py new file mode 100644 index 0000000000..3beead9fc5 --- /dev/null +++ b/aries_cloudagent/messaging/decorators/tests/test_signature_decorator.py @@ -0,0 +1,62 @@ +import pytest + +from asynctest import TestCase as AsyncTestCase, mock as async_mock + +from ....protocols.trustping.v1_0.messages.ping import Ping +from ....wallet.basic import BasicWallet +from .. import signature_decorator as test_module +from ..signature_decorator import SignatureDecorator + +TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + + +class TestSignatureDecorator(AsyncTestCase): + async def test_init(self): + decorator = SignatureDecorator() + assert decorator.signature_type is None + assert decorator.signature is None + assert decorator.sig_data is None + assert decorator.signer is None + assert "SignatureDecorator" in str(decorator) + + async def test_serialize_load(self): + TEST_SIG = "IkJvYiI=" + TEST_SIG_DATA = "MTIzNDU2Nzg5MCJCb2Ii" + + decorator = SignatureDecorator( + signature_type=SignatureDecorator.TYPE_ED25519SHA512, + signature=TEST_SIG, + sig_data=TEST_SIG_DATA, + signer=TEST_VERKEY, + ) + + dumped = decorator.serialize() + loaded = SignatureDecorator.deserialize(dumped) + + assert loaded.signature_type == SignatureDecorator.TYPE_ED25519SHA512 + assert loaded.signature == TEST_SIG + assert loaded.sig_data == TEST_SIG_DATA + assert loaded.signer == TEST_VERKEY + + async def test_create_decode_verify(self): + TEST_MESSAGE = "Hello world" + TEST_TIMESTAMP = 1234567890 + wallet = BasicWallet() + key_info = await wallet.create_signing_key() + + deco = await SignatureDecorator.create( + Ping(), key_info.verkey, wallet, timestamp=None + ) + assert deco + + deco = await SignatureDecorator.create( + TEST_MESSAGE, key_info.verkey, wallet, TEST_TIMESTAMP + ) + + (msg, timestamp) = deco.decode() + assert msg == TEST_MESSAGE + assert timestamp == TEST_TIMESTAMP + + await deco.verify(wallet) + deco.signature_type = "unsupported-sig-type" + assert not await deco.verify(wallet) diff --git a/aries_cloudagent/messaging/models/base.py b/aries_cloudagent/messaging/models/base.py index e516685841..daa5677397 100644 --- a/aries_cloudagent/messaging/models/base.py +++ b/aries_cloudagent/messaging/models/base.py @@ -221,7 +221,7 @@ def __init__(self, *args, **kwargs): TypeError: If model_class is not set on Meta """ - super(BaseModelSchema, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if not self.Meta.model_class: raise TypeError( "Can't instantiate abstract class {} with no model_class".format( diff --git a/aries_cloudagent/messaging/models/tests/test_base.py b/aries_cloudagent/messaging/models/tests/test_base.py index 98eaf35e84..01ac107596 100644 --- a/aries_cloudagent/messaging/models/tests/test_base.py +++ b/aries_cloudagent/messaging/models/tests/test_base.py @@ -11,7 +11,7 @@ from ...responder import BaseResponder, MockResponder from ...util import time_now -from ..base import BaseModel, BaseModelSchema +from ..base import BaseModel, BaseModelError, BaseModelSchema class ModelImpl(BaseModel): @@ -26,7 +26,7 @@ class SchemaImpl(BaseModelSchema): class Meta: model_class = ModelImpl - attr = fields.String() + attr = fields.String(required=True) @validates_schema def validate_fields(self, data, **kwargs): @@ -44,3 +44,21 @@ def test_model_validate_succeeds(self): model = ModelImpl(attr="succeeds") model = model.validate() assert model.attr == "succeeds" + + def test_ser_x(self): + model = ModelImpl(attr="hello world") + with async_mock.patch.object( + model, "_get_schema_class", async_mock.MagicMock() + ) as mock_get_schema_class: + mock_get_schema_class.return_value = async_mock.MagicMock( + return_value=async_mock.MagicMock( + dump=async_mock.MagicMock(side_effect=ValidationError("error")) + ) + ) + with self.assertRaises(BaseModelError): + model.serialize() + + def test_from_json_x(self): + data = "{}{}" + with self.assertRaises(BaseModelError): + ModelImpl.from_json(data) diff --git a/aries_cloudagent/messaging/models/tests/test_base_record.py b/aries_cloudagent/messaging/models/tests/test_base_record.py index 8592c510ad..d02719ba0d 100644 --- a/aries_cloudagent/messaging/models/tests/test_base_record.py +++ b/aries_cloudagent/messaging/models/tests/test_base_record.py @@ -1,10 +1,12 @@ import json from asynctest import TestCase as AsyncTestCase, mock as async_mock +from marshmallow import fields from ....cache.base import BaseCache from ....config.injection_context import InjectionContext -from ....storage.base import BaseStorage, StorageRecord +from ....storage.base import BaseStorage, StorageDuplicateError, StorageRecord +from ....storage.basic import BasicStorage from ...responder import BaseResponder, MockResponder from ...util import time_now @@ -25,6 +27,36 @@ class Meta: model_class = BaseRecordImpl +class ARecordImpl(BaseRecord): + class Meta: + schema_class = "ARecordImplSchema" + + RECORD_TYPE = "a-record" + CACHE_ENABLED = False + RECORD_ID_NAME = "ident" + TAG_NAMES = {"code"} + + def __init__(self, *, ident=None, a, b, code, **kwargs): + super().__init__(ident, **kwargs) + self.a = a + self.b = b + self.code = code + + @property + def record_value(self) -> dict: + return {"a": self.a, "b": self.b} + + +class ARecordImplSchema(BaseRecordSchema): + class Meta: + model_class = BaseRecordImpl + + ident = fields.Str(attribute="_id") + a = fields.Str() + b = fields.Str() + code = fields.Str() + + class UnencTestImpl(BaseRecord): TAG_NAMES = {"~a", "~b", "c"} @@ -42,6 +74,10 @@ def test_from_storage_values(self): assert inst._id == record_id assert inst.value == stored + stored[BaseRecordImpl.RECORD_ID_NAME] = inst._id + with self.assertRaises(ValueError): + BaseRecordImpl.from_storage(record_id, stored) + async def test_post_save_new(self): context = InjectionContext(enforce_typing=False) mock_storage = async_mock.MagicMock() @@ -74,12 +110,15 @@ async def test_post_save_exist(self): mock_storage.update_record_tags.assert_called_once() async def test_cache(self): + assert not await BaseRecordImpl.get_cached_key(None, None) + await BaseRecordImpl.set_cached_key(None, None, None) + await BaseRecordImpl.clear_cached_key(None, None) context = InjectionContext(enforce_typing=False) mock_cache = async_mock.MagicMock(BaseCache, autospec=True) context.injector.bind_instance(BaseCache, mock_cache) record = BaseRecordImpl() cache_key = "cache_key" - cache_result = await record.get_cached_key(context, cache_key) + cache_result = await BaseRecordImpl.get_cached_key(context, cache_key) mock_cache.get.assert_awaited_once_with(cache_key) assert cache_result is mock_cache.get.return_value @@ -109,6 +148,39 @@ async def test_retrieve_cached_id(self): assert result._id == record_id assert result.value == stored + async def test_retrieve_by_tag_filter_multi_x_delete(self): + context = InjectionContext(enforce_typing=False) + basic_storage = BasicStorage() + context.injector.bind_instance(BaseStorage, basic_storage) + records = [] + for i in range(3): + records.append(ARecordImpl(a="1", b=str(i), code="one")) + await records[i].save(context) + with self.assertRaises(StorageDuplicateError): + await ARecordImpl.retrieve_by_tag_filter( + context, {"code": "one"}, {"a": "1"} + ) + await records[0].delete_record(context) + + async def test_save_x(self): + context = InjectionContext(enforce_typing=False) + basic_storage = BasicStorage() + context.injector.bind_instance(BaseStorage, basic_storage) + rec = ARecordImpl(a="1", b="0", code="one") + with async_mock.patch.object( + context, "inject", async_mock.CoroutineMock() + ) as mock_inject: + mock_inject.return_value = async_mock.MagicMock( + add_record=async_mock.CoroutineMock(side_effect=ZeroDivisionError()) + ) + with self.assertRaises(ZeroDivisionError): + await rec.save(context) + + async def test_neq(self): + a_rec = ARecordImpl(a="1", b="0", code="one") + b_rec = BaseRecordImpl() + assert a_rec != b_rec + async def test_retrieve_uncached_id(self): context = InjectionContext(enforce_typing=False) mock_storage = async_mock.MagicMock(BaseStorage, autospec=True) @@ -163,7 +235,7 @@ def test_log_state(self, mock_print): BaseRecordImpl, "LOG_STATE_FLAG", test_param ) as cls: record = BaseRecordImpl() - record.log_state(context, "state") + record.log_state(context, msg="state", params={"a": "1", "b": "2"}) mock_print.assert_called_once() @async_mock.patch("builtins.print") @@ -180,6 +252,8 @@ async def test_webhook(self): record = BaseRecordImpl() payload = {"test": "payload"} topic = "topic" + await record.send_webhook(context, None, None) # cover short circuit + await record.send_webhook(context, "hello", None) # cover short circuit await record.send_webhook(context, payload, topic=topic) assert mock_responder.webhooks == [(topic, payload)] @@ -190,6 +264,11 @@ async def test_tag_prefix(self): tags = {"a": "x", "b": "y", "c": "z"} assert UnencTestImpl.prefix_tag_filter(tags) == {"~a": "x", "~b": "y", "c": "z"} + tags = {"$not": {"a": "x", "b": "y", "c": "z"}} + expect = {"$not": {"~a": "x", "~b": "y", "c": "z"}} + actual = UnencTestImpl.prefix_tag_filter(tags) + assert {**expect} == {**actual} + tags = {"$or": [{"a": "x"}, {"c": "z"}]} assert UnencTestImpl.prefix_tag_filter(tags) == { "$or": [{"~a": "x"}, {"c": "z"}] diff --git a/aries_cloudagent/messaging/request_context.py b/aries_cloudagent/messaging/request_context.py index 61e068525b..fd43f677dd 100644 --- a/aries_cloudagent/messaging/request_context.py +++ b/aries_cloudagent/messaging/request_context.py @@ -80,7 +80,7 @@ def default_endpoint(self) -> str: The default agent endpoint """ - return self.settings["default_endpoint"] + return self.settings.get("default_endpoint") @default_endpoint.setter def default_endpoint(self, endpoint: str): diff --git a/aries_cloudagent/messaging/tests/test_agent_message.py b/aries_cloudagent/messaging/tests/test_agent_message.py index 85536d7f75..ed54c47b76 100644 --- a/aries_cloudagent/messaging/tests/test_agent_message.py +++ b/aries_cloudagent/messaging/tests/test_agent_message.py @@ -2,10 +2,13 @@ from marshmallow import fields import json +from ...wallet.basic import BasicWallet +from ...wallet.util import bytes_to_b64 + from ..agent_message import AgentMessage, AgentMessageSchema from ..decorators.signature_decorator import SignatureDecorator from ..decorators.trace_decorator import TraceReport, TRACE_LOG_TARGET -from ...wallet.basic import BasicWallet +from ..models.base import BaseModelError class SignedAgentMessage(AgentMessage): @@ -19,7 +22,7 @@ class Meta: message_type = "signed-agent-message" def __init__(self, value: str = None, **kwargs): - super(SignedAgentMessage, self).__init__(**kwargs) + super().__init__(**kwargs) self.value = value @@ -46,18 +49,22 @@ class Meta: class TestAgentMessage(AsyncTestCase): """Tests agent message.""" - class BadImplementationClass(AgentMessage): - """Test utility class.""" - - pass - def test_init(self): """Tests init class""" - SignedAgentMessage() + + class BadImplementationClass(AgentMessage): + """Test utility class.""" + + message = SignedAgentMessage() + message._id = "12345" with self.assertRaises(TypeError) as context: - self.BadImplementationClass() # pylint: disable=E0110 + BadImplementationClass() # pylint: disable=E0110 + assert "Can't instantiate abstract" in str(context.exception) + BadImplementationClass.Meta.schema_class = "AgentMessageSchema" + with self.assertRaises(TypeError) as context: + BadImplementationClass() # pylint: disable=E0110 assert "Can't instantiate abstract" in str(context.exception) async def test_field_signature(self): @@ -65,7 +72,16 @@ async def test_field_signature(self): key_info = await wallet.create_signing_key() msg = SignedAgentMessage() + msg.value = None + with self.assertRaises(BaseModelError) as context: + await msg.sign_field("value", key_info.verkey, wallet) + assert "field has no value for signature" in str(context.exception) + msg.value = "Test value" + with self.assertRaises(BaseModelError) as context: + msg.serialize() + assert "Missing signature for field" in str(context.exception) + await msg.sign_field("value", key_info.verkey, wallet) sig = msg.get_signature("value") assert isinstance(sig, SignatureDecorator) @@ -74,9 +90,27 @@ async def test_field_signature(self): assert await msg.verify_signed_field("value", wallet) == key_info.verkey assert await msg.verify_signatures(wallet) + with self.assertRaises(BaseModelError) as context: + await msg.verify_signed_field("value", wallet, "bogus-verkey") + assert "Signer verkey of signature does not match" in str(context.exception) + serial = msg.serialize() assert "value~sig" in serial and "value" not in serial + (_, timestamp) = msg._decorators.field("value")["sig"].decode() + tamper_deco = await SignatureDecorator.create("tamper", key_info.verkey, wallet) + msg._decorators.field("value")["sig"].sig_data = tamper_deco.sig_data + with self.assertRaises(BaseModelError) as context: + await msg.verify_signed_field("value", wallet) + assert "Field signature verification failed" in str(context.exception) + assert not await msg.verify_signatures(wallet) + + msg.value = "Test value" + msg._decorators.field("value").pop("sig") + with self.assertRaises(BaseModelError) as context: + await msg.verify_signed_field("value", wallet) + assert "Missing field signature" in str(context.exception) + loaded = SignedAgentMessage.deserialize(serial) assert isinstance(loaded, SignedAgentMessage) assert await loaded.verify_signed_field("value", wallet) == key_info.verkey @@ -89,6 +123,9 @@ async def test_assign_thread(self): assert reply._thread_id == msg._thread_id assert reply._thread_id != reply._id + msg.assign_thread_id(None, None) + assert not msg._thread + async def test_add_tracing(self): msg = BasicAgentMessage() msg.add_trace_decorator() @@ -148,3 +185,85 @@ async def test_add_tracing(self): assert msg_trace_report.outcome == trace_report2.outcome print("tracer:", tracer.serialize()) + + msg3 = BasicAgentMessage() + msg.add_trace_decorator() + assert msg._trace + + +class TestAgentMessageSchema(AsyncTestCase): + """Tests agent message schema.""" + + def test_init_x(self): + """Tests init class""" + + class BadImplementationClass(AgentMessageSchema): + """Test utility class.""" + + with self.assertRaises(TypeError) as context: + BadImplementationClass() + assert "Can't instantiate abstract" in str(context.exception) + + def test_extract_decorators_x(self): + for serial in [ + { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "value": "Test value", + }, + { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "value": "Test value", + "value~sig": { + "@type": ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" + "spec/signature/1.0/ed25519Sha512_single" + ), + "signature": ( + "-OKdiRRQu-xbVGICg1J6KV_6nXLLzYRXr8BZSXzoXimytBl" + "O8ULY7Nl1lQPqahc-XQPHiBSVraLM8XN_sCzdCg==" + ), + "sig_data": "AAAAAF8bIV4iVGVzdCB2YWx1ZSI=", + "signer": "7VA3CaF9jaTuRN2SGmekANoja6Js4U51kfRSbpZAfdhy", + }, + }, + { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "superfluous~sig": { + "@type": ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" + "spec/signature/1.0/ed25519Sha512_single" + ), + "signature": ( + "-OKdiRRQu-xbVGICg1J6KV_6nXLLzYRXr8BZSXzoXimytBl" + "O8ULY7Nl1lQPqahc-XQPHiBSVraLM8XN_sCzdCg==" + ), + "sig_data": "AAAAAF8bIV4iVGVzdCB2YWx1ZSI=", + "signer": "7VA3CaF9jaTuRN2SGmekANoja6Js4U51kfRSbpZAfdhy", + }, + }, + ]: + with self.assertRaises(BaseModelError) as context: + SignedAgentMessage.deserialize(serial) + + def test_serde(self): + serial = { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "value~sig": { + "@type": ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" + "spec/signature/1.0/ed25519Sha512_single" + ), + "signature": ( + "-OKdiRRQu-xbVGICg1J6KV_6nXLLzYRXr8BZSXzoXimytBl" + "O8ULY7Nl1lQPqahc-XQPHiBSVraLM8XN_sCzdCg==" + ), + "sig_data": "AAAAAF8bIV4iVGVzdCB2YWx1ZSI=", + "signer": "7VA3CaF9jaTuRN2SGmekANoja6Js4U51kfRSbpZAfdhy", + }, + } + result = SignedAgentMessage.deserialize(serial) + result.serialize() diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py index ed235bc818..bb83673086 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py @@ -15,6 +15,8 @@ class TestConnectionRoutes(AsyncTestCase): async def test_connections_list(self): context = RequestContext(base_context=InjectionContext(enforce_typing=False)) + context.default_endpoint = "http://1.2.3.4:8081" # for coverage + assert context.default_endpoint == "http://1.2.3.4:8081" # for coverage mock_req = async_mock.MagicMock() mock_req.app = { "request_context": context, diff --git a/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py b/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py index 2ba4d02fc8..8859afac21 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py @@ -1,8 +1,23 @@ from unittest import TestCase as UnitTestCase +from ......messaging.models.base_record import BaseExchangeRecord, BaseExchangeSchema + from ..presentation_exchange import V10PresentationExchange +class BasexRecordImpl(BaseExchangeRecord): + class Meta: + schema_class = "BasexRecordImplSchema" + + RECORD_TYPE = "record" + CACHE_ENABLED = True + + +class BasexRecordImplSchema(BaseExchangeSchema): + class Meta: + model_class = BasexRecordImpl + + class TestRecord(UnitTestCase): def test_record(self): record = V10PresentationExchange( @@ -37,3 +52,6 @@ def test_record(self): "verified": False, "trace": False, } + + bx_record = BasexRecordImpl() + assert record != bx_record