diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a23d3c6ee..2e9e151515 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,21 @@ +# 0.5.3 + +## July 23, 2020 + +- Store endpoint on provisioned DID records [#610](https://github.com/hyperledger/aries-cloudagent-python/pull/610) +- More reliable delivery of outbound messages and webhooks [#615](https://github.com/hyperledger/aries-cloudagent-python/pull/615) +- Improvements for OpenShift pod handling [#614](https://github.com/hyperledger/aries-cloudagent-python/pull/614) +- Remove support for 'on-demand' revocation registries [#605](https://github.com/hyperledger/aries-cloudagent-python/pull/605) +- Sort tags in generated swagger JSON for better consistency [#602](https://github.com/hyperledger/aries-cloudagent-python/pull/602) +- Improve support for multi-credential proofs [#601](https://github.com/hyperledger/aries-cloudagent-python/pull/601) +- Adjust default settings for tracing and add documentation [#598](https://github.com/hyperledger/aries-cloudagent-python/pull/598), [#597](https://github.com/hyperledger/aries-cloudagent-python/pull/597) +- Fix reliance on local copy of revocation tails file [#590](https://github.com/hyperledger/aries-cloudagent-python/pull/590) +- Improved handling of problem reports [#595](https://github.com/hyperledger/aries-cloudagent-python/pull/595) +- Remove credential preview parameter from credential issue endpoint [#596](https://github.com/hyperledger/aries-cloudagent-python/pull/596) +- Looser format restrictions on dates [#586](https://github.com/hyperledger/aries-cloudagent-python/pull/586) +- Support `names` and attribute-value specifications in present-proof protocol [#587](https://github.com/hyperledger/aries-cloudagent-python/pull/587) +- Misc documentation updates and unit test coverage + # 0.5.2 ## June 26, 2020 diff --git a/aries_cloudagent/admin/server.py b/aries_cloudagent/admin/server.py index 24ef66b21c..f4556186e8 100644 --- a/aries_cloudagent/admin/server.py +++ b/aries_cloudagent/admin/server.py @@ -18,6 +18,7 @@ from ..config.injection_context import InjectionContext from ..core.plugin_registry import PluginRegistry +from ..ledger.error import LedgerConfigError, LedgerTransactionError from ..messaging.responder import BaseResponder from ..transport.queue.basic import BasicMessageQueue from ..transport.outbound.message import OutboundMessage @@ -51,7 +52,7 @@ class AdminStatusLivelinessSchema(Schema): class AdminStatusReadinessSchema(Schema): - """Schema for the liveliness endpoint.""" + """Schema for the readiness endpoint.""" ready = fields.Boolean(description="Readiness status", example=True) @@ -131,7 +132,18 @@ async def ready_middleware(request: web.BaseRequest, handler: Coroutine): "/status/live", "/status/ready", ) or request.app._state.get("ready"): - return await handler(request) + try: + return await handler(request) + except (LedgerConfigError, LedgerTransactionError) as e: + # fatal, signal server shutdown + LOGGER.error("Shutdown with %s", str(e)) + request.app._state["ready"] = False + request.app._state["alive"] = False + raise + except Exception as e: + # some other error? + LOGGER.error("Handler error with exception: %s", str(e)) + raise e raise web.HTTPServiceUnavailable(reason="Shutdown in progress") @@ -295,6 +307,11 @@ async def collect_stats(request, handler): app=app, title=agent_label, version=version_string, swagger_path="/api/doc" ) app.on_startup.append(self.on_startup) + + # ensure we always have status values + app._state["ready"] = False + app._state["alive"] = False + return app async def start(self) -> None: @@ -329,6 +346,7 @@ async def start(self) -> None: try: await self.site.start() self.app._state["ready"] = True + self.app._state["alive"] = True except OSError: raise AdminSetupError( "Unable to start webserver with host " @@ -429,7 +447,11 @@ async def liveliness_handler(self, request: web.BaseRequest): The web response, always indicating True """ - return web.json_response({"alive": True}) + app_live = self.app._state["alive"] + if app_live: + return web.json_response({"alive": app_live}) + else: + raise web.HTTPServiceUnavailable(reason="Service not available") @docs(tags=["server"], summary="Readiness check") @response_schema(AdminStatusReadinessSchema(), 200) @@ -444,7 +466,11 @@ async def readiness_handler(self, request: web.BaseRequest): The web response, indicating readiness for further calls """ - return web.json_response({"ready": self.app._state["ready"]}) + app_ready = self.app._state["ready"] and self.app._state["alive"] + if app_ready: + return web.json_response({"ready": app_ready}) + else: + raise web.HTTPServiceUnavailable(reason="Service not ready") @docs(tags=["server"], summary="Shut down server") async def shutdown_handler(self, request: web.BaseRequest): @@ -464,6 +490,12 @@ async def shutdown_handler(self, request: web.BaseRequest): return web.json_response({}) + def notify_fatal_error(self): + """Set our readiness flags to force a restart (openshift).""" + LOGGER.error("Received shutdown request notify_fatal_error()") + self.app._state["ready"] = False + self.app._state["alive"] = False + async def websocket_handler(self, request): """Send notifications to admin client over websocket.""" diff --git a/aries_cloudagent/admin/tests/test_admin_server.py b/aries_cloudagent/admin/tests/test_admin_server.py index 449d839fc5..a65e2cdf43 100644 --- a/aries_cloudagent/admin/tests/test_admin_server.py +++ b/aries_cloudagent/admin/tests/test_admin_server.py @@ -177,6 +177,15 @@ async def test_import_routes(self): server = self.get_admin_server({"admin.admin_insecure_mode": True}, context) app = await server.make_application() + async def test_register_external_plugin_x(self): + context = InjectionContext() + context.injector.bind_instance(ProtocolRegistry, ProtocolRegistry()) + with self.assertRaises(ValueError): + builder = DefaultContextBuilder( + settings={"external_plugins": "aries_cloudagent.nosuchmodule"} + ) + await builder.load_plugins(context) + async def test_visit_insecure_mode(self): settings = {"admin.admin_insecure_mode": True, "task_queue": True} server = self.get_admin_server(settings) @@ -257,3 +266,36 @@ async def test_visit_shutting_down(self): ) as response: assert response.status == 200 await server.stop() + + async def test_server_health_state(self): + settings = { + "admin.admin_insecure_mode": True, + } + server = self.get_admin_server(settings) + await server.start() + + async with self.client_session.get( + f"http://127.0.0.1:{self.port}/status/live", headers={} + ) as response: + assert response.status == 200 + response_json = await response.json() + assert response_json["alive"] + + async with self.client_session.get( + f"http://127.0.0.1:{self.port}/status/ready", headers={} + ) as response: + assert response.status == 200 + response_json = await response.json() + assert response_json["ready"] + + server.notify_fatal_error() + async with self.client_session.get( + f"http://127.0.0.1:{self.port}/status/live", headers={} + ) as response: + assert response.status == 503 + + async with self.client_session.get( + f"http://127.0.0.1:{self.port}/status/ready", headers={} + ) as response: + assert response.status == 503 + await server.stop() diff --git a/aries_cloudagent/commands/__init__.py b/aries_cloudagent/commands/__init__.py index 57d6388aa3..36490aa705 100644 --- a/aries_cloudagent/commands/__init__.py +++ b/aries_cloudagent/commands/__init__.py @@ -20,8 +20,7 @@ def load_command(command: str): for cmd in available_commands(): if cmd["name"] == command: module = cmd["name"] - if "module" in cmd: - module_path = cmd["module"] + module_path = cmd.get("module") break if module and not module_path: module_path = f"{__package__}.{module}" diff --git a/aries_cloudagent/commands/help.py b/aries_cloudagent/commands/help.py index 537d8b86be..e818c46d9a 100644 --- a/aries_cloudagent/commands/help.py +++ b/aries_cloudagent/commands/help.py @@ -31,5 +31,10 @@ def execute(argv: Sequence[str] = None): parser.print_help() -if __name__ == "__main__": - execute() +def main(): + """Execute the main line.""" + if __name__ == "__main__": + execute() + + +main() diff --git a/aries_cloudagent/commands/provision.py b/aries_cloudagent/commands/provision.py index f4b54b44d8..327420b3dc 100644 --- a/aries_cloudagent/commands/provision.py +++ b/aries_cloudagent/commands/provision.py @@ -56,5 +56,10 @@ def execute(argv: Sequence[str] = None): loop.run_until_complete(provision(settings)) -if __name__ == "__main__": - execute() +def main(): + """Execute the main line.""" + if __name__ == "__main__": + execute() + + +main() diff --git a/aries_cloudagent/commands/tests/test_help.py b/aries_cloudagent/commands/tests/test_help.py index b96a5462f6..1fd3ed4d19 100644 --- a/aries_cloudagent/commands/tests/test_help.py +++ b/aries_cloudagent/commands/tests/test_help.py @@ -8,6 +8,20 @@ class TestHelp(AsyncTestCase): def test_exec_help(self): with async_mock.patch.object( command.ArgumentParser, "print_help" - ) as print_help: + ) as mock_print_help, async_mock.patch( + "builtins.print", async_mock.MagicMock() + ) as mock_print: command.execute([]) - print_help.assert_called_once() + mock_print_help.assert_called_once() + + command.execute(["-v"]) + mock_print.assert_called_once_with(command.__version__) + + def test_main(self): + with async_mock.patch.object( + command, "__name__", "__main__" + ) as mock_name, async_mock.patch.object( + command, "execute", async_mock.MagicMock() + ) as mock_execute: + command.main() + mock_execute.assert_called_once diff --git a/aries_cloudagent/commands/tests/test_init.py b/aries_cloudagent/commands/tests/test_init.py new file mode 100644 index 0000000000..f199b87649 --- /dev/null +++ b/aries_cloudagent/commands/tests/test_init.py @@ -0,0 +1,22 @@ +from asynctest import TestCase as AsyncTestCase +from asynctest import mock as async_mock + +from ... import commands as test_module + + +class TestInit(AsyncTestCase): + def test_available(self): + avail = test_module.available_commands() + assert len(avail) == 3 + + def test_run(self): + with async_mock.patch.object( + test_module, "load_command", async_mock.MagicMock() + ) as mock_load: + mock_module = async_mock.MagicMock() + mock_module.execute = async_mock.MagicMock() + mock_load.return_value = mock_module + + test_module.run_command("hello", ["world"]) + mock_load.assert_called_once() + mock_module.execute.assert_called_once() diff --git a/aries_cloudagent/commands/tests/test_provision.py b/aries_cloudagent/commands/tests/test_provision.py index 28c4a2b7ad..9f69979f00 100644 --- a/aries_cloudagent/commands/tests/test_provision.py +++ b/aries_cloudagent/commands/tests/test_provision.py @@ -17,3 +17,20 @@ def test_bad_calls(self): def test_provision_wallet(self): test_seed = "testseed000000000000000000000001" command.execute(["--wallet-type", "indy", "--seed", test_seed]) + + async def test_provision_ledger_configured(self): + with async_mock.patch.object( + command, "wallet_config", async_mock.CoroutineMock() + ) as mock_wallet_config, async_mock.patch.object( + command, "ledger_config", async_mock.CoroutineMock(return_value=True) + ) as mock_ledger_config: + await command.provision({}) + + def test_main(self): + with async_mock.patch.object( + command, "__name__", "__main__" + ) as mock_name, async_mock.patch.object( + command, "execute", async_mock.MagicMock() + ) as mock_execute: + command.main() + mock_execute.assert_called_once diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index a8cd752a8b..ad6ebc1a00 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -394,12 +394,24 @@ def add_arguments(self, parser: ArgumentParser): The endpoints are used in the formation of a connection\ with another agent.", ) + parser.add_argument( + "--profile-endpoint", + type=str, + metavar="", + help="Specifies the profile endpoint for the (public) DID.", + ) parser.add_argument( "--read-only-ledger", action="store_true", help="Sets ledger to read-only to prevent updates.\ Default: false.", ) + parser.add_argument( + "--tails-server-base-url", + type=str, + metavar="", + help="Sets the base url of the tails server in use.", + ) def get_settings(self, args: Namespace) -> dict: """Extract general settings.""" @@ -407,12 +419,16 @@ def get_settings(self, args: Namespace) -> dict: if args.external_plugins: settings["external_plugins"] = args.external_plugins if args.storage_type: - settings["storage.type"] = args.storage_type + settings["storage_type"] = args.storage_type if args.endpoint: settings["default_endpoint"] = args.endpoint[0] settings["additional_endpoints"] = args.endpoint[1:] + if args.profile_endpoint: + settings["profile_endpoint"] = args.profile_endpoint if args.read_only_ledger: settings["read_only_ledger"] = True + if args.tails_server_base_url: + settings["tails_server_base_url"] = args.tails_server_base_url return settings diff --git a/aries_cloudagent/config/base.py b/aries_cloudagent/config/base.py index 1f04378682..b21d02e412 100644 --- a/aries_cloudagent/config/base.py +++ b/aries_cloudagent/config/base.py @@ -73,12 +73,12 @@ def __iter__(self): def __getitem__(self, index): """Fetch as an array index.""" if not isinstance(index, str): - raise TypeError("Index must be a string") + raise TypeError(f"Index {index} must be a string") missing = object() result = self.get_value(index, default=missing) if result is missing: raise KeyError("Undefined index: {}".format(index)) - return self.get_value(index) + return result @abstractmethod def __len__(self): @@ -111,7 +111,7 @@ async def inject( base_cls: type, settings: Mapping[str, object] = None, *, - required: bool = True + required: bool = True, ) -> object: """ Get the provided instance of a given class identifier. diff --git a/aries_cloudagent/config/default_context.py b/aries_cloudagent/config/default_context.py index edfe1f8bbd..7fe3caf73f 100644 --- a/aries_cloudagent/config/default_context.py +++ b/aries_cloudagent/config/default_context.py @@ -13,6 +13,7 @@ from ..issuer.base import BaseIssuer from ..holder.base import BaseHolder from ..verifier.base import BaseVerifier +from ..tails.base import BaseTailsServer from ..protocols.actionmenu.v1_0.base_service import BaseMenuService from ..protocols.actionmenu.v1_0.driver_service import DriverMenuService @@ -119,6 +120,10 @@ async def bind_providers(self, context: InjectionContext): ClassProvider.Inject(BaseLedger), ), ) + context.injector.bind_provider( + BaseTailsServer, + ClassProvider("aries_cloudagent.tails.indy_tails_server.IndyTailsServer",), + ) # Register default pack format context.injector.bind_provider( diff --git a/aries_cloudagent/config/injection_context.py b/aries_cloudagent/config/injection_context.py index 89d1213978..0d6b5f340e 100644 --- a/aries_cloudagent/config/injection_context.py +++ b/aries_cloudagent/config/injection_context.py @@ -25,7 +25,7 @@ def __init__( ): """Initialize a `ServiceConfig`.""" self._injector = Injector(settings, enforce_typing=enforce_typing) - self._scope_name = self.ROOT_SCOPE + self._scope_name = InjectionContext.ROOT_SCOPE self._scopes = [] @property diff --git a/aries_cloudagent/config/ledger.py b/aries_cloudagent/config/ledger.py index 58a81002b1..827865ca4b 100644 --- a/aries_cloudagent/config/ledger.py +++ b/aries_cloudagent/config/ledger.py @@ -11,6 +11,8 @@ from prompt_toolkit.formatted_text import HTML from ..ledger.base import BaseLedger +from ..ledger.endpoint_type import EndpointType +from ..ledger.error import LedgerError from ..utils.http import fetch, FetchError from ..wallet.base import BaseWallet @@ -57,29 +59,40 @@ async def ledger_config( if not ledger: LOGGER.info("Ledger instance not provided") return False - elif ledger.LEDGER_TYPE != "indy": + elif ledger.type != "indy": LOGGER.info("Non-indy ledger provided") return False async with ledger: # Check transaction author agreement acceptance - taa_info = await ledger.get_txn_author_agreement() - if taa_info["taa_required"] and public_did: - taa_accepted = await ledger.get_latest_txn_author_acceptance() - if ( - not taa_accepted - or taa_info["taa_record"]["digest"] != taa_accepted["digest"] - ): - if not await accept_taa(ledger, taa_info, provision): - return False - - # Publish endpoint if necessary - skipped if TAA is required but not accepted + if not context.settings.get("read_only_ledger"): + taa_info = await ledger.get_txn_author_agreement() + if taa_info["taa_required"] and public_did: + taa_accepted = await ledger.get_latest_txn_author_acceptance() + if ( + not taa_accepted + or taa_info["taa_record"]["digest"] != taa_accepted["digest"] + ): + if not await accept_taa(ledger, taa_info, provision): + return False + + # Publish endpoints if necessary - skipped if TAA is required but not accepted endpoint = context.settings.get("default_endpoint") if public_did: wallet: BaseWallet = await context.inject(BaseWallet) - if wallet.WALLET_TYPE != "indy": + if wallet.type != "indy": raise ConfigError("Cannot provision a non-Indy wallet type") - await wallet.set_did_endpoint(public_did, endpoint, ledger) + try: + await wallet.set_did_endpoint(public_did, endpoint, ledger) + except LedgerError as x_ledger: + raise ConfigError(x_ledger.message) from x_ledger # e.g., read-only + + # Publish profile endpoint + profile_endpoint = context.settings.get("profile_endpoint") + if profile_endpoint: + await ledger.update_endpoint_for_did( + public_did, profile_endpoint, EndpointType.PROFILE + ) return True diff --git a/aries_cloudagent/config/tests/test_default_context.py b/aries_cloudagent/config/tests/test_default_context.py index dae1867311..f9211c1f15 100644 --- a/aries_cloudagent/config/tests/test_default_context.py +++ b/aries_cloudagent/config/tests/test_default_context.py @@ -1,3 +1,5 @@ +from tempfile import NamedTemporaryFile + from asynctest import TestCase as AsyncTestCase from ...core.protocol_registry import ProtocolRegistry @@ -24,3 +26,12 @@ async def test_build_context(self): BaseStorage, ): assert isinstance(await result.inject(cls), cls) + + builder = DefaultContextBuilder( + settings={ + "timing.enabled": True, + "timing.log.file": NamedTemporaryFile().name, + } + ) + result = await builder.build() + assert isinstance(result, InjectionContext) diff --git a/aries_cloudagent/config/tests/test_injection_context.py b/aries_cloudagent/config/tests/test_injection_context.py index f0cf490aac..0a44a6b97d 100644 --- a/aries_cloudagent/config/tests/test_injection_context.py +++ b/aries_cloudagent/config/tests/test_injection_context.py @@ -25,6 +25,11 @@ def test_simple_scope(self): self.test_instance.start_scope(None) with self.assertRaises(InjectionContextError): self.test_instance.start_scope(self.test_instance.ROOT_SCOPE) + + injector = self.test_instance.injector_for_scope(self.test_instance.ROOT_SCOPE) + assert injector == self.test_instance.injector + assert self.test_instance.injector_for_scope("no such scope") is None + context = self.test_instance.start_scope(self.test_scope) assert context.scope_name == self.test_scope with self.assertRaises(InjectionContextError): @@ -48,6 +53,9 @@ async def test_inject_simple(self): self.test_instance.injector.bind_instance(str, self.test_value) assert (await self.test_instance.inject(str)) is self.test_value + self.test_instance.injector = None + assert self.test_instance.injector is None + async def test_inject_scope(self): """Test a scoped injection.""" context = self.test_instance.start_scope(self.test_scope) diff --git a/aries_cloudagent/config/tests/test_injector.py b/aries_cloudagent/config/tests/test_injector.py index ded348a066..8dcce4d227 100644 --- a/aries_cloudagent/config/tests/test_injector.py +++ b/aries_cloudagent/config/tests/test_injector.py @@ -50,11 +50,20 @@ async def test_inject_simple(self): self.test_instance.bind_instance(str, self.test_value) assert (await self.test_instance.inject(str)) is self.test_value + async def test_inject_x(self): + """Test injection failure on null base class.""" + with self.assertRaises(InjectorError): + await self.test_instance.inject(None) + async def test_inject_provider(self): """Test a provider injection.""" mock_provider = MockProvider(self.test_value) + + with self.assertRaises(ValueError): + self.test_instance.bind_provider(str, None) self.test_instance.bind_provider(str, mock_provider) assert self.test_instance.get_provider(str) is mock_provider + override_settings = {self.test_key: "NEWVAL"} assert ( await self.test_instance.inject(str, override_settings) @@ -108,6 +117,12 @@ async def test_inject_class_dependency(self): assert instance.value is test_str assert instance.kwargs["param"] is test_int + self.test_instance.clear_binding(int) + self.test_instance.clear_binding(str) + self.test_instance.bind_instance(str, test_int) + with self.assertRaises(InjectorError): + await self.test_instance.inject(str) + async def test_inject_cached(self): """Test a provider class injection.""" with self.assertRaises(ValueError): diff --git a/aries_cloudagent/config/tests/test_ledger.py b/aries_cloudagent/config/tests/test_ledger.py new file mode 100644 index 0000000000..94b6df61a3 --- /dev/null +++ b/aries_cloudagent/config/tests/test_ledger.py @@ -0,0 +1,333 @@ +from os import remove +from tempfile import NamedTemporaryFile + +from asynctest import TestCase as AsyncTestCase, mock as async_mock + +from ...ledger.base import BaseLedger +from ...ledger.error import LedgerError +from ...wallet.base import BaseWallet + +from .. import ledger as test_module +from ..injection_context import InjectionContext + +TEST_DID = "55GkHamhTU1ZbTbV2ab9DE" + + +class TestLedger(AsyncTestCase): + async def test_fetch_genesis_transactions(self): + with async_mock.patch.object( + test_module, "fetch", async_mock.CoroutineMock() + ) as mock_fetch: + await test_module.fetch_genesis_transactions("http://1.2.3.4:9000/genesis") + + async def test_fetch_genesis_transactions_x(self): + with async_mock.patch.object( + test_module, "fetch", async_mock.CoroutineMock() + ) as mock_fetch: + mock_fetch.side_effect = test_module.FetchError("404 Not Found") + with self.assertRaises(test_module.ConfigError): + await test_module.fetch_genesis_transactions( + "http://1.2.3.4:9000/genesis" + ) + + async def test_ledger_config_genesis_url(self): + settings = { + "ledger.genesis_url": "00000000000000000000000000000000", + "default_endpoint": "http://1.2.3.4:8051", + "profile_endpoint": "http://agent.ca", + } + mock_ledger = async_mock.MagicMock( + type="indy", + get_txn_author_agreement=async_mock.CoroutineMock( + return_value={ + "taa_required": True, + "taa_record": { + "digest": b"ffffffffffffffffffffffffffffffffffffffff" + }, + } + ), + get_latest_txn_author_acceptance=async_mock.CoroutineMock( + return_value={"digest": b"1234567890123456789012345678901234567890"} + ), + update_endpoint_for_did=async_mock.CoroutineMock(), + ) + mock_wallet = async_mock.MagicMock( + type="indy", set_did_endpoint=async_mock.CoroutineMock() + ) + + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseLedger, mock_ledger) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "fetch_genesis_transactions", async_mock.CoroutineMock() + ) as mock_fetch, async_mock.patch.object( + test_module, "accept_taa", async_mock.CoroutineMock() + ) as mock_accept_taa: + mock_accept_taa.return_value = True + await test_module.ledger_config(context, TEST_DID, provision=True) + + async def test_ledger_config_genesis_file(self): + settings = { + "ledger.genesis_file": "/tmp/genesis/path", + "default_endpoint": "http://1.2.3.4:8051", + } + mock_ledger = async_mock.MagicMock( + type="indy", + get_txn_author_agreement=async_mock.CoroutineMock( + return_value={ + "taa_required": True, + "taa_record": { + "digest": b"ffffffffffffffffffffffffffffffffffffffff" + }, + } + ), + get_latest_txn_author_acceptance=async_mock.CoroutineMock( + return_value={"digest": b"1234567890123456789012345678901234567890"} + ), + ) + mock_wallet = async_mock.MagicMock( + type="indy", set_did_endpoint=async_mock.CoroutineMock() + ) + + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseLedger, mock_ledger) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "accept_taa", async_mock.CoroutineMock() + ) as mock_accept_taa, async_mock.patch( + "builtins.open", async_mock.MagicMock() + ) as mock_open: + mock_open.return_value = async_mock.MagicMock( + __enter__=async_mock.MagicMock( + return_value=async_mock.MagicMock( + read=async_mock.MagicMock( + return_value="... genesis transactions ..." + ) + ) + ) + ) + mock_accept_taa.return_value = True + await test_module.ledger_config(context, TEST_DID, provision=True) + + async def test_ledger_config_genesis_file_io_x(self): + settings = { + "ledger.genesis_file": "/tmp/genesis/path", + "default_endpoint": "http://1.2.3.4:8051", + } + context = InjectionContext(settings=settings, enforce_typing=False) + + with async_mock.patch.object( + test_module, "fetch_genesis_transactions", async_mock.CoroutineMock() + ) as mock_fetch, async_mock.patch( + "builtins.open", async_mock.MagicMock() + ) as mock_open: + mock_open.side_effect = IOError("no read permission") + with self.assertRaises(test_module.ConfigError): + await test_module.ledger_config(context, TEST_DID, provision=True) + + async def test_ledger_config_genesis_url_no_ledger(self): + settings = { + "ledger.genesis_url": "00000000000000000000000000000000", + "default_endpoint": "http://1.2.3.4:8051", + } + + context = InjectionContext(settings=settings, enforce_typing=False) + + with async_mock.patch.object( + test_module, "fetch_genesis_transactions", async_mock.CoroutineMock() + ) as mock_fetch, async_mock.patch.object( + test_module, "accept_taa", async_mock.CoroutineMock() + ) as mock_accept_taa: + mock_accept_taa.return_value = True + assert not await test_module.ledger_config( + context, TEST_DID, provision=True + ) + + async def test_ledger_config_genesis_url_non_indy_ledger(self): + settings = { + "ledger.genesis_url": "00000000000000000000000000000000", + "default_endpoint": "http://1.2.3.4:8051", + } + mock_ledger = async_mock.MagicMock( + type="fabric", + get_txn_author_agreement=async_mock.CoroutineMock( + return_value={ + "taa_required": True, + "taa_record": { + "digest": b"ffffffffffffffffffffffffffffffffffffffff" + }, + } + ), + get_latest_txn_author_acceptance=async_mock.CoroutineMock( + return_value={"digest": b"1234567890123456789012345678901234567890"} + ), + ) + + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseLedger, mock_ledger) + + with async_mock.patch.object( + test_module, "fetch_genesis_transactions", async_mock.CoroutineMock() + ) as mock_fetch, async_mock.patch.object( + test_module, "accept_taa", async_mock.CoroutineMock() + ) as mock_accept_taa: + mock_accept_taa.return_value = True + assert not await test_module.ledger_config( + context, TEST_DID, provision=True + ) + + async def test_ledger_config_genesis_url_no_taa_accept(self): + settings = { + "ledger.genesis_url": "00000000000000000000000000000000", + "default_endpoint": "http://1.2.3.4:8051", + } + mock_ledger = async_mock.MagicMock( + type="indy", + get_txn_author_agreement=async_mock.CoroutineMock( + return_value={ + "taa_required": True, + "taa_record": { + "digest": b"ffffffffffffffffffffffffffffffffffffffff" + }, + } + ), + get_latest_txn_author_acceptance=async_mock.CoroutineMock( + return_value={"digest": b"1234567890123456789012345678901234567890"} + ), + ) + + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseLedger, mock_ledger) + + with async_mock.patch.object( + test_module, "fetch_genesis_transactions", async_mock.CoroutineMock() + ) as mock_fetch, async_mock.patch.object( + test_module, "accept_taa", async_mock.CoroutineMock() + ) as mock_accept_taa: + mock_accept_taa.return_value = False + assert not await test_module.ledger_config( + context, TEST_DID, provision=True + ) + + async def test_ledger_config_read_only_skip_taa_accept(self): + settings = { + "ledger.genesis_url": "00000000000000000000000000000000", + "read_only_ledger": True, + } + mock_ledger = async_mock.MagicMock( + type="indy", + get_txn_author_agreement=async_mock.CoroutineMock(), + get_latest_txn_author_acceptance=async_mock.CoroutineMock(), + ) + mock_wallet = async_mock.MagicMock( + type="indy", + set_did_endpoint=async_mock.CoroutineMock( + side_effect=LedgerError( + "Error cannot update endpoint when ledger is in read only mode" + ) + ), + ) + + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseLedger, mock_ledger) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "fetch_genesis_transactions", async_mock.CoroutineMock() + ) as mock_fetch, async_mock.patch.object( + test_module, "accept_taa", async_mock.CoroutineMock() + ) as mock_accept_taa: + with self.assertRaises(test_module.ConfigError) as x_context: + await test_module.ledger_config(context, TEST_DID, provision=True) + assert "ledger is in read only mode" in str(x_context.exception) + mock_ledger.get_txn_author_agreement.assert_not_called() + mock_ledger.get_latest_txn_author_acceptance.assert_not_called() + + async def test_ledger_config_genesis_file_non_indy_wallet(self): + settings = { + "ledger.genesis_file": "/tmp/genesis/path", + "default_endpoint": "http://1.2.3.4:8051", + } + mock_ledger = async_mock.MagicMock( + type="indy", + get_txn_author_agreement=async_mock.CoroutineMock( + return_value={ + "taa_required": True, + "taa_record": { + "digest": b"ffffffffffffffffffffffffffffffffffffffff" + }, + } + ), + get_latest_txn_author_acceptance=async_mock.CoroutineMock( + return_value={"digest": b"1234567890123456789012345678901234567890"} + ), + ) + mock_wallet = async_mock.MagicMock( + type="trifold", set_did_endpoint=async_mock.CoroutineMock() + ) + + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseLedger, mock_ledger) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "accept_taa", async_mock.CoroutineMock() + ) as mock_accept_taa, async_mock.patch( + "builtins.open", async_mock.MagicMock() + ) as mock_open: + mock_open.return_value = async_mock.MagicMock( + __enter__=async_mock.MagicMock( + return_value=async_mock.MagicMock( + read=async_mock.MagicMock( + return_value="... genesis transactions ..." + ) + ) + ) + ) + mock_accept_taa.return_value = True + with self.assertRaises(test_module.ConfigError): + await test_module.ledger_config(context, TEST_DID, provision=True) + + @async_mock.patch("sys.stdout") + async def test_ledger_accept_taa_not_tty(self, mock_stdout): + mock_stdout.isatty = async_mock.MagicMock(return_value=False) + + assert not await test_module.accept_taa(None, None, provision=False) + + @async_mock.patch("sys.stdout") + async def test_ledger_accept_taa(self, mock_stdout): + mock_stdout.isatty = async_mock.MagicMock(return_value=True) + + taa_info = { + "taa_record": {"version": "1.0", "text": "Agreement"}, + "aml_record": {"aml": ["wallet_agreement", "on_file"]}, + } + + with async_mock.patch.object( + test_module, "use_asyncio_event_loop", async_mock.MagicMock() + ) as mock_use_aio_loop, async_mock.patch.object( + test_module.prompt_toolkit, "prompt", async_mock.CoroutineMock() + ) as mock_prompt: + mock_prompt.side_effect = EOFError() + assert not await test_module.accept_taa(None, taa_info, provision=False) + + with async_mock.patch.object( + test_module, "use_asyncio_event_loop", async_mock.MagicMock() + ) as mock_use_aio_loop, async_mock.patch.object( + test_module.prompt_toolkit, "prompt", async_mock.CoroutineMock() + ) as mock_prompt: + mock_prompt.return_value = "x" + assert not await test_module.accept_taa(None, taa_info, provision=False) + + with async_mock.patch.object( + test_module, "use_asyncio_event_loop", async_mock.MagicMock() + ) as mock_use_aio_loop, async_mock.patch.object( + test_module.prompt_toolkit, "prompt", async_mock.CoroutineMock() + ) as mock_prompt: + mock_ledger = async_mock.MagicMock( + accept_txn_author_agreement=async_mock.CoroutineMock() + ) + mock_prompt.return_value = "" + assert await test_module.accept_taa(mock_ledger, taa_info, provision=False) diff --git a/aries_cloudagent/config/tests/test_logging.py b/aries_cloudagent/config/tests/test_logging.py index 9aaacedf6e..82867d32e7 100644 --- a/aries_cloudagent/config/tests/test_logging.py +++ b/aries_cloudagent/config/tests/test_logging.py @@ -1,6 +1,8 @@ import contextlib + +from asynctest import mock as async_mock from io import StringIO -from asynctest import mock +from tempfile import NamedTemporaryFile from .. import logging as test_module @@ -11,8 +13,8 @@ class TestLoggingConfigurator: host_arg_value = "host" port_arg_value = "port" - @mock.patch.object(test_module, "load_resource", autospec=True) - @mock.patch.object(test_module, "fileConfig", autospec=True) + @async_mock.patch.object(test_module, "load_resource", autospec=True) + @async_mock.patch.object(test_module, "fileConfig", autospec=True) def test_configure_default(self, mock_file_config, mock_load_resource): test_module.LoggingConfigurator.configure() @@ -23,8 +25,25 @@ def test_configure_default(self, mock_file_config, mock_load_resource): mock_load_resource.return_value, disable_existing_loggers=False ) - @mock.patch.object(test_module, "load_resource", autospec=True) - @mock.patch.object(test_module, "fileConfig", autospec=True) + def test_configure_default_no_resource(self): + with async_mock.patch.object( + test_module, "load_resource", async_mock.MagicMock() + ) as mock_load: + mock_load.return_value = None + test_module.LoggingConfigurator.configure() + + def test_configure_default_file(self): + log_file = NamedTemporaryFile() + with async_mock.patch.object( + test_module, "load_resource", async_mock.MagicMock() + ) as mock_load: + mock_load.return_value = None + test_module.LoggingConfigurator.configure( + log_level="ERROR", log_file=log_file.name + ) + + @async_mock.patch.object(test_module, "load_resource", autospec=True) + @async_mock.patch.object(test_module, "fileConfig", autospec=True) def test_configure_path(self, mock_file_config, mock_load_resource): path = "a path" test_module.LoggingConfigurator.configure(path) @@ -36,9 +55,39 @@ def test_configure_path(self, mock_file_config, mock_load_resource): def test_banner(self): stdout = StringIO() + mock_http = async_mock.MagicMock(scheme="http", host="1.2.3.4", port=8081) + mock_https = async_mock.MagicMock(schemes=["https", "archie"]) + mock_admin_server = async_mock.MagicMock(host="1.2.3.4", port=8091) with contextlib.redirect_stdout(stdout): test_label = "Aries Cloud Agent" test_did = "55GkHamhTU1ZbTbV2ab9DE" - test_module.LoggingConfigurator.print_banner(test_label, {}, {}, test_did) + test_module.LoggingConfigurator.print_banner( + test_label, + {"in": mock_http}, + {"out": mock_https}, + test_did, + mock_admin_server, + ) + test_module.LoggingConfigurator.print_banner( + test_label, {"in": mock_http}, {"out": mock_https}, test_did + ) output = stdout.getvalue() assert test_did in output + + def test_load_resource(self): + with async_mock.patch("builtins.open", async_mock.MagicMock()) as mock_open: + test_module.load_resource("abc", encoding="utf-8") + mock_open.side_effect = IOError("insufficient privilege") + test_module.load_resource("abc", encoding="utf-8") + + with async_mock.patch.object( + test_module.pkg_resources, "resource_stream", async_mock.MagicMock() + ) as mock_res_stream, async_mock.patch.object( + test_module, "TextIOWrapper", async_mock.MagicMock() + ) as mock_text_io_wrapper: + test_module.load_resource("abc:def", encoding="utf-8") + + with async_mock.patch.object( + test_module.pkg_resources, "resource_stream", async_mock.MagicMock() + ) as mock_res_stream: + test_module.load_resource("abc:def", encoding=None) diff --git a/aries_cloudagent/config/tests/test_provider.py b/aries_cloudagent/config/tests/test_provider.py new file mode 100644 index 0000000000..593acba736 --- /dev/null +++ b/aries_cloudagent/config/tests/test_provider.py @@ -0,0 +1,36 @@ +from tempfile import NamedTemporaryFile + +from asynctest import TestCase as AsyncTestCase, mock as async_mock + +from ...storage.provider import StorageProvider +from ...utils.stats import Collector +from ...wallet.base import BaseWallet +from ...wallet.basic import BasicWallet + +from ..injection_context import InjectionContext +from ..provider import StatsProvider +from ..settings import Settings + + +class TestProvider(AsyncTestCase): + async def test_stats_provider_init_x(self): + """Cover stats provider init error on no provider.""" + with self.assertRaises(ValueError): + StatsProvider(None, ["method"]) + + async def test_stats_provider_provide_collector(self): + """Cover call to provide with collector.""" + + timing_log = NamedTemporaryFile().name + settings = {"timing.enabled": True, "timing.log.file": timing_log} + stats_provider = StatsProvider( + StorageProvider(), ("add_record", "get_record", "search_records") + ) + collector = Collector(log_path=timing_log) + + wallet = BasicWallet() + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(Collector, collector) + context.injector.bind_instance(BaseWallet, wallet) + + await stats_provider.provide(Settings(settings), context.injector) diff --git a/aries_cloudagent/config/tests/test_settings.py b/aries_cloudagent/config/tests/test_settings.py index 4465a78316..583d8a42ef 100644 --- a/aries_cloudagent/config/tests/test_settings.py +++ b/aries_cloudagent/config/tests/test_settings.py @@ -1,6 +1,8 @@ +import pytest + from unittest import TestCase -from ..base import BaseSettings, SettingsError +from ..base import SettingsError from ..settings import Settings @@ -21,9 +23,14 @@ def test_settings_init(self): ) with self.assertRaises(KeyError): self.test_instance["MISSING"] + assert len(self.test_instance) == 1 + assert len(self.test_instance.copy()) == 1 def test_get_formats(self): """Test retrieval with formatting.""" + assert "Settings" in str(self.test_instance) + with pytest.raises(TypeError): + self.test_instance[0] # cover wrong type self.test_instance["BOOL"] = "true" assert self.test_instance.get_bool("BOOL") is True self.test_instance["BOOL"] = "false" diff --git a/aries_cloudagent/config/tests/test_wallet.py b/aries_cloudagent/config/tests/test_wallet.py new file mode 100644 index 0000000000..dcdf9dc416 --- /dev/null +++ b/aries_cloudagent/config/tests/test_wallet.py @@ -0,0 +1,133 @@ +from asynctest import TestCase as AsyncTestCase, mock as async_mock + +from ...wallet.base import BaseWallet + +from .. import wallet as test_module +from ..injection_context import InjectionContext + +TEST_DID = "55GkHamhTU1ZbTbV2ab9DE" +TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + + +class TestWallet(AsyncTestCase): + async def test_wallet_config_existing_replace(self): + settings = { + "wallet.seed": "00000000000000000000000000000000", + "wallet.replace_public_did": True, + "debug.enabled": True, + } + mock_wallet = async_mock.MagicMock( + type="indy", + name="Test Wallet", + created=False, + get_public_did=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(did=TEST_DID, verkey=TEST_VERKEY) + ), + set_public_did=async_mock.CoroutineMock(), + create_local_did=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(did=TEST_DID, verkey=TEST_VERKEY) + ), + ) + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "seed_to_did", async_mock.MagicMock() + ) as mock_seed_to_did: + mock_seed_to_did.return_value = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + + await test_module.wallet_config(context, provision=True) + + async def test_wallet_config_bad_seed_x(self): + settings = { + "wallet.seed": "00000000000000000000000000000000", + } + mock_wallet = async_mock.MagicMock( + type="indy", + name="Test Wallet", + created=False, + get_public_did=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(did=TEST_DID, verkey=TEST_VERKEY) + ), + ) + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "seed_to_did", async_mock.MagicMock() + ) as mock_seed_to_did: + mock_seed_to_did.return_value = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + + with self.assertRaises(test_module.ConfigError): + await test_module.wallet_config(context, provision=True) + + async def test_wallet_config_seed_local(self): + settings = { + "wallet.seed": "00000000000000000000000000000000", + "wallet.local_did": True, + } + mock_wallet = async_mock.MagicMock( + type="indy", + name="Test Wallet", + created=False, + get_public_did=async_mock.CoroutineMock(return_value=None), + set_public_did=async_mock.CoroutineMock(), + create_local_did=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(did=TEST_DID, verkey=TEST_VERKEY) + ), + ) + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "seed_to_did", async_mock.MagicMock() + ) as mock_seed_to_did: + mock_seed_to_did.return_value = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + + await test_module.wallet_config(context, provision=True) + + async def test_wallet_config_seed_public(self): + settings = { + "wallet.seed": "00000000000000000000000000000000", + } + mock_wallet = async_mock.MagicMock( + type="indy", + name="Test Wallet", + created=False, + get_public_did=async_mock.CoroutineMock(return_value=None), + set_public_did=async_mock.CoroutineMock(), + create_public_did=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(did=TEST_DID, verkey=TEST_VERKEY) + ), + ) + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "seed_to_did", async_mock.MagicMock() + ) as mock_seed_to_did: + mock_seed_to_did.return_value = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + + await test_module.wallet_config(context, provision=True) + + async def test_wallet_config_seed_no_public_did(self): + settings = {} + mock_wallet = async_mock.MagicMock( + type="indy", + name="Test Wallet", + created=False, + get_public_did=async_mock.CoroutineMock(return_value=None), + set_public_did=async_mock.CoroutineMock(), + create_public_did=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(did=TEST_DID, verkey=TEST_VERKEY) + ), + ) + context = InjectionContext(settings=settings, enforce_typing=False) + context.injector.bind_instance(BaseWallet, mock_wallet) + + with async_mock.patch.object( + test_module, "seed_to_did", async_mock.MagicMock() + ) as mock_seed_to_did: + mock_seed_to_did.return_value = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" + + await test_module.wallet_config(context, provision=True) diff --git a/aries_cloudagent/config/wallet.py b/aries_cloudagent/config/wallet.py index 0864f1a19e..466d2fd666 100644 --- a/aries_cloudagent/config/wallet.py +++ b/aries_cloudagent/config/wallet.py @@ -15,7 +15,7 @@ async def wallet_config(context: InjectionContext, provision: bool = False): """Initialize the wallet.""" wallet: BaseWallet = await context.inject(BaseWallet) if provision: - if wallet.WALLET_TYPE != "indy": + if wallet.type != "indy": raise ConfigError("Cannot provision a non-Indy wallet type") if wallet.created: print("Created new wallet") diff --git a/aries_cloudagent/connections/models/connection_target.py b/aries_cloudagent/connections/models/connection_target.py index 4ccf956691..ade3f47392 100644 --- a/aries_cloudagent/connections/models/connection_target.py +++ b/aries_cloudagent/connections/models/connection_target.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ...messaging.models.base import BaseModel, BaseModelSchema from ...messaging.valid import INDY_DID, INDY_RAW_PUBLIC_KEY @@ -51,6 +51,7 @@ class Meta: """ConnectionTargetSchema metadata.""" model_class = ConnectionTarget + unknown = EXCLUDE did = fields.Str(required=False, description="", **INDY_DID) endpoint = fields.Str( diff --git a/aries_cloudagent/connections/models/tests/__init__.py b/aries_cloudagent/connections/models/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aries_cloudagent/connections/models/tests/test_connection_target.py b/aries_cloudagent/connections/models/tests/test_connection_target.py new file mode 100644 index 0000000000..093313924f --- /dev/null +++ b/aries_cloudagent/connections/models/tests/test_connection_target.py @@ -0,0 +1,29 @@ +from asynctest import TestCase as AsyncTestCase + +from ..connection_target import ConnectionTarget + +TEST_DID = "55GkHamhTU1ZbTbV2ab9DE" +TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" +TEST_ENDPOINT = "http://localhost" + + +class TestConnectionTarget(AsyncTestCase): + def test_deser(self): + target = ConnectionTarget( + did=TEST_DID, + endpoint=TEST_ENDPOINT, + label="a label", + recipient_keys=[TEST_VERKEY], + routing_keys=[TEST_VERKEY], + sender_key=TEST_VERKEY, + ) + serial = target.serialize() + serial["extra-stuff"] = "to exclude" + deser = ConnectionTarget.deserialize(serial) + + assert deser.did == target.did + assert deser.endpoint == target.endpoint + assert deser.label == target.label + assert deser.recipient_keys == target.recipient_keys + assert deser.routing_keys == target.routing_keys + assert deser.sender_key == target.sender_key diff --git a/aries_cloudagent/core/conductor.py b/aries_cloudagent/core/conductor.py index 7627a5e66d..3a8cc87e09 100644 --- a/aries_cloudagent/core/conductor.py +++ b/aries_cloudagent/core/conductor.py @@ -17,7 +17,8 @@ from ..config.injection_context import InjectionContext from ..config.ledger import ledger_config from ..config.logging import LoggingConfigurator -from ..config.wallet import wallet_config +from ..config.wallet import wallet_config, BaseWallet +from ..ledger.error import LedgerConfigError, LedgerTransactionError from ..messaging.responder import BaseResponder from ..protocols.connections.v1_0.manager import ( ConnectionManager, @@ -86,6 +87,13 @@ async def setup(self): ) await self.outbound_transport_manager.setup() + # Configure the wallet + public_did = await wallet_config(context) + + # Configure the ledger + if not await ledger_config(context, public_did): + LOGGER.warning("No ledger configured") + # Admin API if context.settings.get("admin.enabled"): try: @@ -141,13 +149,6 @@ async def start(self) -> None: context = self.context - # Configure the wallet - public_did = await wallet_config(context) - - # Configure the ledger - if not await ledger_config(context, public_did): - LOGGER.warning("No ledger configured") - # Start up transports try: await self.inbound_transport_manager.start() @@ -174,12 +175,16 @@ async def start(self) -> None: # Get agent label default_label = context.settings.get("default_label") + # Get public did + wallet: BaseWallet = await context.inject(BaseWallet) + public_did = await wallet.get_public_did() + # Show some details about the configuration to the user LoggingConfigurator.print_banner( default_label, self.inbound_transport_manager.registered_transports, self.outbound_transport_manager.registered_transports, - public_did, + public_did.did if public_did else None, self.admin_server, ) @@ -251,12 +256,18 @@ def inbound_message_router( # Note: at this point we could send the message to a shared queue # if this pod is too busy to process it - self.dispatcher.queue_message( - message, - self.outbound_message_router, - self.admin_server and self.admin_server.send_webhook, - lambda completed: self.dispatch_complete(message, completed), - ) + try: + self.dispatcher.queue_message( + message, + self.outbound_message_router, + self.admin_server and self.admin_server.send_webhook, + lambda completed: self.dispatch_complete(message, completed), + ) + except (LedgerConfigError, LedgerTransactionError) as e: + LOGGER.error("Shutdown on ledger error %s", str(e)) + if self.admin_server: + self.admin_server.notify_fatal_error() + raise def dispatch_complete(self, message: InboundMessage, completed: CompletedTask): """Handle completion of message dispatch.""" @@ -264,6 +275,22 @@ def dispatch_complete(self, message: InboundMessage, completed: CompletedTask): LOGGER.exception( "Exception in message handler:", exc_info=completed.exc_info ) + if isinstance(completed.exc_info[1], LedgerConfigError) or isinstance( + completed.exc_info[1], LedgerTransactionError + ): + LOGGER.error( + "%shutdown on ledger error %s", + "S" if self.admin_server else "No admin server to s", + str(completed.exc_info[1]), + ) + if self.admin_server: + self.admin_server.notify_fatal_error() + else: + LOGGER.error( + "DON'T shutdown on %s %s", + completed.exc_info[0].__name__, + str(completed.exc_info[1]), + ) self.inbound_transport_manager.dispatch_complete(message, completed) async def get_stats(self) -> dict: @@ -310,7 +337,13 @@ async def outbound_message_router( def handle_not_returned(self, context: InjectionContext, outbound: OutboundMessage): """Handle a message that failed delivery via an inbound session.""" - self.dispatcher.run_task(self.queue_outbound(context, outbound)) + try: + self.dispatcher.run_task(self.queue_outbound(context, outbound)) + except (LedgerConfigError, LedgerTransactionError) as e: + LOGGER.error("Shutdown on ledger error %s", str(e)) + if self.admin_server: + self.admin_server.notify_fatal_error() + raise async def queue_outbound( self, @@ -337,6 +370,11 @@ async def queue_outbound( except ConnectionManagerError: LOGGER.exception("Error preparing outbound message for transmission") return + except (LedgerConfigError, LedgerTransactionError) as e: + LOGGER.error("Shutdown on ledger error %s", str(e)) + if self.admin_server: + self.admin_server.notify_fatal_error() + raise try: self.outbound_transport_manager.enqueue_message(context, outbound) diff --git a/aries_cloudagent/core/tests/test_conductor.py b/aries_cloudagent/core/tests/test_conductor.py index 6f843cd6c6..821d1722f9 100644 --- a/aries_cloudagent/core/tests/test_conductor.py +++ b/aries_cloudagent/core/tests/test_conductor.py @@ -35,11 +35,14 @@ class Config: test_settings = {"admin.webhook_urls": ["http://sample.webhook.ca"]} + test_settings_admin = { + "admin.webhook_urls": ["http://sample.webhook.ca"], + "admin.enabled": True, + } test_settings_with_queue = {"queue.enable_undelivered_queue": True} class TestDIDs: - test_seed = "testseed000000000000000000000001" test_did = "55GkHamhTU1ZbTbV2ab9DE" test_verkey = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" @@ -103,6 +106,9 @@ async def test_startup(self): await conductor.setup() + wallet = await conductor.context.inject(BaseWallet) + await wallet.create_public_did() + mock_inbound_mgr.return_value.setup.assert_awaited_once() mock_outbound_mgr.return_value.setup.assert_awaited_once() @@ -121,6 +127,39 @@ async def test_startup(self): mock_inbound_mgr.return_value.stop.assert_awaited_once_with() mock_outbound_mgr.return_value.stop.assert_awaited_once_with() + async def test_startup_no_public_did(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings) + conductor = test_module.Conductor(builder) + + with async_mock.patch.object( + test_module, "InboundTransportManager", autospec=True + ) as mock_inbound_mgr, async_mock.patch.object( + test_module, "OutboundTransportManager", autospec=True + ) as mock_outbound_mgr, async_mock.patch.object( + test_module, "LoggingConfigurator", autospec=True + ) as mock_logger: + + await conductor.setup() + + mock_inbound_mgr.return_value.setup.assert_awaited_once() + mock_outbound_mgr.return_value.setup.assert_awaited_once() + + mock_inbound_mgr.return_value.registered_transports = {} + mock_outbound_mgr.return_value.registered_transports = {} + + # Doesn't raise + await conductor.start() + + mock_inbound_mgr.return_value.start.assert_awaited_once_with() + mock_outbound_mgr.return_value.start.assert_awaited_once_with() + + mock_logger.print_banner.assert_called_once() + + await conductor.stop() + + mock_inbound_mgr.return_value.stop.assert_awaited_once_with() + mock_outbound_mgr.return_value.stop.assert_awaited_once_with() + async def test_stats(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) conductor = test_module.Conductor(builder) @@ -187,7 +226,7 @@ async def test_inbound_message_handler(self): with async_mock.patch.object( conductor.dispatcher, "queue_message", autospec=True - ) as mock_dispatch: + ) as mock_dispatch_q: message_body = "{}" receipt = MessageReceipt(direct_response_mode="snail mail") @@ -195,11 +234,34 @@ async def test_inbound_message_handler(self): conductor.inbound_message_router(message, can_respond=False) - mock_dispatch.assert_called_once() - assert mock_dispatch.call_args[0][0] is message - assert mock_dispatch.call_args[0][1] == conductor.outbound_message_router - assert mock_dispatch.call_args[0][2] is None # admin webhook router - assert callable(mock_dispatch.call_args[0][3]) + mock_dispatch_q.assert_called_once() + assert mock_dispatch_q.call_args[0][0] is message + assert mock_dispatch_q.call_args[0][1] == conductor.outbound_message_router + assert mock_dispatch_q.call_args[0][2] is None # admin webhook router + assert callable(mock_dispatch_q.call_args[0][3]) + + async def test_inbound_message_handler_ledger_x(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings_admin) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + with async_mock.patch.object( + conductor.dispatcher, "queue_message", autospec=True + ) as mock_dispatch_q, async_mock.patch.object( + conductor.admin_server, "notify_fatal_error", async_mock.MagicMock() + ) as mock_notify: + mock_dispatch_q.side_effect = test_module.LedgerConfigError("ledger down") + + message_body = "{}" + receipt = MessageReceipt(direct_response_mode="snail mail") + message = InboundMessage(message_body, receipt) + + with self.assertRaises(test_module.LedgerConfigError): + conductor.inbound_message_router(message, can_respond=False) + + mock_dispatch_q.assert_called_once() + mock_notify.assert_called_once() async def test_outbound_message_handler_return_route(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) @@ -332,7 +394,7 @@ async def test_handle_nots(self): with async_mock.patch.object( test_module, "ConnectionManager" ) as mock_conn_mgr, async_mock.patch.object( - conductor.dispatcher, "run_task", async_mock.CoroutineMock() + conductor.dispatcher, "run_task", async_mock.MagicMock() ) as mock_run_task: mock_conn_mgr.return_value.get_connection_targets = ( async_mock.CoroutineMock() @@ -348,6 +410,67 @@ async def test_handle_nots(self): await conductor.queue_outbound(conductor.context, message) mock_run_task.assert_called_once() + async def test_handle_not_returned_ledger_x(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings_admin) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + with async_mock.patch.object( + conductor.dispatcher, "run_task", async_mock.MagicMock() + ) as mock_dispatch_run, async_mock.patch.object( + conductor, "queue_outbound", async_mock.CoroutineMock() + ) as mock_queue, async_mock.patch.object( + conductor.admin_server, "notify_fatal_error", async_mock.MagicMock() + ) as mock_notify: + mock_dispatch_run.side_effect = test_module.LedgerConfigError( + "No such ledger" + ) + + payload = "{}" + message = OutboundMessage( + payload=payload, + connection_id="dummy-conn-id", + reply_to_verkey=TestDIDs.test_verkey, + ) + + with self.assertRaises(test_module.LedgerConfigError): + conductor.handle_not_returned(conductor.context, message) + + mock_dispatch_run.assert_called_once() + mock_notify.assert_called_once() + + async def test_queue_outbound_ledger_x(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings_admin) + conductor = test_module.Conductor(builder) + + await conductor.setup() + + with async_mock.patch.object( + test_module, "ConnectionManager", autospec=True + ) as conn_mgr, async_mock.patch.object( + conductor.dispatcher, "run_task", async_mock.MagicMock() + ) as mock_dispatch_run, async_mock.patch.object( + conductor.admin_server, "notify_fatal_error", async_mock.MagicMock() + ) as mock_notify: + conn_mgr.return_value.get_connection_targets = async_mock.CoroutineMock() + mock_dispatch_run.side_effect = test_module.LedgerConfigError( + "No such ledger" + ) + + payload = "{}" + message = OutboundMessage( + payload=payload, + connection_id="dummy-conn-id", + reply_to_verkey=TestDIDs.test_verkey, + ) + + with self.assertRaises(test_module.LedgerConfigError): + await conductor.queue_outbound(conductor.context, message) + + mock_dispatch_run.assert_called_once() + mock_notify.assert_called_once() + async def test_admin(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) builder.update_settings({"admin.enabled": "1"}) @@ -357,6 +480,9 @@ async def test_admin(self): admin = await conductor.context.inject(BaseAdminServer) assert admin is conductor.admin_server + wallet = await conductor.context.inject(BaseWallet) + await wallet.create_public_did() + with async_mock.patch.object( admin, "start", autospec=True ) as admin_start, async_mock.patch.object( @@ -377,6 +503,9 @@ async def test_admin_startx(self): admin = await conductor.context.inject(BaseAdminServer) assert admin is conductor.admin_server + wallet = await conductor.context.inject(BaseWallet) + await wallet.create_public_did() + with async_mock.patch.object( admin, "start", async_mock.CoroutineMock() ) as admin_start, async_mock.patch.object( @@ -415,6 +544,10 @@ async def test_start_static(self): with async_mock.patch.object(test_module, "ConnectionManager") as mock_mgr: await conductor.setup() + + wallet = await conductor.context.inject(BaseWallet) + await wallet.create_public_did() + mock_mgr.return_value.create_static_connection = async_mock.CoroutineMock() await conductor.start() mock_mgr.return_value.create_static_connection.assert_awaited_once() @@ -457,8 +590,8 @@ async def test_start_x_out(self): with self.assertRaises(KeyError): await conductor.start() - async def test_dispatch_complete(self): - builder: ContextBuilder = StubContextBuilder(self.test_settings) + async def test_dispatch_complete_non_fatal_x(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings_admin) conductor = test_module.Conductor(builder) message_body = "{}" @@ -476,7 +609,42 @@ async def test_dispatch_complete(self): ) await conductor.setup() - conductor.dispatch_complete(message, mock_task) + + with async_mock.patch.object( + conductor.admin_server, "notify_fatal_error", async_mock.MagicMock() + ) as mock_notify: + conductor.dispatch_complete(message, mock_task) + mock_notify.assert_not_called() + + async def test_dispatch_complete_fatal_x(self): + builder: ContextBuilder = StubContextBuilder(self.test_settings_admin) + conductor = test_module.Conductor(builder) + + message_body = "{}" + receipt = MessageReceipt(direct_response_mode="snail mail") + message = InboundMessage(message_body, receipt) + mock_task = async_mock.MagicMock( + exc_info=( + test_module.LedgerTransactionError, + test_module.LedgerTransactionError("Ledger is wobbly"), + "...", + ), + ident="abc", + timing={ + "queued": 1234567890, + "unqueued": 1234567899, + "started": 1234567901, + "ended": 1234567999, + }, + ) + + await conductor.setup() + + with async_mock.patch.object( + conductor.admin_server, "notify_fatal_error", async_mock.MagicMock() + ) as mock_notify: + conductor.dispatch_complete(message, mock_task) + mock_notify.assert_called_once_with() async def test_print_invite(self): builder: ContextBuilder = StubContextBuilder(self.test_settings) @@ -487,6 +655,10 @@ async def test_print_invite(self): with async_mock.patch("sys.stdout", new=StringIO()) as captured: await conductor.setup() + + wallet = await conductor.context.inject(BaseWallet) + await wallet.create_public_did() + await conductor.start() await conductor.stop() assert "http://localhost?c_i=" in captured.getvalue() diff --git a/aries_cloudagent/core/tests/test_dispatcher.py b/aries_cloudagent/core/tests/test_dispatcher.py index 998632a846..4a33a495a7 100644 --- a/aries_cloudagent/core/tests/test_dispatcher.py +++ b/aries_cloudagent/core/tests/test_dispatcher.py @@ -3,10 +3,13 @@ from asynctest import TestCase as AsyncTestCase, mock as async_mock +from marshmallow import EXCLUDE + from ...config.injection_context import InjectionContext from ...connections.models.connection_record import ConnectionRecord from ...core.protocol_registry import ProtocolRegistry from ...messaging.agent_message import AgentMessage, AgentMessageSchema +from ...messaging.responder import MockResponder from ...messaging.util import datetime_now from ...protocols.problem_report.v1_0.message import ProblemReport @@ -55,6 +58,7 @@ class Meta: class StubAgentMessageSchema(AgentMessageSchema): class Meta: model_class = StubAgentMessage + unknown = EXCLUDE class StubAgentMessageHandler: @@ -72,6 +76,7 @@ class Meta: class StubV1_2AgentMessageSchema(AgentMessageSchema): class Meta: model_class = StubV1_2AgentMessage + unknonw = EXCLUDE class StubV1_2AgentMessageHandler: @@ -301,3 +306,22 @@ async def test_create_outbound_send_webhook(self): result = await responder.create_outbound(message) assert json.loads(result.payload)["@type"] == StubAgentMessage.Meta.message_type await responder.send_webhook("topic", "payload") + + async def test_create_send_outbound(self): + message = StubAgentMessage() + responder = MockResponder() + outbound_message = await responder.create_outbound(message) + await responder.send_outbound(outbound_message) + assert len(responder.messages) == 1 + + async def test_create_enc_outbound(self): + context = make_context() + message = b"abc123xyz7890000" + responder = test_module.DispatcherResponder( + context, message, None, async_mock.CoroutineMock() + ) + with async_mock.patch.object( + responder, "send_outbound", async_mock.CoroutineMock() + ) as mock_send_outbound: + await responder.send(message) + assert mock_send_outbound.called_once() diff --git a/aries_cloudagent/holder/routes.py b/aries_cloudagent/holder/routes.py index 03f1b0f3d1..06b342a102 100644 --- a/aries_cloudagent/holder/routes.py +++ b/aries_cloudagent/holder/routes.py @@ -4,9 +4,10 @@ from aiohttp import web from aiohttp_apispec import docs, match_info_schema, querystring_schema, response_schema -from marshmallow import fields, Schema +from marshmallow import fields from .base import BaseHolder, HolderError +from ..messaging.models.openapi import OpenAPISchema from ..messaging.valid import ( INDY_CRED_DEF_ID, INDY_REV_REG_ID, @@ -19,11 +20,11 @@ from ..wallet.error import WalletNotFoundError -class AttributeMimeTypesResultSchema(Schema): +class AttributeMimeTypesResultSchema(OpenAPISchema): """Result schema for credential attribute MIME type.""" -class RawEncCredAttrSchema(Schema): +class RawEncCredAttrSchema(OpenAPISchema): """Credential attribute schema.""" raw = fields.Str(description="Raw value", example="Alex") @@ -33,7 +34,7 @@ class RawEncCredAttrSchema(Schema): ) -class RevRegSchema(Schema): +class RevRegSchema(OpenAPISchema): """Revocation registry schema.""" accum = fields.Str( @@ -42,7 +43,7 @@ class RevRegSchema(Schema): ) -class WitnessSchema(Schema): +class WitnessSchema(OpenAPISchema): """Witness schema.""" omega = fields.Str( @@ -51,7 +52,7 @@ class WitnessSchema(Schema): ) -class CredentialSchema(Schema): +class CredentialSchema(OpenAPISchema): """Result schema for a credential query.""" schema_id = fields.Str(description="Schema identifier", **INDY_SCHEMA_ID) @@ -72,13 +73,13 @@ class CredentialSchema(Schema): witness = fields.Nested(WitnessSchema) -class CredentialsListSchema(Schema): +class CredentialsListSchema(OpenAPISchema): """Result schema for a credential query.""" results = fields.List(fields.Nested(CredentialSchema())) -class CredentialsListQueryStringSchema(Schema): +class CredentialsListQueryStringSchema(OpenAPISchema): """Parameters and validators for query string in credentials list query.""" start = fields.Int(description="Start index", required=False, **WHOLE_NUM,) @@ -88,7 +89,7 @@ class CredentialsListQueryStringSchema(Schema): wql = fields.Str(description="(JSON) WQL query", required=False, **INDY_WQL,) -class CredIdMatchInfoSchema(Schema): +class CredIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking credential id.""" credential_id = fields.Str( diff --git a/aries_cloudagent/indy/tests/test_indy.py b/aries_cloudagent/indy/tests/test_indy.py index 3138b72262..f988e1a703 100644 --- a/aries_cloudagent/indy/tests/test_indy.py +++ b/aries_cloudagent/indy/tests/test_indy.py @@ -1,3 +1,4 @@ +import pytest from os import makedirs from pathlib import Path from shutil import rmtree @@ -10,6 +11,7 @@ from .. import util as test_module_util +@pytest.mark.indy class TestIndyUtils(AsyncTestCase): TAILS_HASH = "8UW1Sz5cqoUnK9hqQk7nvtKK65t7Chu3ui866J23sFyJ" diff --git a/aries_cloudagent/issuer/indy.py b/aries_cloudagent/issuer/indy.py index a684bb76e7..3b94aa549b 100644 --- a/aries_cloudagent/issuer/indy.py +++ b/aries_cloudagent/issuer/indy.py @@ -220,7 +220,7 @@ async def create_credential( tails_reader_handle, ) except AnoncredsRevocationRegistryFullError: - self.logger.error( + self.logger.warning( f"Revocation registry {revoc_reg_id} is full: cannot create credential" ) raise IssuerRevocationRegistryFullError( diff --git a/aries_cloudagent/ledger/base.py b/aries_cloudagent/ledger/base.py index 3b111dd45e..62032b0417 100644 --- a/aries_cloudagent/ledger/base.py +++ b/aries_cloudagent/ledger/base.py @@ -1,11 +1,14 @@ """Ledger base class.""" -from abc import ABC, abstractmethod, ABCMeta import re + +from abc import ABC, abstractmethod, ABCMeta from typing import Tuple, Sequence from ..issuer.base import BaseIssuer +from .endpoint_type import EndpointType + class BaseLedger(ABC, metaclass=ABCMeta): """Base class for ledger.""" @@ -25,6 +28,11 @@ async def __aenter__(self) -> "BaseLedger": async def __aexit__(self, exc_type, exc, tb): """Context manager exit.""" + @property + @abstractmethod + def type(self) -> str: + """Accessor for the ledger type.""" + @abstractmethod async def get_key_for_did(self, did: str) -> str: """Fetch the verkey for a ledger DID. @@ -34,20 +42,29 @@ async def get_key_for_did(self, did: str) -> str: """ @abstractmethod - async def get_endpoint_for_did(self, did: str) -> str: + async def get_endpoint_for_did( + self, did: str, endpoint_type: EndpointType = EndpointType.ENDPOINT + ) -> str: """Fetch the endpoint for a ledger DID. Args: did: The DID to look up on the ledger or in the cache + endpoint_type: The type of the endpoint (default 'endpoint') """ @abstractmethod - async def update_endpoint_for_did(self, did: str, endpoint: str) -> bool: + async def update_endpoint_for_did( + self, + did: str, + endpoint: str, + endpoint_type: EndpointType = EndpointType.ENDPOINT, + ) -> bool: """Check and update the endpoint on the ledger. Args: did: The ledger DID endpoint: The endpoint address + endpoint_type: The type of the endpoint (default 'endpoint') """ @abstractmethod @@ -64,6 +81,15 @@ async def register_nym( role: For permissioned ledgers, what role should the new DID have. """ + @abstractmethod + async def get_nym_role(self, did: str): + """ + Return the role registered to input public DID on the ledger. + + Args: + did: DID to register on the ledger. + """ + @abstractmethod def nym_to_did(self, nym: str) -> str: """Format a nym with the ledger's DID prefix.""" @@ -144,7 +170,7 @@ async def create_and_send_credential_definition( signature_type: str = None, tag: str = None, support_revocation: bool = False, - ) -> Tuple[str, dict]: + ) -> Tuple[str, dict, bool]: """ Send credential definition to ledger and store relevant key matter in wallet. @@ -155,6 +181,9 @@ async def create_and_send_credential_definition( tag: Optional tag to distinguish multiple credential definitions support_revocation: Optional flag to enable revocation for this cred def + Returns: + Tuple with cred def id, cred def structure, and whether it's novel + """ @abstractmethod diff --git a/aries_cloudagent/ledger/endpoint_type.py b/aries_cloudagent/ledger/endpoint_type.py new file mode 100644 index 0000000000..2f5f7da592 --- /dev/null +++ b/aries_cloudagent/ledger/endpoint_type.py @@ -0,0 +1,36 @@ +"""Ledger utilities.""" + +from collections import namedtuple +from enum import Enum + +EndpointTypeName = namedtuple("EndpointTypeName", "w3c indy") + + +class EndpointType(Enum): + """Enum for endpoint/service types.""" + + ENDPOINT = EndpointTypeName("Endpoint", "endpoint") + PROFILE = EndpointTypeName("Profile", "profile") + LINKED_DOMAINS = EndpointTypeName("LinkedDomains", "linked_domains") + + @staticmethod + def get(name: str) -> "EndpointType": + """Return enum instance corresponding to input string.""" + if name is None: + return None + + for endpoint_type in EndpointType: + if name.replace("_", "").lower() == endpoint_type.w3c.lower(): + return endpoint_type + + return None + + @property + def w3c(self): + """W3C name of endpoint type: externally-facing.""" + return self.value.w3c + + @property + def indy(self): + """Indy name of endpoint type: internally-facing, on ledger and in wallet.""" + return self.value.indy diff --git a/aries_cloudagent/ledger/indy.py b/aries_cloudagent/ledger/indy.py index 72e640db22..f9327936a3 100644 --- a/aries_cloudagent/ledger/indy.py +++ b/aries_cloudagent/ledger/indy.py @@ -27,6 +27,7 @@ from ..wallet.base import BaseWallet, DIDInfo from .base import BaseLedger +from .endpoint_type import EndpointType from .error import ( BadLedgerRequestError, ClosedPoolError, @@ -36,7 +37,6 @@ ) from .util import TAA_ACCEPTED_RECORD_TYPE - GENESIS_TRANSACTION_PATH = tempfile.gettempdir() GENESIS_TRANSACTION_PATH = path.join( GENESIS_TRANSACTION_PATH, "indy_genesis_transactions.txt" @@ -79,7 +79,7 @@ def to_indy_num_str(self) -> str: """ Return (typically, numeric) string value that indy-sdk associates with role. - Recall that None signifies USER and "" signifies role in reset. + Recall that None signifies USER and "" signifies a role undergoing reset. """ return str(self.value[0]) if isinstance(self.value[0], int) else self.value[0] @@ -131,9 +131,14 @@ def __init__( self.taa_cache = None self.read_only = read_only - if wallet.WALLET_TYPE != "indy": + if wallet.type != "indy": raise LedgerConfigError("Wallet type is not 'indy'") + @property + def type(self) -> str: + """Accessor for the ledger type.""" + return IndyLedger.LEDGER_TYPE + async def create_pool_config( self, genesis_transactions: str, recreate: bool = False ): @@ -174,7 +179,9 @@ async def open(self): ): await indy.pool.set_protocol_version(2) - with IndyErrorHandler("Exception when opening pool ledger", LedgerConfigError): + with IndyErrorHandler( + f"Exception when opening pool ledger {self.pool_name}", LedgerConfigError + ): self.pool_handle = await indy.pool.open_pool_ledger(self.pool_name, "{}") self.opened = True @@ -405,17 +412,17 @@ async def create_and_send_schema( else: raise - schema_id_parts = schema_id.split(":") - schema_tags = { - "schema_id": schema_id, - "schema_issuer_did": public_info.did, - "schema_name": schema_id_parts[-2], - "schema_version": schema_id_parts[-1], - "epoch": str(int(time())), - } - record = StorageRecord(SCHEMA_SENT_RECORD_TYPE, schema_id, schema_tags) - storage = self.get_indy_storage() - await storage.add_record(record) + schema_id_parts = schema_id.split(":") + schema_tags = { + "schema_id": schema_id, + "schema_issuer_did": public_info.did, + "schema_name": schema_id_parts[-2], + "schema_version": schema_id_parts[-1], + "epoch": str(int(time())), + } + record = StorageRecord(SCHEMA_SENT_RECORD_TYPE, schema_id, schema_tags) + storage = self.get_indy_storage() + await storage.add_record(record) return schema_id, schema_def @@ -539,7 +546,7 @@ async def create_and_send_credential_definition( signature_type: str = None, tag: str = None, support_revocation: bool = False, - ) -> Tuple[str, dict]: + ) -> Tuple[str, dict, bool]: """ Send credential definition to ledger and store relevant key matter in wallet. @@ -550,6 +557,9 @@ async def create_and_send_credential_definition( tag: Optional tag to distinguish multiple credential definitions support_revocation: Optional flag to enable revocation for this cred def + Returns: + Tuple with cred def id, cred def structure, and whether it's novel + """ public_info = await self.wallet.get_public_did() if not public_info: @@ -561,6 +571,8 @@ async def create_and_send_credential_definition( if not schema: raise LedgerError(f"Ledger {self.pool_name} has no schema {schema_id}") + novel = False + # check if cred def is on ledger already for test_tag in [tag] if tag else ["tag", DEFAULT_CRED_DEF_TAG]: credential_definition_id = issuer.make_credential_definition_id( @@ -602,6 +614,7 @@ async def create_and_send_credential_definition( raise LedgerError(err.message) from err # Cred def is neither on ledger nor in wallet: create and send it + novel = True try: ( credential_definition_id, @@ -617,7 +630,6 @@ async def create_and_send_credential_definition( "Error cannot write cred def when ledger is in read only mode" ) - wallet_cred_def = json.loads(credential_definition_json) with IndyErrorHandler( "Exception when building cred def request", LedgerError ): @@ -625,19 +637,9 @@ async def create_and_send_credential_definition( public_info.did, credential_definition_json ) await self._submit(request_json, True, sign_did=public_info) - ledger_cred_def = await self.fetch_credential_definition( - credential_definition_id - ) - assert wallet_cred_def["value"] == ledger_cred_def["value"] - - # Add non-secrets records if not yet present - storage = self.get_indy_storage() - found = await storage.search_records( - type_filter=CRED_DEF_SENT_RECORD_TYPE, - tag_query={"cred_def_id": credential_definition_id}, - ).fetch_all() - if not found: + # Add non-secrets record + storage = self.get_indy_storage() schema_id_parts = schema_id.split(":") cred_def_tags = { "schema_id": schema_id, @@ -653,7 +655,7 @@ async def create_and_send_credential_definition( ) await storage.add_record(record) - return credential_definition_id, json.loads(credential_definition_json) + return (credential_definition_id, json.loads(credential_definition_json), novel) async def get_credential_definition(self, credential_definition_id: str) -> dict: """ @@ -746,12 +748,41 @@ async def get_key_for_did(self, did: str) -> str: data_json = (json.loads(response_json))["result"]["data"] return json.loads(data_json)["verkey"] if data_json else None - async def get_endpoint_for_did(self, did: str) -> str: + async def get_all_endpoints_for_did(self, did: str) -> dict: + """Fetch all endpoints for a ledger DID. + + Args: + did: The DID to look up on the ledger or in the cache + """ + nym = self.did_to_nym(did) + public_info = await self.wallet.get_public_did() + public_did = public_info.did if public_info else None + with IndyErrorHandler("Exception when building attribute request", LedgerError): + request_json = await indy.ledger.build_get_attrib_request( + public_did, nym, "endpoint", None, None + ) + response_json = await self._submit(request_json, sign_did=public_info) + data_json = json.loads(response_json)["result"]["data"] + + if data_json: + endpoints = json.loads(data_json).get("endpoint", None) + else: + endpoints = None + + return endpoints + + async def get_endpoint_for_did( + self, did: str, endpoint_type: EndpointType = None + ) -> str: """Fetch the endpoint for a ledger DID. Args: did: The DID to look up on the ledger or in the cache + endpoint_type: The type of the endpoint. If none given, returns all """ + + if not endpoint_type: + endpoint_type = EndpointType.ENDPOINT nym = self.did_to_nym(did) public_info = await self.wallet.get_public_did() public_did = public_info.did if public_info else None @@ -763,28 +794,46 @@ async def get_endpoint_for_did(self, did: str) -> str: data_json = json.loads(response_json)["result"]["data"] if data_json: endpoint = json.loads(data_json).get("endpoint", None) - address = endpoint.get("endpoint", None) if endpoint else None + address = endpoint.get(endpoint_type.indy, None) if endpoint else None else: address = None return address - async def update_endpoint_for_did(self, did: str, endpoint: str) -> bool: + async def update_endpoint_for_did( + self, did: str, endpoint: str, endpoint_type: EndpointType = None + ) -> bool: """Check and update the endpoint on the ledger. Args: did: The ledger DID endpoint: The endpoint address + endpoint_type: The type of the endpoint """ - exist_endpoint = await self.get_endpoint_for_did(did) - if exist_endpoint != endpoint: + if not endpoint_type: + endpoint_type = EndpointType.ENDPOINT + + all_exist_endpoints = await self.get_all_endpoints_for_did(did) + exist_endpoint_of_type = ( + all_exist_endpoints.get(endpoint_type.indy, None) + if all_exist_endpoints + else None + ) + + if exist_endpoint_of_type != endpoint: if self.read_only: raise LedgerError( "Error cannot update endpoint when ledger is in read only mode" ) nym = self.did_to_nym(did) - attr_json = json.dumps({"endpoint": {"endpoint": endpoint}}) + + if all_exist_endpoints: + all_exist_endpoints[endpoint_type.indy] = endpoint + attr_json = json.dumps({"endpoint": all_exist_endpoints}) + else: + attr_json = json.dumps({"endpoint": {endpoint_type.indy: endpoint}}) + with IndyErrorHandler( "Exception when building attribute request", LedgerError ): @@ -814,8 +863,33 @@ async def register_nym( public_info = await self.wallet.get_public_did() public_did = public_info.did if public_info else None - r = await indy.ledger.build_nym_request(public_did, did, verkey, alias, role) - await self._submit(r, True, True, sign_did=public_info) + with IndyErrorHandler("Exception when building nym request", LedgerError): + request_json = await indy.ledger.build_nym_request( + public_did, did, verkey, alias, role + ) + + await self._submit(request_json) + + async def get_nym_role(self, did: str) -> Role: + """ + Return the role of the input public DID's NYM on the ledger. + + Args: + did: DID to query for role on the ledger. + """ + public_info = await self.wallet.get_public_did() + public_did = public_info.did if public_info else None + + with IndyErrorHandler("Exception when building get-nym request", LedgerError): + request_json = await indy.ledger.build_get_nym_request(public_did, did) + + response_json = await self._submit(request_json) + response = json.loads(response_json) + nym_data = json.loads(response["result"]["data"]) + if not nym_data: + raise BadLedgerRequestError(f"DID {did} is not public") + + return Role.get(nym_data["role"]) def nym_to_did(self, nym: str) -> str: """Format a nym with the ledger's DID prefix.""" @@ -840,25 +914,29 @@ async def rotate_public_did_keypair(self, next_seed: str = None) -> None: nym = self.did_to_nym(public_did) with IndyErrorHandler("Exception when building nym request", LedgerError): request_json = await indy.ledger.build_get_nym_request(public_did, nym) - response_json = await self._submit(request_json) - data = json.loads((json.loads(response_json))["result"]["data"]) - if not data: - raise BadLedgerRequestError( - f"Ledger has no public DID for wallet {self.wallet.name}" - ) - seq_no = data["seqNo"] + + response_json = await self._submit(request_json) + data = json.loads((json.loads(response_json))["result"]["data"]) + if not data: + raise BadLedgerRequestError( + f"Ledger has no public DID for wallet {self.wallet.name}" + ) + seq_no = data["seqNo"] + + with IndyErrorHandler("Exception when building get-txn request", LedgerError): txn_req_json = await indy.ledger.build_get_txn_request(None, None, seq_no) - txn_resp_json = await self._submit(txn_req_json) - txn_resp = json.loads(txn_resp_json) - txn_resp_data = txn_resp["result"]["data"] - if not txn_resp_data: - raise BadLedgerRequestError( - f"Bad or missing ledger NYM transaction for DID {public_did}" - ) - txn_data_data = txn_resp_data["txn"]["data"] - role_token = Role.get(txn_data_data.get("role")).token() - alias = txn_data_data.get("alias") - await self.register_nym(public_did, verkey, role_token, alias) + + txn_resp_json = await self._submit(txn_req_json) + txn_resp = json.loads(txn_resp_json) + txn_resp_data = txn_resp["result"]["data"] + if not txn_resp_data: + raise BadLedgerRequestError( + f"Bad or missing ledger NYM transaction for DID {public_did}" + ) + txn_data_data = txn_resp_data["txn"]["data"] + role_token = Role.get(txn_data_data.get("role")).token() + alias = txn_data_data.get("alias") + await self.register_nym(public_did, verkey, role_token, alias) # update wallet await self.wallet.rotate_did_keypair_apply(public_did) diff --git a/aries_cloudagent/ledger/provider.py b/aries_cloudagent/ledger/provider.py index a730cb4392..895dbcb401 100644 --- a/aries_cloudagent/ledger/provider.py +++ b/aries_cloudagent/ledger/provider.py @@ -26,7 +26,7 @@ async def provide(self, settings: BaseSettings, injector: BaseInjector): wallet = await injector.inject(BaseWallet) ledger = None - if wallet.WALLET_TYPE == "indy": + if wallet.type == "indy": IndyLedger = ClassLoader.load_class(self.LEDGER_CLASSES["indy"]) cache = await injector.inject(BaseCache, required=False) ledger = IndyLedger( diff --git a/aries_cloudagent/ledger/routes.py b/aries_cloudagent/ledger/routes.py index bf6105865f..f171d249bc 100644 --- a/aries_cloudagent/ledger/routes.py +++ b/aries_cloudagent/ledger/routes.py @@ -3,17 +3,20 @@ from aiohttp import web from aiohttp_apispec import docs, querystring_schema, request_schema, response_schema -from marshmallow import fields, Schema, validate +from marshmallow import fields, validate -from ..messaging.valid import INDY_DID, INDY_RAW_PUBLIC_KEY +from ..messaging.models.openapi import OpenAPISchema +from ..messaging.valid import ENDPOINT_TYPE, INDY_DID, INDY_RAW_PUBLIC_KEY from ..storage.error import StorageError from ..wallet.error import WalletError + from .base import BaseLedger -from .indy import Role +from .endpoint_type import EndpointType from .error import BadLedgerRequestError, LedgerError, LedgerTransactionError +from .indy import Role -class AMLRecordSchema(Schema): +class AMLRecordSchema(OpenAPISchema): """Ledger AML record.""" version = fields.Str() @@ -21,7 +24,7 @@ class AMLRecordSchema(Schema): amlContext = fields.Str() -class TAARecordSchema(Schema): +class TAARecordSchema(OpenAPISchema): """Ledger TAA record.""" version = fields.Str() @@ -29,14 +32,14 @@ class TAARecordSchema(Schema): digest = fields.Str() -class TAAAcceptanceSchema(Schema): +class TAAAcceptanceSchema(OpenAPISchema): """TAA acceptance record.""" mechanism = fields.Str() time = fields.Int() -class TAAInfoSchema(Schema): +class TAAInfoSchema(OpenAPISchema): """Transaction author agreement info.""" aml_record = fields.Nested(AMLRecordSchema()) @@ -45,13 +48,13 @@ class TAAInfoSchema(Schema): taa_accepted = fields.Nested(TAAAcceptanceSchema()) -class TAAResultSchema(Schema): +class TAAResultSchema(OpenAPISchema): """Result schema for a transaction author agreement.""" result = fields.Nested(TAAInfoSchema()) -class TAAAcceptSchema(Schema): +class TAAAcceptSchema(OpenAPISchema): """Input schema for accepting the TAA.""" version = fields.Str() @@ -59,7 +62,7 @@ class TAAAcceptSchema(Schema): mechanism = fields.Str() -class RegisterLedgerNymQueryStringSchema(Schema): +class RegisterLedgerNymQueryStringSchema(OpenAPISchema): """Query string parameters and validators for register ledger nym request.""" did = fields.Str(description="DID to register", required=True, **INDY_DID,) @@ -76,12 +79,25 @@ class RegisterLedgerNymQueryStringSchema(Schema): ) -class QueryStringDIDSchema(Schema): +class QueryStringDIDSchema(OpenAPISchema): """Parameters and validators for query string with DID only.""" did = fields.Str(description="DID of interest", required=True, **INDY_DID) +class QueryStringEndpointSchema(OpenAPISchema): + """Parameters and validators for query string with DID and endpoint type.""" + + did = fields.Str(description="DID of interest", required=True, **INDY_DID) + endpoint_type = fields.Str( + description=( + f"Endpoint type of interest (default '{EndpointType.ENDPOINT.w3c}')" + ), + required=False, + **ENDPOINT_TYPE, + ) + + @docs( tags=["ledger"], summary="Send a NYM registration to the ledger.", ) @@ -120,9 +136,47 @@ async def register_ledger_nym(request: web.BaseRequest): success = True except LedgerTransactionError as err: raise web.HTTPForbidden(reason=err.roll_up) + except LedgerError as err: + raise web.HTTPBadRequest(reason=err.roll_up) + return web.json_response({"success": success}) +@docs( + tags=["ledger"], summary="Get the role from the NYM registration of a public DID.", +) +@querystring_schema(QueryStringDIDSchema) +async def get_nym_role(request: web.BaseRequest): + """ + Request handler for getting the role from the NYM registration of a public DID. + + Args: + request: aiohttp request object + """ + context = request.app["request_context"] + ledger = await context.inject(BaseLedger, required=False) + if not ledger: + reason = "No ledger available" + if not context.settings.get_value("wallet.type"): + reason += ": missing wallet-type?" + raise web.HTTPForbidden(reason=reason) + + did = request.query.get("did") + if not did: + raise web.HTTPBadRequest(reason="Request query must include DID") + + async with ledger: + try: + role = await ledger.get_nym_role(did) + except LedgerTransactionError as err: + raise web.HTTPForbidden(reason=err.roll_up) + except BadLedgerRequestError as err: + raise web.HTTPNotFound(reason=err.roll_up) + except LedgerError as err: + raise web.HTTPBadRequest(reason=err.roll_up) + return web.json_response({"role": role.name}) + + @docs(tags=["ledger"], summary="Rotate key pair for public DID.") async def rotate_public_did_keypair(request: web.BaseRequest): """ @@ -184,7 +238,7 @@ async def get_did_verkey(request: web.BaseRequest): @docs( tags=["ledger"], summary="Get the endpoint for a DID from the ledger.", ) -@querystring_schema(QueryStringDIDSchema()) +@querystring_schema(QueryStringEndpointSchema()) async def get_did_endpoint(request: web.BaseRequest): """ Request handler for getting a verkey for a DID from the ledger. @@ -201,12 +255,16 @@ async def get_did_endpoint(request: web.BaseRequest): raise web.HTTPForbidden(reason=reason) did = request.query.get("did") + endpoint_type = EndpointType.get( + request.query.get("endpoint_type", EndpointType.ENDPOINT.w3c) + ) + if not did: raise web.HTTPBadRequest(reason="Request query must include DID") async with ledger: try: - r = await ledger.get_endpoint_for_did(did) + r = await ledger.get_endpoint_for_did(did, endpoint_type) except LedgerError as err: raise web.HTTPBadRequest(reason=err.roll_up) from err @@ -228,7 +286,7 @@ async def ledger_get_taa(request: web.BaseRequest): """ context = request.app["request_context"] ledger: BaseLedger = await context.inject(BaseLedger, required=False) - if not ledger or ledger.LEDGER_TYPE != "indy": + if not ledger or ledger.type != "indy": reason = "No indy ledger available" if not context.settings.get_value("wallet.type"): reason += ": missing wallet-type?" @@ -267,7 +325,7 @@ async def ledger_accept_taa(request: web.BaseRequest): """ context = request.app["request_context"] ledger: BaseLedger = await context.inject(BaseLedger, required=False) - if not ledger or ledger.LEDGER_TYPE != "indy": + if not ledger or ledger.type != "indy": reason = "No indy ledger available" if not context.settings.get_value("wallet.type"): reason += ": missing wallet-type?" @@ -303,6 +361,7 @@ async def register(app: web.Application): app.add_routes( [ web.post("/ledger/register-nym", register_ledger_nym), + web.get("/ledger/get-nym-role", get_nym_role, allow_head=False), web.patch("/ledger/rotate-public-did-keypair", rotate_public_did_keypair), web.get("/ledger/did-verkey", get_did_verkey, allow_head=False), web.get("/ledger/did-endpoint", get_did_endpoint, allow_head=False), diff --git a/aries_cloudagent/ledger/tests/test_endpoint_type.py b/aries_cloudagent/ledger/tests/test_endpoint_type.py new file mode 100644 index 0000000000..7287cce995 --- /dev/null +++ b/aries_cloudagent/ledger/tests/test_endpoint_type.py @@ -0,0 +1,19 @@ +from asynctest import TestCase as AsyncTestCase + +from ..endpoint_type import EndpointType + + +class TestEndpointType(AsyncTestCase): + async def test_endpoint_type(self): + assert EndpointType.ENDPOINT == EndpointType.get("endpoint") + assert EndpointType.PROFILE == EndpointType.get("PROFILE") + assert EndpointType.LINKED_DOMAINS == EndpointType.get("linked_domains") + assert EndpointType.get("no-such-type") is None + assert EndpointType.get(None) is None + + assert EndpointType.PROFILE.w3c == "Profile" + assert EndpointType.PROFILE.indy == "profile" + assert EndpointType.ENDPOINT.w3c == "Endpoint" + assert EndpointType.ENDPOINT.indy == "endpoint" + assert EndpointType.LINKED_DOMAINS.w3c == "LinkedDomains" + assert EndpointType.LINKED_DOMAINS.indy == "linked_domains" diff --git a/aries_cloudagent/ledger/tests/test_indy.py b/aries_cloudagent/ledger/tests/test_indy.py index 859d40bc2a..a19373dac5 100644 --- a/aries_cloudagent/ledger/tests/test_indy.py +++ b/aries_cloudagent/ledger/tests/test_indy.py @@ -5,9 +5,10 @@ from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock -from aries_cloudagent.cache.basic import BasicCache -from aries_cloudagent.issuer.base import BaseIssuer, IssuerError -from aries_cloudagent.ledger.indy import ( +from ...cache.basic import BasicCache +from ...issuer.base import BaseIssuer, IssuerError +from ...ledger.endpoint_type import EndpointType +from ...ledger.indy import ( BadLedgerRequestError, ClosedPoolError, ErrorCode, @@ -21,9 +22,9 @@ Role, TAA_ACCEPTED_RECORD_TYPE, ) -from aries_cloudagent.storage.indy import IndyStorage -from aries_cloudagent.storage.record import StorageRecord -from aries_cloudagent.wallet.base import DIDInfo +from ...storage.indy import IndyStorage +from ...storage.record import StorageRecord +from ...wallet.base import DIDInfo class TestRole(AsyncTestCase): @@ -70,7 +71,7 @@ async def test_init(self, mock_open, mock_create_config): mock_open.return_value = async_mock.MagicMock() mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) assert ledger.pool_name == "name" @@ -94,8 +95,9 @@ async def test_init_do_not_recreate(self, mock_open, mock_list_pools): mock_list_pools.return_value = [{"pool": "name"}, {"pool": "another"}] mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) + assert ledger.type == "indy" assert ledger.pool_name == "name" assert ledger.wallet is mock_wallet @@ -117,7 +119,7 @@ async def test_init_recreate( mock_delete_config.return_value = None mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) assert ledger.pool_name == "name" @@ -133,7 +135,7 @@ async def test_init_recreate( async def test_init_non_indy(self): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "non-indy" + mock_wallet.type = "non-indy" with self.assertRaises(LedgerConfigError): IndyLedger("name", mock_wallet) @@ -144,7 +146,7 @@ async def test_aenter_aexit( self, mock_close_pool, mock_open_ledger, mock_set_proto ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) async with ledger as led: @@ -164,7 +166,7 @@ async def test_aenter_aexit_nested_keepalive( self, mock_close_pool, mock_open_ledger, mock_set_proto ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet, keepalive=1) async with ledger as led0: @@ -191,7 +193,7 @@ async def test_aenter_aexit_close_x( self, mock_close_pool, mock_open_ledger, mock_set_proto ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_close_pool.side_effect = IndyError(ErrorCode.PoolLedgerTimeout) ledger = IndyLedger("name", mock_wallet) @@ -210,7 +212,7 @@ async def test_submit_pool_closed( self, mock_close_pool, mock_open_ledger, mock_create_config, mock_set_proto ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) with self.assertRaises(ClosedPoolError) as context: @@ -234,7 +236,7 @@ async def test_submit_signed( mock_sign_submit.return_value = '{"op": "REPLY"}' mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) @@ -277,7 +279,7 @@ async def test_submit_signed_taa_accept( mock_sign_submit.return_value = '{"op": "REPLY"}' mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) ledger.get_latest_txn_author_acceptance = async_mock.CoroutineMock( @@ -332,7 +334,7 @@ async def test_submit_unsigned( mock_submit.return_value = '{"op": "REPLY"}' mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did.return_value = future ledger = IndyLedger("name", mock_wallet) @@ -366,7 +368,7 @@ async def test_submit_unsigned_ledger_transaction_error( mock_submit.return_value = '{"op": "NO-SUCH-OP"}' mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did.return_value = future ledger = IndyLedger("name", mock_wallet) @@ -401,7 +403,7 @@ async def test_submit_rejected( mock_submit.return_value = '{"op": "REQNACK", "reason": "a reason"}' mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did.return_value = future ledger = IndyLedger("name", mock_wallet) @@ -414,7 +416,7 @@ async def test_submit_rejected( mock_submit.return_value = '{"op": "REJECT", "reason": "another reason"}' mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did.return_value = future ledger = IndyLedger("name", mock_wallet) @@ -442,7 +444,7 @@ async def test_send_schema( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" issuer = async_mock.MagicMock(BaseIssuer) ledger = IndyLedger("name", mock_wallet) @@ -512,23 +514,36 @@ async def test_send_schema_already_exists( # mock_did = async_mock.CoroutineMock() mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_wallet.get_public_did.return_value.did = "abc" - fetch_schema_id = f"{mock_wallet.get_public_did.return_value.did}:{2}:schema_name:schema_version" + fetch_schema_id = ( + f"{mock_wallet.get_public_did.return_value.did}:2:" + "schema_name:schema_version" + ) mock_check_existing.return_value = (fetch_schema_id, {}) issuer = async_mock.MagicMock(BaseIssuer) issuer.create_and_store_schema.return_value = ("1", "{}") ledger = IndyLedger("name", mock_wallet) - async with ledger: - schema_id, schema_def = await ledger.create_and_send_schema( - issuer, "schema_name", "schema_version", [1, 2, 3] + with async_mock.patch.object( + ledger, "get_indy_storage", async_mock.MagicMock() + ) as mock_get_storage: + mock_add_record = async_mock.CoroutineMock() + mock_get_storage.return_value = async_mock.MagicMock( + add_record=mock_add_record ) - assert schema_id == fetch_schema_id - assert schema_def == {} + + async with ledger: + schema_id, schema_def = await ledger.create_and_send_schema( + issuer, "schema_name", "schema_version", [1, 2, 3] + ) + assert schema_id == fetch_schema_id + assert schema_def == {} + + mock_add_record.assert_not_called() @async_mock.patch("indy.pool.set_protocol_version") @async_mock.patch("indy.pool.create_pool_ledger_config") @@ -549,11 +564,14 @@ async def test_send_schema_ledger_transaction_error_already_exists( ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_wallet.get_public_did.return_value.did = "abc" - fetch_schema_id = f"{mock_wallet.get_public_did.return_value.did}:{2}:schema_name:schema_version" + fetch_schema_id = ( + f"{mock_wallet.get_public_did.return_value.did}:2:" + "schema_name:schema_version" + ) mock_check_existing.side_effect = [None, (fetch_schema_id, "{}")] issuer = async_mock.MagicMock(BaseIssuer) @@ -584,12 +602,12 @@ async def test_send_schema_ledger_read_only( ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_wallet.get_public_did.return_value.did = "abc" fetch_schema_id = ( - f"{mock_wallet.get_public_did.return_value.did}:{2}:" + f"{mock_wallet.get_public_did.return_value.did}:2:" "schema_name:schema_version" ) mock_check_existing.side_effect = [None, fetch_schema_id] @@ -620,12 +638,12 @@ async def test_send_schema_issuer_error( ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_wallet.get_public_did.return_value.did = "abc" fetch_schema_id = ( - f"{mock_wallet.get_public_did.return_value.did}:{2}:" + f"{mock_wallet.get_public_did.return_value.did}:2:" "schema_name:schema_version" ) mock_check_existing.side_effect = [None, fetch_schema_id] @@ -662,12 +680,12 @@ async def test_send_schema_ledger_transaction_error( ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_wallet.get_public_did.return_value.did = "abc" fetch_schema_id = ( - f"{mock_wallet.get_public_did.return_value.did}:{2}:" + f"{mock_wallet.get_public_did.return_value.did}:2:" "schema_name:schema_version" ) mock_check_existing.side_effect = [None, fetch_schema_id] @@ -703,7 +721,7 @@ async def test_send_schema_no_seq_no( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" issuer = async_mock.MagicMock(BaseIssuer) ledger = IndyLedger("name", mock_wallet) @@ -737,7 +755,7 @@ async def test_check_existing_schema( self, mock_fetch_schema_by_id, mock_close, mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value mock_did.did = self.test_did @@ -776,7 +794,7 @@ async def test_get_schema( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value mock_did.did = self.test_did @@ -810,7 +828,7 @@ async def test_get_schema_not_found( self, mock_build_get_schema_req, mock_submit, mock_close, mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value mock_did.did = self.test_did @@ -846,7 +864,7 @@ async def test_get_schema_by_seq_no( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value mock_did.did = self.test_did @@ -913,7 +931,7 @@ async def test_get_schema_by_wrong_seq_no( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value mock_did.did = self.test_did @@ -956,7 +974,7 @@ async def test_send_credential_definition( mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_search_records.return_value.fetch_all = async_mock.CoroutineMock( return_value=[] @@ -1006,16 +1024,108 @@ async def test_send_credential_definition( ) mock_did = mock_wallet.get_public_did.return_value - result_id, result_def = await ledger.create_and_send_credential_definition( + ( + result_id, + result_def, + novel, + ) = await ledger.create_and_send_credential_definition( issuer, schema_id, None, tag ) assert result_id == cred_def_id + assert novel mock_wallet.get_public_did.assert_called_once_with() mock_get_schema.assert_called_once_with(schema_id) mock_build_cred_def.assert_called_once_with(mock_did.did, cred_def_json) + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger.get_schema") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch( + "aries_cloudagent.ledger.indy.IndyLedger.fetch_credential_definition" + ) + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + @async_mock.patch("aries_cloudagent.storage.indy.IndyStorage.search_records") + @async_mock.patch("aries_cloudagent.storage.indy.IndyStorage.add_record") + @async_mock.patch("indy.ledger.build_cred_def_request") + async def test_send_credential_definition_exists_in_ledger_and_wallet( + self, + mock_build_cred_def, + mock_add_record, + mock_search_records, + mock_submit, + mock_fetch_cred_def, + mock_close, + mock_open, + mock_get_schema, + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + mock_search_records.return_value.fetch_all = async_mock.CoroutineMock( + return_value=[] + ) + + mock_get_schema.return_value = {"seqNo": 999} + cred_def_id = f"{self.test_did}:3:CL:999:default" + cred_def_value = { + "primary": {"n": "...", "s": "...", "r": "...", "revocation": None} + } + cred_def = { + "ver": "1.0", + "id": cred_def_id, + "schemaId": "999", + "type": "CL", + "tag": "default", + "value": cred_def_value, + } + cred_def_json = json.dumps(cred_def) + + mock_fetch_cred_def.return_value = {"mock": "cred-def"} + + issuer = async_mock.MagicMock(BaseIssuer) + issuer.make_credential_definition_id.return_value = cred_def_id + issuer.create_and_store_credential_definition.return_value = ( + cred_def_id, + cred_def_json, + ) + issuer.credential_definition_in_wallet.return_value = True + ledger = IndyLedger("name", mock_wallet) + + schema_id = "schema_issuer_did:name:1.0" + tag = "default" + + with async_mock.patch.object( + ledger, "get_indy_storage", async_mock.MagicMock() + ) as mock_get_storage: + mock_get_storage.return_value = async_mock.MagicMock( + add_record=async_mock.CoroutineMock() + ) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock() + mock_wallet.get_public_did.return_value = DIDInfo( + self.test_did, self.test_verkey, None + ) + mock_did = mock_wallet.get_public_did.return_value + + ( + result_id, + result_def, + novel, + ) = await ledger.create_and_send_credential_definition( + issuer, schema_id, None, tag + ) + assert result_id == cred_def_id + assert not novel + + mock_wallet.get_public_did.assert_called_once_with() + mock_get_schema.assert_called_once_with(schema_id) + + mock_build_cred_def.assert_not_called() + mock_get_storage.assert_not_called() + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger.get_schema") @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") @@ -1023,7 +1133,7 @@ async def test_send_credential_definition_no_such_schema( self, mock_close, mock_open, mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_get_schema.return_value = {} @@ -1063,7 +1173,7 @@ async def test_send_credential_definition_offer_exception( mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_search_records.return_value.fetch_all = async_mock.CoroutineMock( return_value=[] @@ -1098,7 +1208,7 @@ async def test_send_credential_definition_cred_def_in_wallet_not_ledger( self, mock_fetch_cred_def, mock_close, mock_open, mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_get_schema.return_value = {"seqNo": 999} cred_def_id = f"{self.test_did}:3:CL:999:default" @@ -1141,7 +1251,7 @@ async def test_send_credential_definition_cred_def_not_on_ledger_wallet_check_x( self, mock_fetch_cred_def, mock_close, mock_open, mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_get_schema.return_value = {"seqNo": 999} cred_def_id = f"{self.test_did}:3:CL:999:default" @@ -1188,7 +1298,7 @@ async def test_send_credential_definition_cred_def_not_on_ledger_nor_wallet_send self, mock_fetch_cred_def, mock_close, mock_open, mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_get_schema.return_value = {"seqNo": 999} cred_def_id = f"{self.test_did}:3:CL:999:default" @@ -1238,7 +1348,7 @@ async def test_send_credential_definition_read_only( self, mock_fetch_cred_def, mock_close, mock_open, mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_get_schema.return_value = {"seqNo": 999} cred_def_id = f"{self.test_did}:3:CL:999:default" @@ -1288,7 +1398,7 @@ async def test_send_credential_definition_cred_def_on_ledger_not_in_wallet( self, mock_fetch_cred_def, mock_close, mock_open, mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_get_schema.return_value = {"seqNo": 999} cred_def_id = f"{self.test_did}:3:CL:999:default" @@ -1346,7 +1456,7 @@ async def test_send_credential_definition_on_ledger_in_wallet( mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_search_records.return_value.fetch_all = async_mock.CoroutineMock( return_value=[] @@ -1395,7 +1505,11 @@ async def test_send_credential_definition_on_ledger_in_wallet( ) mock_did = mock_wallet.get_public_did.return_value - result_id, result_def = await ledger.create_and_send_credential_definition( + ( + result_id, + result_def, + novel, + ) = await ledger.create_and_send_credential_definition( issuer, schema_id, None, tag ) assert result_id == cred_def_id @@ -1427,7 +1541,7 @@ async def test_send_credential_definition_create_cred_def_exception( mock_get_schema, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_search_records.return_value.fetch_all = async_mock.CoroutineMock( return_value=[] @@ -1454,9 +1568,6 @@ async def test_send_credential_definition_create_cred_def_exception( issuer.create_and_store_credential_definition.side_effect = IssuerError( "invalid structure" ) - # issuer.credential_definition_in_wallet.side_effect = IndyError( - # error_code=ErrorCode.CommonInvalidStructure - # ) ledger = IndyLedger("name", mock_wallet) schema_id = "schema_issuer_did:name:1.0" @@ -1487,7 +1598,7 @@ async def test_get_credential_definition( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value @@ -1533,7 +1644,7 @@ async def test_get_credential_definition_ledger_not_found( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value @@ -1573,7 +1684,7 @@ async def test_fetch_credential_definition_ledger_x( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_wallet.get_public_did = async_mock.CoroutineMock() mock_did = mock_wallet.get_public_did.return_value @@ -1597,7 +1708,7 @@ async def test_get_key_for_did( self, mock_submit, mock_build_get_nym_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_submit.return_value = json.dumps( {"result": {"data": json.dumps({"verkey": self.test_verkey})}} @@ -1627,7 +1738,7 @@ async def test_get_endpoint_for_did( self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" endpoint = "http://aries.ca" mock_submit.return_value = json.dumps( @@ -1650,6 +1761,107 @@ async def test_get_endpoint_for_did( ) assert response == endpoint + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_attrib_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_get_endpoint_of_type_profile_for_did( + self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + endpoint = "http://company.com/masterdata" + endpoint_type = EndpointType.PROFILE + mock_submit.return_value = json.dumps( + { + "result": { + "data": json.dumps( + {"endpoint": {EndpointType.PROFILE.indy: endpoint}} + ) + } + } + ) + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + response = await ledger.get_endpoint_for_did(self.test_did, endpoint_type) + + assert mock_build_get_attrib_req.called_once_with( + self.test_did, ledger.did_to_nym(self.test_did), "endpoint", None, None + ) + assert mock_submit.called_once_with( + mock_build_get_attrib_req.return_value, + sign_did=mock_wallet.get_public_did.return_value, + ) + assert response == endpoint + + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_attrib_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_get_all_endpoints_for_did( + self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + profile_endpoint = "http://company.com/masterdata" + default_endpoint = "http://agent.company.com" + data_json = json.dumps( + {"endpoint": {"endpoint": default_endpoint, "profile": profile_endpoint}} + ) + mock_submit.return_value = json.dumps({"result": {"data": data_json}}) + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + response = await ledger.get_all_endpoints_for_did(self.test_did) + + assert mock_build_get_attrib_req.called_once_with( + self.test_did, ledger.did_to_nym(self.test_did), "endpoint", None, None + ) + assert mock_submit.called_once_with( + mock_build_get_attrib_req.return_value, + sign_did=mock_wallet.get_public_did.return_value, + ) + assert response == json.loads(data_json).get("endpoint") + + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_attrib_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_get_all_endpoints_for_did_none( + self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + profile_endpoint = "http://company.com/masterdata" + default_endpoint = "http://agent.company.com" + mock_submit.return_value = json.dumps({"result": {"data": None}}) + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + response = await ledger.get_all_endpoints_for_did(self.test_did) + + assert mock_build_get_attrib_req.called_once_with( + self.test_did, ledger.did_to_nym(self.test_did), "endpoint", None, None + ) + assert mock_submit.called_once_with( + mock_build_get_attrib_req.return_value, + sign_did=mock_wallet.get_public_did.return_value, + ) + assert response is None + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") @async_mock.patch("indy.ledger.build_get_attrib_request") @@ -1658,7 +1870,7 @@ async def test_get_endpoint_for_did_address_none( self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_submit.return_value = json.dumps( {"result": {"data": json.dumps({"endpoint": None})}} @@ -1688,7 +1900,7 @@ async def test_get_endpoint_for_did_no_endpoint( self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_submit.return_value = json.dumps({"result": {"data": None}}) ledger = IndyLedger("name", mock_wallet) @@ -1722,7 +1934,7 @@ async def test_update_endpoint_for_did( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" endpoint = ["http://old.aries.ca", "http://new.aries.ca"] mock_submit.side_effect = [ @@ -1757,6 +1969,101 @@ async def test_update_endpoint_for_did( ) assert response + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_attrib_request") + @async_mock.patch("indy.ledger.build_attrib_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_update_endpoint_for_did_no_prior_endpoints( + self, + mock_submit, + mock_build_attrib_req, + mock_build_get_attrib_req, + mock_close, + mock_open, + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + endpoint = "http://new.aries.ca" + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + with async_mock.patch.object( + ledger, "get_all_endpoints_for_did", async_mock.CoroutineMock() + ) as mock_get_all: + mock_get_all.return_value = None + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + response = await ledger.update_endpoint_for_did(self.test_did, endpoint) + + assert mock_build_get_attrib_req.called_once_with( + self.test_did, + ledger.did_to_nym(self.test_did), + "endpoint", + None, + None, + ) + mock_submit.assert_has_calls( + [async_mock.call(mock_build_attrib_req.return_value, True, True),] + ) + assert response + + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_attrib_request") + @async_mock.patch("indy.ledger.build_attrib_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_update_endpoint_of_type_profile_for_did( + self, + mock_submit, + mock_build_attrib_req, + mock_build_get_attrib_req, + mock_close, + mock_open, + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + endpoint = ["http://company.com/oldProfile", "http://company.com/newProfile"] + endpoint_type = EndpointType.PROFILE + mock_submit.side_effect = [ + json.dumps( + { + "result": { + "data": json.dumps( + {"endpoint": {endpoint_type.indy: endpoint[i]}} + ) + } + } + ) + for i in range(len(endpoint)) + ] + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + response = await ledger.update_endpoint_for_did( + self.test_did, endpoint[1], endpoint_type + ) + + assert mock_build_get_attrib_req.called_once_with( + self.test_did, ledger.did_to_nym(self.test_did), "endpoint", None, None + ) + mock_submit.assert_has_calls( + [ + async_mock.call( + mock_build_get_attrib_req.return_value, + sign_did=mock_wallet.get_public_did.return_value, + ), + async_mock.call(mock_build_attrib_req.return_value, True, True), + ] + ) + assert response + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") @async_mock.patch("indy.ledger.build_get_attrib_request") @@ -1765,7 +2072,7 @@ async def test_update_endpoint_for_did_duplicate( self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" endpoint = "http://aries.ca" mock_submit.return_value = json.dumps( @@ -1796,7 +2103,7 @@ async def test_update_endpoint_for_did_read_only( self, mock_submit, mock_build_get_attrib_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" endpoint = "http://aries.ca" mock_submit.return_value = json.dumps( @@ -1820,7 +2127,7 @@ async def test_register_nym( self, mock_submit, mock_build_nym_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) @@ -1840,6 +2147,32 @@ async def test_register_nym( sign_did=mock_wallet.get_public_did.return_value, ) + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_nym_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_register_nym_ledger_x( + self, mock_submit, mock_build_nym_req, mock_close, mock_open + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + mock_build_nym_req.side_effect = IndyError( + error_code=ErrorCode.CommonInvalidParam1, + error_details={"message": "not today"}, + ) + + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + with self.assertRaises(LedgerError): + await ledger.register_nym( + self.test_did, self.test_verkey, "alias", None + ) + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") @async_mock.patch("indy.ledger.build_nym_request") @@ -1848,7 +2181,7 @@ async def test_register_nym_read_only( self, mock_submit, mock_build_nym_req, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -1862,6 +2195,139 @@ async def test_register_nym_read_only( ) assert "read only" in str(context.exception) + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_nym_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_get_nym_role( + self, mock_submit, mock_build_get_nym_req, mock_close, mock_open + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + mock_submit.return_value = json.dumps( + { + "result": { + "dest": "GjZWsBLgZCR18aL468JAT7w9CZRiBnpxUPPgyQxh4voa", + "txnTime": 1597858571, + "reqId": 1597858571783588400, + "state_proof": { + "root_hash": "7K26MUQt8E2X1vsRJUmc2298VtY8YC5BSDfT5CRJeUDi", + "proof_nodes": "+QHo...", + "multi_signature": { + "participants": ["Node4", "Node3", "Node2"], + "value": { + "state_root_hash": "7K2...", + "pool_state_root_hash": "GT8...", + "ledger_id": 1, + "txn_root_hash": "Hnr...", + "timestamp": 1597858571, + }, + "signature": "QuX...", + }, + }, + "data": json.dumps( + { + "dest": "GjZWsBLgZCR18aL468JAT7w9CZRiBnpxUPPgyQxh4voa", + "identifier": "V4SGRU86Z58d6TV7PBUe6f", + "role": 101, + "seqNo": 11, + "txnTime": 1597858571, + "verkey": "GjZWsBLgZCR18aL468JAT7w9CZRiBnpxUPPgyQxh4voa", + } + ), + "seqNo": 11, + "identifier": "GjZWsBLgZCR18aL468JAT7w9CZRiBnpxUPPgyQxh4voa", + "type": "105", + }, + "op": "REPLY", + } + ) + + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + assert await ledger.get_nym_role(self.test_did) == Role.ENDORSER + + assert mock_build_get_nym_req.called_once_with(self.test_did, self.test_did) + assert mock_submit.called_once_with(mock_build_get_nym_req.return_value) + + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_nym_request") + async def test_get_nym_role_indy_x( + self, mock_build_get_nym_req, mock_close, mock_open + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + mock_build_get_nym_req.side_effect = IndyError( + error_code=ErrorCode.CommonInvalidParam1, + error_details={"message": "not today"}, + ) + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + + with self.assertRaises(LedgerError) as context: + await ledger.get_nym_role(self.test_did) + assert "not today" in context.exception.message + + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") + @async_mock.patch("indy.ledger.build_get_nym_request") + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._submit") + async def test_get_nym_role_did_not_public_x( + self, mock_submit, mock_build_get_nym_req, mock_close, mock_open + ): + mock_wallet = async_mock.MagicMock() + mock_wallet.type = "indy" + + mock_submit.return_value = json.dumps( + { + "result": { + "dest": "GjZWsBLgZCR18aL468JAT7w9CZRiBnpxUPPgyQxh4voa", + "txnTime": 1597858571, + "reqId": 1597858571783588400, + "state_proof": { + "root_hash": "7K26MUQt8E2X1vsRJUmc2298VtY8YC5BSDfT5CRJeUDi", + "proof_nodes": "+QHo...", + "multi_signature": { + "participants": ["Node4", "Node3", "Node2"], + "value": { + "state_root_hash": "7K2...", + "pool_state_root_hash": "GT8...", + "ledger_id": 1, + "txn_root_hash": "Hnr...", + "timestamp": 1597858571, + }, + "signature": "QuX...", + }, + }, + "data": json.dumps(None), + "seqNo": 11, + "identifier": "GjZWsBLgZCR18aL468JAT7w9CZRiBnpxUPPgyQxh4voa", + "type": "105", + }, + "op": "REPLY", + } + ) + + ledger = IndyLedger("name", mock_wallet) + + async with ledger: + mock_wallet.get_public_did = async_mock.CoroutineMock( + return_value=self.test_did_info + ) + with self.assertRaises(BadLedgerRequestError): + await ledger.get_nym_role(self.test_did) + @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_open") @async_mock.patch("aries_cloudagent.ledger.indy.IndyLedger._context_close") @async_mock.patch("indy.ledger.build_get_nym_request") @@ -1884,7 +2350,7 @@ async def test_rotate_public_did_keypair( ), rotate_did_keypair_apply=async_mock.CoroutineMock(return_value=None), ) - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_submit.side_effect = [ json.dumps({"result": {"data": json.dumps({"seqNo": 1234})}}), json.dumps( @@ -1914,7 +2380,7 @@ async def test_rotate_public_did_keypair_no_nym( ), rotate_did_keypair_apply=async_mock.CoroutineMock(return_value=None), ) - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_submit.return_value = json.dumps({"result": {"data": json.dumps(None)}}) ledger = IndyLedger("name", mock_wallet) @@ -1944,7 +2410,7 @@ async def test_rotate_public_did_keypair_corrupt_nym_txn( ), rotate_did_keypair_apply=async_mock.CoroutineMock(return_value=None), ) - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_submit.side_effect = [ json.dumps({"result": {"data": json.dumps({"seqNo": 1234})}}), json.dumps({"result": {"data": None}}), @@ -1969,7 +2435,7 @@ async def test_get_revoc_reg_def( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_parse_get_rrdef_resp.return_value = ("rr-id", '{"hello": "world"}') ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -1990,7 +2456,7 @@ async def test_get_revoc_reg_def_indy_x( self, mock_indy_build_get_rrdef_req, mock_submit, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_build_get_rrdef_req.side_effect = IndyError( error_code=ErrorCode.CommonInvalidParam1, error_details={"message": "not today"}, @@ -2021,7 +2487,7 @@ async def test_get_revoc_reg_entry( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_parse_get_rr_resp.return_value = ( "rr-id", '{"hello": "world"}', @@ -2052,7 +2518,7 @@ async def test_get_revoc_reg_delta( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_parse_get_rrd_resp.return_value = ( "rr-id", '{"hello": "world"}', @@ -2077,7 +2543,7 @@ async def test_send_revoc_reg_def_public_did( self, mock_indy_build_rrdef_req, mock_submit, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_build_rrdef_req.return_value = '{"hello": "world"}' ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -2101,7 +2567,7 @@ async def test_send_revoc_reg_def_local_did( self, mock_indy_build_rrdef_req, mock_submit, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_build_rrdef_req.return_value = '{"hello": "world"}' ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -2125,7 +2591,7 @@ async def test_send_revoc_reg_def_x_no_did( self, mock_indy_build_rrdef_req, mock_submit, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_build_rrdef_req.return_value = '{"hello": "world"}' ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -2146,7 +2612,7 @@ async def test_send_revoc_reg_entry_public_did( self, mock_indy_build_rre_req, mock_submit, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_build_rre_req.return_value = '{"hello": "world"}' ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -2172,7 +2638,7 @@ async def test_send_revoc_reg_entry_local_did( self, mock_indy_build_rre_req, mock_submit, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_build_rre_req.return_value = '{"hello": "world"}' ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -2198,7 +2664,7 @@ async def test_send_revoc_reg_entry_x_no_did( self, mock_indy_build_rre_req, mock_submit, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" mock_indy_build_rre_req.return_value = '{"hello": "world"}' ledger = IndyLedger("name", mock_wallet, read_only=True) @@ -2219,7 +2685,7 @@ async def test_taa_digest_bad_value( self, mock_close_pool, mock_open_ledger, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet) @@ -2245,7 +2711,7 @@ async def test_get_txn_author_agreement( mock_open, ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" txn_result_data = {"text": "text", "version": "1.0"} mock_submit.side_effect = [ @@ -2295,7 +2761,7 @@ async def test_accept_and_get_latest_txn_author_agreement( self, mock_search_records, mock_add_record, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet, cache=BasicCache()) @@ -2340,7 +2806,7 @@ async def test_get_latest_txn_author_agreement_none( self, mock_search_records, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" ledger = IndyLedger("name", mock_wallet, cache=BasicCache()) @@ -2360,7 +2826,7 @@ async def test_credential_definition_id2schema_id( self, mock_get_schema, mock_close, mock_open ): mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" S_ID = f"{TestIndyLedger.test_did}:2:favourite_drink:1.0" SEQ_NO = "9999" diff --git a/aries_cloudagent/ledger/tests/test_provider.py b/aries_cloudagent/ledger/tests/test_provider.py index ab5ec8b33b..2e1e7fb836 100644 --- a/aries_cloudagent/ledger/tests/test_provider.py +++ b/aries_cloudagent/ledger/tests/test_provider.py @@ -29,7 +29,7 @@ async def test_provide( context = InjectionContext(enforce_typing=False) mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" context.injector.bind_instance(BaseWallet, mock_wallet) result = await provider.provide( @@ -53,7 +53,7 @@ async def test_provide_no_pool_config(self, mock_open, mock_list_pools): context = InjectionContext(enforce_typing=False) mock_wallet = async_mock.MagicMock() - mock_wallet.WALLET_TYPE = "indy" + mock_wallet.type = "indy" context.injector.bind_instance(BaseWallet, mock_wallet) result = await provider.provide( diff --git a/aries_cloudagent/ledger/tests/test_routes.py b/aries_cloudagent/ledger/tests/test_routes.py index a040bcda86..a073216c54 100644 --- a/aries_cloudagent/ledger/tests/test_routes.py +++ b/aries_cloudagent/ledger/tests/test_routes.py @@ -4,8 +4,10 @@ from ...config.injection_context import InjectionContext from ...ledger.base import BaseLedger +from ...ledger.endpoint_type import EndpointType from .. import routes as test_module +from ..indy import Role class TestLedgerRoutes(AsyncTestCase): @@ -21,6 +23,8 @@ def setUp(self): self.test_did = "did" self.test_verkey = "verkey" self.test_endpoint = "http://localhost:8021" + self.test_endpoint_type = EndpointType.PROFILE + self.test_endpoint_type_profile = "http://company.com/profile" async def test_missing_ledger(self): request = async_mock.MagicMock(app=self.app,) @@ -29,6 +33,9 @@ async def test_missing_ledger(self): with self.assertRaises(test_module.web.HTTPForbidden): await test_module.register_ledger_nym(request) + with self.assertRaises(test_module.web.HTTPForbidden): + await test_module.get_nym_role(request) + with self.assertRaises(test_module.web.HTTPForbidden): await test_module.rotate_public_did_keypair(request) @@ -89,6 +96,25 @@ async def test_get_endpoint(self): ) assert result is json_response.return_value + async def test_get_endpoint_of_type_profile(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = { + "did": self.test_did, + "endpoint_type": self.test_endpoint_type.w3c, + } + with async_mock.patch.object( + test_module.web, "json_response", async_mock.Mock() + ) as json_response: + self.ledger.get_endpoint_for_did.return_value = ( + self.test_endpoint_type_profile + ) + result = await test_module.get_did_endpoint(request) + json_response.assert_called_once_with( + {"endpoint": self.ledger.get_endpoint_for_did.return_value} + ) + assert result is json_response.return_value + async def test_get_endpoint_no_did(self): request = async_mock.MagicMock() request.app = self.app @@ -136,6 +162,62 @@ async def test_register_nym_ledger_txn_error(self): with self.assertRaises(test_module.web.HTTPForbidden): await test_module.register_ledger_nym(request) + async def test_register_nym_ledger_error(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {"did": self.test_did, "verkey": self.test_verkey} + self.ledger.register_nym.side_effect = test_module.LedgerError("Error") + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.register_ledger_nym(request) + + async def test_get_nym_role(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {"did": self.test_did} + + with async_mock.patch.object( + test_module.web, "json_response", async_mock.Mock() + ) as json_response: + self.ledger.get_nym_role.return_value = Role.USER + result = await test_module.get_nym_role(request) + json_response.assert_called_once_with({"role": "USER"}) + assert result is json_response.return_value + + async def test_get_nym_role_bad_request(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {"no": "did"} + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.get_nym_role(request) + + async def test_get_nym_role_ledger_txn_error(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {"did": self.test_did} + self.ledger.get_nym_role.side_effect = test_module.LedgerTransactionError( + "Error in building get-nym request" + ) + with self.assertRaises(test_module.web.HTTPForbidden): + await test_module.get_nym_role(request) + + async def test_get_nym_role_bad_ledger_req(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {"did": self.test_did} + self.ledger.get_nym_role.side_effect = test_module.BadLedgerRequestError( + "No such public DID" + ) + with self.assertRaises(test_module.web.HTTPNotFound): + await test_module.get_nym_role(request) + + async def test_get_nym_role_ledger_error(self): + request = async_mock.MagicMock() + request.app = self.app + request.query = {"did": self.test_did} + self.ledger.get_nym_role.side_effect = test_module.LedgerError("Error") + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.get_nym_role(request) + async def test_rotate_public_did_keypair(self): request = async_mock.MagicMock() request.app = self.app @@ -175,7 +257,7 @@ async def test_get_taa(self): with async_mock.patch.object( test_module.web, "json_response", async_mock.Mock() ) as json_response: - self.ledger.LEDGER_TYPE = "indy" + self.ledger.type = "indy" self.ledger.get_txn_author_agreement.return_value = {"taa_required": False} self.ledger.get_latest_txn_author_acceptance.return_value = None result = await test_module.ledger_get_taa(request) @@ -196,7 +278,7 @@ async def test_get_taa_required(self): with async_mock.patch.object( test_module.web, "json_response", async_mock.Mock() ) as json_response: - self.ledger.LEDGER_TYPE = "indy" + self.ledger.type = "indy" self.ledger.get_txn_author_agreement.return_value = taa_info self.ledger.get_latest_txn_author_acceptance.return_value = accepted result = await test_module.ledger_get_taa(request) @@ -208,7 +290,7 @@ async def test_get_taa_x(self): request = async_mock.MagicMock() request.app = self.app - self.ledger.LEDGER_TYPE = "indy" + self.ledger.type = "indy" self.ledger.get_txn_author_agreement.side_effect = test_module.LedgerError() with self.assertRaises(test_module.web.HTTPBadRequest): @@ -226,7 +308,7 @@ async def test_taa_accept_not_required(self): ) with self.assertRaises(test_module.web.HTTPBadRequest): - self.ledger.LEDGER_TYPE = "indy" + self.ledger.type = "indy" self.ledger.get_txn_author_agreement.return_value = {"taa_required": False} await test_module.ledger_accept_taa(request) @@ -244,7 +326,7 @@ async def test_accept_taa(self): with async_mock.patch.object( test_module.web, "json_response", async_mock.Mock() ) as json_response: - self.ledger.LEDGER_TYPE = "indy" + self.ledger.type = "indy" self.ledger.get_txn_author_agreement.return_value = {"taa_required": True} result = await test_module.ledger_accept_taa(request) json_response.assert_called_once_with({}) @@ -262,7 +344,7 @@ async def test_accept_taa_bad_ledger(self): request = async_mock.MagicMock() request.app = self.app - self.ledger.LEDGER_TYPE = "not-indy" + self.ledger.type = "not-indy" with self.assertRaises(test_module.web.HTTPForbidden): await test_module.ledger_accept_taa(request) @@ -276,7 +358,7 @@ async def test_accept_taa_x(self): "mechanism": "mechanism", } ) - self.ledger.LEDGER_TYPE = "indy" + self.ledger.type = "indy" self.ledger.get_txn_author_agreement.return_value = {"taa_required": True} self.ledger.accept_txn_author_agreement.side_effect = test_module.StorageError() with self.assertRaises(test_module.web.HTTPBadRequest): diff --git a/aries_cloudagent/messaging/ack/message.py b/aries_cloudagent/messaging/ack/message.py index 8f9e39bed5..201ff20953 100644 --- a/aries_cloudagent/messaging/ack/message.py +++ b/aries_cloudagent/messaging/ack/message.py @@ -1,6 +1,6 @@ """Represents an explicit ack message as per Aries RFC 15.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ..agent_message import AgentMessage, AgentMessageSchema @@ -36,6 +36,7 @@ class Meta: """Ack schema metadata.""" model_class = Ack + unknown = EXCLUDE status = fields.Constant( constant="OK", diff --git a/aries_cloudagent/messaging/agent_message.py b/aries_cloudagent/messaging/agent_message.py index b19505c19c..0a680a5de8 100644 --- a/aries_cloudagent/messaging/agent_message.py +++ b/aries_cloudagent/messaging/agent_message.py @@ -5,6 +5,7 @@ import uuid from marshmallow import ( + EXCLUDE, fields, pre_load, post_load, @@ -61,7 +62,7 @@ def __init__(self, _id: str = None, _decorators: BaseDecoratorSet = None): TypeError: If message type is missing on subclass Meta class """ - super(AgentMessage, self).__init__() + super().__init__() if _id: self._message_id = _id self._message_new_id = False @@ -265,7 +266,10 @@ def _thread(self, val: Union[ThreadDecorator, dict]): Args: val: ThreadDecorator or dict to set as the thread """ - self._decorators["thread"] = val + if val is None: + self._decorators.pop("thread", None) + else: + self._decorators["thread"] = val @property def _thread_id(self) -> str: @@ -319,7 +323,10 @@ def _trace(self, val: Union[TraceDecorator, dict]): Args: val: TraceDecorator or dict to set as the trace """ - self._decorators["trace"] = val + if val is None: + self._decorators.pop("trace", None) + else: + self._decorators["trace"] = val def assign_trace_from(self, msg: "AgentMessage"): """ @@ -384,6 +391,7 @@ class Meta: model_class = None signed_fields = None + unknown = EXCLUDE # Avoid clobbering keywords _type = fields.Str( @@ -408,13 +416,7 @@ def __init__(self, *args, **kwargs): TypeError: If Meta.model_class has not been set """ - super(AgentMessageSchema, self).__init__(*args, **kwargs) - if not self.Meta.model_class: - raise TypeError( - "Can't instantiate abstract class {} with no model_class".format( - self.__class__.__name__ - ) - ) + super().__init__(*args, **kwargs) self._decorators = DecoratorSet() self._decorators_dict = None self._signatures = {} diff --git a/aries_cloudagent/messaging/credential_definitions/routes.py b/aries_cloudagent/messaging/credential_definitions/routes.py index 288259b805..76691cd710 100644 --- a/aries_cloudagent/messaging/credential_definitions/routes.py +++ b/aries_cloudagent/messaging/credential_definitions/routes.py @@ -1,6 +1,6 @@ """Credential definition admin routes.""" -from asyncio import shield +from asyncio import ensure_future, shield from aiohttp import web from aiohttp_apispec import ( @@ -11,24 +11,32 @@ response_schema, ) -from marshmallow import fields, Schema +from marshmallow import fields from ...issuer.base import BaseIssuer from ...ledger.base import BaseLedger from ...storage.base import BaseStorage +from ...tails.base import BaseTailsServer +from ..models.openapi import OpenAPISchema from ..valid import INDY_CRED_DEF_ID, INDY_SCHEMA_ID, INDY_VERSION +from ...revocation.error import RevocationError, RevocationNotSupportedError +from ...revocation.indy import IndyRevocation + +from ...ledger.error import LedgerError + from .util import CredDefQueryStringSchema, CRED_DEF_TAGS, CRED_DEF_SENT_RECORD_TYPE -class CredentialDefinitionSendRequestSchema(Schema): +class CredentialDefinitionSendRequestSchema(OpenAPISchema): """Request schema for schema send request.""" schema_id = fields.Str(description="Schema identifier", **INDY_SCHEMA_ID) support_revocation = fields.Boolean( required=False, description="Revocation supported flag" ) + revocation_registry_size = fields.Int(required=False) tag = fields.Str( required=False, description="Credential definition identifier tag", @@ -37,7 +45,7 @@ class CredentialDefinitionSendRequestSchema(Schema): ) -class CredentialDefinitionSendResultsSchema(Schema): +class CredentialDefinitionSendResultsSchema(OpenAPISchema): """Results schema for schema send request.""" credential_definition_id = fields.Str( @@ -45,14 +53,14 @@ class CredentialDefinitionSendResultsSchema(Schema): ) -class CredentialDefinitionSchema(Schema): +class CredentialDefinitionSchema(OpenAPISchema): """Credential definition schema.""" ver = fields.Str(description="Node protocol version", **INDY_VERSION) ident = fields.Str( description="Credential definition identifier", data_key="id", - **INDY_CRED_DEF_ID + **INDY_CRED_DEF_ID, ) schemaId = fields.Str( description="Schema identifier within credential definition identifier", @@ -73,13 +81,13 @@ class CredentialDefinitionSchema(Schema): ) -class CredentialDefinitionGetResultsSchema(Schema): +class CredentialDefinitionGetResultsSchema(OpenAPISchema): """Results schema for schema get request.""" credential_definition = fields.Nested(CredentialDefinitionSchema) -class CredentialDefinitionsCreatedResultsSchema(Schema): +class CredentialDefinitionsCreatedResultsSchema(OpenAPISchema): """Results schema for cred-defs-created request.""" credential_definition_ids = fields.List( @@ -87,13 +95,13 @@ class CredentialDefinitionsCreatedResultsSchema(Schema): ) -class CredDefIdMatchInfoSchema(Schema): +class CredDefIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking cred def id.""" cred_def_id = fields.Str( description="Credential definition identifier", required=True, - **INDY_CRED_DEF_ID + **INDY_CRED_DEF_ID, ) @@ -121,6 +129,7 @@ async def credential_definitions_send_credential_definition(request: web.BaseReq schema_id = body.get("schema_id") support_revocation = bool(body.get("support_revocation")) tag = body.get("tag") + revocation_registry_size = body.get("revocation_registry_size") ledger: BaseLedger = await context.inject(BaseLedger, required=False) if not ledger: @@ -130,18 +139,65 @@ async def credential_definitions_send_credential_definition(request: web.BaseReq raise web.HTTPForbidden(reason=reason) issuer: BaseIssuer = await context.inject(BaseIssuer) - async with ledger: - credential_definition_id, credential_definition = await shield( - ledger.create_and_send_credential_definition( - issuer, - schema_id, - signature_type=None, - tag=tag, - support_revocation=support_revocation, + try: # even if in wallet, send it and raise if erroneously so + async with ledger: + (cred_def_id, cred_def, novel) = await shield( + ledger.create_and_send_credential_definition( + issuer, + schema_id, + signature_type=None, + tag=tag, + support_revocation=support_revocation, + ) + ) + except LedgerError as e: + raise web.HTTPBadRequest(reason=e.message) from e + + # If revocation is requested and cred def is novel, create revocation registry + if support_revocation and novel: + tails_base_url = context.settings.get("tails_server_base_url") + if not tails_base_url: + raise web.HTTPBadRequest(reason="tails_server_base_url not configured") + try: + # Create registry + issuer_did = cred_def_id.split(":")[0] + revoc = IndyRevocation(context) + registry_record = await revoc.init_issuer_registry( + cred_def_id, issuer_did, max_cred_num=revocation_registry_size, + ) + + except RevocationNotSupportedError as e: + raise web.HTTPBadRequest(reason=e.message) from e + await shield(registry_record.generate_registry(context)) + try: + await registry_record.set_tails_file_public_uri( + context, f"{tails_base_url}/{registry_record.revoc_reg_id}" ) - ) + await registry_record.publish_registry_definition(context) + await registry_record.publish_registry_entry(context) + + tails_server: BaseTailsServer = await context.inject(BaseTailsServer) + upload_success, reason = await tails_server.upload_tails_file( + context, registry_record.revoc_reg_id, registry_record.tails_local_path + ) + if not upload_success: + raise web.HTTPInternalServerError( + reason=f"Tails file failed to upload: {reason}" + ) + + pending_registry_record = await revoc.init_issuer_registry( + registry_record.cred_def_id, + registry_record.issuer_did, + max_cred_num=registry_record.max_cred_num, + ) + ensure_future( + pending_registry_record.stage_pending_registry_definition(context) + ) + + except RevocationError as e: + raise web.HTTPBadRequest(reason=e.message) from e - return web.json_response({"credential_definition_id": credential_definition_id}) + return web.json_response({"credential_definition_id": cred_def_id}) @docs( @@ -195,7 +251,7 @@ async def credential_definitions_get_credential_definition(request: web.BaseRequ """ context = request.app["request_context"] - credential_definition_id = request.match_info["cred_def_id"] + cred_def_id = request.match_info["cred_def_id"] ledger: BaseLedger = await context.inject(BaseLedger, required=False) if not ledger: @@ -205,11 +261,9 @@ async def credential_definitions_get_credential_definition(request: web.BaseRequ raise web.HTTPForbidden(reason=reason) async with ledger: - credential_definition = await ledger.get_credential_definition( - credential_definition_id - ) + cred_def = await ledger.get_credential_definition(cred_def_id) - return web.json_response({"credential_definition": credential_definition}) + return web.json_response({"credential_definition": cred_def}) async def register(app: web.Application): diff --git a/aries_cloudagent/messaging/credential_definitions/tests/test_routes.py b/aries_cloudagent/messaging/credential_definitions/tests/test_routes.py index 1bbb049537..6d2baba131 100644 --- a/aries_cloudagent/messaging/credential_definitions/tests/test_routes.py +++ b/aries_cloudagent/messaging/credential_definitions/tests/test_routes.py @@ -6,8 +6,9 @@ from ....config.injection_context import InjectionContext from ....issuer.base import BaseIssuer from ....ledger.base import BaseLedger -from ....storage.base import BaseStorage from ....messaging.request_context import RequestContext +from ....storage.base import BaseStorage +from ....tails.base import BaseTailsServer from .. import routes as test_module @@ -23,7 +24,7 @@ def setUp(self): self.ledger = async_mock.create_autospec(BaseLedger) self.ledger.__aenter__ = async_mock.CoroutineMock(return_value=self.ledger) self.ledger.create_and_send_credential_definition = async_mock.CoroutineMock( - return_value=(CRED_DEF_ID, {"cred": "def"}) + return_value=(CRED_DEF_ID, {"cred": "def"}, True) ) self.ledger.get_credential_definition = async_mock.CoroutineMock( return_value={"cred": "def"} @@ -68,6 +69,168 @@ async def test_send_credential_definition(self): {"credential_definition_id": CRED_DEF_ID} ) + async def test_send_credential_definition_revoc(self): + mock_request = async_mock.MagicMock( + app=self.app, + json=async_mock.CoroutineMock( + return_value={ + "schema_id": "WgWxqztrNooG92RXvxSTWv:2:schema_name:1.0", + "support_revocation": True, + "tag": "tag", + } + ), + ) + self.context.settings.set_value("tails_server_base_url", "http://1.2.3.4:8222") + + mock_tails_server = async_mock.MagicMock( + upload_tails_file=async_mock.CoroutineMock(return_value=(True, None)) + ) + self.context.injector.bind_instance(BaseTailsServer, mock_tails_server) + + with async_mock.patch.object( + test_module, "IndyRevocation", async_mock.MagicMock() + ) as test_indy_revoc, async_mock.patch.object( + test_module.web, "json_response", async_mock.MagicMock() + ) as mock_response: + test_indy_revoc.return_value = async_mock.MagicMock( + init_issuer_registry=async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + set_tails_file_public_uri=async_mock.CoroutineMock(), + generate_registry=async_mock.CoroutineMock(), + publish_registry_definition=async_mock.CoroutineMock(), + publish_registry_entry=async_mock.CoroutineMock(), + stage_pending_registry_definition=async_mock.CoroutineMock(), + ) + ) + ) + + await test_module.credential_definitions_send_credential_definition( + mock_request + ) + mock_response.assert_called_once_with( + {"credential_definition_id": CRED_DEF_ID} + ) + + async def test_send_credential_definition_revoc_no_tails_server_x(self): + mock_request = async_mock.MagicMock( + app=self.app, + json=async_mock.CoroutineMock( + return_value={ + "schema_id": "WgWxqztrNooG92RXvxSTWv:2:schema_name:1.0", + "support_revocation": True, + "tag": "tag", + } + ), + ) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.credential_definitions_send_credential_definition( + mock_request + ) + + async def test_send_credential_definition_revoc_no_support_x(self): + mock_request = async_mock.MagicMock( + app=self.app, + json=async_mock.CoroutineMock( + return_value={ + "schema_id": "WgWxqztrNooG92RXvxSTWv:2:schema_name:1.0", + "support_revocation": True, + "tag": "tag", + } + ), + ) + self.context.settings.set_value("tails_server_base_url", "http://1.2.3.4:8222") + + with async_mock.patch.object( + test_module, "IndyRevocation", async_mock.MagicMock() + ) as test_indy_revoc: + test_indy_revoc.return_value = async_mock.MagicMock( + init_issuer_registry=async_mock.CoroutineMock( + side_effect=test_module.RevocationNotSupportedError("nope") + ) + ) + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.credential_definitions_send_credential_definition( + mock_request + ) + + async def test_send_credential_definition_revoc_upload_x(self): + mock_request = async_mock.MagicMock( + app=self.app, + json=async_mock.CoroutineMock( + return_value={ + "schema_id": "WgWxqztrNooG92RXvxSTWv:2:schema_name:1.0", + "support_revocation": True, + "tag": "tag", + } + ), + ) + self.context.settings.set_value("tails_server_base_url", "http://1.2.3.4:8222") + + mock_tails_server = async_mock.MagicMock( + upload_tails_file=async_mock.CoroutineMock( + return_value=(False, "Down for maintenance") + ) + ) + self.context.injector.bind_instance(BaseTailsServer, mock_tails_server) + + with async_mock.patch.object( + test_module, "IndyRevocation", async_mock.MagicMock() + ) as test_indy_revoc: + test_indy_revoc.return_value = async_mock.MagicMock( + init_issuer_registry=async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + set_tails_file_public_uri=async_mock.CoroutineMock(), + generate_registry=async_mock.CoroutineMock(), + publish_registry_definition=async_mock.CoroutineMock(), + publish_registry_entry=async_mock.CoroutineMock(), + ) + ) + ) + with self.assertRaises(test_module.web.HTTPInternalServerError): + await test_module.credential_definitions_send_credential_definition( + mock_request + ) + + async def test_send_credential_definition_revoc_init_issuer_rev_reg_x(self): + mock_request = async_mock.MagicMock( + app=self.app, + json=async_mock.CoroutineMock( + return_value={ + "schema_id": "WgWxqztrNooG92RXvxSTWv:2:schema_name:1.0", + "support_revocation": True, + "tag": "tag", + } + ), + ) + self.context.settings.set_value("tails_server_base_url", "http://1.2.3.4:8222") + + mock_tails_server = async_mock.MagicMock( + upload_tails_file=async_mock.CoroutineMock(return_value=(True, None)) + ) + self.context.injector.bind_instance(BaseTailsServer, mock_tails_server) + + with async_mock.patch.object( + test_module, "IndyRevocation", async_mock.MagicMock() + ) as test_indy_revoc: + test_indy_revoc.return_value = async_mock.MagicMock( + init_issuer_registry=async_mock.CoroutineMock( + side_effect=[ + async_mock.MagicMock( + set_tails_file_public_uri=async_mock.CoroutineMock(), + generate_registry=async_mock.CoroutineMock(), + publish_registry_definition=async_mock.CoroutineMock(), + publish_registry_entry=async_mock.CoroutineMock(), + ), + test_module.RevocationError("Error on pending rev reg init"), + ] + ) + ) + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.credential_definitions_send_credential_definition( + mock_request + ) + async def test_send_credential_definition_no_ledger(self): mock_request = async_mock.MagicMock( app=self.app, @@ -86,6 +249,29 @@ async def test_send_credential_definition_no_ledger(self): mock_request ) + async def test_send_credential_definition_ledger_x(self): + mock_request = async_mock.MagicMock( + app=self.app, + json=async_mock.CoroutineMock( + return_value={ + "schema_id": "WgWxqztrNooG92RXvxSTWv:2:schema_name:1.0", + "support_revocation": False, + "tag": "tag", + } + ), + ) + + self.context.injector.clear_binding(BaseLedger) + self.ledger.__aenter__ = async_mock.CoroutineMock( + side_effect=test_module.LedgerError("oops") + ) + self.context.injector.bind_instance(BaseLedger, self.ledger) + + with self.assertRaises(test_module.web.HTTPBadRequest): + await test_module.credential_definitions_send_credential_definition( + mock_request + ) + async def test_created(self): mock_request = async_mock.MagicMock( app=self.app, match_info={"cred_def_id": CRED_DEF_ID}, diff --git a/aries_cloudagent/messaging/credential_definitions/util.py b/aries_cloudagent/messaging/credential_definitions/util.py index e9e7a133bc..8038f21fc2 100644 --- a/aries_cloudagent/messaging/credential_definitions/util.py +++ b/aries_cloudagent/messaging/credential_definitions/util.py @@ -1,7 +1,8 @@ """Credential definition utilities.""" -from marshmallow import fields, Schema +from marshmallow import fields +from ..models.openapi import OpenAPISchema from ..valid import ( INDY_CRED_DEF_ID, INDY_DID, @@ -13,7 +14,7 @@ CRED_DEF_SENT_RECORD_TYPE = "cred_def_sent" -class CredDefQueryStringSchema(Schema): +class CredDefQueryStringSchema(OpenAPISchema): """Query string parameters for credential definition searches.""" schema_id = fields.Str( diff --git a/aries_cloudagent/messaging/decorators/attach_decorator.py b/aries_cloudagent/messaging/decorators/attach_decorator.py index 214eb4d774..b8737d2270 100644 --- a/aries_cloudagent/messaging/decorators/attach_decorator.py +++ b/aries_cloudagent/messaging/decorators/attach_decorator.py @@ -10,7 +10,7 @@ from typing import Any, Mapping, Sequence, Union -from marshmallow import fields, pre_load +from marshmallow import EXCLUDE, fields, pre_load from ...wallet.base import BaseWallet from ...wallet.util import ( @@ -62,6 +62,7 @@ class Meta: """Attach decorator data schema metadata.""" model_class = AttachDecoratorDataJWSHeader + unknown = EXCLUDE kid = fields.Str( description="Key identifier, in W3C did:key or DID URL format", @@ -108,6 +109,7 @@ class Meta: """Single attach decorator data JWS schema metadata.""" model_class = AttachDecoratorData1JWS + unknown = EXCLUDE header = fields.Nested(AttachDecoratorDataJWSHeaderSchema, required=True) protected = fields.Str( @@ -152,6 +154,7 @@ class Meta: """Metadata for schema for detached JWS for inclusion in attach deco data.""" model_class = AttachDecoratorDataJWS + unknown = EXCLUDE @pre_load def validate_single_xor_multi_sig(self, data: Mapping, **kwargs): @@ -454,6 +457,7 @@ class Meta: """Attach decorator data schema metadata.""" model_class = AttachDecoratorData + unknown = EXCLUDE @pre_load def validate_data_spec(self, data: Mapping, **kwargs): @@ -632,6 +636,7 @@ class Meta: """AttachDecoratorSchema metadata.""" model_class = AttachDecorator + unknown = EXCLUDE ident = fields.Str( description="Attachment identifier", diff --git a/aries_cloudagent/messaging/decorators/localization_decorator.py b/aries_cloudagent/messaging/decorators/localization_decorator.py index f45b6ca17e..bbb731cb19 100644 --- a/aries_cloudagent/messaging/decorators/localization_decorator.py +++ b/aries_cloudagent/messaging/decorators/localization_decorator.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ..models.base import BaseModel, BaseModelSchema @@ -44,6 +44,7 @@ class Meta: """LocalizationDecoratorSchema metadata.""" model_class = LocalizationDecorator + unknown = EXCLUDE locale = fields.Str(required=True, description="Locale specifier", example="en-CA",) localizable = fields.List( diff --git a/aries_cloudagent/messaging/decorators/please_ack_decorator.py b/aries_cloudagent/messaging/decorators/please_ack_decorator.py index b3380b8050..3c97a7f40d 100644 --- a/aries_cloudagent/messaging/decorators/please_ack_decorator.py +++ b/aries_cloudagent/messaging/decorators/please_ack_decorator.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ..models.base import BaseModel, BaseModelSchema from ..valid import UUIDFour @@ -39,6 +39,7 @@ class Meta: """PleaseAckDecoratorSchema metadata.""" model_class = PleaseAckDecorator + unknown = EXCLUDE message_id = fields.Str( description="Message identifier", diff --git a/aries_cloudagent/messaging/decorators/signature_decorator.py b/aries_cloudagent/messaging/decorators/signature_decorator.py index 374c9e9f49..fe141b1412 100644 --- a/aries_cloudagent/messaging/decorators/signature_decorator.py +++ b/aries_cloudagent/messaging/decorators/signature_decorator.py @@ -1,11 +1,10 @@ """Model and schema for working with field signatures within message bodies.""" - import json import struct import time -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ...wallet.base import BaseWallet from ...wallet.util import b64_to_bytes, bytes_to_b64 @@ -92,7 +91,7 @@ def decode(self) -> (object, int): """ msg_bin = b64_to_bytes(self.sig_data, urlsafe=True) (timestamp,) = struct.unpack_from("!Q", msg_bin, 0) - return json.loads(msg_bin[8:]), timestamp + return (json.loads(msg_bin[8:]), timestamp) async def verify(self, wallet: BaseWallet) -> bool: """ @@ -128,12 +127,15 @@ class Meta: """SignatureDecoratorSchema metadata.""" model_class = SignatureDecorator + unknown = EXCLUDE signature_type = fields.Str( data_key="@type", required=True, description="Signature type", - example="did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/signature/1.0/ed25519Sha512_single", + example=( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" "spec/signature/1.0/ed25519Sha512_single" + ), ) signature = fields.Str( required=True, diff --git a/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py b/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py index 8c72ca888b..3a78075933 100644 --- a/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py +++ b/aries_cloudagent/messaging/decorators/tests/test_attach_decorator.py @@ -330,8 +330,9 @@ def test_indy_dict(self): assert lynx_str == lynx_list assert lynx_str != links + assert links != DATA_LINKS # has sha256 - def test_indy_dict(self): + def test_from_aries_msg(self): deco_aries = AttachDecorator.from_aries_msg( message=INDY_CRED, ident=IDENT, description=DESCRIPTION, ) diff --git a/aries_cloudagent/messaging/decorators/tests/test_base.py b/aries_cloudagent/messaging/decorators/tests/test_base.py index f1d1662edd..2c59f08c47 100644 --- a/aries_cloudagent/messaging/decorators/tests/test_base.py +++ b/aries_cloudagent/messaging/decorators/tests/test_base.py @@ -7,7 +7,7 @@ from time import time from unittest import TestCase -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ....messaging.models.base import BaseModel, BaseModelSchema @@ -33,6 +33,7 @@ class SampleDecoratorSchema(BaseModelSchema): class Meta: model_class = SampleDecorator + unknown = EXCLUDE score = fields.Int(required=True) diff --git a/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py b/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py index 6acd7e8b4d..bed13f62c7 100644 --- a/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py +++ b/aries_cloudagent/messaging/decorators/tests/test_decorator_set.py @@ -1,6 +1,7 @@ -from marshmallow import fields from unittest import TestCase +from marshmallow import EXCLUDE, fields + from ...models.base import BaseModel, BaseModelSchema from ..base import BaseDecoratorSet @@ -20,6 +21,7 @@ def __init__(self, *, value: str = None, handled_decorator: str = None, **kwargs class SimpleModelSchema(BaseModelSchema): class Meta: model_class = SimpleModel + unknown = EXCLUDE value = fields.Str(required=True) handled_decorator = fields.Str(required=False, data_key="handled~decorator") diff --git a/aries_cloudagent/messaging/decorators/tests/test_signature_decorator.py b/aries_cloudagent/messaging/decorators/tests/test_signature_decorator.py new file mode 100644 index 0000000000..3beead9fc5 --- /dev/null +++ b/aries_cloudagent/messaging/decorators/tests/test_signature_decorator.py @@ -0,0 +1,62 @@ +import pytest + +from asynctest import TestCase as AsyncTestCase, mock as async_mock + +from ....protocols.trustping.v1_0.messages.ping import Ping +from ....wallet.basic import BasicWallet +from .. import signature_decorator as test_module +from ..signature_decorator import SignatureDecorator + +TEST_VERKEY = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + + +class TestSignatureDecorator(AsyncTestCase): + async def test_init(self): + decorator = SignatureDecorator() + assert decorator.signature_type is None + assert decorator.signature is None + assert decorator.sig_data is None + assert decorator.signer is None + assert "SignatureDecorator" in str(decorator) + + async def test_serialize_load(self): + TEST_SIG = "IkJvYiI=" + TEST_SIG_DATA = "MTIzNDU2Nzg5MCJCb2Ii" + + decorator = SignatureDecorator( + signature_type=SignatureDecorator.TYPE_ED25519SHA512, + signature=TEST_SIG, + sig_data=TEST_SIG_DATA, + signer=TEST_VERKEY, + ) + + dumped = decorator.serialize() + loaded = SignatureDecorator.deserialize(dumped) + + assert loaded.signature_type == SignatureDecorator.TYPE_ED25519SHA512 + assert loaded.signature == TEST_SIG + assert loaded.sig_data == TEST_SIG_DATA + assert loaded.signer == TEST_VERKEY + + async def test_create_decode_verify(self): + TEST_MESSAGE = "Hello world" + TEST_TIMESTAMP = 1234567890 + wallet = BasicWallet() + key_info = await wallet.create_signing_key() + + deco = await SignatureDecorator.create( + Ping(), key_info.verkey, wallet, timestamp=None + ) + assert deco + + deco = await SignatureDecorator.create( + TEST_MESSAGE, key_info.verkey, wallet, TEST_TIMESTAMP + ) + + (msg, timestamp) = deco.decode() + assert msg == TEST_MESSAGE + assert timestamp == TEST_TIMESTAMP + + await deco.verify(wallet) + deco.signature_type = "unsupported-sig-type" + assert not await deco.verify(wallet) diff --git a/aries_cloudagent/messaging/decorators/thread_decorator.py b/aries_cloudagent/messaging/decorators/thread_decorator.py index 07315fc477..3d447f6b85 100644 --- a/aries_cloudagent/messaging/decorators/thread_decorator.py +++ b/aries_cloudagent/messaging/decorators/thread_decorator.py @@ -7,7 +7,7 @@ from typing import Mapping -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ..models.base import BaseModel, BaseModelSchema from ..valid import UUIDFour @@ -47,7 +47,7 @@ def __init__( as it provides an implicit ACK.) """ - super(ThreadDecorator, self).__init__() + super().__init__() self._thid = thid self._pthid = pthid self._sender_order = sender_order or None @@ -117,6 +117,7 @@ class Meta: """ThreadDecoratorSchema metadata.""" model_class = ThreadDecorator + unknown = EXCLUDE thid = fields.Str( required=False, diff --git a/aries_cloudagent/messaging/decorators/timing_decorator.py b/aries_cloudagent/messaging/decorators/timing_decorator.py index bcc052c373..25da632a2c 100644 --- a/aries_cloudagent/messaging/decorators/timing_decorator.py +++ b/aries_cloudagent/messaging/decorators/timing_decorator.py @@ -8,7 +8,7 @@ from datetime import datetime from typing import Union -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ..models.base import BaseModel, BaseModelSchema from ..util import datetime_to_str @@ -44,7 +44,7 @@ def __init__( delay_milli: The number of milliseconds to delay processing wait_until_time: The earliest time at which to perform processing """ - super(TimingDecorator, self).__init__() + super().__init__() self.in_time = datetime_to_str(in_time) self.out_time = datetime_to_str(out_time) self.stale_time = datetime_to_str(stale_time) @@ -60,6 +60,7 @@ class Meta: """TimingDecoratorSchema metadata.""" model_class = TimingDecorator + unknown = EXCLUDE in_time = fields.Str( required=False, description="Time of message receipt", **INDY_ISO8601_DATETIME diff --git a/aries_cloudagent/messaging/decorators/trace_decorator.py b/aries_cloudagent/messaging/decorators/trace_decorator.py index cbc4f46122..c140758fa3 100644 --- a/aries_cloudagent/messaging/decorators/trace_decorator.py +++ b/aries_cloudagent/messaging/decorators/trace_decorator.py @@ -7,7 +7,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ..models.base import BaseModel, BaseModelSchema from ..valid import UUIDFour @@ -50,7 +50,7 @@ def __init__( ellapsed_milli: ... outcome: ... """ - super(TraceReport, self).__init__() + super().__init__() self._msg_id = msg_id self._thread_id = thread_id self._traced_type = traced_type @@ -234,6 +234,7 @@ class Meta: """TraceReportSchema metadata.""" model_class = TraceReport + unknown = EXCLUDE msg_id = fields.Str( required=True, @@ -292,6 +293,7 @@ class Meta: """TraceDecoratorSchema metadata.""" model_class = TraceDecorator + unknown = EXCLUDE target = fields.Str( required=True, diff --git a/aries_cloudagent/messaging/decorators/transport_decorator.py b/aries_cloudagent/messaging/decorators/transport_decorator.py index 22cc0c1f26..1dc7c6718e 100644 --- a/aries_cloudagent/messaging/decorators/transport_decorator.py +++ b/aries_cloudagent/messaging/decorators/transport_decorator.py @@ -4,7 +4,7 @@ This decorator allows changes to agent response behaviour and queue status updates. """ -from marshmallow import fields, validate +from marshmallow import EXCLUDE, fields, validate from ..models.base import BaseModel, BaseModelSchema from ..valid import UUIDFour, WHOLE_NUM @@ -46,6 +46,7 @@ class Meta: """TransportDecoratorSchema metadata.""" model_class = TransportDecorator + unknown = EXCLUDE return_route = fields.Str( required=False, diff --git a/aries_cloudagent/messaging/jsonld/create_verify_data.py b/aries_cloudagent/messaging/jsonld/create_verify_data.py index eef4ab4371..69c9aafbc9 100644 --- a/aries_cloudagent/messaging/jsonld/create_verify_data.py +++ b/aries_cloudagent/messaging/jsonld/create_verify_data.py @@ -5,11 +5,10 @@ https://github.com/transmute-industries/Ed25519Signature2018/blob/master/src/createVerifyData/index.js """ - import datetime +import hashlib from pyld import jsonld -import hashlib def _canonize(data): diff --git a/aries_cloudagent/messaging/jsonld/credential.py b/aries_cloudagent/messaging/jsonld/credential.py index ddd42fb681..20bf334225 100644 --- a/aries_cloudagent/messaging/jsonld/credential.py +++ b/aries_cloudagent/messaging/jsonld/credential.py @@ -1,6 +1,7 @@ """Sign and verify functions for json-ld based credentials.""" import json + from ...wallet.util import ( b58_to_bytes, b64_to_bytes, diff --git a/aries_cloudagent/messaging/jsonld/routes.py b/aries_cloudagent/messaging/jsonld/routes.py index c6ee839c01..0640ef2c61 100644 --- a/aries_cloudagent/messaging/jsonld/routes.py +++ b/aries_cloudagent/messaging/jsonld/routes.py @@ -3,23 +3,21 @@ from aiohttp import web from aiohttp_apispec import docs, request_schema, response_schema -from ...messaging.jsonld.credential import ( - sign_credential, - verify_credential, -) -from ...wallet.base import BaseWallet +from marshmallow import fields -from marshmallow import fields, Schema +from ...wallet.base import BaseWallet +from ..models.openapi import OpenAPISchema +from .credential import sign_credential, verify_credential -class SignRequestSchema(Schema): +class SignRequestSchema(OpenAPISchema): """Request schema for signing a jsonld doc.""" verkey = fields.Str(required=True, description="verkey to use for signing") doc = fields.Dict(required=True, description="JSON-LD Doc to sign") -class SignResponseSchema(Schema): +class SignResponseSchema(OpenAPISchema): """Response schema for a signed jsonld doc.""" signed_doc = fields.Dict(required=True) @@ -60,14 +58,14 @@ async def sign(request: web.BaseRequest): return web.json_response(response) -class VerifyRequestSchema(Schema): +class VerifyRequestSchema(OpenAPISchema): """Request schema for signing a jsonld doc.""" verkey = fields.Str(required=True, description="verkey to use for doc verification") doc = fields.Dict(required=True, description="JSON-LD Doc to verify") -class VerifyResponseSchema(Schema): +class VerifyResponseSchema(OpenAPISchema): """Response schema for verification result.""" valid = fields.Bool(required=True) diff --git a/aries_cloudagent/messaging/models/base.py b/aries_cloudagent/messaging/models/base.py index e516685841..91b97ddeb9 100644 --- a/aries_cloudagent/messaging/models/base.py +++ b/aries_cloudagent/messaging/models/base.py @@ -221,7 +221,7 @@ def __init__(self, *args, **kwargs): TypeError: If model_class is not set on Meta """ - super(BaseModelSchema, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) if not self.Meta.model_class: raise TypeError( "Can't instantiate abstract class {} with no model_class".format( @@ -296,3 +296,12 @@ def remove_skipped_values(self, data, **kwargs): """ skip_vals = resolve_meta_property(self, "skip_values", []) return {key: value for key, value in data.items() if value not in skip_vals} + + +class OpenAPISchema(Schema): + """Schema for OpenAPI artifacts: excluding unknown fields, not raising exception.""" + + class Meta: + """BaseModelSchema metadata.""" + + unknown = EXCLUDE diff --git a/aries_cloudagent/messaging/models/base_record.py b/aries_cloudagent/messaging/models/base_record.py index 86c1aedd20..4d013f481e 100644 --- a/aries_cloudagent/messaging/models/base_record.py +++ b/aries_cloudagent/messaging/models/base_record.py @@ -475,6 +475,8 @@ class BaseRecordSchema(BaseModelSchema): class Meta: """BaseRecordSchema metadata.""" + model_class = None + state = fields.Str( required=False, description="Current record state", example="active" ) diff --git a/aries_cloudagent/messaging/models/openapi.py b/aries_cloudagent/messaging/models/openapi.py new file mode 100644 index 0000000000..2ddeb140ee --- /dev/null +++ b/aries_cloudagent/messaging/models/openapi.py @@ -0,0 +1,13 @@ +"""Base class for OpenAPI artifact schema.""" + +from marshmallow import Schema, EXCLUDE + + +class OpenAPISchema(Schema): + """Schema for OpenAPI artifacts: excluding unknown fields, not raising exception.""" + + class Meta: + """OpenAPISchema metadata.""" + + model_class = None + unknown = EXCLUDE diff --git a/aries_cloudagent/messaging/models/tests/test_base.py b/aries_cloudagent/messaging/models/tests/test_base.py index 98eaf35e84..62c327b7a1 100644 --- a/aries_cloudagent/messaging/models/tests/test_base.py +++ b/aries_cloudagent/messaging/models/tests/test_base.py @@ -2,7 +2,7 @@ from asynctest import TestCase as AsyncTestCase, mock as async_mock -from marshmallow import fields, validates_schema, ValidationError +from marshmallow import EXCLUDE, fields, validates_schema, ValidationError from ....cache.base import BaseCache from ....config.injection_context import InjectionContext @@ -11,7 +11,7 @@ from ...responder import BaseResponder, MockResponder from ...util import time_now -from ..base import BaseModel, BaseModelSchema +from ..base import BaseModel, BaseModelError, BaseModelSchema class ModelImpl(BaseModel): @@ -25,8 +25,9 @@ def __init__(self, *, attr=None): class SchemaImpl(BaseModelSchema): class Meta: model_class = ModelImpl + unknown = EXCLUDE - attr = fields.String() + attr = fields.String(required=True) @validates_schema def validate_fields(self, data, **kwargs): @@ -44,3 +45,21 @@ def test_model_validate_succeeds(self): model = ModelImpl(attr="succeeds") model = model.validate() assert model.attr == "succeeds" + + def test_ser_x(self): + model = ModelImpl(attr="hello world") + with async_mock.patch.object( + model, "_get_schema_class", async_mock.MagicMock() + ) as mock_get_schema_class: + mock_get_schema_class.return_value = async_mock.MagicMock( + return_value=async_mock.MagicMock( + dump=async_mock.MagicMock(side_effect=ValidationError("error")) + ) + ) + with self.assertRaises(BaseModelError): + model.serialize() + + def test_from_json_x(self): + data = "{}{}" + with self.assertRaises(BaseModelError): + ModelImpl.from_json(data) diff --git a/aries_cloudagent/messaging/models/tests/test_base_record.py b/aries_cloudagent/messaging/models/tests/test_base_record.py index 8592c510ad..8c3e9f6780 100644 --- a/aries_cloudagent/messaging/models/tests/test_base_record.py +++ b/aries_cloudagent/messaging/models/tests/test_base_record.py @@ -1,10 +1,12 @@ import json from asynctest import TestCase as AsyncTestCase, mock as async_mock +from marshmallow import EXCLUDE, fields from ....cache.base import BaseCache from ....config.injection_context import InjectionContext -from ....storage.base import BaseStorage, StorageRecord +from ....storage.base import BaseStorage, StorageDuplicateError, StorageRecord +from ....storage.basic import BasicStorage from ...responder import BaseResponder, MockResponder from ...util import time_now @@ -23,6 +25,38 @@ class Meta: class BaseRecordImplSchema(BaseRecordSchema): class Meta: model_class = BaseRecordImpl + unknown = EXCLUDE + + +class ARecordImpl(BaseRecord): + class Meta: + schema_class = "ARecordImplSchema" + + RECORD_TYPE = "a-record" + CACHE_ENABLED = False + RECORD_ID_NAME = "ident" + TAG_NAMES = {"code"} + + def __init__(self, *, ident=None, a, b, code, **kwargs): + super().__init__(ident, **kwargs) + self.a = a + self.b = b + self.code = code + + @property + def record_value(self) -> dict: + return {"a": self.a, "b": self.b} + + +class ARecordImplSchema(BaseRecordSchema): + class Meta: + model_class = BaseRecordImpl + unknown = EXCLUDE + + ident = fields.Str(attribute="_id") + a = fields.Str() + b = fields.Str() + code = fields.Str() class UnencTestImpl(BaseRecord): @@ -42,6 +76,10 @@ def test_from_storage_values(self): assert inst._id == record_id assert inst.value == stored + stored[BaseRecordImpl.RECORD_ID_NAME] = inst._id + with self.assertRaises(ValueError): + BaseRecordImpl.from_storage(record_id, stored) + async def test_post_save_new(self): context = InjectionContext(enforce_typing=False) mock_storage = async_mock.MagicMock() @@ -74,12 +112,15 @@ async def test_post_save_exist(self): mock_storage.update_record_tags.assert_called_once() async def test_cache(self): + assert not await BaseRecordImpl.get_cached_key(None, None) + await BaseRecordImpl.set_cached_key(None, None, None) + await BaseRecordImpl.clear_cached_key(None, None) context = InjectionContext(enforce_typing=False) mock_cache = async_mock.MagicMock(BaseCache, autospec=True) context.injector.bind_instance(BaseCache, mock_cache) record = BaseRecordImpl() cache_key = "cache_key" - cache_result = await record.get_cached_key(context, cache_key) + cache_result = await BaseRecordImpl.get_cached_key(context, cache_key) mock_cache.get.assert_awaited_once_with(cache_key) assert cache_result is mock_cache.get.return_value @@ -109,6 +150,39 @@ async def test_retrieve_cached_id(self): assert result._id == record_id assert result.value == stored + async def test_retrieve_by_tag_filter_multi_x_delete(self): + context = InjectionContext(enforce_typing=False) + basic_storage = BasicStorage() + context.injector.bind_instance(BaseStorage, basic_storage) + records = [] + for i in range(3): + records.append(ARecordImpl(a="1", b=str(i), code="one")) + await records[i].save(context) + with self.assertRaises(StorageDuplicateError): + await ARecordImpl.retrieve_by_tag_filter( + context, {"code": "one"}, {"a": "1"} + ) + await records[0].delete_record(context) + + async def test_save_x(self): + context = InjectionContext(enforce_typing=False) + basic_storage = BasicStorage() + context.injector.bind_instance(BaseStorage, basic_storage) + rec = ARecordImpl(a="1", b="0", code="one") + with async_mock.patch.object( + context, "inject", async_mock.CoroutineMock() + ) as mock_inject: + mock_inject.return_value = async_mock.MagicMock( + add_record=async_mock.CoroutineMock(side_effect=ZeroDivisionError()) + ) + with self.assertRaises(ZeroDivisionError): + await rec.save(context) + + async def test_neq(self): + a_rec = ARecordImpl(a="1", b="0", code="one") + b_rec = BaseRecordImpl() + assert a_rec != b_rec + async def test_retrieve_uncached_id(self): context = InjectionContext(enforce_typing=False) mock_storage = async_mock.MagicMock(BaseStorage, autospec=True) @@ -163,7 +237,7 @@ def test_log_state(self, mock_print): BaseRecordImpl, "LOG_STATE_FLAG", test_param ) as cls: record = BaseRecordImpl() - record.log_state(context, "state") + record.log_state(context, msg="state", params={"a": "1", "b": "2"}) mock_print.assert_called_once() @async_mock.patch("builtins.print") @@ -180,6 +254,8 @@ async def test_webhook(self): record = BaseRecordImpl() payload = {"test": "payload"} topic = "topic" + await record.send_webhook(context, None, None) # cover short circuit + await record.send_webhook(context, "hello", None) # cover short circuit await record.send_webhook(context, payload, topic=topic) assert mock_responder.webhooks == [(topic, payload)] @@ -190,6 +266,11 @@ async def test_tag_prefix(self): tags = {"a": "x", "b": "y", "c": "z"} assert UnencTestImpl.prefix_tag_filter(tags) == {"~a": "x", "~b": "y", "c": "z"} + tags = {"$not": {"a": "x", "b": "y", "c": "z"}} + expect = {"$not": {"~a": "x", "~b": "y", "c": "z"}} + actual = UnencTestImpl.prefix_tag_filter(tags) + assert {**expect} == {**actual} + tags = {"$or": [{"a": "x"}, {"c": "z"}]} assert UnencTestImpl.prefix_tag_filter(tags) == { "$or": [{"~a": "x"}, {"c": "z"}] diff --git a/aries_cloudagent/messaging/request_context.py b/aries_cloudagent/messaging/request_context.py index 61e068525b..fd43f677dd 100644 --- a/aries_cloudagent/messaging/request_context.py +++ b/aries_cloudagent/messaging/request_context.py @@ -80,7 +80,7 @@ def default_endpoint(self) -> str: The default agent endpoint """ - return self.settings["default_endpoint"] + return self.settings.get("default_endpoint") @default_endpoint.setter def default_endpoint(self, endpoint: str): diff --git a/aries_cloudagent/messaging/schemas/routes.py b/aries_cloudagent/messaging/schemas/routes.py index 5515570f29..bfae859cee 100644 --- a/aries_cloudagent/messaging/schemas/routes.py +++ b/aries_cloudagent/messaging/schemas/routes.py @@ -11,18 +11,19 @@ response_schema, ) -from marshmallow import fields, Schema +from marshmallow import fields from marshmallow.validate import Regexp from ...issuer.base import BaseIssuer, IssuerError from ...ledger.base import BaseLedger from ...ledger.error import LedgerError from ...storage.base import BaseStorage +from ..models.openapi import OpenAPISchema from ..valid import B58, NATURAL_NUM, INDY_SCHEMA_ID, INDY_VERSION from .util import SchemaQueryStringSchema, SCHEMA_SENT_RECORD_TYPE, SCHEMA_TAGS -class SchemaSendRequestSchema(Schema): +class SchemaSendRequestSchema(OpenAPISchema): """Request schema for schema send request.""" schema_name = fields.Str(required=True, description="Schema name", example="prefs",) @@ -36,7 +37,7 @@ class SchemaSendRequestSchema(Schema): ) -class SchemaSendResultsSchema(Schema): +class SchemaSendResultsSchema(OpenAPISchema): """Results schema for schema send request.""" schema_id = fields.Str( @@ -45,7 +46,7 @@ class SchemaSendResultsSchema(Schema): schema = fields.Dict(description="Schema result", required=True) -class SchemaSchema(Schema): +class SchemaSchema(OpenAPISchema): """Content for returned schema.""" ver = fields.Str(description="Node protocol version", **INDY_VERSION) @@ -62,13 +63,13 @@ class SchemaSchema(Schema): seqNo = fields.Int(description="Schema sequence number", **NATURAL_NUM) -class SchemaGetResultsSchema(Schema): +class SchemaGetResultsSchema(OpenAPISchema): """Results schema for schema get request.""" - schema_json = fields.Nested(SchemaSchema()) + schema = fields.Nested(SchemaSchema()) -class SchemasCreatedResultsSchema(Schema): +class SchemasCreatedResultsSchema(OpenAPISchema): """Results schema for a schemas-created request.""" schema_ids = fields.List( @@ -76,7 +77,7 @@ class SchemasCreatedResultsSchema(Schema): ) -class SchemaIdMatchInfoSchema(Schema): +class SchemaIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking schema id.""" schema_id = fields.Str( diff --git a/aries_cloudagent/messaging/schemas/util.py b/aries_cloudagent/messaging/schemas/util.py index 48d190cb6b..4c71781003 100644 --- a/aries_cloudagent/messaging/schemas/util.py +++ b/aries_cloudagent/messaging/schemas/util.py @@ -1,15 +1,12 @@ """Schema utilities.""" -from marshmallow import fields, Schema +from marshmallow import fields -from ..valid import ( - INDY_DID, - INDY_SCHEMA_ID, - INDY_VERSION, -) +from ..models.openapi import OpenAPISchema +from ..valid import INDY_DID, INDY_SCHEMA_ID, INDY_VERSION -class SchemaQueryStringSchema(Schema): +class SchemaQueryStringSchema(OpenAPISchema): """Query string parameters for schema searches.""" schema_id = fields.Str( diff --git a/aries_cloudagent/messaging/tests/test_agent_message.py b/aries_cloudagent/messaging/tests/test_agent_message.py index 85536d7f75..4e6a9c5452 100644 --- a/aries_cloudagent/messaging/tests/test_agent_message.py +++ b/aries_cloudagent/messaging/tests/test_agent_message.py @@ -1,11 +1,15 @@ -from asynctest import TestCase as AsyncTestCase -from marshmallow import fields import json +from asynctest import TestCase as AsyncTestCase +from marshmallow import EXCLUDE, fields + +from ...wallet.basic import BasicWallet +from ...wallet.util import bytes_to_b64 + from ..agent_message import AgentMessage, AgentMessageSchema from ..decorators.signature_decorator import SignatureDecorator from ..decorators.trace_decorator import TraceReport, TRACE_LOG_TARGET -from ...wallet.basic import BasicWallet +from ..models.base import BaseModelError class SignedAgentMessage(AgentMessage): @@ -19,7 +23,7 @@ class Meta: message_type = "signed-agent-message" def __init__(self, value: str = None, **kwargs): - super(SignedAgentMessage, self).__init__(**kwargs) + super().__init__(**kwargs) self.value = value @@ -29,6 +33,7 @@ class SignedAgentMessageSchema(AgentMessageSchema): class Meta: model_class = SignedAgentMessage signed_fields = ("value",) + unknown = EXCLUDE value = fields.Str(required=True) @@ -39,25 +44,29 @@ class BasicAgentMessage(AgentMessage): class Meta: """Meta data""" - schema_class = "AgentMessageSchema" + schema_class = AgentMessageSchema message_type = "basic-message" class TestAgentMessage(AsyncTestCase): """Tests agent message.""" - class BadImplementationClass(AgentMessage): - """Test utility class.""" - - pass - def test_init(self): """Tests init class""" - SignedAgentMessage() + + class BadImplementationClass(AgentMessage): + """Test utility class.""" + + message = SignedAgentMessage() + message._id = "12345" with self.assertRaises(TypeError) as context: - self.BadImplementationClass() # pylint: disable=E0110 + BadImplementationClass() # pylint: disable=E0110 + assert "Can't instantiate abstract" in str(context.exception) + BadImplementationClass.Meta.schema_class = "AgentMessageSchema" + with self.assertRaises(TypeError) as context: + BadImplementationClass() # pylint: disable=E0110 assert "Can't instantiate abstract" in str(context.exception) async def test_field_signature(self): @@ -65,7 +74,16 @@ async def test_field_signature(self): key_info = await wallet.create_signing_key() msg = SignedAgentMessage() + msg.value = None + with self.assertRaises(BaseModelError) as context: + await msg.sign_field("value", key_info.verkey, wallet) + assert "field has no value for signature" in str(context.exception) + msg.value = "Test value" + with self.assertRaises(BaseModelError) as context: + msg.serialize() + assert "Missing signature for field" in str(context.exception) + await msg.sign_field("value", key_info.verkey, wallet) sig = msg.get_signature("value") assert isinstance(sig, SignatureDecorator) @@ -74,9 +92,27 @@ async def test_field_signature(self): assert await msg.verify_signed_field("value", wallet) == key_info.verkey assert await msg.verify_signatures(wallet) + with self.assertRaises(BaseModelError) as context: + await msg.verify_signed_field("value", wallet, "bogus-verkey") + assert "Signer verkey of signature does not match" in str(context.exception) + serial = msg.serialize() assert "value~sig" in serial and "value" not in serial + (_, timestamp) = msg._decorators.field("value")["sig"].decode() + tamper_deco = await SignatureDecorator.create("tamper", key_info.verkey, wallet) + msg._decorators.field("value")["sig"].sig_data = tamper_deco.sig_data + with self.assertRaises(BaseModelError) as context: + await msg.verify_signed_field("value", wallet) + assert "Field signature verification failed" in str(context.exception) + assert not await msg.verify_signatures(wallet) + + msg.value = "Test value" + msg._decorators.field("value").pop("sig") + with self.assertRaises(BaseModelError) as context: + await msg.verify_signed_field("value", wallet) + assert "Missing field signature" in str(context.exception) + loaded = SignedAgentMessage.deserialize(serial) assert isinstance(loaded, SignedAgentMessage) assert await loaded.verify_signed_field("value", wallet) == key_info.verkey @@ -89,6 +125,9 @@ async def test_assign_thread(self): assert reply._thread_id == msg._thread_id assert reply._thread_id != reply._id + msg.assign_thread_id(None, None) + assert not msg._thread + async def test_add_tracing(self): msg = BasicAgentMessage() msg.add_trace_decorator() @@ -148,3 +187,85 @@ async def test_add_tracing(self): assert msg_trace_report.outcome == trace_report2.outcome print("tracer:", tracer.serialize()) + + msg3 = BasicAgentMessage() + msg.add_trace_decorator() + assert msg._trace + + +class TestAgentMessageSchema(AsyncTestCase): + """Tests agent message schema.""" + + def test_init_x(self): + """Tests init class""" + + class BadImplementationClass(AgentMessageSchema): + """Test utility class.""" + + with self.assertRaises(TypeError) as context: + BadImplementationClass() + assert "Can't instantiate abstract" in str(context.exception) + + def test_extract_decorators_x(self): + for serial in [ + { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "value": "Test value", + }, + { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "value": "Test value", + "value~sig": { + "@type": ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" + "spec/signature/1.0/ed25519Sha512_single" + ), + "signature": ( + "-OKdiRRQu-xbVGICg1J6KV_6nXLLzYRXr8BZSXzoXimytBl" + "O8ULY7Nl1lQPqahc-XQPHiBSVraLM8XN_sCzdCg==" + ), + "sig_data": "AAAAAF8bIV4iVGVzdCB2YWx1ZSI=", + "signer": "7VA3CaF9jaTuRN2SGmekANoja6Js4U51kfRSbpZAfdhy", + }, + }, + { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "superfluous~sig": { + "@type": ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" + "spec/signature/1.0/ed25519Sha512_single" + ), + "signature": ( + "-OKdiRRQu-xbVGICg1J6KV_6nXLLzYRXr8BZSXzoXimytBl" + "O8ULY7Nl1lQPqahc-XQPHiBSVraLM8XN_sCzdCg==" + ), + "sig_data": "AAAAAF8bIV4iVGVzdCB2YWx1ZSI=", + "signer": "7VA3CaF9jaTuRN2SGmekANoja6Js4U51kfRSbpZAfdhy", + }, + }, + ]: + with self.assertRaises(BaseModelError) as context: + SignedAgentMessage.deserialize(serial) + + def test_serde(self): + serial = { + "@type": "signed-agent-message", + "@id": "030ac9e6-0d60-49d3-a8c6-e7ce0be8df5a", + "value~sig": { + "@type": ( + "did:sov:BzCbsNYhMrjHiqZDTUASHg;" + "spec/signature/1.0/ed25519Sha512_single" + ), + "signature": ( + "-OKdiRRQu-xbVGICg1J6KV_6nXLLzYRXr8BZSXzoXimytBl" + "O8ULY7Nl1lQPqahc-XQPHiBSVraLM8XN_sCzdCg==" + ), + "sig_data": "AAAAAF8bIV4iVGVzdCB2YWx1ZSI=", + "signer": "7VA3CaF9jaTuRN2SGmekANoja6Js4U51kfRSbpZAfdhy", + }, + } + result = SignedAgentMessage.deserialize(serial) + result.serialize() diff --git a/aries_cloudagent/messaging/tests/test_valid.py b/aries_cloudagent/messaging/tests/test_valid.py index 5d592ee626..701f81358b 100644 --- a/aries_cloudagent/messaging/tests/test_valid.py +++ b/aries_cloudagent/messaging/tests/test_valid.py @@ -12,6 +12,7 @@ BASE64URL_NO_PAD, DID_KEY, ENDPOINT, + ENDPOINT_TYPE, INDY_CRED_DEF_ID, INDY_CRED_REV_ID, INDY_DID, @@ -426,3 +427,22 @@ def test_endpoint(self): ENDPOINT["validate"]("newproto://myhost.ca:8080/path") ENDPOINT["validate"]("ftp://10.10.100.90:8021") ENDPOINT["validate"]("zzzp://someplace.ca:9999/path") + + def test_endpoint_type(self): + non_endpoint_types = [ + "123", + "endpoint", + "end point", + "end-point", + "profile", + "linked_domains", + None, + ] + + for non_endpoint_type in non_endpoint_types: + with self.assertRaises(ValidationError): + ENDPOINT_TYPE["validate"](non_endpoint_type) + + ENDPOINT_TYPE["validate"]("Endpoint") + ENDPOINT_TYPE["validate"]("Profile") + ENDPOINT_TYPE["validate"]("LinkedDomains") diff --git a/aries_cloudagent/messaging/valid.py b/aries_cloudagent/messaging/valid.py index f5552d3410..3294205709 100644 --- a/aries_cloudagent/messaging/valid.py +++ b/aries_cloudagent/messaging/valid.py @@ -10,6 +10,8 @@ from .util import epoch_to_str +from ..ledger.endpoint_type import EndpointType as EndpointTypeEnum + B58 = alphabet if isinstance(alphabet, str) else alphabet.decode("ascii") @@ -426,6 +428,20 @@ def __init__(self): ) +class EndpointType(OneOf): + """Validate value against allowed endpoint/service types.""" + + EXAMPLE = EndpointTypeEnum.ENDPOINT.w3c + + def __init__(self): + """Initializer.""" + + super().__init__( + choices=[e.w3c for e in EndpointTypeEnum], + error="Value {input} must be one of {choices}", + ) + + # Instances for marshmallow schema specification INT_EPOCH = {"validate": IntEpoch(), "example": IntEpoch.EXAMPLE} WHOLE_NUM = {"validate": WholeNumber(), "example": WholeNumber.EXAMPLE} @@ -460,3 +476,4 @@ def __init__(self): } UUID4 = {"validate": UUIDFour(), "example": UUIDFour.EXAMPLE} ENDPOINT = {"validate": Endpoint(), "example": Endpoint.EXAMPLE} +ENDPOINT_TYPE = {"validate": EndpointType(), "example": EndpointType.EXAMPLE} diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu.py b/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu.py index f94e074c94..e90620035d 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -40,7 +40,7 @@ def __init__( errormsg: An optional error message to display options: A sequence of menu options """ - super(Menu, self).__init__(**kwargs) + super().__init__(**kwargs) self.title = title self.description = description self.options = list(options) if options else [] @@ -53,6 +53,7 @@ class Meta: """Menu schema metadata.""" model_class = Menu + unknown = EXCLUDE title = fields.Str(required=False, description="Menu title", example="My Menu") description = fields.Str( diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu_request.py b/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu_request.py index e294d7ce5c..9aa81676d4 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu_request.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/messages/menu_request.py @@ -1,5 +1,7 @@ """Represents a request for an action menu.""" +from marshmallow import EXCLUDE + from .....messaging.agent_message import AgentMessage, AgentMessageSchema from ..message_types import MENU_REQUEST, PROTOCOL_PACKAGE @@ -19,7 +21,7 @@ class Meta: def __init__(self, **kwargs): """Initialize a menu request object.""" - super(MenuRequest, self).__init__(**kwargs) + super().__init__(**kwargs) class MenuRequestSchema(AgentMessageSchema): @@ -29,3 +31,4 @@ class Meta: """MenuRequest schema metadata.""" model_class = MenuRequest + unknown = EXCLUDE diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/messages/perform.py b/aries_cloudagent/protocols/actionmenu/v1_0/messages/perform.py index 1ef15e7788..8d30f66404 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/messages/perform.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/messages/perform.py @@ -2,7 +2,7 @@ from typing import Mapping -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -29,7 +29,7 @@ def __init__(self, *, name: str = None, params: Mapping[str, str] = None, **kwar name: The name of the menu option params: Input parameter values """ - super(Perform, self).__init__(**kwargs) + super().__init__(**kwargs) self.name = name self.params = params @@ -41,6 +41,7 @@ class Meta: """Perform schema metadata.""" model_class = Perform + unknown = EXCLUDE name = fields.Str(required=True, description="Menu option name", example="Query",) params = fields.Dict( diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py index 9193b82122..5e718500c5 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -47,6 +47,7 @@ class Meta: """MenuFormSchema metadata.""" model_class = MenuForm + unknown = EXCLUDE title = fields.Str( required=False, description="Menu form title", example="Preferences", diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form_param.py b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form_param.py index 7b2c77e70e..f265aba953 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form_param.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_form_param.py @@ -1,6 +1,6 @@ """Record used to represent a parameter in a menu form.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -50,6 +50,7 @@ class Meta: """MenuFormParamSchema metadata.""" model_class = MenuFormParam + unknown = EXCLUDE name = fields.Str( required=True, description="Menu parameter name", example="delay", diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_option.py b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_option.py index ac7dfd40e0..c7f36ff068 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_option.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/models/menu_option.py @@ -1,6 +1,6 @@ """Record used to represent individual menu options in an action menu.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -48,6 +48,7 @@ class Meta: """MenuOptionSchema metadata.""" model_class = MenuOption + unknown = EXCLUDE name = fields.Str( required=True, diff --git a/aries_cloudagent/protocols/actionmenu/v1_0/routes.py b/aries_cloudagent/protocols/actionmenu/v1_0/routes.py index 235cc6766b..e9520eacb9 100644 --- a/aries_cloudagent/protocols/actionmenu/v1_0/routes.py +++ b/aries_cloudagent/protocols/actionmenu/v1_0/routes.py @@ -5,10 +5,11 @@ from aiohttp import web from aiohttp_apispec import docs, match_info_schema, request_schema -from marshmallow import fields, Schema +from marshmallow import fields from ....connections.models.connection_record import ConnectionRecord from ....messaging.models.base import BaseModelError +from ....messaging.models.openapi import OpenAPISchema from ....messaging.valid import UUIDFour from ....storage.error import StorageError, StorageNotFoundError @@ -21,7 +22,7 @@ LOGGER = logging.getLogger(__name__) -class PerformRequestSchema(Schema): +class PerformRequestSchema(OpenAPISchema): """Request schema for performing a menu action.""" name = fields.Str(description="Menu option name", example="Query") @@ -33,7 +34,7 @@ class PerformRequestSchema(Schema): ) -class MenuJsonSchema(Schema): +class MenuJsonSchema(OpenAPISchema): """Matches MenuSchema but without the inherited AgentMessage properties.""" title = fields.Str(required=False, description="Menu title", example="My Menu",) @@ -54,7 +55,7 @@ class MenuJsonSchema(Schema): ) -class SendMenuSchema(Schema): +class SendMenuSchema(OpenAPISchema): """Request schema for sending a menu to a connection.""" menu = fields.Nested( @@ -62,7 +63,7 @@ class SendMenuSchema(Schema): ) -class ConnIdMatchInfoSchema(Schema): +class ConnIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking connection id.""" conn_id = fields.Str( diff --git a/aries_cloudagent/protocols/basicmessage/v1_0/messages/basicmessage.py b/aries_cloudagent/protocols/basicmessage/v1_0/messages/basicmessage.py index 11b2bd8720..addbbd8597 100644 --- a/aries_cloudagent/protocols/basicmessage/v1_0/messages/basicmessage.py +++ b/aries_cloudagent/protocols/basicmessage/v1_0/messages/basicmessage.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import Union -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.util import datetime_now, datetime_to_str @@ -57,6 +57,7 @@ class Meta: """Basic message schema metadata.""" model_class = BasicMessage + unknown = EXCLUDE sent_time = fields.Str( required=False, diff --git a/aries_cloudagent/protocols/basicmessage/v1_0/routes.py b/aries_cloudagent/protocols/basicmessage/v1_0/routes.py index 7d15063e0d..ffbac40f51 100644 --- a/aries_cloudagent/protocols/basicmessage/v1_0/routes.py +++ b/aries_cloudagent/protocols/basicmessage/v1_0/routes.py @@ -3,9 +3,10 @@ from aiohttp import web from aiohttp_apispec import docs, match_info_schema, request_schema -from marshmallow import fields, Schema +from marshmallow import fields from ....connections.models.connection_record import ConnectionRecord +from ....messaging.models.openapi import OpenAPISchema from ....messaging.valid import UUIDFour from ....storage.error import StorageNotFoundError @@ -13,13 +14,13 @@ from .messages.basicmessage import BasicMessage -class SendMessageSchema(Schema): +class SendMessageSchema(OpenAPISchema): """Request schema for sending a message.""" content = fields.Str(description="Message content", example="Hello") -class ConnIdMatchInfoSchema(Schema): +class ConnIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking connection id.""" conn_id = fields.Str( diff --git a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py index 46cd1d2eb5..f11d8dc597 100644 --- a/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py +++ b/aries_cloudagent/protocols/connections/v1_0/messages/connection_invitation.py @@ -3,7 +3,7 @@ from typing import Sequence from urllib.parse import parse_qs, urljoin, urlparse -from marshmallow import ValidationError, fields, validates_schema +from marshmallow import EXCLUDE, fields, validates_schema, ValidationError from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.valid import INDY_DID, INDY_RAW_PUBLIC_KEY @@ -50,7 +50,7 @@ def __init__( routing_keys: List of routing keys image_url: Optional image URL for connection invitation """ - super(ConnectionInvitation, self).__init__(**kwargs) + super().__init__(**kwargs) self.label = label self.did = did self.recipient_keys = list(recipient_keys) if recipient_keys else None @@ -97,6 +97,7 @@ class Meta: """Connection invitation schema metadata.""" model_class = ConnectionInvitation + unknown = EXCLUDE label = fields.Str( required=False, description="Optional label for connection", example="Bob" diff --git a/aries_cloudagent/protocols/connections/v1_0/messages/connection_request.py b/aries_cloudagent/protocols/connections/v1_0/messages/connection_request.py index c7db4a054e..f008658c29 100644 --- a/aries_cloudagent/protocols/connections/v1_0/messages/connection_request.py +++ b/aries_cloudagent/protocols/connections/v1_0/messages/connection_request.py @@ -1,6 +1,6 @@ """Represents a connection request message.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -39,7 +39,7 @@ def __init__( label: Label for this connection request image_url: Optional image URL for this connection request """ - super(ConnectionRequest, self).__init__(**kwargs) + super().__init__(**kwargs) self.connection = connection self.label = label @@ -51,6 +51,7 @@ class Meta: """Connection request schema metadata.""" model_class = ConnectionRequest + unknown = EXCLUDE connection = fields.Nested(ConnectionDetailSchema, required=True) label = fields.Str( diff --git a/aries_cloudagent/protocols/connections/v1_0/messages/connection_response.py b/aries_cloudagent/protocols/connections/v1_0/messages/connection_response.py index 9addb7961d..5472df6439 100644 --- a/aries_cloudagent/protocols/connections/v1_0/messages/connection_response.py +++ b/aries_cloudagent/protocols/connections/v1_0/messages/connection_response.py @@ -1,6 +1,6 @@ """Represents a connection response message.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -31,7 +31,7 @@ def __init__(self, *, connection: ConnectionDetail = None, **kwargs): connection: Connection details object """ - super(ConnectionResponse, self).__init__(**kwargs) + super().__init__(**kwargs) self.connection = connection @@ -43,5 +43,6 @@ class Meta: model_class = ConnectionResponse signed_fields = ("connection",) + unknown = EXCLUDE connection = fields.Nested(ConnectionDetailSchema, required=True) diff --git a/aries_cloudagent/protocols/connections/v1_0/messages/problem_report.py b/aries_cloudagent/protocols/connections/v1_0/messages/problem_report.py index 872ccacf6e..832ddf85a2 100644 --- a/aries_cloudagent/protocols/connections/v1_0/messages/problem_report.py +++ b/aries_cloudagent/protocols/connections/v1_0/messages/problem_report.py @@ -1,7 +1,7 @@ """Represents a connection problem report message.""" from enum import Enum -from marshmallow import fields, validate +from marshmallow import EXCLUDE, fields, validate from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -50,6 +50,7 @@ class Meta: """Metadata for problem report schema.""" model_class = ProblemReport + unknown = EXCLUDE explain = fields.Str( required=False, diff --git a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py index 60787b0d66..cdbac51cd7 100644 --- a/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py +++ b/aries_cloudagent/protocols/connections/v1_0/models/connection_detail.py @@ -1,6 +1,6 @@ """An object for containing the connection request/response DID information.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....connections.models.diddoc import DIDDoc from .....messaging.models.base import BaseModel, BaseModelSchema @@ -87,7 +87,8 @@ class ConnectionDetailSchema(BaseModelSchema): class Meta: """ConnectionDetailSchema metadata.""" - model_class = "ConnectionDetail" + model_class = ConnectionDetail + unknown = EXCLUDE did = fields.Str( data_key="DID", diff --git a/aries_cloudagent/protocols/connections/v1_0/routes.py b/aries_cloudagent/protocols/connections/v1_0/routes.py index d1b9787e76..a80cf818eb 100644 --- a/aries_cloudagent/protocols/connections/v1_0/routes.py +++ b/aries_cloudagent/protocols/connections/v1_0/routes.py @@ -11,13 +11,14 @@ response_schema, ) -from marshmallow import fields, Schema, validate, validates_schema +from marshmallow import fields, validate, validates_schema from ....connections.models.connection_record import ( ConnectionRecord, ConnectionRecordSchema, ) from ....messaging.models.base import BaseModelError +from ....messaging.models.openapi import OpenAPISchema from ....messaging.valid import ( ENDPOINT, INDY_DID, @@ -35,7 +36,7 @@ ) -class ConnectionListSchema(Schema): +class ConnectionListSchema(OpenAPISchema): """Result schema for connection list.""" results = fields.List( @@ -52,7 +53,7 @@ def validate_fields(self, data, **kwargs): """Bypass middleware field validation.""" -class InvitationResultSchema(Schema): +class InvitationResultSchema(OpenAPISchema): """Result schema for a new connection invitation.""" connection_id = fields.Str( @@ -65,7 +66,7 @@ class InvitationResultSchema(Schema): ) -class ConnectionStaticRequestSchema(Schema): +class ConnectionStaticRequestSchema(OpenAPISchema): """Request schema for a new static connection.""" my_seed = fields.Str(description="Seed to use for the local DID", required=False) @@ -89,7 +90,7 @@ class ConnectionStaticRequestSchema(Schema): alias = fields.Str(description="Alias to assign to this connection", required=False) -class ConnectionStaticResultSchema(Schema): +class ConnectionStaticResultSchema(OpenAPISchema): """Result schema for new static connection.""" my_did = fields.Str(description="Local DID", required=True, **INDY_DID) @@ -104,7 +105,7 @@ class ConnectionStaticResultSchema(Schema): record = fields.Nested(ConnectionRecordSchema, required=True) -class ConnectionsListQueryStringSchema(Schema): +class ConnectionsListQueryStringSchema(OpenAPISchema): """Parameters and validators for connections list request query string.""" alias = fields.Str(description="Alias", required=False, example="Barry",) @@ -136,7 +137,7 @@ class ConnectionsListQueryStringSchema(Schema): ) -class CreateInvitationQueryStringSchema(Schema): +class CreateInvitationQueryStringSchema(OpenAPISchema): """Parameters and validators for create invitation request query string.""" alias = fields.Str(description="Alias", required=False, example="Barry",) @@ -152,7 +153,7 @@ class CreateInvitationQueryStringSchema(Schema): ) -class ReceiveInvitationQueryStringSchema(Schema): +class ReceiveInvitationQueryStringSchema(OpenAPISchema): """Parameters and validators for receive invitation request query string.""" alias = fields.Str(description="Alias", required=False, example="Barry",) @@ -162,7 +163,7 @@ class ReceiveInvitationQueryStringSchema(Schema): ) -class AcceptInvitationQueryStringSchema(Schema): +class AcceptInvitationQueryStringSchema(OpenAPISchema): """Parameters and validators for accept invitation request query string.""" my_endpoint = fields.Str(description="My URL endpoint", required=False, **ENDPOINT) @@ -171,13 +172,13 @@ class AcceptInvitationQueryStringSchema(Schema): ) -class AcceptRequestQueryStringSchema(Schema): +class AcceptRequestQueryStringSchema(OpenAPISchema): """Parameters and validators for accept conn-request web-request query string.""" my_endpoint = fields.Str(description="My URL endpoint", required=False, **ENDPOINT) -class ConnIdMatchInfoSchema(Schema): +class ConnIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking connection id.""" conn_id = fields.Str( @@ -185,7 +186,7 @@ class ConnIdMatchInfoSchema(Schema): ) -class ConnIdRefIdMatchInfoSchema(Schema): +class ConnIdRefIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking connection and ref ids.""" conn_id = fields.Str( diff --git a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py index ed235bc818..bb83673086 100644 --- a/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/connections/v1_0/tests/test_routes.py @@ -15,6 +15,8 @@ class TestConnectionRoutes(AsyncTestCase): async def test_connections_list(self): context = RequestContext(base_context=InjectionContext(enforce_typing=False)) + context.default_endpoint = "http://1.2.3.4:8081" # for coverage + assert context.default_endpoint == "http://1.2.3.4:8081" # for coverage mock_req = async_mock.MagicMock() mock_req.app = { "request_context": context, diff --git a/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py b/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py index ac43982be6..963ecda892 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py +++ b/aries_cloudagent/protocols/discovery/v1_0/messages/disclose.py @@ -2,7 +2,7 @@ from typing import Mapping, Sequence -from marshmallow import fields, Schema, validate +from marshmallow import EXCLUDE, fields, Schema, validate from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -28,7 +28,7 @@ def __init__(self, *, protocols: Sequence[Mapping[str, Mapping]] = None, **kwarg Args: protocols: A mapping of protocol names to a dictionary of properties """ - super(Disclose, self).__init__(**kwargs) + super().__init__(**kwargs) self.protocols = list(protocols) if protocols else [] @@ -55,6 +55,7 @@ class Meta: """DiscloseSchema metadata.""" model_class = Disclose + unknown = EXCLUDE protocols = fields.List( fields.Nested(ProtocolDescriptorSchema()), diff --git a/aries_cloudagent/protocols/discovery/v1_0/messages/query.py b/aries_cloudagent/protocols/discovery/v1_0/messages/query.py index eba7a2c3d6..530ef99ae0 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/messages/query.py +++ b/aries_cloudagent/protocols/discovery/v1_0/messages/query.py @@ -1,10 +1,10 @@ """Represents a feature discovery query message.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema -from ..message_types import QUERY, PROTOCOL_PACKAGE +from ..message_types import PROTOCOL_PACKAGE, QUERY HANDLER_CLASS = f"{PROTOCOL_PACKAGE}.handlers.query_handler.QueryHandler" @@ -30,7 +30,7 @@ def __init__(self, *, query: str = None, comment: str = None, **kwargs): query: The query string to match against supported message types comment: An optional comment """ - super(Query, self).__init__(**kwargs) + super().__init__(**kwargs) self.query = query self.comment = comment @@ -42,6 +42,7 @@ class Meta: """QuerySchema metadata.""" model_class = Query + unknown = EXCLUDE query = fields.Str(required=True) - comment = fields.Str(required=False) + comment = fields.Str(required=False, allow_none=True) diff --git a/aries_cloudagent/protocols/discovery/v1_0/routes.py b/aries_cloudagent/protocols/discovery/v1_0/routes.py index 949cfe5567..b3c864e003 100644 --- a/aries_cloudagent/protocols/discovery/v1_0/routes.py +++ b/aries_cloudagent/protocols/discovery/v1_0/routes.py @@ -2,14 +2,15 @@ from aiohttp import web from aiohttp_apispec import docs, querystring_schema, response_schema -from marshmallow import fields, Schema +from marshmallow import fields from ....core.protocol_registry import ProtocolRegistry +from ....messaging.models.openapi import OpenAPISchema from .message_types import SPEC_URI -class QueryResultSchema(Schema): +class QueryResultSchema(OpenAPISchema): """Result schema for the protocol list.""" results = fields.Dict( @@ -19,7 +20,7 @@ class QueryResultSchema(Schema): ) -class QueryFeaturesQueryStringSchema(Schema): +class QueryFeaturesQueryStringSchema(OpenAPISchema): """Query string parameters for feature query.""" query = fields.Str(description="Query", required=False, example="did:sov:*") diff --git a/aries_cloudagent/protocols/introduction/v0_1/messages/forward_invitation.py b/aries_cloudagent/protocols/introduction/v0_1/messages/forward_invitation.py index 975582fe10..e5cabf832a 100644 --- a/aries_cloudagent/protocols/introduction/v0_1/messages/forward_invitation.py +++ b/aries_cloudagent/protocols/introduction/v0_1/messages/forward_invitation.py @@ -36,7 +36,7 @@ def __init__( invitation: The connection invitation message: Comments on the introduction """ - super(ForwardInvitation, self).__init__(**kwargs) + super().__init__(**kwargs) self.invitation = invitation self.message = message diff --git a/aries_cloudagent/protocols/introduction/v0_1/messages/invitation.py b/aries_cloudagent/protocols/introduction/v0_1/messages/invitation.py index f42555887e..38382b739e 100644 --- a/aries_cloudagent/protocols/introduction/v0_1/messages/invitation.py +++ b/aries_cloudagent/protocols/introduction/v0_1/messages/invitation.py @@ -1,6 +1,6 @@ """Represents an invitation returned to the introduction service.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....protocols.connections.v1_0.messages.connection_invitation import ( @@ -34,7 +34,7 @@ def __init__( invitation: The connection invitation message: Comments on the introduction """ - super(Invitation, self).__init__(**kwargs) + super().__init__(**kwargs) self.invitation = invitation self.message = message @@ -46,6 +46,7 @@ class Meta: """Invitation request schema metadata.""" model_class = Invitation + unknown = EXCLUDE invitation = fields.Nested(ConnectionInvitationSchema(), required=True) message = fields.Str( diff --git a/aries_cloudagent/protocols/introduction/v0_1/messages/invitation_request.py b/aries_cloudagent/protocols/introduction/v0_1/messages/invitation_request.py index bc0c88c829..60f0c9ffe7 100644 --- a/aries_cloudagent/protocols/introduction/v0_1/messages/invitation_request.py +++ b/aries_cloudagent/protocols/introduction/v0_1/messages/invitation_request.py @@ -1,6 +1,6 @@ """Represents an request for an invitation from the introduction service.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -29,7 +29,7 @@ def __init__(self, *, responder: str = None, message: str = None, **kwargs): responder: The name of the agent initiating the introduction message: Comments on the introduction """ - super(InvitationRequest, self).__init__(**kwargs) + super().__init__(**kwargs) self.responder = responder self.message = message @@ -41,6 +41,7 @@ class Meta: """Invitation request schema metadata.""" model_class = InvitationRequest + unknown = EXCLUDE responder = fields.Str( required=True, diff --git a/aries_cloudagent/protocols/introduction/v0_1/routes.py b/aries_cloudagent/protocols/introduction/v0_1/routes.py index 02d2bfeee9..6424b0551f 100644 --- a/aries_cloudagent/protocols/introduction/v0_1/routes.py +++ b/aries_cloudagent/protocols/introduction/v0_1/routes.py @@ -5,8 +5,9 @@ from aiohttp import web from aiohttp_apispec import docs, match_info_schema, querystring_schema -from marshmallow import fields, Schema +from marshmallow import fields +from ....messaging.models.openapi import OpenAPISchema from ....messaging.valid import UUIDFour from ....storage.error import StorageError @@ -15,7 +16,7 @@ LOGGER = logging.getLogger(__name__) -class IntroStartQueryStringSchema(Schema): +class IntroStartQueryStringSchema(OpenAPISchema): """Query string parameters for request to start introduction.""" target_connection_id = fields.Str( @@ -28,7 +29,7 @@ class IntroStartQueryStringSchema(Schema): ) -class ConnIdMatchInfoSchema(Schema): +class ConnIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking connection id.""" conn_id = fields.Str( diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/manager.py b/aries_cloudagent/protocols/issue_credential/v1_0/manager.py index 28de4f3b95..cf71ffa407 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/manager.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/manager.py @@ -1,5 +1,6 @@ """Classes to manage credentials.""" +import asyncio import json import logging from typing import Mapping, Sequence, Text, Tuple @@ -453,7 +454,11 @@ async def receive_request(self): return cred_ex_record async def issue_credential( - self, cred_ex_record: V10CredentialExchange, *, comment: str = None, + self, + cred_ex_record: V10CredentialExchange, + *, + comment: str = None, + retries: int = 5, ) -> Tuple[V10CredentialExchange, CredentialIssue]: """ Issue a credential. @@ -494,24 +499,40 @@ async def issue_credential( cred_ex_record.credential_definition_id ) + tails_path = None if credential_definition["value"].get("revocation"): - issuer_rev_regs = await IssuerRevRegRecord.query_by_cred_def_id( + staged_rev_regs = await IssuerRevRegRecord.query_by_cred_def_id( self.context, cred_ex_record.credential_definition_id, - state=IssuerRevRegRecord.STATE_ACTIVE, + state=IssuerRevRegRecord.STATE_STAGED, ) - if not issuer_rev_regs: - raise CredentialManagerError( - "Cred def id {} has no active revocation registry".format( - cred_ex_record.credential_definition_id - ) + + if staged_rev_regs and retries > 0: + # We know there is a staged registry that will be ready soon. + # So we wait and retry. + await asyncio.sleep(1) + return await self.issue_credential( + cred_ex_record=cred_ex_record, + comment=comment, + retries=retries - 1, + ) + else: + active_rev_regs = await IssuerRevRegRecord.query_by_cred_def_id( + self.context, + cred_ex_record.credential_definition_id, + state=IssuerRevRegRecord.STATE_ACTIVE, ) + if not active_rev_regs: + raise CredentialManagerError( + "Cred def id {} has no active revocation registry".format( + cred_ex_record.credential_definition_id + ) + ) - registry = await issuer_rev_regs[0].get_registry() - cred_ex_record.revoc_reg_id = issuer_rev_regs[0].revoc_reg_id - tails_path = registry.tails_local_path - else: - tails_path = None + active_reg = active_rev_regs[0] + registry = await active_reg.get_registry() + cred_ex_record.revoc_reg_id = active_reg.revoc_reg_id + tails_path = registry.tails_local_path credential_values = CredentialProposal.deserialize( cred_ex_record.credential_proposal_dict @@ -529,14 +550,77 @@ async def issue_credential( cred_ex_record.revoc_reg_id, tails_path, ) + + # If the revocation registry is full if registry and registry.max_creds == int( cred_ex_record.revocation_id # monotonic "1"-based ): - await issuer_rev_regs[0].mark_full(self.context) + # Check to see if we have a registry record staged and waiting + pending_rev_regs = await IssuerRevRegRecord.query_by_cred_def_id( + self.context, + cred_ex_record.credential_definition_id, + state=IssuerRevRegRecord.STATE_PUBLISHED, + ) + if pending_rev_regs: + pending_rev_reg = pending_rev_regs[0] + pending_rev_reg.state = IssuerRevRegRecord.STATE_STAGED + await pending_rev_reg.save( + self.context, reason="revocation registry staged" + ) + + # Make it active + await pending_rev_reg.publish_registry_entry(self.context) + # Kick off a task to create and publish the next revocation + # registry in the background. It is assumed that the size of + # the registry is large enough so that this completes before + # the current registry is full + revoc = IndyRevocation(self.context) + pending_registry_record = await revoc.init_issuer_registry( + active_reg.cred_def_id, + active_reg.issuer_did, + max_cred_num=active_reg.max_cred_num, + ) + asyncio.ensure_future( + pending_registry_record.stage_pending_registry_definition( + self.context + ) + ) + + # Make the current registry full + await active_reg.mark_full(self.context) except IssuerRevocationRegistryFullError: - await issuer_rev_regs[0].mark_full(self.context) - raise + active_rev_regs = await IssuerRevRegRecord.query_by_cred_def_id( + self.context, + cred_ex_record.credential_definition_id, + state=IssuerRevRegRecord.STATE_ACTIVE, + ) + staged_rev_regs = await IssuerRevRegRecord.query_by_cred_def_id( + self.context, + cred_ex_record.credential_definition_id, + state=IssuerRevRegRecord.STATE_STAGED, + ) + published_rev_regs = await IssuerRevRegRecord.query_by_cred_def_id( + self.context, + cred_ex_record.credential_definition_id, + state=IssuerRevRegRecord.STATE_PUBLISHED, + ) + + if ( + staged_rev_regs or active_rev_regs or published_rev_regs + ) and retries > 0: + + # We know there is a staged registry that will be ready soon. + # So we wait and retry. + await asyncio.sleep(1) + return await self.issue_credential( + cred_ex_record=cred_ex_record, + comment=comment, + retries=retries - 1, + ) + else: + await active_reg.mark_full(self.context) + raise cred_ex_record.credential = json.loads(credential_json) diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_ack.py b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_ack.py index 0eac4b2556..05e0af8792 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_ack.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_ack.py @@ -1,8 +1,10 @@ """A credential ack message.""" +from marshmallow import EXCLUDE + from .....messaging.ack.message import Ack, AckSchema -from ..message_types import CREDENTIAL_ACK, PROTOCOL_PACKAGE +from ..message_types import CREDENTIAL_ACK, PROTOCOL_PACKAGE HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.credential_ack_handler.CredentialAckHandler" @@ -31,3 +33,4 @@ class Meta: """Schema metadata.""" model_class = CredentialAck + unknown = EXCLUDE diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_issue.py b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_issue.py index 6a496c45c9..94338e53c0 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_issue.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_issue.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.decorators.attach_decorator import ( @@ -12,7 +12,6 @@ from ..message_types import ATTACH_DECO_IDS, CREDENTIAL_ISSUE, PROTOCOL_PACKAGE - HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.credential_issue_handler.CredentialIssueHandler" ) @@ -74,8 +73,11 @@ class Meta: """Credential schema metadata.""" model_class = CredentialIssue + unknown = EXCLUDE - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) credentials_attach = fields.Nested( AttachDecoratorSchema, required=True, many=True, data_key="credentials~attach" ) diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_offer.py b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_offer.py index 6d36513e7d..4cd36dfe94 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_offer.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_offer.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.decorators.attach_decorator import ( @@ -13,7 +13,6 @@ from ..message_types import ATTACH_DECO_IDS, CREDENTIAL_OFFER, PROTOCOL_PACKAGE from .inner.credential_preview import CredentialPreview, CredentialPreviewSchema - HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.credential_offer_handler.CredentialOfferHandler" ) @@ -78,8 +77,11 @@ class Meta: """Credential offer schema metadata.""" model_class = CredentialOffer + unknown = EXCLUDE - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) credential_preview = fields.Nested(CredentialPreviewSchema, required=False) offers_attach = fields.Nested( AttachDecoratorSchema, required=True, many=True, data_key="offers~attach" diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_proposal.py b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_proposal.py index 7decaeea10..1b5cc23e71 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_proposal.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_proposal.py @@ -1,6 +1,6 @@ """A credential proposal content message.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.valid import ( @@ -14,7 +14,6 @@ from .inner.credential_preview import CredentialPreview, CredentialPreviewSchema - HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers." "credential_proposal_handler.CredentialProposalHandler" @@ -76,8 +75,11 @@ class Meta: """Credential proposal schema metadata.""" model_class = CredentialProposal + unknown = EXCLUDE - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) credential_proposal = fields.Nested( CredentialPreviewSchema, required=False, allow_none=False ) diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_request.py b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_request.py index a59e4964ad..03d7849371 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_request.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/messages/credential_request.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.decorators.attach_decorator import ( @@ -12,7 +12,6 @@ from ..message_types import ATTACH_DECO_IDS, CREDENTIAL_REQUEST, PROTOCOL_PACKAGE - HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers." "credential_request_handler.CredentialRequestHandler" @@ -75,8 +74,11 @@ class Meta: """Credential request schema metadata.""" model_class = CredentialRequest + unknown = EXCLUDE - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) requests_attach = fields.Nested( AttachDecoratorSchema, required=True, many=True, data_key="requests~attach" ) diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/messages/inner/credential_preview.py b/aries_cloudagent/protocols/issue_credential/v1_0/messages/inner/credential_preview.py index 84f1fcddac..893a570b21 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/messages/inner/credential_preview.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/messages/inner/credential_preview.py @@ -3,7 +3,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from ......messaging.models.base import BaseModel, BaseModelSchema from ......wallet.util import b64_to_str @@ -74,6 +74,7 @@ class Meta: """Attribute preview schema metadata.""" model_class = CredAttrSpec + unknown = EXCLUDE name = fields.Str( description="Attribute name", required=True, example="favourite_drink" @@ -163,6 +164,7 @@ class Meta: """Credential preview schema metadata.""" model_class = CredentialPreview + unknown = EXCLUDE _type = fields.Str( description="Message type identifier", diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/routes.py b/aries_cloudagent/protocols/issue_credential/v1_0/routes.py index 69a03dd6b1..55ae95ed1e 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/routes.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/routes.py @@ -11,13 +11,13 @@ response_schema, ) from json.decoder import JSONDecodeError -from marshmallow import fields, Schema, validate +from marshmallow import fields, validate from ....connections.models.connection_record import ConnectionRecord from ....issuer.base import IssuerError from ....ledger.error import LedgerError from ....messaging.credential_definitions.util import CRED_DEF_TAGS -from ....messaging.models.base import BaseModelError +from ....messaging.models.base import BaseModelError, OpenAPISchema from ....messaging.valid import ( INDY_CRED_DEF_ID, INDY_CRED_REV_ID, @@ -34,6 +34,7 @@ from ....wallet.base import BaseWallet from ....wallet.error import WalletError from ....utils.outofband import serialize_outofband +from ....utils.tracing import trace_event, get_timer, AdminAPIMessageTracingSchema from ...problem_report.v1_0 import internal_error from ...problem_report.v1_0.message import ProblemReport @@ -51,10 +52,8 @@ V10CredentialExchangeSchema, ) -from ....utils.tracing import trace_event, get_timer, AdminAPIMessageTracingSchema - -class V10CredentialExchangeListQueryStringSchema(Schema): +class V10CredentialExchangeListQueryStringSchema(OpenAPISchema): """Parameters and validators for credential exchange list query.""" connection_id = fields.UUID( @@ -91,7 +90,7 @@ class V10CredentialExchangeListQueryStringSchema(Schema): ) -class V10CredentialExchangeListResultSchema(Schema): +class V10CredentialExchangeListResultSchema(OpenAPISchema): """Result schema for Aries#0036 v1.0 credential exchange query.""" results = fields.List( @@ -100,7 +99,7 @@ class V10CredentialExchangeListResultSchema(Schema): ) -class V10CredentialStoreRequestSchema(Schema): +class V10CredentialStoreRequestSchema(OpenAPISchema): """Request schema for sending a credential store admin message.""" credential_id = fields.Str(required=False) @@ -136,7 +135,9 @@ class V10CredentialCreateSchema(AdminAPIMessageTracingSchema): ), required=False, ) - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) trace = fields.Bool( description="Whether to trace event (default false)", required=False, @@ -180,7 +181,9 @@ class V10CredentialProposalRequestSchemaBase(AdminAPIMessageTracingSchema): ), required=False, ) - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) trace = fields.Bool( description="Whether to trace event (default false)", required=False, @@ -228,7 +231,9 @@ class V10CredentialOfferRequestSchema(AdminAPIMessageTracingSchema): required=False, default=True, ) - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) credential_preview = fields.Nested(CredentialPreviewSchema, required=True) trace = fields.Bool( description="Whether to trace event (default false)", @@ -237,19 +242,21 @@ class V10CredentialOfferRequestSchema(AdminAPIMessageTracingSchema): ) -class V10CredentialIssueRequestSchema(Schema): +class V10CredentialIssueRequestSchema(OpenAPISchema): """Request schema for sending credential issue admin message.""" - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) -class V10CredentialProblemReportRequestSchema(Schema): +class V10CredentialProblemReportRequestSchema(OpenAPISchema): """Request schema for sending problem report.""" explain_ltxt = fields.Str(required=True) -class V10PublishRevocationsSchema(Schema): +class V10PublishRevocationsSchema(OpenAPISchema): """Request and result schema for revocation publication API call.""" rrid2crid = fields.Dict( @@ -264,7 +271,7 @@ class V10PublishRevocationsSchema(Schema): ) -class V10ClearPendingRevocationsRequestSchema(Schema): +class V10ClearPendingRevocationsRequestSchema(OpenAPISchema): """Request schema for clear pending revocations API call.""" purge = fields.Dict( @@ -282,7 +289,7 @@ class V10ClearPendingRevocationsRequestSchema(Schema): ) -class RevokeQueryStringSchema(Schema): +class RevokeQueryStringSchema(OpenAPISchema): """Parameters and validators for revocation request.""" rev_reg_id = fields.Str( @@ -300,7 +307,7 @@ class RevokeQueryStringSchema(Schema): ) -class CredIdMatchInfoSchema(Schema): +class CredIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking credential id.""" credential_id = fields.Str( @@ -308,7 +315,7 @@ class CredIdMatchInfoSchema(Schema): ) -class CredExIdMatchInfoSchema(Schema): +class CredExIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking credential exchange id.""" cred_ex_id = fields.Str( @@ -747,9 +754,7 @@ async def credential_exchange_create_free_offer(request: web.BaseRequest): perf_counter=r_time, ) - oob_url = serialize_outofband( - context, credential_offer_message, conn_did, endpoint - ) + oob_url = serialize_outofband(credential_offer_message, conn_did, endpoint) result = cred_ex_record.serialize() except (BaseModelError, CredentialManagerError, LedgerError) as err: await internal_error( diff --git a/aries_cloudagent/protocols/issue_credential/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/issue_credential/v1_0/tests/test_manager.py index 740731d2a6..141b16d9f4 100644 --- a/aries_cloudagent/protocols/issue_credential/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/issue_credential/v1_0/tests/test_manager.py @@ -1,3 +1,4 @@ +import asyncio import json from asynctest import TestCase as AsyncTestCase @@ -885,6 +886,10 @@ async def test_issue_credential(self): with async_mock.patch.object( test_module, "IssuerRevRegRecord", autospec=True ) as issuer_rr_rec, async_mock.patch.object( + test_module, "IndyRevocation", autospec=True + ) as revoc, async_mock.patch.object( + asyncio, "ensure_future", autospec=True + ) as asyncio_mock, async_mock.patch.object( V10CredentialExchange, "save", autospec=True ) as save_ex: issuer_rr_rec.query_by_cred_def_id = async_mock.CoroutineMock( @@ -897,11 +902,13 @@ async def test_issue_credential(self): ), mark_full=async_mock.CoroutineMock(), revoc_reg_id=REV_REG_ID, + save=async_mock.CoroutineMock(), + publish_registry_entry=async_mock.CoroutineMock(), ) ] ) (ret_exchange, ret_cred_issue) = await self.manager.issue_credential( - stored_exchange, comment=comment + stored_exchange, comment=comment, retries=1 ) save_ex.assert_called_once() @@ -921,7 +928,9 @@ async def test_issue_credential(self): ( ret_existing_exchange, ret_existing_cred, - ) = await self.manager.issue_credential(stored_exchange, comment=comment) + ) = await self.manager.issue_credential( + stored_exchange, comment=comment, retries=0 + ) assert ret_existing_exchange == ret_exchange assert ret_existing_cred._thread_id == thread_id @@ -975,7 +984,7 @@ async def test_issue_credential_non_revocable(self): V10CredentialExchange, "save", autospec=True ) as save_ex: (ret_exchange, ret_cred_issue) = await self.manager.issue_credential( - stored_exchange, comment=comment + stored_exchange, comment=comment, retries=0 ) save_ex.assert_called_once() @@ -1053,7 +1062,9 @@ async def test_issue_credential_no_active_rr(self): return_value=[] ) with self.assertRaises(CredentialManagerError) as x_cred_mgr: - await self.manager.issue_credential(stored_exchange, comment=comment) + await self.manager.issue_credential( + stored_exchange, comment=comment, retries=0 + ) assert "has no active revocation registry" in x_cred_mgr.message async def test_issue_credential_rr_full(self): @@ -1108,7 +1119,109 @@ async def test_issue_credential_rr_full(self): ) with self.assertRaises(test_module.IssuerRevocationRegistryFullError): - await self.manager.issue_credential(stored_exchange, comment=comment) + await self.manager.issue_credential( + stored_exchange, comment=comment, retries=0 + ) + + async def test_issue_credential_rr_full_rr_staged_retry(self): + connection_id = "test_conn_id" + comment = "comment" + cred_values = {"attr": "value"} + indy_offer = {"schema_id": SCHEMA_ID, "cred_def_id": CRED_DEF_ID, "nonce": "0"} + indy_cred_req = {"schema_id": SCHEMA_ID, "cred_def_id": CRED_DEF_ID} + thread_id = "thread-id" + revocation_id = 1 + + stored_exchange = V10CredentialExchange( + credential_exchange_id="dummy-cxid", + connection_id=connection_id, + credential_definition_id=CRED_DEF_ID, + credential_offer=indy_offer, + credential_request=indy_cred_req, + credential_proposal_dict=CredentialProposal( + credential_proposal=CredentialPreview.deserialize( + {"attributes": [{"name": "attr", "value": "value"}]} + ), + cred_def_id=CRED_DEF_ID, + schema_id=SCHEMA_ID, + ).serialize(), + initiator=V10CredentialExchange.INITIATOR_SELF, + role=V10CredentialExchange.ROLE_ISSUER, + state=V10CredentialExchange.STATE_REQUEST_RECEIVED, + thread_id=thread_id, + revocation_id=revocation_id, + ) + + stored_exchange.save = async_mock.CoroutineMock() + + issuer = async_mock.MagicMock() + cred = {"indy": "credential"} + cred_rev_id = "1" + issuer.create_credential = async_mock.CoroutineMock( + side_effect=[ + test_module.IssuerRevocationRegistryFullError("it's full"), + (json.dumps({"good": "credential"}), str(revocation_id)), + ] + ) + self.context.injector.bind_instance(BaseIssuer, issuer) + + active_full_reg = [ + async_mock.MagicMock( + get_registry=async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + tails_local_path="dummy-path", max_creds=revocation_id + ) + ), + revoc_reg_id=REV_REG_ID, + mark_full=async_mock.CoroutineMock(), + ) + ] + + active_non_full_reg = [ + async_mock.MagicMock( + get_registry=async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + tails_local_path="dummy-path", max_creds=1000 + ) + ), + revoc_reg_id=REV_REG_ID, + mark_full=async_mock.CoroutineMock(), + ) + ] + + pending_reg = [ + async_mock.MagicMock( + get_registry=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(tails_local_path="dummy-path") + ), + revoc_reg_id=REV_REG_ID, + mark_full=async_mock.CoroutineMock(), + state="published", + save=async_mock.CoroutineMock(), + publish_registry_entry=async_mock.CoroutineMock(), + ) + ] + + with async_mock.patch.object( + test_module, "IssuerRevRegRecord", autospec=True + ) as issuer_rr_rec, async_mock.patch.object( + test_module, "IndyRevocation", autospec=True + ) as revoc, async_mock.patch.object( + asyncio, "ensure_future", autospec=True + ) as asyncio_mock: + issuer_rr_rec.query_by_cred_def_id.side_effect = [ + [], # First call checking for staged registries + active_full_reg, # Get active full registry + [], # Get active full registry after rev-reg-full + [], # Get staged + pending_reg, # Get published + [], # Get staged, on retry pass + active_non_full_reg, # Get active + ] + + await self.manager.issue_credential( + stored_exchange, comment=comment, retries=1 + ) async def test_receive_credential(self): connection_id = "test_conn_id" diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py index 532d7fcd0c..aa3cb14435 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/manager.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/manager.py @@ -6,6 +6,7 @@ from ....config.injection_context import InjectionContext from ....core.error import BaseError from ....ledger.base import BaseLedger +from ....wallet.util import did_key_to_naked, naked_to_did_key from ....protocols.connections.v1_0.manager import ConnectionManager from ....protocols.connections.v1_0.messages.connection_invitation import ( ConnectionInvitation, @@ -142,8 +143,14 @@ async def create_invitation( service = ServiceMessage( _id="#inline", _type="did-communication", - recipient_keys=connection_invitation.recipient_keys, - routing_keys=connection_invitation.routing_keys, + recipient_keys=[ + naked_to_did_key(key) + for key in connection_invitation.recipient_keys or [] + ], + routing_keys=[ + naked_to_did_key(key) + for key in connection_invitation.routing_keys or [] + ], service_endpoint=connection_invitation.endpoint, ).validate() @@ -196,17 +203,19 @@ async def receive_invitation( # Get the single service item if invitation_message.service_blocks: service = invitation_message.service_blocks[0] + else: # If it's in the did format, we need to convert to a full service block service_did = invitation_message.service_dids[0] async with ledger: verkey = await ledger.get_key_for_did(service_did) + did_key = naked_to_did_key(verkey) endpoint = await ledger.get_endpoint_for_did(service_did) service = ServiceMessage.deserialize( { "id": "#inline", "type": "did-communication", - "recipientKeys": [verkey], + "recipientKeys": [did_key], "routingKeys": [], "serviceEndpoint": endpoint, } @@ -224,6 +233,14 @@ async def receive_invitation( "request block must be empty for invitation message type." ) + # Transform back to 'naked' verkey + service.recipient_keys = [ + did_key_to_naked(key) for key in service.recipient_keys or [] + ] + service.routing_keys = [ + did_key_to_naked(key) for key in service.routing_keys + ] or [] + # Convert to the old message format connection_invitation = ConnectionInvitation.deserialize( { diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py index 780f55bfef..396c7d7405 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/invitation.py @@ -2,7 +2,14 @@ from typing import Sequence, Text, Union -from marshmallow import fields, validates_schema, ValidationError, pre_load, post_dump +from marshmallow import ( + EXCLUDE, + fields, + post_dump, + pre_load, + validates_schema, + ValidationError, +) from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.decorators.attach_decorator import ( @@ -83,6 +90,7 @@ class Meta: """Invitation schema metadata.""" model_class = Invitation + unknown = EXCLUDE label = fields.Str(required=False, description="Optional label", example="Bob") handshake_protocols = fields.List(fields.String, required=False, many=True) diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py b/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py index 1393a43609..f10f53ee29 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/messages/service.py @@ -2,10 +2,10 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema -from .....messaging.valid import INDY_DID, INDY_RAW_PUBLIC_KEY +from .....messaging.valid import INDY_DID, DID_KEY class Service(BaseModel): @@ -52,20 +52,21 @@ class Meta: """ServiceSchema metadata.""" model_class = Service + unknown = EXCLUDE _id = fields.Str(required=True, description="", data_key="id") _type = fields.Str(required=True, description="", data_key="type") did = fields.Str(required=False, description="", **INDY_DID) recipient_keys = fields.List( - fields.Str(description="Recipient public key", **INDY_RAW_PUBLIC_KEY), + fields.Str(description="Recipient public key", **DID_KEY), data_key="recipientKeys", required=False, description="List of recipient keys", ) routing_keys = fields.List( - fields.Str(description="Routing key", **INDY_RAW_PUBLIC_KEY), + fields.Str(description="Routing key", **DID_KEY), data_key="routingKeys", required=False, description="List of routing keys", diff --git a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py index 3ee436bbee..a185414913 100644 --- a/aries_cloudagent/protocols/out_of_band/v1_0/routes.py +++ b/aries_cloudagent/protocols/out_of_band/v1_0/routes.py @@ -5,9 +5,10 @@ from aiohttp import web from aiohttp_apispec import docs, request_schema -from marshmallow import fields, Schema +from marshmallow import fields from marshmallow.exceptions import ValidationError +from ....messaging.models.openapi import OpenAPISchema from ....storage.error import StorageNotFoundError from .manager import OutOfBandManager, OutOfBandManagerError @@ -17,10 +18,10 @@ LOGGER = logging.getLogger(__name__) -class InvitationCreateRequestSchema(Schema): +class InvitationCreateRequestSchema(OpenAPISchema): """Invitation create request Schema.""" - class AttachmentDefSchema(Schema): + class AttachmentDefSchema(OpenAPISchema): """Attachment Schema.""" _id = fields.String(data_key="id") @@ -31,7 +32,7 @@ class AttachmentDefSchema(Schema): use_public_did = fields.Boolean(default=False) -class InvitationSchema(InvitationSchema): +class InvitationReceiveRequestSchema(InvitationSchema): """Invitation Schema.""" service = fields.Field() @@ -78,7 +79,7 @@ async def invitation_create(request: web.BaseRequest): @docs( tags=["out-of-band"], summary="Create a new connection invitation", ) -@request_schema(InvitationSchema()) +@request_schema(InvitationReceiveRequestSchema()) async def invitation_receive(request: web.BaseRequest): """ Request handler for creating a new connection invitation. diff --git a/aries_cloudagent/protocols/present_proof/v1_0/manager.py b/aries_cloudagent/protocols/present_proof/v1_0/manager.py index a0878bf80b..a4741ccf7a 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/manager.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/manager.py @@ -326,11 +326,11 @@ async def create_presentation( # Get delta with non-revocation interval defined in "non_revoked" # of the presentation request or attributes - current_timestamp = int(time.time()) + epoch_now = int(time.time()) - non_revoc_interval = {"from": 0, "to": current_timestamp} + non_revoc_interval = {"from": 0, "to": epoch_now} non_revoc_interval.update( - presentation_exchange_record.presentation_request.get("non_revoked", {}) + presentation_exchange_record.presentation_request.get("non_revoked") or {} ) revoc_reg_deltas = {} @@ -348,14 +348,14 @@ async def create_presentation( if referent_non_revoc_interval: key = ( - f"{rev_reg_id}_{non_revoc_interval['from']}_" - f"{non_revoc_interval['to']}" + f"{rev_reg_id}_{referent_non_revoc_interval.get('from', 0)}_" + f"{referent_non_revoc_interval.get('to', epoch_now)}" ) if key not in revoc_reg_deltas: (delta, delta_timestamp) = await ledger.get_revoc_reg_delta( rev_reg_id, - non_revoc_interval["from"], - non_revoc_interval["to"], + referent_non_revoc_interval.get("from", 0), + referent_non_revoc_interval.get("to", epoch_now), ) revoc_reg_deltas[key] = ( rev_reg_id, diff --git a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation.py b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation.py index dae1b0d6f9..f1b1aa58f7 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation.py @@ -3,7 +3,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.decorators.attach_decorator import ( @@ -13,7 +13,6 @@ from ..message_types import PRESENTATION, PROTOCOL_PACKAGE - HANDLER_CLASS = f"{PROTOCOL_PACKAGE}.handlers.presentation_handler.PresentationHandler" @@ -68,8 +67,11 @@ class Meta: """Presentation schema metadata.""" model_class = Presentation + unknown = EXCLUDE - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) presentations_attach = fields.Nested( AttachDecoratorSchema, required=True, many=True, data_key="presentations~attach" ) diff --git a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_ack.py b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_ack.py index 5d72bcd93f..998ac80d18 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_ack.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_ack.py @@ -1,9 +1,10 @@ """Represents an explicit RFC 15 ack message, adopted into present-proof protocol.""" +from marshmallow import EXCLUDE + from .....messaging.ack.message import Ack, AckSchema from ..message_types import PRESENTATION_ACK, PROTOCOL_PACKAGE - HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers.presentation_ack_handler.PresentationAckHandler" ) @@ -37,3 +38,4 @@ class Meta: """PresentationAck schema metadata.""" model_class = PresentationAck + unknown = EXCLUDE diff --git a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_proposal.py b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_proposal.py index 70da18fd9c..a3e0e1373e 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_proposal.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_proposal.py @@ -1,6 +1,6 @@ """A presentation proposal content message.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -8,7 +8,6 @@ from .inner.presentation_preview import PresentationPreview, PresentationPreviewSchema - HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers." "presentation_proposal_handler.PresentationProposalHandler" @@ -52,6 +51,9 @@ class Meta: """Presentation proposal schema metadata.""" model_class = PresentationProposal + unknown = EXCLUDE - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) presentation_proposal = fields.Nested(PresentationPreviewSchema, required=True) diff --git a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_request.py b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_request.py index 0cdf7f589d..d8249f3646 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_request.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/messages/presentation_request.py @@ -3,7 +3,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema from .....messaging.decorators.attach_decorator import ( @@ -13,7 +13,6 @@ from ..message_types import PRESENTATION_REQUEST, PROTOCOL_PACKAGE - HANDLER_CLASS = ( f"{PROTOCOL_PACKAGE}.handlers." "presentation_request_handler.PresentationRequestHandler" @@ -71,8 +70,11 @@ class Meta: """Presentation request schema metadata.""" model_class = PresentationRequest + unknown = EXCLUDE - comment = fields.Str(description="Human-readable comment", required=False) + comment = fields.Str( + description="Human-readable comment", required=False, allow_none=True + ) request_presentations_attach = fields.Nested( AttachDecoratorSchema, required=True, diff --git a/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py b/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py index 2ba4d02fc8..8859afac21 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/models/tests/test_record.py @@ -1,8 +1,23 @@ from unittest import TestCase as UnitTestCase +from ......messaging.models.base_record import BaseExchangeRecord, BaseExchangeSchema + from ..presentation_exchange import V10PresentationExchange +class BasexRecordImpl(BaseExchangeRecord): + class Meta: + schema_class = "BasexRecordImplSchema" + + RECORD_TYPE = "record" + CACHE_ENABLED = True + + +class BasexRecordImplSchema(BaseExchangeSchema): + class Meta: + model_class = BasexRecordImpl + + class TestRecord(UnitTestCase): def test_record(self): record = V10PresentationExchange( @@ -37,3 +52,6 @@ def test_record(self): "verified": False, "trace": False, } + + bx_record = BasexRecordImpl() + assert record != bx_record diff --git a/aries_cloudagent/protocols/present_proof/v1_0/routes.py b/aries_cloudagent/protocols/present_proof/v1_0/routes.py index 883b2ef14f..c8a85f9bb1 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/routes.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/routes.py @@ -10,7 +10,7 @@ request_schema, response_schema, ) -from marshmallow import fields, Schema, validate, validates_schema +from marshmallow import fields, validate, validates_schema from marshmallow.exceptions import ValidationError from ....connections.models.connection_record import ConnectionRecord @@ -18,6 +18,7 @@ from ....ledger.error import LedgerError from ....messaging.decorators.attach_decorator import AttachDecorator from ....messaging.models.base import BaseModelError +from ....messaging.models.openapi import OpenAPISchema from ....messaging.valid import ( INDY_CRED_DEF_ID, INDY_DID, @@ -53,7 +54,7 @@ from ....utils.tracing import trace_event, get_timer, AdminAPIMessageTracingSchema -class V10PresentationExchangeListQueryStringSchema(Schema): +class V10PresentationExchangeListQueryStringSchema(OpenAPISchema): """Parameters and validators for presentation exchange list query.""" connection_id = fields.UUID( @@ -90,7 +91,7 @@ class V10PresentationExchangeListQueryStringSchema(Schema): ) -class V10PresentationExchangeListSchema(Schema): +class V10PresentationExchangeListSchema(OpenAPISchema): """Result schema for an Aries RFC 37 v1.0 presentation exchange query.""" results = fields.List( @@ -106,7 +107,7 @@ class V10PresentationProposalRequestSchema(AdminAPIMessageTracingSchema): description="Connection identifier", required=True, example=UUIDFour.EXAMPLE ) comment = fields.Str( - description="Human-readable comment", required=False, default="" + description="Human-readable comment", required=False, allow_none=True ) presentation_proposal = fields.Nested(PresentationPreviewSchema(), required=True) auto_present = fields.Boolean( @@ -124,7 +125,7 @@ class V10PresentationProposalRequestSchema(AdminAPIMessageTracingSchema): ) -class IndyProofReqPredSpecRestrictionsSchema(Schema): +class IndyProofReqPredSpecRestrictionsSchema(OpenAPISchema): """Schema for restrictions in attr or pred specifier indy proof request.""" schema_id = fields.String( @@ -149,7 +150,7 @@ class IndyProofReqPredSpecRestrictionsSchema(Schema): ) -class IndyProofReqNonRevokedSchema(Schema): +class IndyProofReqNonRevokedSchema(OpenAPISchema): """Non-revocation times specification in indy proof request.""" fro = fields.Int( @@ -182,7 +183,7 @@ def validate_fields(self, data, **kwargs): ) -class IndyProofReqAttrSpecSchema(Schema): +class IndyProofReqAttrSpecSchema(OpenAPISchema): """Schema for attribute specification in indy proof request.""" name = fields.String( @@ -245,7 +246,7 @@ def validate_fields(self, data, **kwargs): ) -class IndyProofReqPredSpecSchema(Schema): +class IndyProofReqPredSpecSchema(OpenAPISchema): """Schema for predicate specification in indy proof request.""" name = fields.String(example="index", description="Attribute name", required=True) @@ -263,7 +264,7 @@ class IndyProofReqPredSpecSchema(Schema): non_revoked = fields.Nested(IndyProofReqNonRevokedSchema(), required=False) -class IndyProofRequestSchema(Schema): +class IndyProofRequestSchema(OpenAPISchema): """Schema for indy proof request.""" nonce = fields.String(description="Nonce", required=False, example="1234567890") @@ -298,7 +299,7 @@ class V10PresentationCreateRequestRequestSchema(AdminAPIMessageTracingSchema): """Request schema for creating a proof request free of any connection.""" proof_request = fields.Nested(IndyProofRequestSchema(), required=True) - comment = fields.Str(required=False) + comment = fields.Str(required=False, allow_none=True) trace = fields.Bool( description="Whether to trace event (default false)", required=False, @@ -316,7 +317,7 @@ class V10PresentationSendRequestRequestSchema( ) -class IndyRequestedCredsRequestedAttrSchema(Schema): +class IndyRequestedCredsRequestedAttrSchema(OpenAPISchema): """Schema for requested attributes within indy requested credentials structure.""" cred_id = fields.Str( @@ -336,7 +337,7 @@ class IndyRequestedCredsRequestedAttrSchema(Schema): ) -class IndyRequestedCredsRequestedPredSchema(Schema): +class IndyRequestedCredsRequestedPredSchema(OpenAPISchema): """Schema for requested predicates within indy requested credentials structure.""" cred_id = fields.Str( @@ -393,7 +394,7 @@ class V10PresentationRequestSchema(AdminAPIMessageTracingSchema): ) -class CredentialsFetchQueryStringSchema(Schema): +class CredentialsFetchQueryStringSchema(OpenAPISchema): """Parameters and validators for credentials fetch request query string.""" referent = fields.Str( @@ -412,7 +413,7 @@ class CredentialsFetchQueryStringSchema(Schema): ) -class PresExIdMatchInfoSchema(Schema): +class PresExIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking presentation exchange id.""" pres_ex_id = fields.Str( diff --git a/aries_cloudagent/protocols/present_proof/v1_0/tests/test_manager.py b/aries_cloudagent/protocols/present_proof/v1_0/tests/test_manager.py index 6bccdb98b6..46a4c153e9 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/tests/test_manager.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/tests/test_manager.py @@ -332,6 +332,56 @@ async def test_create_presentation(self): save_ex.assert_called_once() assert exchange_out.state == V10PresentationExchange.STATE_PRESENTATION_SENT + async def test_create_presentation_proof_req_non_revoc_interval_none(self): + self.context.connection_record = async_mock.MagicMock() + self.context.connection_record.connection_id = CONN_ID + + exchange_in = V10PresentationExchange() + indy_proof_req = await PRES_PREVIEW.indy_proof_request( + name=PROOF_REQ_NAME, + version=PROOF_REQ_VERSION, + nonce=PROOF_REQ_NONCE, + ledger=await self.context.inject(BaseLedger, required=False), + ) + indy_proof_req["non_revoked"] = None # simulate interop with indy-vcx + + exchange_in.presentation_request = indy_proof_req + request = async_mock.MagicMock() + request.indy_proof_request = async_mock.MagicMock() + request._thread_id = "dummy" + self.context.message = request + + more_magic_rr = async_mock.MagicMock( + get_or_fetch_local_tails_path=async_mock.CoroutineMock( + return_value="/tmp/sample/tails/path" + ) + ) + with async_mock.patch.object( + V10PresentationExchange, "save", autospec=True + ) as save_ex, async_mock.patch.object( + test_module, "AttachDecorator", autospec=True + ) as mock_attach_decorator, async_mock.patch.object( + test_module, "RevocationRegistry", autospec=True + ) as mock_rr: + mock_rr.from_definition = async_mock.MagicMock(return_value=more_magic_rr) + + mock_attach_decorator.from_indy_dict = async_mock.MagicMock( + return_value=mock_attach_decorator + ) + + req_creds = await indy_proof_req_preview2indy_requested_creds( + indy_proof_req, holder=self.holder + ) + assert not req_creds["self_attested_attributes"] + assert len(req_creds["requested_attributes"]) == 2 + assert len(req_creds["requested_predicates"]) == 1 + + (exchange_out, pres_msg) = await self.manager.create_presentation( + exchange_in, req_creds + ) + save_ex.assert_called_once() + assert exchange_out.state == V10PresentationExchange.STATE_PRESENTATION_SENT + async def test_create_presentation_self_asserted(self): self.context.connection_record = async_mock.MagicMock() self.context.connection_record.connection_id = CONN_ID diff --git a/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py b/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py index e500f15fbf..b58b4aaebf 100644 --- a/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py +++ b/aries_cloudagent/protocols/present_proof/v1_0/tests/test_routes.py @@ -1,3 +1,5 @@ +import importlib + from aiohttp import web as aio_web from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock @@ -58,10 +60,14 @@ async def test_presentation_exchange_list(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.query = async_mock.CoroutineMock() mock_presentation_exchange.query.return_value = [mock_presentation_exchange] mock_presentation_exchange.serialize = async_mock.MagicMock() @@ -90,10 +96,14 @@ async def test_presentation_exchange_list_x(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.query = async_mock.CoroutineMock( side_effect=test_module.StorageError() ) @@ -109,13 +119,18 @@ async def test_presentation_exchange_credentials_list_not_found(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True - ) as mock_pres_ex: - mock_pres_ex.retrieve_by_id = async_mock.CoroutineMock() + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, + ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock() # Emulate storage not found (bad presentation exchange id) - mock_pres_ex.retrieve_by_id.side_effect = StorageNotFoundError + mock_presentation_exchange.retrieve_by_id.side_effect = StorageNotFoundError with self.assertRaises(test_module.web.HTTPNotFound): await test_module.presentation_exchange_credentials_list(mock) @@ -140,9 +155,14 @@ async def test_presentation_exchange_credentials_x(self): ), } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.return_value.retrieve_by_id.return_value = ( mock_presentation_exchange ) @@ -169,9 +189,14 @@ async def test_presentation_exchange_credentials_list_single_referent(self): ), } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.return_value.retrieve_by_id.return_value = ( mock_presentation_exchange ) @@ -204,9 +229,14 @@ async def test_presentation_exchange_credentials_list_multiple_referents(self): ), } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.return_value.retrieve_by_id.return_value = ( mock_presentation_exchange ) @@ -225,9 +255,14 @@ async def test_presentation_exchange_retrieve(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_pres_ex: + + # Since we are mocking import + importlib.reload(test_module) + mock_pres_ex.retrieve_by_id = async_mock.CoroutineMock() mock_pres_ex.retrieve_by_id.return_value = mock_pres_ex mock_pres_ex.serialize = async_mock.MagicMock() @@ -249,9 +284,14 @@ async def test_presentation_exchange_retrieve_not_found(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_pres_ex: + + # Since we are mocking import + importlib.reload(test_module) + mock_pres_ex.retrieve_by_id = async_mock.CoroutineMock() # Emulate storage not found (bad presentation exchange id) @@ -271,9 +311,14 @@ async def test_presentation_exchange_retrieve_ser_x(self): mock_pres_ex_rec = async_mock.MagicMock( connection_id="abc123", thread_id="thid123" ) - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_pres_ex: + + # Since we are mocking import + importlib.reload(test_module) + mock_pres_ex.retrieve_by_id = async_mock.CoroutineMock( return_value=mock_pres_ex_rec ) @@ -292,21 +337,26 @@ async def test_presentation_exchange_send_proposal(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( - test_module, "PresentationPreview", autospec=True - ) as mock_presentation_proposal: + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.inner.presentation_preview.PresentationPreview", + autospec=True, + ) as mock_preview: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange_record = async_mock.MagicMock() mock_presentation_manager.return_value.create_exchange_for_proposal = async_mock.CoroutineMock( return_value=mock_presentation_exchange_record ) - mock_presentation_proposal.return_value.deserialize.return_value = ( - async_mock.MagicMock() - ) + mock_preview.return_value.deserialize.return_value = async_mock.MagicMock() with async_mock.patch.object( test_module.web, "json_response" @@ -324,10 +374,14 @@ async def test_presentation_exchange_send_proposal_no_conn_record(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, ) as mock_connection_record: + # Since we are mocking import + importlib.reload(test_module) + # Emulate storage not found (bad connection id) mock_connection_record.retrieve_by_id = async_mock.CoroutineMock( side_effect=StorageNotFoundError @@ -344,14 +398,19 @@ async def test_presentation_exchange_send_proposal_not_ready(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationPreview", autospec=True - ) as mock_preview, async_mock.patch.object( - test_module, "PresentationProposal", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.inner.presentation_preview.PresentationPreview", + autospec=True, + ) as mock_preview, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.presentation_proposal.PresentationProposal", + autospec=True, ) as mock_proposal: - mock_preview.deserialize = async_mock.CoroutineMock() + + # Since we are mocking import + importlib.reload(test_module) mock_connection_record.retrieve_by_id = async_mock.CoroutineMock() mock_connection_record.retrieve_by_id.return_value.is_ready = False @@ -367,13 +426,20 @@ async def test_presentation_exchange_send_proposal_x(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( - test_module, "PresentationPreview", autospec=True - ) as mock_presentation_proposal: + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.inner.presentation_preview.PresentationPreview", + autospec=True, + ) as mock_preview: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange_record = async_mock.MagicMock() mock_presentation_manager.return_value.create_exchange_for_proposal = async_mock.CoroutineMock( side_effect=test_module.StorageError() @@ -392,17 +458,29 @@ async def test_presentation_exchange_create_request(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( - test_module, "PresentationPreview", autospec=True - ) as mock_presentation_proposal, async_mock.patch.object( + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.inner.presentation_preview.PresentationPreview", + autospec=True, + ) as mock_preview, async_mock.patch.object( test_module, "PresentationRequest", autospec=True - ) as mock_presentation_request, async_mock.patch.object( - test_module, "AttachDecorator", autospec=True - ) as mock_attach_decorator, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True - ) as mock_presentation_exchange: + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, + ) as mock_presentation_exchange, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce: + + # Since we are mocking import + importlib.reload(test_module) + + mock_generate_nonce = async_mock.CoroutineMock() + mock_attach_decorator.from_indy_dict = async_mock.MagicMock( return_value=mock_attach_decorator ) @@ -435,20 +513,27 @@ async def test_presentation_exchange_create_request_x(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( - test_module, "PresentationPreview", autospec=True - ) as mock_presentation_proposal, async_mock.patch.object( + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.inner.presentation_preview.PresentationPreview", + autospec=True, + ) as mock_preview, async_mock.patch.object( test_module, "PresentationRequest", autospec=True - ) as mock_presentation_request, async_mock.patch.object( - test_module, "AttachDecorator", autospec=True - ) as mock_attach_decorator, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True - ) as mock_presentation_exchange: - mock_attach_decorator.from_indy_dict = async_mock.MagicMock( - return_value=mock_attach_decorator - ) + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, + ) as mock_presentation_exchange, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.serialize = async_mock.MagicMock() mock_presentation_exchange.serialize.return_value = { "thread_id": "sample-thread-id" @@ -477,19 +562,30 @@ async def test_presentation_exchange_send_free_request(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( - test_module, "PresentationPreview", autospec=True - ) as mock_presentation_proposal, async_mock.patch.object( + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.inner.presentation_preview.PresentationPreview", + autospec=True, + ) as mock_preview, async_mock.patch.object( test_module, "PresentationRequest", autospec=True - ) as mock_presentation_request, async_mock.patch.object( - test_module, "AttachDecorator", autospec=True - ) as mock_attach_decorator, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_connection_record.retrieve_by_id = async_mock.CoroutineMock( return_value=mock_connection_record ) @@ -524,9 +620,14 @@ async def test_presentation_exchange_send_free_request_not_found(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, ) as mock_connection_record: + + # Since we are mocking import + importlib.reload(test_module) + mock_connection_record.retrieve_by_id = async_mock.CoroutineMock() mock_connection_record.retrieve_by_id.side_effect = StorageNotFoundError @@ -543,9 +644,14 @@ async def test_presentation_exchange_send_free_request_not_ready(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, ) as mock_connection_record: + + # Since we are mocking import + importlib.reload(test_module) + mock_connection_record.is_ready = False mock_connection_record.retrieve_by_id = async_mock.CoroutineMock( return_value=mock_connection_record @@ -568,19 +674,31 @@ async def test_presentation_exchange_send_free_request_x(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( test_module, "PresentationPreview", autospec=True ) as mock_presentation_proposal, async_mock.patch.object( test_module, "PresentationRequest", autospec=True - ) as mock_presentation_request, async_mock.patch.object( - test_module, "AttachDecorator", autospec=True - ) as mock_attach_decorator, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + + mock_generate_nonce = async_mock.CoroutineMock() + mock_connection_record.retrieve_by_id = async_mock.CoroutineMock( return_value=mock_connection_record ) @@ -614,22 +732,40 @@ async def test_presentation_exchange_send_bound_request(self): mock.match_info = {"pres_ex_id": "dummy"} mock.app = { "outbound_message_router": async_mock.CoroutineMock(), - "request_context": self.mock_context, + "request_context": async_mock.CoroutineMock( + inject=async_mock.CoroutineMock( + return_value=async_mock.CoroutineMock( + __aenter__=async_mock.CoroutineMock(), + __aexit__=async_mock.CoroutineMock(), + verify_presentation=async_mock.CoroutineMock(), + ) + ) + ), } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( test_module, "PresentationPreview", autospec=True ) as mock_presentation_proposal, async_mock.patch.object( test_module, "PresentationRequest", autospec=True - ) as mock_presentation_request, async_mock.patch.object( - test_module, "AttachDecorator", autospec=True - ) as mock_attach_decorator, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.state = ( test_module.V10PresentationExchange.STATE_PROPOSAL_RECEIVED ) @@ -675,12 +811,29 @@ async def test_presentation_exchange_send_bound_request_not_found(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( + test_module, "PresentationPreview", autospec=True + ) as mock_presentation_proposal, async_mock.patch.object( + test_module, "PresentationRequest", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.state = ( test_module.V10PresentationExchange.STATE_PROPOSAL_RECEIVED ) @@ -709,12 +862,29 @@ async def test_presentation_exchange_send_bound_request_not_ready(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( + test_module, "PresentationPreview", autospec=True + ) as mock_presentation_proposal, async_mock.patch.object( + test_module, "PresentationRequest", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.state = ( test_module.V10PresentationExchange.STATE_PROPOSAL_RECEIVED ) @@ -745,9 +915,14 @@ async def test_presentation_exchange_send_bound_request_bad_state(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_PRESENTATION_ACKED @@ -771,19 +946,29 @@ async def test_presentation_exchange_send_bound_request_x(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( test_module, "PresentationPreview", autospec=True ) as mock_presentation_proposal, async_mock.patch.object( test_module, "PresentationRequest", autospec=True - ) as mock_presentation_request, async_mock.patch.object( - test_module, "AttachDecorator", autospec=True - ) as mock_attach_decorator, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.state = ( test_module.V10PresentationExchange.STATE_PROPOSAL_RECEIVED ) @@ -823,18 +1008,33 @@ async def test_presentation_exchange_send_presentation(self): mock.match_info = {"pres_ex_id": "dummy"} mock.app = { "outbound_message_router": async_mock.CoroutineMock(), - "request_context": self.mock_context, + "request_context": async_mock.CoroutineMock( + inject=async_mock.CoroutineMock( + return_value=async_mock.CoroutineMock( + __aenter__=async_mock.CoroutineMock(), + __aexit__=async_mock.CoroutineMock(), + verify_presentation=async_mock.CoroutineMock(), + ) + ) + ), } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, ) as mock_presentation_manager, async_mock.patch.object( test_module, "PresentationPreview", autospec=True - ) as mock_presentation_proposal, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + ) as mock_presentation_proposal, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.state = ( test_module.V10PresentationExchange.STATE_REQUEST_RECEIVED ) @@ -875,11 +1075,29 @@ async def test_presentation_exchange_send_presentation_not_found(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( + test_module, "PresentationPreview", autospec=True + ) as mock_presentation_proposal, async_mock.patch.object( + test_module, "PresentationRequest", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_REQUEST_RECEIVED, @@ -903,11 +1121,29 @@ async def test_presentation_exchange_send_presentation_not_ready(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( + test_module, "PresentationPreview", autospec=True + ) as mock_presentation_proposal, async_mock.patch.object( + test_module, "PresentationRequest", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_REQUEST_RECEIVED, @@ -932,9 +1168,14 @@ async def test_presentation_exchange_send_presentation_bad_state(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_PRESENTATION_ACKED @@ -959,15 +1200,29 @@ async def test_presentation_exchange_send_presentation_x(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch.object( test_module, "PresentationPreview", autospec=True ) as mock_presentation_proposal, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + test_module, "PresentationRequest", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.state = ( test_module.V10PresentationExchange.STATE_REQUEST_RECEIVED ) @@ -1002,13 +1257,30 @@ async def test_presentation_exchange_verify_presentation(self): "request_context": async_mock.MagicMock(settings={}), } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.indy.util.generate_pr_nonce", autospec=True, + ) as mock_generate_nonce, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.messages.inner.presentation_preview.PresentationPreview", + autospec=True, + ) as mock_preview, async_mock.patch.object( + test_module, "PresentationRequest", autospec=True + ) as mock_presentation_request, async_mock.patch( + "aries_cloudagent.messaging.decorators.attach_decorator.AttachDecorator", + autospec=True, + ) as mock_attach_decorator, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_PRESENTATION_RECEIVED, @@ -1041,15 +1313,31 @@ async def test_presentation_exchange_verify_presentation_not_found(self): mock.match_info = {"pres_ex_id": "dummy"} mock.app = { "outbound_message_router": async_mock.CoroutineMock(), - "request_context": "context", + "request_context": async_mock.CoroutineMock( + inject=async_mock.CoroutineMock( + return_value=async_mock.CoroutineMock( + __aenter__=async_mock.CoroutineMock(), + __aexit__=async_mock.CoroutineMock(), + verify_presentation=async_mock.CoroutineMock(), + ) + ) + ), } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_PRESENTATION_RECEIVED, @@ -1072,11 +1360,20 @@ async def test_presentation_exchange_verify_presentation_not_ready(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_PRESENTATION_RECEIVED, @@ -1101,9 +1398,14 @@ async def test_presentation_exchange_verify_presentation_bad_state(self): "request_context": self.mock_context, } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_PRESENTATION_ACKED @@ -1117,16 +1419,31 @@ async def test_presentation_exchange_verify_presentation_x(self): mock.match_info = {"pres_ex_id": "dummy"} mock.app = { "outbound_message_router": async_mock.CoroutineMock(), - "request_context": "context", + "request_context": async_mock.CoroutineMock( + inject=async_mock.CoroutineMock( + return_value=async_mock.CoroutineMock( + __aenter__=async_mock.CoroutineMock(), + __aexit__=async_mock.CoroutineMock(), + verify_presentation=async_mock.CoroutineMock(), + ) + ) + ), } - with async_mock.patch.object( - test_module, "ConnectionRecord", autospec=True - ) as mock_connection_record, async_mock.patch.object( - test_module, "PresentationManager", autospec=True - ) as mock_presentation_manager, async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.connections.models.connection_record.ConnectionRecord", + autospec=True, + ) as mock_connection_record, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.manager.PresentationManager", + autospec=True, + ) as mock_presentation_manager, async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_PRESENTATION_RECEIVED, @@ -1137,6 +1454,7 @@ async def test_presentation_exchange_verify_presentation_x(self): ), ) ) + mock_connection_record.is_ready = True mock_connection_record.retrieve_by_id = async_mock.CoroutineMock( return_value=mock_connection_record @@ -1144,7 +1462,7 @@ async def test_presentation_exchange_verify_presentation_x(self): mock_mgr = async_mock.MagicMock( verify_presentation=async_mock.CoroutineMock( side_effect=test_module.LedgerError() - ) + ), ) mock_presentation_manager.return_value = mock_mgr @@ -1159,9 +1477,14 @@ async def test_presentation_exchange_remove(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_VERIFIED, @@ -1185,9 +1508,14 @@ async def test_presentation_exchange_remove_not_found(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + # Emulate storage not found (bad pres ex id) mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( side_effect=StorageNotFoundError @@ -1204,9 +1532,14 @@ async def test_presentation_exchange_remove_x(self): "request_context": "context", } - with async_mock.patch.object( - test_module, "V10PresentationExchange", autospec=True + with async_mock.patch( + "aries_cloudagent.protocols.present_proof.v1_0.models.presentation_exchange.V10PresentationExchange", + autospec=True, ) as mock_presentation_exchange: + + # Since we are mocking import + importlib.reload(test_module) + mock_presentation_exchange.retrieve_by_id = async_mock.CoroutineMock( return_value=async_mock.MagicMock( state=mock_presentation_exchange.STATE_VERIFIED, diff --git a/aries_cloudagent/protocols/problem_report/v1_0/message.py b/aries_cloudagent/protocols/problem_report/v1_0/message.py index a111d64993..bd6093d4c7 100644 --- a/aries_cloudagent/protocols/problem_report/v1_0/message.py +++ b/aries_cloudagent/protocols/problem_report/v1_0/message.py @@ -2,7 +2,7 @@ from typing import Mapping, Sequence -from marshmallow import fields, validate +from marshmallow import EXCLUDE, fields, validate from ....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -55,7 +55,7 @@ def __init__( tracking_uri: URI for tracking the problem escalation_uri: URI for escalating the problem """ - super(ProblemReport, self).__init__(**kwargs) + super().__init__(**kwargs) self.msg_catalog = msg_catalog self.locale = locale self.explain_ltxt = explain_ltxt @@ -77,6 +77,7 @@ class Meta: """Problem report schema metadata.""" model_class = ProblemReport + unknown = EXCLUDE msg_catalog = fields.Str( data_key="@msg_catalog", diff --git a/aries_cloudagent/protocols/routing/v1_0/messages/forward.py b/aries_cloudagent/protocols/routing/v1_0/messages/forward.py index b8dea6ec8c..bd2c827044 100644 --- a/aries_cloudagent/protocols/routing/v1_0/messages/forward.py +++ b/aries_cloudagent/protocols/routing/v1_0/messages/forward.py @@ -4,7 +4,7 @@ from typing import Union -from marshmallow import fields, pre_load +from marshmallow import EXCLUDE, fields, pre_load from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -31,7 +31,7 @@ def __init__(self, *, to: str = None, msg: Union[dict, str] = None, **kwargs): to (str): Recipient DID msg (str): Message content """ - super(Forward, self).__init__(**kwargs) + super().__init__(**kwargs) self.to = to if isinstance(msg, str): msg = json.loads(msg) @@ -45,6 +45,7 @@ class Meta: """ForwardSchema metadata.""" model_class = Forward + unknown = EXCLUDE @pre_load def handle_str_message(self, data, **kwargs): diff --git a/aries_cloudagent/protocols/routing/v1_0/messages/route_query_request.py b/aries_cloudagent/protocols/routing/v1_0/messages/route_query_request.py index e5354696d5..0926ed28ce 100644 --- a/aries_cloudagent/protocols/routing/v1_0/messages/route_query_request.py +++ b/aries_cloudagent/protocols/routing/v1_0/messages/route_query_request.py @@ -1,6 +1,6 @@ """Query existing forwarding routes.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -30,7 +30,7 @@ def __init__(self, *, filter: dict = None, paginate: Paginate = None, **kwargs): filter: Filter results according to specific field values """ - super(RouteQueryRequest, self).__init__(**kwargs) + super().__init__(**kwargs) self.filter = filter self.paginate = paginate @@ -42,6 +42,7 @@ class Meta: """RouteQueryRequestSchema metadata.""" model_class = RouteQueryRequest + unknown = EXCLUDE filter = fields.Dict( keys=fields.Str(description="field"), diff --git a/aries_cloudagent/protocols/routing/v1_0/messages/route_query_response.py b/aries_cloudagent/protocols/routing/v1_0/messages/route_query_response.py index 279d69c75b..9dd7002f23 100644 --- a/aries_cloudagent/protocols/routing/v1_0/messages/route_query_response.py +++ b/aries_cloudagent/protocols/routing/v1_0/messages/route_query_response.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -40,7 +40,7 @@ def __init__( filter: Filter results according to specific field values """ - super(RouteQueryResponse, self).__init__(**kwargs) + super().__init__(**kwargs) self.routes = routes or [] self.paginated = paginated @@ -52,6 +52,7 @@ class Meta: """RouteQueryResponseSchema metadata.""" model_class = RouteQueryResponse + unknown = EXCLUDE routes = fields.List(fields.Nested(RouteQueryResultSchema()), required=True) paginated = fields.Nested(PaginatedSchema(), required=False) diff --git a/aries_cloudagent/protocols/routing/v1_0/messages/route_update_request.py b/aries_cloudagent/protocols/routing/v1_0/messages/route_update_request.py index 68fc4ceedf..d1a8c98357 100644 --- a/aries_cloudagent/protocols/routing/v1_0/messages/route_update_request.py +++ b/aries_cloudagent/protocols/routing/v1_0/messages/route_update_request.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -33,7 +33,7 @@ def __init__(self, *, updates: Sequence[RouteUpdate] = None, **kwargs): updates: A list of route updates """ - super(RouteUpdateRequest, self).__init__(**kwargs) + super().__init__(**kwargs) self.updates = updates or [] @@ -44,5 +44,6 @@ class Meta: """RouteUpdateRequestSchema metadata.""" model_class = RouteUpdateRequest + unknown = EXCLUDE updates = fields.List(fields.Nested(RouteUpdateSchema()), required=True) diff --git a/aries_cloudagent/protocols/routing/v1_0/messages/route_update_response.py b/aries_cloudagent/protocols/routing/v1_0/messages/route_update_response.py index 0b2fdca557..be0ba028a6 100644 --- a/aries_cloudagent/protocols/routing/v1_0/messages/route_update_response.py +++ b/aries_cloudagent/protocols/routing/v1_0/messages/route_update_response.py @@ -2,7 +2,7 @@ from typing import Sequence -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -33,7 +33,7 @@ def __init__(self, *, updated: Sequence[RouteUpdated] = None, **kwargs): updated: A list of route updates """ - super(RouteUpdateResponse, self).__init__(**kwargs) + super().__init__(**kwargs) self.updated = updated or [] @@ -44,5 +44,6 @@ class Meta: """RouteUpdateResponseSchema metadata.""" model_class = RouteUpdateResponse + unknown = EXCLUDE updated = fields.List(fields.Nested(RouteUpdatedSchema()), required=True) diff --git a/aries_cloudagent/protocols/routing/v1_0/models/paginate.py b/aries_cloudagent/protocols/routing/v1_0/models/paginate.py index ba2a478301..eea7357a9d 100644 --- a/aries_cloudagent/protocols/routing/v1_0/models/paginate.py +++ b/aries_cloudagent/protocols/routing/v1_0/models/paginate.py @@ -1,6 +1,6 @@ """An object for containing the request pagination information.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -22,7 +22,7 @@ def __init__(self, *, limit: int = None, offset: int = None, **kwargs): offset: Set the offset of the first requested result """ - super(Paginate, self).__init__(**kwargs) + super().__init__(**kwargs) self.limit = limit self.offset = offset @@ -33,7 +33,8 @@ class PaginateSchema(BaseModelSchema): class Meta: """PaginateSchema metadata.""" - model_class = "Paginate" + model_class = Paginate + unknown = EXCLUDE limit = fields.Int(required=False) offset = fields.Int(required=False) diff --git a/aries_cloudagent/protocols/routing/v1_0/models/paginated.py b/aries_cloudagent/protocols/routing/v1_0/models/paginated.py index d0f22111e3..80be99cc1a 100644 --- a/aries_cloudagent/protocols/routing/v1_0/models/paginated.py +++ b/aries_cloudagent/protocols/routing/v1_0/models/paginated.py @@ -1,6 +1,6 @@ """An object for containing the response pagination information.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -32,7 +32,7 @@ def __init__( total: Total number of records available """ - super(Paginated, self).__init__(**kwargs) + super().__init__(**kwargs) self.start = start self.end = end self.limit = limit @@ -45,7 +45,8 @@ class PaginatedSchema(BaseModelSchema): class Meta: """PaginatedSchema metadata.""" - model_class = "Paginated" + model_class = Paginated + unknown = EXCLUDE start = fields.Int(required=False) end = fields.Int(required=False) diff --git a/aries_cloudagent/protocols/routing/v1_0/models/route_query_result.py b/aries_cloudagent/protocols/routing/v1_0/models/route_query_result.py index a5b2b45f81..cc7ad92c78 100644 --- a/aries_cloudagent/protocols/routing/v1_0/models/route_query_result.py +++ b/aries_cloudagent/protocols/routing/v1_0/models/route_query_result.py @@ -1,6 +1,6 @@ """An object for containing returned route information.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -21,7 +21,7 @@ def __init__(self, *, recipient_key: str = None, **kwargs): recipient_key: The recipient verkey of the route """ - super(RouteQueryResult, self).__init__(**kwargs) + super().__init__(**kwargs) self.recipient_key = recipient_key @@ -31,6 +31,7 @@ class RouteQueryResultSchema(BaseModelSchema): class Meta: """RouteQueryResultSchema metadata.""" - model_class = "RouteQueryResult" + model_class = RouteQueryResult + unknown = EXCLUDE recipient_key = fields.Str(required=True) diff --git a/aries_cloudagent/protocols/routing/v1_0/models/route_record.py b/aries_cloudagent/protocols/routing/v1_0/models/route_record.py index 78173735a5..93f7ec839f 100644 --- a/aries_cloudagent/protocols/routing/v1_0/models/route_record.py +++ b/aries_cloudagent/protocols/routing/v1_0/models/route_record.py @@ -1,6 +1,6 @@ """An object for containing information on an individual route.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -30,7 +30,7 @@ def __init__( recipient_key: The recipient verkey of the route """ - super(RouteRecord, self).__init__(**kwargs) + super().__init__(**kwargs) self.record_id = record_id self.connection_id = connection_id self.recipient_key = recipient_key @@ -44,7 +44,8 @@ class RouteRecordSchema(BaseModelSchema): class Meta: """RouteRecordSchema metadata.""" - model_class = "RouteRecord" + model_class = RouteRecord + unknown = EXCLUDE record_id = fields.Str(required=False) connection_id = fields.Str(required=True) diff --git a/aries_cloudagent/protocols/routing/v1_0/models/route_update.py b/aries_cloudagent/protocols/routing/v1_0/models/route_update.py index 39f1954c9b..ecd2d14c25 100644 --- a/aries_cloudagent/protocols/routing/v1_0/models/route_update.py +++ b/aries_cloudagent/protocols/routing/v1_0/models/route_update.py @@ -1,6 +1,6 @@ """An object for containing route information to be updated.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -25,7 +25,7 @@ def __init__(self, *, recipient_key: str = None, action: str = None, **kwargs): action: The action to perform """ - super(RouteUpdate, self).__init__(**kwargs) + super().__init__(**kwargs) self.recipient_key = recipient_key self.action = action @@ -36,7 +36,8 @@ class RouteUpdateSchema(BaseModelSchema): class Meta: """RouteUpdateSchema metadata.""" - model_class = "RouteUpdate" + model_class = RouteUpdate + unknown = EXCLUDE recipient_key = fields.Str(required=True) action = fields.Str(required=True) diff --git a/aries_cloudagent/protocols/routing/v1_0/models/route_updated.py b/aries_cloudagent/protocols/routing/v1_0/models/route_updated.py index ae120bfac7..c86168bf68 100644 --- a/aries_cloudagent/protocols/routing/v1_0/models/route_updated.py +++ b/aries_cloudagent/protocols/routing/v1_0/models/route_updated.py @@ -1,6 +1,6 @@ """An object for containing updated route information.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.models.base import BaseModel, BaseModelSchema @@ -35,7 +35,7 @@ def __init__( result: The result of the requested action """ - super(RouteUpdated, self).__init__(**kwargs) + super().__init__(**kwargs) self.recipient_key = recipient_key self.action = action self.result = result @@ -47,7 +47,8 @@ class RouteUpdatedSchema(BaseModelSchema): class Meta: """RouteUpdatedSchema metadata.""" - model_class = "RouteUpdated" + model_class = RouteUpdated + unknown = EXCLUDE recipient_key = fields.Str(required=True) action = fields.Str(required=True) diff --git a/aries_cloudagent/protocols/trustping/v1_0/messages/ping.py b/aries_cloudagent/protocols/trustping/v1_0/messages/ping.py index 64110fb63b..58442f6e45 100644 --- a/aries_cloudagent/protocols/trustping/v1_0/messages/ping.py +++ b/aries_cloudagent/protocols/trustping/v1_0/messages/ping.py @@ -1,6 +1,6 @@ """Represents a trust ping message.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -43,16 +43,17 @@ class Meta: """PingSchema metadata.""" model_class = Ping + unknown = EXCLUDE response_requested = fields.Bool( + description="Whether response is requested (default True)", default=True, required=False, - description="Whether response is requested (default True)", example=True, ) comment = fields.Str( - required=False, description="Optional comment to include", - example="Hello", + required=False, allow_none=True, + example="Hello", ) diff --git a/aries_cloudagent/protocols/trustping/v1_0/messages/ping_response.py b/aries_cloudagent/protocols/trustping/v1_0/messages/ping_response.py index a5b1f2a5df..d67b425c28 100644 --- a/aries_cloudagent/protocols/trustping/v1_0/messages/ping_response.py +++ b/aries_cloudagent/protocols/trustping/v1_0/messages/ping_response.py @@ -1,6 +1,6 @@ """Represents an response to a trust ping message.""" -from marshmallow import fields +from marshmallow import EXCLUDE, fields from .....messaging.agent_message import AgentMessage, AgentMessageSchema @@ -38,6 +38,7 @@ class Meta: """PingResponseSchema metadata.""" model_class = PingResponse + unknown = EXCLUDE comment = fields.Str( required=False, diff --git a/aries_cloudagent/protocols/trustping/v1_0/routes.py b/aries_cloudagent/protocols/trustping/v1_0/routes.py index 59caab84f4..00de2df80f 100644 --- a/aries_cloudagent/protocols/trustping/v1_0/routes.py +++ b/aries_cloudagent/protocols/trustping/v1_0/routes.py @@ -3,9 +3,10 @@ from aiohttp import web from aiohttp_apispec import docs, match_info_schema, request_schema, response_schema -from marshmallow import fields, Schema +from marshmallow import fields from ....connections.models.connection_record import ConnectionRecord +from ....messaging.models.openapi import OpenAPISchema from ....messaging.valid import UUIDFour from ....storage.error import StorageNotFoundError @@ -13,19 +14,21 @@ from .messages.ping import Ping -class PingRequestSchema(Schema): +class PingRequestSchema(OpenAPISchema): """Request schema for performing a ping.""" - comment = fields.Str(required=False, description="Comment for the ping message") + comment = fields.Str( + description="Comment for the ping message", required=False, allow_none=True + ) -class PingRequestResponseSchema(Schema): +class PingRequestResponseSchema(OpenAPISchema): """Request schema for performing a ping.""" thread_id = fields.Str(required=False, description="Thread ID of the ping message") -class ConnIdMatchInfoSchema(Schema): +class ConnIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking connection id.""" conn_id = fields.Str( diff --git a/aries_cloudagent/revocation/models/issuer_rev_reg_record.py b/aries_cloudagent/revocation/models/issuer_rev_reg_record.py index 5fd30de753..ec83a9373a 100644 --- a/aries_cloudagent/revocation/models/issuer_rev_reg_record.py +++ b/aries_cloudagent/revocation/models/issuer_rev_reg_record.py @@ -4,6 +4,7 @@ import logging import uuid +from asyncio import shield from os.path import join from shutil import move from typing import Any, Sequence @@ -11,6 +12,7 @@ from marshmallow import fields, validate +from ...tails.base import BaseTailsServer from ...config.injection_context import InjectionContext from ...indy.util import indy_client_dir from ...issuer.base import BaseIssuer, IssuerError @@ -43,6 +45,7 @@ class Meta: RECORD_ID_NAME = "record_id" RECORD_TYPE = "issuer_rev_reg" + WEBHOOK_TOPIC = "revocation_registry" LOG_STATE_FLAG = "debug.revocation" CACHE_ENABLED = False TAG_NAMES = { @@ -58,6 +61,7 @@ class Meta: STATE_INIT = "init" STATE_GENERATED = "generated" STATE_PUBLISHED = "published" # definition published + STATE_STAGED = "staged" STATE_ACTIVE = "active" # first entry published STATE_FULL = "full" @@ -191,6 +195,22 @@ async def set_tails_file_public_uri( self.revoc_reg_def["value"]["tailsLocation"] = tails_file_uri await self.save(context, reason="Set tails file public URI") + async def stage_pending_registry_definition( + self, context: InjectionContext, + ): + """Prepare registry definition for future use.""" + await shield(self.generate_registry(context)) + tails_base_url = context.settings.get("tails_server_base_url") + await self.set_tails_file_public_uri( + context, f"{tails_base_url}/{self.revoc_reg_id}", + ) + await self.publish_registry_definition(context) + + tails_server: BaseTailsServer = await context.inject(BaseTailsServer) + await tails_server.upload_tails_file( + context, self.revoc_reg_id, self.tails_local_path, + ) + async def publish_registry_definition(self, context: InjectionContext): """Send the revocation registry definition to the ledger.""" if not (self.revoc_reg_def and self.issuer_did): @@ -198,7 +218,10 @@ async def publish_registry_definition(self, context: InjectionContext): self._check_url(self.tails_public_uri) - if self.state != IssuerRevRegRecord.STATE_GENERATED: + if self.state not in ( + IssuerRevRegRecord.STATE_GENERATED, + IssuerRevRegRecord.STATE_STAGED, + ): raise RevocationError( "Revocation registry {} in state {}: cannot publish definition".format( self.revoc_reg_id, self.state @@ -226,6 +249,7 @@ async def publish_registry_entry(self, context: InjectionContext): if self.state not in ( IssuerRevRegRecord.STATE_PUBLISHED, IssuerRevRegRecord.STATE_ACTIVE, + IssuerRevRegRecord.STATE_STAGED, IssuerRevRegRecord.STATE_FULL, # can still publish revocation deltas ): raise RevocationError( @@ -242,7 +266,10 @@ async def publish_registry_entry(self, context: InjectionContext): self.revoc_reg_entry, self.issuer_did, ) - if self.state == IssuerRevRegRecord.STATE_PUBLISHED: # initial entry activates + if self.state in ( + IssuerRevRegRecord.STATE_PUBLISHED, + IssuerRevRegRecord.STATE_STAGED, + ): # initial entry activates self.state = IssuerRevRegRecord.STATE_ACTIVE await self.save( context, reason="Published initial revocation registry entry" diff --git a/aries_cloudagent/revocation/models/tests/test_issuer_rev_reg_record.py b/aries_cloudagent/revocation/models/tests/test_issuer_rev_reg_record.py index 7c4bb06a4e..ca7e3e8633 100644 --- a/aries_cloudagent/revocation/models/tests/test_issuer_rev_reg_record.py +++ b/aries_cloudagent/revocation/models/tests/test_issuer_rev_reg_record.py @@ -16,23 +16,29 @@ from ....ledger.base import BaseLedger from ....storage.base import BaseStorage from ....storage.basic import BasicStorage +from ....tails.base import BaseTailsServer from ....wallet.base import BaseWallet, DIDInfo from ...error import RevocationError +from .. import issuer_rev_reg_record as test_module from ..issuer_rev_reg_record import IssuerRevRegRecord from ..revocation_registry import RevocationRegistry -from .. import issuer_rev_reg_record as test_module +TEST_DID = "55GkHamhTU1ZbTbV2ab9DE" +CRED_DEF_ID = f"{TEST_DID}:3:CL:1234:default" +REV_REG_ID = f"{TEST_DID}:4:{CRED_DEF_ID}:CL_ACCUM:0" -class TestRecord(AsyncTestCase): - test_did = "55GkHamhTU1ZbTbV2ab9DE" +class TestRecord(AsyncTestCase): def setUp(self): - self.context = InjectionContext(enforce_typing=False) + self.context = InjectionContext( + settings={"tails_server_base_url": "http://1.2.3.4:8088"}, + enforce_typing=False, + ) self.wallet = async_mock.MagicMock() - self.wallet.WALLET_TYPE = "indy" + self.wallet.type = "indy" self.context.injector.bind_instance(BaseWallet, self.wallet) Ledger = async_mock.MagicMock(BaseLedger, autospec=True) @@ -41,17 +47,17 @@ def setUp(self): self.ledger.send_revoc_reg_entry = async_mock.CoroutineMock() self.context.injector.bind_instance(BaseLedger, self.ledger) + TailsServer = async_mock.MagicMock(BaseTailsServer, autospec=True) + self.tails_server = TailsServer() + self.tails_server.upload_tails_file = async_mock.CoroutineMock() + self.context.injector.bind_instance(BaseTailsServer, self.tails_server) + self.storage = BasicStorage() self.context.injector.bind_instance(BaseStorage, self.storage) async def test_generate_registry_etc(self): - CRED_DEF_ID = f"{TestRecord.test_did}:3:CL:1234:default" - REV_REG_ID = f"{TestRecord.test_did}:4:{CRED_DEF_ID}:CL_ACCUM:0" - rec = IssuerRevRegRecord( - issuer_did=TestRecord.test_did, - cred_def_id=CRED_DEF_ID, - revoc_reg_id=REV_REG_ID, + issuer_did=TEST_DID, cred_def_id=CRED_DEF_ID, revoc_reg_id=REV_REG_ID, ) issuer = async_mock.MagicMock(BaseIssuer) self.context.injector.bind_instance(BaseIssuer, issuer) @@ -128,11 +134,8 @@ async def test_generate_registry_etc(self): assert model_instance == rec async def test_operate_on_full_record(self): - CRED_DEF_ID = f"{TestRecord.test_did}:3:CL:1234:default" - REV_REG_ID = f"{TestRecord.test_did}:4:{CRED_DEF_ID}:CL_ACCUM:0" - rec_full = IssuerRevRegRecord( - issuer_did=TestRecord.test_did, + issuer_did=TEST_DID, revoc_reg_id=REV_REG_ID, revoc_reg_def={"sample": "rr-def"}, revoc_def_type="CL_ACCUM", @@ -185,6 +188,29 @@ async def test_set_tails_file_public_uri_rev_reg_undef(self): with self.assertRaises(RevocationError): await rec.set_tails_file_public_uri(self.context, "dummy") + async def test_stage_pending_registry_definition(self): + issuer = async_mock.MagicMock(BaseIssuer) + issuer.create_and_store_revocation_registry = async_mock.CoroutineMock( + return_value=( + REV_REG_ID, + json.dumps( + { + "value": { + "tailsHash": "abcd1234", + "tailsLocation": "/tmp/location", + } + } + ), + json.dumps({}), + ) + ) + self.context.injector.bind_instance(BaseIssuer, issuer) + rec = IssuerRevRegRecord(issuer_did=TEST_DID, revoc_reg_id=REV_REG_ID) + with async_mock.patch.object( + test_module, "move", async_mock.MagicMock() + ) as mock_move: + await rec.stage_pending_registry_definition(self.context) + async def test_publish_rev_reg_undef(self): rec = IssuerRevRegRecord() with self.assertRaises(RevocationError): diff --git a/aries_cloudagent/revocation/routes.py b/aries_cloudagent/revocation/routes.py index 6aeac9ca8f..b6b6b0e219 100644 --- a/aries_cloudagent/revocation/routes.py +++ b/aries_cloudagent/revocation/routes.py @@ -13,9 +13,10 @@ response_schema, ) -from marshmallow import fields, Schema, validate +from marshmallow import fields, validate from ..messaging.credential_definitions.util import CRED_DEF_SENT_RECORD_TYPE +from ..messaging.models.openapi import OpenAPISchema from ..messaging.valid import INDY_CRED_DEF_ID, INDY_REV_REG_ID from ..storage.base import BaseStorage, StorageNotFoundError @@ -23,11 +24,10 @@ from .indy import IndyRevocation from .models.issuer_rev_reg_record import IssuerRevRegRecord, IssuerRevRegRecordSchema - LOGGER = logging.getLogger(__name__) -class RevRegCreateRequestSchema(Schema): +class RevRegCreateRequestSchema(OpenAPISchema): """Request schema for revocation registry creation request.""" credential_definition_id = fields.Str( @@ -38,13 +38,13 @@ class RevRegCreateRequestSchema(Schema): ) -class RevRegCreateResultSchema(Schema): +class RevRegCreateResultSchema(OpenAPISchema): """Result schema for revocation registry creation request.""" result = IssuerRevRegRecordSchema() -class RevRegsCreatedSchema(Schema): +class RevRegsCreatedSchema(OpenAPISchema): """Result schema for request for revocation registries created.""" rev_reg_ids = fields.List( @@ -52,7 +52,7 @@ class RevRegsCreatedSchema(Schema): ) -class RevRegUpdateTailsFileUriSchema(Schema): +class RevRegUpdateTailsFileUriSchema(OpenAPISchema): """Request schema for updating tails file URI.""" tails_public_uri = fields.Url( @@ -65,7 +65,7 @@ class RevRegUpdateTailsFileUriSchema(Schema): ) -class RevRegsCreatedQueryStringSchema(Schema): +class RevRegsCreatedQueryStringSchema(OpenAPISchema): """Query string parameters and validators for rev regs created request.""" cred_def_id = fields.Str( @@ -86,7 +86,7 @@ class RevRegsCreatedQueryStringSchema(Schema): ) -class RevRegIdMatchInfoSchema(Schema): +class RevRegIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking rev reg id.""" rev_reg_id = fields.Str( @@ -94,7 +94,7 @@ class RevRegIdMatchInfoSchema(Schema): ) -class CredDefIdMatchInfoSchema(Schema): +class CredDefIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking cred def id.""" cred_def_id = fields.Str( diff --git a/aries_cloudagent/storage/basic.py b/aries_cloudagent/storage/basic.py index 01990c0034..2d5e6cdcaf 100644 --- a/aries_cloudagent/storage/basic.py +++ b/aries_cloudagent/storage/basic.py @@ -260,9 +260,7 @@ def __init__( options: Dictionary of backend-specific options """ - super(BasicStorageRecordSearch, self).__init__( - store, type_filter, tag_query, page_size, options - ) + super().__init__(store, type_filter, tag_query, page_size, options) self._cache = None self._iter = None diff --git a/aries_cloudagent/storage/provider.py b/aries_cloudagent/storage/provider.py index 6786617a4c..cf0f192e2d 100644 --- a/aries_cloudagent/storage/provider.py +++ b/aries_cloudagent/storage/provider.py @@ -25,7 +25,7 @@ async def provide(self, settings: BaseSettings, injector: BaseInjector): wallet_type = settings.get_value("wallet.type", default="basic").lower() storage_default_type = "indy" if wallet_type == "indy" else "basic" storage_type = settings.get_value( - "storage.type", default=storage_default_type + "storage_type", default=storage_default_type ).lower() storage_class = self.STORAGE_TYPES.get(storage_type, storage_type) storage = ClassLoader.load_class(storage_class)(wallet) diff --git a/aries_cloudagent/tails/__init__.py b/aries_cloudagent/tails/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aries_cloudagent/tails/base.py b/aries_cloudagent/tails/base.py new file mode 100644 index 0000000000..16e4555988 --- /dev/null +++ b/aries_cloudagent/tails/base.py @@ -0,0 +1,20 @@ +"""Tails server interface base class.""" + +from abc import ABC, abstractmethod, ABCMeta + +from ..config.injection_context import InjectionContext + + +class BaseTailsServer(ABC, metaclass=ABCMeta): + """Base class for tails server interface.""" + + @abstractmethod + async def upload_tails_file( + self, context: InjectionContext, rev_reg_id: str, tails_file_path: str + ) -> (bool, str): + """Upload tails file to tails server. + + Args: + rev_reg_id: The revocation registry identifier + tails_file: The path to the tails file to upload + """ diff --git a/aries_cloudagent/tails/error.py b/aries_cloudagent/tails/error.py new file mode 100644 index 0000000000..26bf55288c --- /dev/null +++ b/aries_cloudagent/tails/error.py @@ -0,0 +1,7 @@ +"""Tails server related errors.""" + +from ..core.error import BaseError + + +class TailsServerNotConfiguredError(BaseError): + """Error indicating the tails server plugin hasn't been configured.""" diff --git a/aries_cloudagent/tails/indy_tails_server.py b/aries_cloudagent/tails/indy_tails_server.py new file mode 100644 index 0000000000..d4a6715a14 --- /dev/null +++ b/aries_cloudagent/tails/indy_tails_server.py @@ -0,0 +1,39 @@ +"""Indy tails server interface class.""" + +import aiohttp + +from .base import BaseTailsServer +from .error import TailsServerNotConfiguredError + + +class IndyTailsServer(BaseTailsServer): + """Indy tails server interface.""" + + async def upload_tails_file( + self, context, rev_reg_id: str, tails_file_path: str + ) -> (bool, str): + """Upload tails file to tails server. + + Args: + rev_reg_id: The revocation registry identifier + tails_file: The path to the tails file to upload + """ + + genesis_transactions = context.settings.get("ledger.genesis_transactions") + tails_server_base_url = context.settings.get("tails_server_base_url") + + if not tails_server_base_url: + raise TailsServerNotConfiguredError( + "tails_server_base_url setting is not set" + ) + + with open(tails_file_path, "rb") as tails_file: + async with aiohttp.ClientSession() as session: + async with session.put( + f"{tails_server_base_url}/{rev_reg_id}", + data={"genesis": genesis_transactions, "tails": tails_file}, + ) as resp: + if resp.status == 200: + return True, None + else: + return False, resp.reason diff --git a/aries_cloudagent/tails/tests/__init__.py b/aries_cloudagent/tails/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aries_cloudagent/tails/tests/test_indy.py b/aries_cloudagent/tails/tests/test_indy.py new file mode 100644 index 0000000000..5e80811b5a --- /dev/null +++ b/aries_cloudagent/tails/tests/test_indy.py @@ -0,0 +1,84 @@ +from asynctest import TestCase as AsyncTestCase +from asynctest import mock as async_mock + +from ...config.injection_context import InjectionContext + +from .. import indy_tails_server as test_module + +TEST_DID = "55GkHamhTU1ZbTbV2ab9DE" +CRED_DEF_ID = f"{TEST_DID}:3:CL:1234:default" +REV_REG_ID = f"{TEST_DID}:4:{CRED_DEF_ID}:CL_ACCUM:0" + + +class TestIndyTailsServer(AsyncTestCase): + async def test_upload_no_tails_base_url_x(self): + context = InjectionContext(settings={"ledger.genesis_transactions": "dummy"}) + indy_tails = test_module.IndyTailsServer() + + with self.assertRaises(test_module.TailsServerNotConfiguredError): + await indy_tails.upload_tails_file(context, REV_REG_ID, "/tmp/dummy/path") + + async def test_upload(self): + context = InjectionContext( + settings={ + "ledger.genesis_transactions": "dummy", + "tails_server_base_url": "http://1.2.3.4:8088", + } + ) + indy_tails = test_module.IndyTailsServer() + + with async_mock.patch( + "builtins.open", async_mock.MagicMock() + ) as mock_open, async_mock.patch.object( + test_module.aiohttp, "ClientSession", async_mock.MagicMock() + ) as mock_cli_session: + mock_open.return_value = async_mock.MagicMock( + __enter__=async_mock.MagicMock() + ) + mock_cli_session.return_value = async_mock.MagicMock( + __aenter__=async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + put=async_mock.MagicMock( + return_value=async_mock.MagicMock( + __aenter__=async_mock.CoroutineMock( + return_value=async_mock.MagicMock(status=200) + ) + ) + ) + ) + ) + ) + (ok, reason) = await indy_tails.upload_tails_file( + context, REV_REG_ID, "/tmp/dummy/path" + ) + assert ok + assert reason is None + + with async_mock.patch( + "builtins.open", async_mock.MagicMock() + ) as mock_open, async_mock.patch.object( + test_module.aiohttp, "ClientSession", async_mock.MagicMock() + ) as mock_cli_session: + mock_open.return_value = async_mock.MagicMock( + __enter__=async_mock.MagicMock() + ) + mock_cli_session.return_value = async_mock.MagicMock( + __aenter__=async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + put=async_mock.MagicMock( + return_value=async_mock.MagicMock( + __aenter__=async_mock.CoroutineMock( + return_value=async_mock.MagicMock( + status=403, reason="Unauthorized" + ) + ) + ) + ) + ) + ) + ) + (ok, reason) = await indy_tails.upload_tails_file( + context, REV_REG_ID, "/tmp/dummy/path" + ) + assert not ok + assert reason == "Unauthorized" diff --git a/aries_cloudagent/transport/inbound/tests/test_http_transport.py b/aries_cloudagent/transport/inbound/tests/test_http_transport.py index 18b304e5ea..a64977011a 100644 --- a/aries_cloudagent/transport/inbound/tests/test_http_transport.py +++ b/aries_cloudagent/transport/inbound/tests/test_http_transport.py @@ -26,7 +26,7 @@ def setUp(self): self.transport.wire_format = JsonWireFormat() self.result_event = None self.response_message = None - super(TestHttpTransport, self).setUp() + super().setUp() def create_session( self, diff --git a/aries_cloudagent/transport/inbound/tests/test_ws_transport.py b/aries_cloudagent/transport/inbound/tests/test_ws_transport.py index 7a917dbf5d..82ec959531 100644 --- a/aries_cloudagent/transport/inbound/tests/test_ws_transport.py +++ b/aries_cloudagent/transport/inbound/tests/test_ws_transport.py @@ -23,7 +23,7 @@ def setUp(self): self.transport = WsTransport("0.0.0.0", self.port, self.create_session) self.transport.wire_format = JsonWireFormat() self.result_event = None - super(TestWsTransport, self).setUp() + super().setUp() def create_session( self, diff --git a/aries_cloudagent/transport/outbound/http.py b/aries_cloudagent/transport/outbound/http.py index 2bbc70f4ce..788085e220 100644 --- a/aries_cloudagent/transport/outbound/http.py +++ b/aries_cloudagent/transport/outbound/http.py @@ -19,7 +19,7 @@ class HttpTransport(BaseOutboundTransport): def __init__(self) -> None: """Initialize an `HttpTransport` instance.""" - super(HttpTransport, self).__init__() + super().__init__() self.client_session: ClientSession = None self.connector: TCPConnector = None self.logger = logging.getLogger(__name__) diff --git a/aries_cloudagent/transport/outbound/manager.py b/aries_cloudagent/transport/outbound/manager.py index 31a8ac501e..977817a7a6 100644 --- a/aries_cloudagent/transport/outbound/manager.py +++ b/aries_cloudagent/transport/outbound/manager.py @@ -392,6 +392,9 @@ async def _process_loop(self): if self.outbound_buffer: if (not new_pending) and (not retry_count): await self.outbound_event.wait() + elif retry_count: + # only retries - yield here so we don't hog resources + await asyncio.sleep(0.05) else: break diff --git a/aries_cloudagent/transport/outbound/tests/test_manager.py b/aries_cloudagent/transport/outbound/tests/test_manager.py index 805b041641..25c6b2ebac 100644 --- a/aries_cloudagent/transport/outbound/tests/test_manager.py +++ b/aries_cloudagent/transport/outbound/tests/test_manager.py @@ -6,6 +6,7 @@ from ....config.injection_context import InjectionContext from ....connections.models.connection_target import ConnectionTarget +from .. import manager as test_module from ..manager import ( OutboundDeliveryError, OutboundTransportManager, @@ -83,6 +84,7 @@ async def test_send_message(self): send_context = InjectionContext() mgr.enqueue_message(send_context, message) await mgr.flush() + transport.wire_format.encode_message.assert_awaited_once_with( send_context, message.payload, @@ -182,9 +184,96 @@ async def test_process_finished_x(self): mgr.finished_deliver(mock_queued, mock_task) mgr.finished_deliver(mock_queued, mock_task) + async def test_process_loop_retry_now(self): + mock_queued = async_mock.MagicMock( + state=QueuedOutboundMessage.STATE_RETRY, + retry_at=test_module.get_timer() - 1, + ) + + context = InjectionContext() + mock_handle_not_delivered = async_mock.MagicMock() + mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr.outbound_buffer.append(mock_queued) + + with async_mock.patch.object( + test_module, "trace_event", async_mock.MagicMock() + ) as mock_trace: + mock_trace.side_effect = KeyError() + with self.assertRaises(KeyError): # cover retry logic and bail + await mgr._process_loop() + assert mock_queued.retry_at is None + + async def test_process_loop_retry_later(self): + mock_queued = async_mock.MagicMock( + state=QueuedOutboundMessage.STATE_RETRY, + retry_at=test_module.get_timer() + 3600, + ) + + context = InjectionContext() + mock_handle_not_delivered = async_mock.MagicMock() + mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr.outbound_buffer.append(mock_queued) + + with async_mock.patch.object( + test_module.asyncio, "sleep", async_mock.CoroutineMock() + ) as mock_sleep_x: + mock_sleep_x.side_effect = KeyError() + with self.assertRaises(KeyError): # cover retry logic and bail + await mgr._process_loop() + assert mock_queued.retry_at is not None + + async def test_process_loop_new(self): + context = InjectionContext() + mock_handle_not_delivered = async_mock.MagicMock() + mgr = OutboundTransportManager(context, mock_handle_not_delivered) + + mgr.outbound_new = [ + async_mock.MagicMock( + state=test_module.QueuedOutboundMessage.STATE_NEW, + message=async_mock.MagicMock(enc_payload=b"encr"), + ) + ] + with async_mock.patch.object( + mgr, "deliver_queued_message", async_mock.MagicMock() + ) as mock_deliver, async_mock.patch.object( + mgr.outbound_event, "wait", async_mock.CoroutineMock() + ) as mock_wait, async_mock.patch.object( + test_module, "trace_event", async_mock.MagicMock() + ) as mock_trace: + mock_wait.side_effect = KeyError() # cover state=NEW logic and bail + + with self.assertRaises(KeyError): + await mgr._process_loop() + + async def test_process_loop_new_deliver(self): + context = InjectionContext() + mock_handle_not_delivered = async_mock.MagicMock() + mgr = OutboundTransportManager(context, mock_handle_not_delivered) + + mgr.outbound_new = [ + async_mock.MagicMock( + state=test_module.QueuedOutboundMessage.STATE_DELIVER, + message=async_mock.MagicMock(enc_payload=b"encr"), + ) + ] + with async_mock.patch.object( + mgr, "deliver_queued_message", async_mock.MagicMock() + ) as mock_deliver, async_mock.patch.object( + mgr.outbound_event, "wait", async_mock.CoroutineMock() + ) as mock_wait, async_mock.patch.object( + test_module, "trace_event", async_mock.MagicMock() + ) as mock_trace: + mock_wait.side_effect = KeyError() # cover state=DELIVER logic and bail + + with self.assertRaises(KeyError): + await mgr._process_loop() + async def test_process_loop_x(self): mock_queued = async_mock.MagicMock( - state=QueuedOutboundMessage.STATE_DONE, error=KeyError() + state=QueuedOutboundMessage.STATE_DONE, + error=KeyError(), + endpoint="http://1.2.3.4:8081", + payload="Hello world", ) context = InjectionContext() @@ -193,3 +282,25 @@ async def test_process_loop_x(self): mgr.outbound_buffer.append(mock_queued) await mgr._process_loop() + + async def test_finished_deliver_x_log_debug(self): + mock_queued = async_mock.MagicMock( + state=QueuedOutboundMessage.STATE_DONE, retries=1 + ) + mock_completed_x = async_mock.MagicMock(exc_info=KeyError("an error occurred")) + + context = InjectionContext() + mock_handle_not_delivered = async_mock.MagicMock() + mgr = OutboundTransportManager(context, mock_handle_not_delivered) + mgr.outbound_buffer.append(mock_queued) + with async_mock.patch.object( + test_module.LOGGER, "exception", async_mock.MagicMock() + ) as mock_logger_exception, async_mock.patch.object( + test_module.LOGGER, "error", async_mock.MagicMock() + ) as mock_logger_error, async_mock.patch.object( + test_module.LOGGER, "isEnabledFor", async_mock.MagicMock() + ) as mock_logger_enabled, async_mock.patch.object( + mgr, "process_queued", async_mock.MagicMock() + ) as mock_process: + mock_logger_enabled.return_value = True # cover debug logging + mgr.finished_deliver(mock_queued, mock_completed_x) diff --git a/aries_cloudagent/transport/outbound/ws.py b/aries_cloudagent/transport/outbound/ws.py index 9d532bd5ee..e39ec8b46b 100644 --- a/aries_cloudagent/transport/outbound/ws.py +++ b/aries_cloudagent/transport/outbound/ws.py @@ -17,7 +17,7 @@ class WsTransport(BaseOutboundTransport): def __init__(self) -> None: """Initialize an `WsTransport` instance.""" - super(WsTransport, self).__init__() + super().__init__() self.logger = logging.getLogger(__name__) async def start(self): diff --git a/aries_cloudagent/transport/queue/tests/test_basic_queue.py b/aries_cloudagent/transport/queue/tests/test_basic_queue.py index fd078bf71c..53bcf7353b 100644 --- a/aries_cloudagent/transport/queue/tests/test_basic_queue.py +++ b/aries_cloudagent/transport/queue/tests/test_basic_queue.py @@ -1,7 +1,8 @@ import asyncio -from asynctest import TestCase as AsyncTestCase +from asynctest import mock as async_mock, TestCase as AsyncTestCase +from .. import basic as test_module from ..basic import BasicMessageQueue @@ -30,6 +31,70 @@ async def test_enqueue_dequeue(self): queue.task_done() await queue.join() + async def test_dequeue_x(self): + queue = BasicMessageQueue() + test_value = "test value" + await queue.enqueue(test_value) + + with async_mock.patch.object( + test_module.asyncio, "get_event_loop", async_mock.MagicMock() + ) as mock_get_event_loop, async_mock.patch.object( + test_module.asyncio, "wait", async_mock.CoroutineMock() + ) as mock_wait: + mock_wait.return_value = ( + async_mock.MagicMock(), + [ + async_mock.MagicMock( + done=async_mock.MagicMock(), cancel=async_mock.MagicMock() + ) + ], + ) + mock_get_event_loop.return_value = async_mock.MagicMock( + create_task=async_mock.MagicMock( + side_effect=[ + async_mock.MagicMock(), # stopped + async_mock.MagicMock( # dequeued + done=async_mock.MagicMock(return_value=True), + exception=async_mock.MagicMock(return_value=KeyError()), + ), + ] + ) + ) + with self.assertRaises(KeyError): + await queue.dequeue(timeout=0) + + async def test_dequeue_none(self): + queue = BasicMessageQueue() + test_value = "test value" + await queue.enqueue(test_value) + + with async_mock.patch.object( + test_module.asyncio, "get_event_loop", async_mock.MagicMock() + ) as mock_get_event_loop, async_mock.patch.object( + test_module.asyncio, "wait", async_mock.CoroutineMock() + ) as mock_wait: + mock_wait.return_value = ( + async_mock.MagicMock(), + [ + async_mock.MagicMock( + done=async_mock.MagicMock(), cancel=async_mock.MagicMock() + ) + ], + ) + mock_get_event_loop.return_value = async_mock.MagicMock( + create_task=async_mock.MagicMock( + side_effect=[ + async_mock.MagicMock( # stopped + done=async_mock.MagicMock(return_value=True) + ), + async_mock.MagicMock( # dequeued + done=async_mock.MagicMock(return_value=False) + ), + ] + ) + ) + assert await queue.dequeue(timeout=0) is None + async def test_async_iter(self): queue = BasicMessageQueue() diff --git a/aries_cloudagent/transport/tests/test_stats.py b/aries_cloudagent/transport/tests/test_stats.py new file mode 100644 index 0000000000..6676092d5a --- /dev/null +++ b/aries_cloudagent/transport/tests/test_stats.py @@ -0,0 +1,22 @@ +from asynctest import TestCase as AsyncTestCase, mock as async_mock + +from ...config.injection_context import InjectionContext + +from .. import stats as test_module + + +class TestStatsTracer(AsyncTestCase): + def setUp(self): + self.context = async_mock.MagicMock( + socket_timer=async_mock.MagicMock( + stop=async_mock.MagicMock(side_effect=AttributeError("wrong")) + ) + ) + self.tracer = test_module.StatsTracer(test_module.Collector(), "test") + + async def test_queued_start_stop(self): + await self.tracer.connection_queued_start(None, self.context, None) + await self.tracer.connection_queued_end(None, self.context, None) + + async def test_connection_ready_error_pass(self): + await self.tracer.connection_ready(None, self.context, None) diff --git a/aries_cloudagent/utils/outofband.py b/aries_cloudagent/utils/outofband.py index 1ee815ef65..03eb38bcac 100644 --- a/aries_cloudagent/utils/outofband.py +++ b/aries_cloudagent/utils/outofband.py @@ -4,15 +4,12 @@ from urllib.parse import quote, urljoin -from ..config.injection_context import InjectionContext from ..messaging.agent_message import AgentMessage from ..wallet.base import DIDInfo -from ..wallet.util import bytes_to_b64 +from ..wallet.util import str_to_b64 -def serialize_outofband( - context: InjectionContext, message: AgentMessage, did: DIDInfo, endpoint: str -) -> str: +def serialize_outofband(message: AgentMessage, did: DIDInfo, endpoint: str) -> str: """ Serialize the agent message as an out-of-band message. @@ -27,6 +24,5 @@ def serialize_outofband( "routingKeys": [], "serviceEndpoint": endpoint, } - d_m = quote(bytes_to_b64(json.dumps(body).encode("ascii"))) - result = urljoin(endpoint, "?d_m={}".format(d_m)) - return result + d_m = quote(str_to_b64(json.dumps(body))) + return urljoin(endpoint, "?d_m={}".format(d_m)) diff --git a/aries_cloudagent/utils/task_queue.py b/aries_cloudagent/utils/task_queue.py index 04256ed3f8..b51c03a6ad 100644 --- a/aries_cloudagent/utils/task_queue.py +++ b/aries_cloudagent/utils/task_queue.py @@ -186,7 +186,8 @@ def __bool__(self) -> bool: """ Support for the bool() builtin. - Otherwise, evaluates as false when there are no tasks. + Return: + True - the task queue exists even if there are no tasks """ return True diff --git a/aries_cloudagent/utils/temp.py b/aries_cloudagent/utils/temp.py deleted file mode 100644 index 35ffeb8d79..0000000000 --- a/aries_cloudagent/utils/temp.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Temp file utilities.""" - -import tempfile - -TEMP_DIRS = {} - - -def get_temp_dir(category: str) -> str: - """Accessor for the temp directory.""" - if category not in TEMP_DIRS: - TEMP_DIRS[category] = tempfile.TemporaryDirectory(category) - return TEMP_DIRS[category].name diff --git a/aries_cloudagent/utils/tests/test_http.py b/aries_cloudagent/utils/tests/test_http.py index a9b486fb5d..4f5eaf3611 100644 --- a/aries_cloudagent/utils/tests/test_http.py +++ b/aries_cloudagent/utils/tests/test_http.py @@ -1,7 +1,7 @@ from aiohttp import web from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop -from ..http import fetch, FetchError +from ..http import fetch, fetch_stream, FetchError class TestTransportUtils(AioHTTPTestCase): @@ -25,6 +25,36 @@ async def succeed_route(self, request): ret = web.json_response([True]) return ret + @unittest_run_loop + async def test_fetch_stream(self): + server_addr = f"http://localhost:{self.server.port}" + stream = await fetch_stream( + f"{server_addr}/succeed", session=self.client.session + ) + result = await stream.read() + assert result == b"[true]" + assert self.succeed_calls == 1 + + @unittest_run_loop + async def test_fetch_stream_default_client(self): + server_addr = f"http://localhost:{self.server.port}" + stream = await fetch_stream(f"{server_addr}/succeed") + result = await stream.read() + assert result == b"[true]" + assert self.succeed_calls == 1 + + @unittest_run_loop + async def test_fetch_stream_fail(self): + server_addr = f"http://localhost:{self.server.port}" + with self.assertRaises(FetchError): + await fetch_stream( + f"{server_addr}/fail", + max_attempts=2, + interval=0, + session=self.client.session, + ) + assert self.fail_calls == 2 + @unittest_run_loop async def test_fetch(self): server_addr = f"http://localhost:{self.server.port}" @@ -34,6 +64,13 @@ async def test_fetch(self): assert result == [1] assert self.succeed_calls == 1 + @unittest_run_loop + async def test_fetch_default_client(self): + server_addr = f"http://localhost:{self.server.port}" + result = await fetch(f"{server_addr}/succeed", json=True) + assert result == [1] + assert self.succeed_calls == 1 + @unittest_run_loop async def test_fetch_fail(self): server_addr = f"http://localhost:{self.server.port}" diff --git a/aries_cloudagent/utils/tests/test_outofband.py b/aries_cloudagent/utils/tests/test_outofband.py new file mode 100644 index 0000000000..f013b81591 --- /dev/null +++ b/aries_cloudagent/utils/tests/test_outofband.py @@ -0,0 +1,23 @@ +from asynctest import mock, TestCase + +from ...messaging.agent_message import AgentMessage +from ...protocols.out_of_band.v1_0.messages.invitation import Invitation +from ...wallet.base import DIDInfo + +from .. import outofband as test_module + + +class TestOutOfBand(TestCase): + test_did = "55GkHamhTU1ZbTbV2ab9DE" + test_verkey = "3Dn1SJNPaCXcvvJvSbsFWP2xaCjMom3can8CQNhWrTRx" + test_did_info = DIDInfo(test_did, test_verkey, None) + + def test_serialize_oob(self): + invi = Invitation( + comment="my sister", label=u"ma sœur", service=[TestOutOfBand.test_did] + ) + + result = test_module.serialize_outofband( + invi, TestOutOfBand.test_did_info, "http://1.2.3.4:8081" + ) + assert "?d_m=" in result diff --git a/aries_cloudagent/utils/tests/test_repeat.py b/aries_cloudagent/utils/tests/test_repeat.py index a245b029f1..63ac160867 100644 --- a/aries_cloudagent/utils/tests/test_repeat.py +++ b/aries_cloudagent/utils/tests/test_repeat.py @@ -9,6 +9,13 @@ def test_iter(self): seq = test_module.RepeatSequence(5, interval=5.0, backoff=0.25) assert [round(attempt.next_interval) for attempt in seq] == expect + seq = test_module.RepeatSequence(2, interval=5.0, backoff=0.25) + attempt = seq.start() + attempt = attempt.next() + attempt.timeout(interval=0.01) + with self.assertRaises(StopIteration): + attempt.next() + async def test_aiter(self): seq = test_module.RepeatSequence(5, interval=5.0, backoff=0.25) sleeps = [0] diff --git a/aries_cloudagent/utils/tests/test_stats.py b/aries_cloudagent/utils/tests/test_stats.py index 8e56e5a61c..014b77b729 100644 --- a/aries_cloudagent/utils/tests/test_stats.py +++ b/aries_cloudagent/utils/tests/test_stats.py @@ -1,3 +1,5 @@ +from tempfile import NamedTemporaryFile + from asynctest import TestCase as AsyncTestCase from asynctest import mock as async_mock @@ -39,9 +41,12 @@ def test_mark(self): def test_wrap(self): pass + def test_wrap_again(self): + pass + instance = TestClass() - stats.wrap(instance, "test_wrap") + stats.wrap(instance, ["test_wrap", "test_wrap_again"]) instance.test() await instance.test_async() instance.test_mark() @@ -75,7 +80,8 @@ async def test_disable(self): assert stats.results["avg"] == {"test": 1.0} async def test_extract(self): - stats = Collector() + tmp_file = NamedTemporaryFile() + stats = Collector(log_path=tmp_file.name) stats.log("test", 1.0) stats.log("test", 2.0) diff --git a/aries_cloudagent/utils/tests/test_task_queue.py b/aries_cloudagent/utils/tests/test_task_queue.py index fc3fad77a3..4ac409eefa 100644 --- a/aries_cloudagent/utils/tests/test_task_queue.py +++ b/aries_cloudagent/utils/tests/test_task_queue.py @@ -1,5 +1,5 @@ import asyncio -from asynctest import TestCase +from asynctest import mock as async_mock, TestCase as AsyncTestCase from ..task_queue import CompletedTask, PendingTask, TaskQueue, task_exc_info @@ -10,7 +10,7 @@ async def retval(val, *, delay=0): return val -class TestTaskQueue(TestCase): +class TestTaskQueue(AsyncTestCase): async def test_run(self): queue = TaskQueue() task = None @@ -70,6 +70,7 @@ def done(complete: CompletedTask): async def test_pending(self): coro = retval(1, delay=1) pend = PendingTask(coro, None) + assert str(pend).startswith(" 0.0 + + def test_tracing_enabled(self): + invi = Invitation( + comment="no comment", label="cable guy", service=[TestTracing.test_did] + ) + assert not test_module.tracing_enabled({}, invi) + invi._trace = TraceDecorator(target="message") + assert test_module.tracing_enabled({}, invi) + + cred_ex_rec = V10CredentialExchange() + assert not test_module.tracing_enabled({}, cred_ex_rec) + cred_ex_rec = V10CredentialExchange(trace=True) + assert test_module.tracing_enabled({}, cred_ex_rec) + + dict_message = {"no": "trace"} + assert not test_module.tracing_enabled({}, dict_message) + dict_message["trace"] = True + assert test_module.tracing_enabled({}, dict_message) + dict_message["~trace"] = dict_message.pop("trace") + assert test_module.tracing_enabled({}, dict_message) + + str_message = json.dumps({"item": "I can't draw but I can trace"}) + assert not test_module.tracing_enabled({}, str_message) + str_message = json.dumps( + "Finding a ~trace as a false positive represents an outlier" + ) + assert test_module.tracing_enabled({}, str_message) + str_message = json.dumps({"trace": False, "content": "sample"}) + assert not test_module.tracing_enabled({}, str_message) + str_message = json.dumps({"trace": True, "content": "sample"}) + assert test_module.tracing_enabled({}, str_message) + + invi._trace = None + outbound_message = OutboundMessage(payload=invi) + assert not test_module.tracing_enabled({}, outbound_message) + invi._trace = TraceDecorator(target="message") + assert test_module.tracing_enabled({}, outbound_message) + dict_message = {"no": "trace"} + outbound_message = OutboundMessage(payload=dict_message) + assert not test_module.tracing_enabled({}, outbound_message) + dict_message["trace"] = True + assert test_module.tracing_enabled({}, outbound_message) + dict_message["~trace"] = dict_message.pop("trace") + assert test_module.tracing_enabled({}, outbound_message) + outbound_message = OutboundMessage(payload="This text does not have the T word") + assert not test_module.tracing_enabled({}, outbound_message) + outbound_message = OutboundMessage(payload=json.dumps({"trace": True})) + assert test_module.tracing_enabled({}, outbound_message) + + def test_decode_inbound_message(self): + invi = Invitation( + comment="no comment", label="cable guy", service=[TestTracing.test_did] + ) + message = OutboundMessage(payload=invi) + assert invi == test_module.decode_inbound_message(message) + + dict_message = {"a": 1, "b": 2} + message = OutboundMessage(payload=dict_message) + assert dict_message == test_module.decode_inbound_message(message) + assert dict_message == test_module.decode_inbound_message(dict_message) + + str_message = json.dumps(dict_message) + message = OutboundMessage(payload=str_message) + assert dict_message == test_module.decode_inbound_message(message) + assert dict_message == test_module.decode_inbound_message(str_message) + + x_message = "'bad json'" + message = OutboundMessage(payload=x_message) + assert message == test_module.decode_inbound_message(message) + assert x_message == test_module.decode_inbound_message(x_message) + def test_log_event(self): - message = Ping() - message._thread = {"thid": "dummy_thread_id_12345"} + ping = Ping() + ping._thread = {"thid": "dummy_thread_id_12345"} context = { "trace.enabled": True, "trace.target": "log", @@ -23,18 +103,33 @@ def test_log_event(self): } ret = test_module.trace_event( context, - message, + ping, handler="message_handler", perf_counter=None, outcome="processed Start", ) + test_module.trace_event( + context, ping, perf_counter=ret, outcome="processed OK", + ) + context["trace.label"] = "trace-label" + test_module.trace_event(context, ping) + ping = Ping() + test_module.trace_event(context, ping) test_module.trace_event( context, - message, - handler="message_handler", - perf_counter=ret, - outcome="processed OK", + InboundMessage(session_id="1234", payload="Hello world", receipt=None), + ) + test_module.trace_event( + context, OutboundMessage(reply_thread_id="5678", payload="Hello world") + ) + test_module.trace_event( + context, {"@type": "sample-type", "~thread": {"thid": "abcd"}} ) + test_module.trace_event(context, {"~thread": {"thid": "1234"}}) + test_module.trace_event(context, {"thread_id": "1234"}) + test_module.trace_event(context, {"@id": "12345"}) + test_module.trace_event(context, V10CredentialExchange()) + test_module.trace_event(context, []) async def test_post_event(self): message = Ping() diff --git a/aries_cloudagent/utils/tracing.py b/aries_cloudagent/utils/tracing.py index 38a94bcbbc..07463a5d06 100644 --- a/aries_cloudagent/utils/tracing.py +++ b/aries_cloudagent/utils/tracing.py @@ -6,7 +6,7 @@ import datetime import requests -from marshmallow import fields, Schema +from marshmallow import fields from ..transport.inbound.message import InboundMessage from ..transport.outbound.message import OutboundMessage @@ -17,13 +17,14 @@ TRACE_LOG_TARGET, ) from ..messaging.models.base_record import BaseExchangeRecord +from ..messaging.models.openapi import OpenAPISchema LOGGER = logging.getLogger(__name__) DT_FMT = "%Y-%m-%d %H:%M:%S.%f%z" -class AdminAPIMessageTracingSchema(Schema): +class AdminAPIMessageTracingSchema(OpenAPISchema): """ Request/result schema including agent message tracing. @@ -77,7 +78,7 @@ def tracing_enabled(context, message) -> bool: if message.payload.get("~trace") or message.payload.get("trace"): return True elif message.payload and isinstance(message.payload, str): - if "~trace" in message.payload or "trace" in message.payload: + if "trace" in message.payload: # includes "~trace" in message.payload return True # default off @@ -126,7 +127,7 @@ def trace_event( ("log", "message" or an http endpoint) context["trace.tag"]: Tag to be included in trace output message: the current message, can be an AgentMessage, - InboundMessage, OutboundMessage or Exchange record + InboundMessage, OutboundMessage or Exchange record event: Dict that will be converted to json and posted to the target """ diff --git a/aries_cloudagent/version.py b/aries_cloudagent/version.py index e7eafcf35c..2e82f5b4cf 100644 --- a/aries_cloudagent/version.py +++ b/aries_cloudagent/version.py @@ -1,3 +1,3 @@ """Library version information.""" -__version__ = "0.5.2" +__version__ = "0.5.3" diff --git a/aries_cloudagent/wallet/base.py b/aries_cloudagent/wallet/base.py index ac23724bba..8ee77e4f26 100644 --- a/aries_cloudagent/wallet/base.py +++ b/aries_cloudagent/wallet/base.py @@ -5,10 +5,9 @@ from typing import Sequence from ..ledger.base import BaseLedger - +from ..ledger.endpoint_type import EndpointType KeyInfo = namedtuple("KeyInfo", "verkey metadata") - DIDInfo = namedtuple("DIDInfo", "did verkey metadata") @@ -271,21 +270,30 @@ async def replace_local_did_metadata(self, did: str, metadata: dict): """ - async def set_did_endpoint(self, did: str, endpoint: str, ledger: BaseLedger): + async def set_did_endpoint( + self, + did: str, + endpoint: str, + ledger: BaseLedger, + endpoint_type: EndpointType = None, + ): """ Update the endpoint for a DID in the wallet, send to ledger if public. Args: did: DID for which to set endpoint endpoint: the endpoint to set, None to clear - ledger: the ledger to which to send endpoint update - if DID is public - specify None for basic wallet - + ledger: the ledger to which to send endpoint update if DID is public + endpoint_type: the type of the endpoint/service. Only endpoint_type + 'endpoint' affects local wallet """ did_info = await self.get_local_did(did) metadata = {**did_info.metadata} - metadata.pop("endpoint", None) - metadata["endpoint"] = endpoint + if not endpoint_type: + endpoint_type = EndpointType.ENDPOINT + if endpoint_type == EndpointType.ENDPOINT: + metadata.pop("endpoint", None) + metadata["endpoint"] = endpoint await self.replace_local_did_metadata(did, metadata) diff --git a/aries_cloudagent/wallet/basic.py b/aries_cloudagent/wallet/basic.py index f2f59f4219..16db12d776 100644 --- a/aries_cloudagent/wallet/basic.py +++ b/aries_cloudagent/wallet/basic.py @@ -32,7 +32,7 @@ def __init__(self, config: dict = None): """ if not config: config = {} - super(BasicWallet, self).__init__(config) + super().__init__(config) self._name = config.get("name") self._keys = {} self._local_dids = {} diff --git a/aries_cloudagent/wallet/crypto.py b/aries_cloudagent/wallet/crypto.py index 35bd2f75a6..bdf1a49271 100644 --- a/aries_cloudagent/wallet/crypto.py +++ b/aries_cloudagent/wallet/crypto.py @@ -1,14 +1,16 @@ """Cryptography functions used by BasicWallet.""" -from collections import OrderedDict import json + +from collections import OrderedDict from typing import Callable, Optional, Sequence, Tuple -from marshmallow import fields, Schema, ValidationError import nacl.bindings import nacl.exceptions import nacl.utils +from marshmallow import fields, Schema, ValidationError + from .error import WalletError from .util import bytes_to_b58, bytes_to_b64, b64_to_bytes, b58_to_bytes diff --git a/aries_cloudagent/wallet/indy.py b/aries_cloudagent/wallet/indy.py index 0edc5c281b..2ba209f6fc 100644 --- a/aries_cloudagent/wallet/indy.py +++ b/aries_cloudagent/wallet/indy.py @@ -2,16 +2,19 @@ import json import logging + from typing import Sequence import indy.anoncreds import indy.did import indy.crypto import indy.wallet + from indy.error import IndyError, ErrorCode from ..indy.error import IndyErrorHandler from ..ledger.base import BaseLedger +from ..ledger.endpoint_type import EndpointType from ..ledger.error import LedgerConfigError from .base import BaseWallet, KeyInfo, DIDInfo @@ -549,7 +552,13 @@ async def replace_local_did_metadata(self, did: str, metadata: dict): await self.get_local_did(did) # throw exception if undefined await indy.did.set_did_metadata(self.handle, did, meta_json) - async def set_did_endpoint(self, did: str, endpoint: str, ledger: BaseLedger): + async def set_did_endpoint( + self, + did: str, + endpoint: str, + ledger: BaseLedger, + endpoint_type: EndpointType = None, + ): """ Update the endpoint for a DID in the wallet, send to ledger if public. @@ -557,12 +566,15 @@ async def set_did_endpoint(self, did: str, endpoint: str, ledger: BaseLedger): did: DID for which to set endpoint endpoint: the endpoint to set, None to clear ledger: the ledger to which to send endpoint update if DID is public - + endpoint_type: the type of the endpoint/service. Only endpoint_type + 'endpoint' affects local wallet """ did_info = await self.get_local_did(did) metadata = {**did_info.metadata} - metadata.pop("endpoint", None) - metadata["endpoint"] = endpoint + if not endpoint_type: + endpoint_type = EndpointType.ENDPOINT + if endpoint_type == EndpointType.ENDPOINT: + metadata[endpoint_type.indy] = endpoint wallet_public_didinfo = await self.get_public_did() if wallet_public_didinfo and wallet_public_didinfo.did == did: @@ -572,7 +584,7 @@ async def set_did_endpoint(self, did: str, endpoint: str, ledger: BaseLedger): f"No ledger available but DID {did} is public: missing wallet-type?" ) async with ledger: - await ledger.update_endpoint_for_did(did, endpoint) + await ledger.update_endpoint_for_did(did, endpoint, endpoint_type) await self.replace_local_did_metadata(did, metadata) @@ -700,54 +712,6 @@ async def unpack_message(self, enc_message: bytes) -> (str, str, str): from_verkey = unpacked.get("sender_verkey", None) return message, from_verkey, to_verkey - ''' - async def get_credential_definition_tag_policy(self, credential_definition_id: str): - """Return the tag policy for a given credential definition ID.""" - try: - policy_json = await indy.anoncreds.prover_get_credential_attr_tag_policy( - self.handle, credential_definition_id - ) - except IndyError as x_indy: - raise IndyErrorHandler.wrap_error( - x_indy, "Wallet {} error".format(self.name), WalletError - ) from x_indy - - return json.loads(policy_json) if policy_json else None - - async def set_credential_definition_tag_policy( - self, - credential_definition_id: str, - taggables: Sequence[str] = None, - retroactive: bool = True, - ): - """ - Set the tag policy for a given credential definition ID. - - Args: - credential_definition_id: The ID of the credential definition - taggables: A sequence of string values representing attribute names; - empty array for none, None for all - retroactive: Whether to apply the policy to previously-stored credentials - """ - - self.logger.info( - "%s tagging policy: %s", - "Clear" if taggables is None else "Set", - credential_definition_id, - ) - try: - await indy.anoncreds.prover_set_credential_attr_tag_policy( - self.handle, - credential_definition_id, - json.dumps(taggables), - retroactive, - ) - except IndyError as x_indy: - raise IndyErrorHandler.wrap_error( - x_indy, "Wallet {} error".format(self.name), WalletError - ) from x_indy - ''' - @classmethod async def generate_wallet_key(self, seed: str = None) -> str: """Generate a raw Indy wallet key.""" diff --git a/aries_cloudagent/wallet/routes.py b/aries_cloudagent/wallet/routes.py index 135b16b747..19bd1b8a5b 100644 --- a/aries_cloudagent/wallet/routes.py +++ b/aries_cloudagent/wallet/routes.py @@ -11,17 +11,25 @@ response_schema, ) -from marshmallow import fields, Schema +from marshmallow import fields from ..ledger.base import BaseLedger +from ..ledger.endpoint_type import EndpointType from ..ledger.error import LedgerConfigError, LedgerError -from ..messaging.valid import ENDPOINT, INDY_CRED_DEF_ID, INDY_DID, INDY_RAW_PUBLIC_KEY +from ..messaging.models.openapi import OpenAPISchema +from ..messaging.valid import ( + ENDPOINT, + ENDPOINT_TYPE, + INDY_CRED_DEF_ID, + INDY_DID, + INDY_RAW_PUBLIC_KEY, +) from .base import DIDInfo, BaseWallet from .error import WalletError, WalletNotFoundError -class DIDSchema(Schema): +class DIDSchema(OpenAPISchema): """Result schema for a DID.""" did = fields.Str(description="DID of interest", **INDY_DID) @@ -29,19 +37,36 @@ class DIDSchema(Schema): public = fields.Boolean(description="Whether DID is public", example=False) -class DIDResultSchema(Schema): +class DIDResultSchema(OpenAPISchema): """Result schema for a DID.""" result = fields.Nested(DIDSchema()) -class DIDListSchema(Schema): +class DIDListSchema(OpenAPISchema): """Result schema for connection list.""" results = fields.List(fields.Nested(DIDSchema()), description="DID list") -class DIDEndpointSchema(Schema): +class DIDEndpointWithTypeSchema(OpenAPISchema): + """Request schema to set DID endpoint of particular type.""" + + did = fields.Str(description="DID of interest", required=True, **INDY_DID) + endpoint = fields.Str( + description="Endpoint to set (omit to delete)", required=False, **ENDPOINT + ) + endpoint_type = fields.Str( + description=( + f"Endpoint type to set (default '{EndpointType.ENDPOINT.w3c}'); " + "affects only public DIDs" + ), + required=False, + **ENDPOINT_TYPE, + ) + + +class DIDEndpointSchema(OpenAPISchema): """Request schema to set DID endpoint; response schema to get DID endpoint.""" did = fields.Str(description="DID of interest", required=True, **INDY_DID) @@ -50,7 +75,7 @@ class DIDEndpointSchema(Schema): ) -class DIDListQueryStringSchema(Schema): +class DIDListQueryStringSchema(OpenAPISchema): """Parameters and validators for DID list request query string.""" did = fields.Str(description="DID of interest", required=False, **INDY_DID) @@ -59,16 +84,16 @@ class DIDListQueryStringSchema(Schema): required=False, **INDY_RAW_PUBLIC_KEY, ) - public = fields.Boolean(description="Whether DID is on the ledger", required=False) + public = fields.Boolean(description="Whether DID is public", required=False) -class DIDQueryStringSchema(Schema): +class DIDQueryStringSchema(OpenAPISchema): """Parameters and validators for set public DID request query string.""" did = fields.Str(description="DID of interest", required=True, **INDY_DID) -class CredDefIdMatchInfoSchema(Schema): +class CredDefIdMatchInfoSchema(OpenAPISchema): """Path parameters and validators for request taking credential definition id.""" cred_def_id = fields.Str( @@ -82,7 +107,7 @@ def format_did_info(info: DIDInfo): return { "did": info.did, "verkey": info.verkey, - "public": json.dumps(bool(info.metadata.get("public"))), + "public": bool(info.metadata.get("public")), } @@ -251,7 +276,7 @@ async def wallet_set_public_did(request: web.BaseRequest): @docs(tags=["wallet"], summary="Update endpoint in wallet and, if public, on ledger") -@request_schema(DIDEndpointSchema) +@request_schema(DIDEndpointWithTypeSchema) async def wallet_set_did_endpoint(request: web.BaseRequest): """ Request handler for setting an endpoint for a public or local DID. @@ -267,10 +292,13 @@ async def wallet_set_did_endpoint(request: web.BaseRequest): body = await request.json() did = body["did"] endpoint = body.get("endpoint") + endpoint_type = EndpointType.get( + body.get("endpoint_type", EndpointType.ENDPOINT.w3c) + ) try: ledger: BaseLedger = await context.inject(BaseLedger, required=False) - await wallet.set_did_endpoint(did, endpoint, ledger) + await wallet.set_did_endpoint(did, endpoint, ledger, endpoint_type) except WalletNotFoundError as err: raise web.HTTPNotFound(reason=err.roll_up) from err except LedgerConfigError as err: diff --git a/aries_cloudagent/wallet/tests/test_provider.py b/aries_cloudagent/wallet/tests/test_provider.py index 8067b1adb0..850e916637 100644 --- a/aries_cloudagent/wallet/tests/test_provider.py +++ b/aries_cloudagent/wallet/tests/test_provider.py @@ -25,6 +25,7 @@ async def test_provide_basic(self): assert wallet.name == "name" await wallet.close() + @pytest.mark.indy async def test_provide_indy(self): provider = test_module.WalletProvider() settings = Settings( diff --git a/aries_cloudagent/wallet/tests/test_routes.py b/aries_cloudagent/wallet/tests/test_routes.py index abbbc88d22..4192a22e8f 100644 --- a/aries_cloudagent/wallet/tests/test_routes.py +++ b/aries_cloudagent/wallet/tests/test_routes.py @@ -52,11 +52,11 @@ def test_format_did_info(self): assert ( result["did"] == self.test_did and result["verkey"] == self.test_verkey - and result["public"] == "false" + and result["public"] is False ) did_info = DIDInfo(self.test_did, self.test_verkey, {"public": True}) result = test_module.format_did_info(did_info) - assert result["public"] == "true" + assert result["public"] is True async def test_create_did(self): request = async_mock.MagicMock() diff --git a/aries_cloudagent/wallet/tests/test_util.py b/aries_cloudagent/wallet/tests/test_util.py index 827cef78e5..a032a44671 100644 --- a/aries_cloudagent/wallet/tests/test_util.py +++ b/aries_cloudagent/wallet/tests/test_util.py @@ -11,6 +11,8 @@ str_to_b64, set_urlsafe_b64, unpad, + naked_to_did_key, + did_key_to_naked, ) @@ -62,3 +64,15 @@ def test_pad(self): def test_b58(self): b58 = bytes_to_b58(BYTES) assert b58_to_bytes(b58) == BYTES + + def test_naked_to_did_key(self): + assert ( + naked_to_did_key("8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K") + == "did:key:z6MkmjY8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th" + ) + + def test_did_key_to_naked(self): + assert ( + did_key_to_naked("did:key:z6MkmjY8GnV5i9YTDtPETC2uUAW6ejw3nk5mXF5yci5ab7th") + == "8HH5gYEeNc3z7PYXmd54d4x6qAfCNrqQqEB3nS7Zfu7K" + ) diff --git a/aries_cloudagent/wallet/util.py b/aries_cloudagent/wallet/util.py index 808391172c..7987bbaf51 100644 --- a/aries_cloudagent/wallet/util.py +++ b/aries_cloudagent/wallet/util.py @@ -3,6 +3,8 @@ import base58 import base64 +from multicodec import add_prefix, remove_prefix + def pad(val: str) -> str: """Pad base64 values if need be: JWT calls to omit trailing padding.""" @@ -57,3 +59,19 @@ def b58_to_bytes(val: str) -> bytes: def bytes_to_b58(val: bytes) -> str: """Convert a byte string to base 58.""" return base58.b58encode(val).decode("ascii") + + +def naked_to_did_key(key: str) -> str: + """Convert a naked ed25519 verkey to did:key format.""" + key_bytes = b58_to_bytes(key) + prefixed_key_bytes = add_prefix("ed25519-pub", key_bytes) + did_key = f"did:key:z{bytes_to_b58(prefixed_key_bytes)}" + return did_key + + +def did_key_to_naked(did_key: str) -> str: + """Convert a did:key to naked ed25519 verkey format.""" + stripped_key = did_key.split("did:key:z").pop() + stripped_key_bytes = b58_to_bytes(stripped_key) + naked_key_bytes = remove_prefix(stripped_key_bytes) + return bytes_to_b58(naked_key_bytes) diff --git a/demo/AliceGetsAPhone.md b/demo/AliceGetsAPhone.md index ae7e29bb4b..10dea1a1b3 100644 --- a/demo/AliceGetsAPhone.md +++ b/demo/AliceGetsAPhone.md @@ -21,19 +21,16 @@ This demo also introduces revocation of credentials. - [Revoke the Credential and Send Another Proof Request](#revoke-the-credential-and-send-another-proof-request) - [Conclusion](#conclusion) - ## Getting Started -This demo can be run on your local machine or on Play with Docker (PWD), and will demonstrate credential exchange and proof exchange as well as revocation with a mobile agent. Both approaches (running locally and on PWD) will be described, for the most part the commands are the same, but there are a couple of different parameters you need to provide when starting up. +This demo can be run on your local machine or on Play with Docker (PWD), and will demonstrate credential exchange and proof exchange as well as revocation with a mobile agent. Both approaches (running locally and on PWD) will be described, for the most part the commands are the same, but there are a couple of different parameters you need to provide when starting up. If you are not familiar with how revocation is currently implemented in Hyperledger Indy, [this article](https://github.com/hyperledger/indy-hipe/tree/master/text/0011-cred-revocation) provides a good background on the technique. A challenge with revocation as it is currently implemented in Hyperledger Indy is the need for the prover (the agent creating the proof) to download tails files associated with the credentials it holds. - ### Get a mobile agent Of course for this, you need to have a mobile agent. To find, install and setup a compatible mobile agent, follow the instructions [here](https://github.com/bcgov/identity-kit-poc/blob/master/docs/GettingApp.md). - ### Running Locally in Docker Open a new bash shell and in a project directory run the following: @@ -47,7 +44,6 @@ We'll come back to this in a minute, when we start the `faber` agent! There are a couple of extra steps you need to take to prepare to run the Faber agent locally: - #### Install ngrok and jq [ngrok](https://ngrok.com/) is used to expose public endpoints for services running locally on your computer. @@ -58,10 +54,9 @@ You can install ngrok from [here](https://ngrok.com/) You can download jq releases [here](https://github.com/stedolan/jq/releases) - #### Expose services publicly using ngrok -Note that this is *only required when running docker on your local machine*. When you run on PWD a public endpoint for your agent is exposed automatically. +Note that this is _only required when running docker on your local machine_. When you run on PWD a public endpoint for your agent is exposed automatically. Since the mobile agent will need some way to communicate with the agent running on your local machine in docker, we will need to create a publicly accesible url for some services on your machine. The easiest way to do this is with [ngrok](https://ngrok.com/). Once ngrok is installed, create a tunnel to your local machine: @@ -78,13 +73,12 @@ Forwarding http://abc123.ngrok.io -> http://localhost:8020 Forwarding https://abc123.ngrok.io -> http://localhost:8020 ``` -This creates a public url for ports 8020 on your local machine. +This creates a public url for ports 8020 on your local machine. Note that an ngrok process is created automatically for your tails server. Keep this process running as we'll come back to it in a moment. - ### Running in Play With Docker To run the necessary terminal sessions in your browser, go to the Docker playground service [Play with Docker](https://labs.play-with-docker.com/). Don't know about Play with Docker? Check [this out](https://github.com/cloudcompass/ToIPLabs/blob/master/docs/LFS173x/RunningLabs.md#running-on-play-with-docker) to learn more. @@ -98,7 +92,6 @@ cd aries-cloudagent-python/demo We'll come back to this in a minute, when we start the `faber` agent! - ### Run an instance of indy-tails-server For revocation to function, we need another component running that is used to store what are called tails files. @@ -111,24 +104,23 @@ Open a new bash shell, and in a project directory, run: git clone https://github.com/bcgov/indy-tails-server.git cd indy-tails-server/docker ./manage build -GENESIS_URL=http://test.bcovrin.vonx.io/genesis ./manage start +./manage start ``` This will run the required components for the tails server to function and make a tails server available on port 6543. -This will also automatically start an ngrok server that will expose a public url for your tails server - this is required to support mobile agents. The docker output will look something like this: +This will also automatically start an ngrok server that will expose a public url for your tails server - this is required to support mobile agents. The docker output will look something like this: ```bash ngrok-tails-server_1 | t=2020-05-13T22:51:14+0000 lvl=info msg="started tunnel" obj=tunnels name="command_line (http)" addr=http://tails-server:6543 url=http://c5789aa0.ngrok.io ngrok-tails-server_1 | t=2020-05-13T22:51:14+0000 lvl=info msg="started tunnel" obj=tunnels name=command_line addr=http://tails-server:6543 url=https://c5789aa0.ngrok.io ``` -Note the server name in the `url=https://c5789aa0.ngrok.io` parameter (`https://c5789aa0.ngrok.io`) - this is the external url for your tails server. Make sure you use the `https` url! - +Note the server name in the `url=https://c5789aa0.ngrok.io` parameter (`https://c5789aa0.ngrok.io`) - this is the external url for your tails server. Make sure you use the `https` url! ### Run `faber` With Extra Parameters -If you are running in a *local bash shell*, navigate to [The demo direcory](/demo) and run: +If you are running in a _local bash shell_, navigate to [The demo direcory](/demo) and run: ```bash TAILS_NETWORK=docker_tails-server LEDGER_URL=http://test.bcovrin.vonx.io ./run_demo faber --revocation --events @@ -136,13 +128,13 @@ TAILS_NETWORK=docker_tails-server LEDGER_URL=http://test.bcovrin.vonx.io ./run_d The `TAILS_NETWORK` parameter lets the demo script know how to connect to the tails server (which should be running in a separate shell on the same machine). -If you are running in *Play with Docker*, navigate to [The demo direcory](/demo) and run: +If you are running in _Play with Docker_, navigate to [The demo direcory](/demo) and run: ```bash PUBLIC_TAILS_URL=https://def456.ngrok.io LEDGER_URL=http://test.bcovrin.vonx.io ./run_demo faber --revocation --events ``` -The `PUBLIC_TAILS_URL` parameter lets the demo script know how to connect to the tails server. This can be running in another PWD session, or even on your local machine - the ngrok endpoint is public and will map to the correct location. +The `PUBLIC_TAILS_URL` parameter lets the demo script know how to connect to the tails server. This can be running in another PWD session, or even on your local machine - the ngrok endpoint is public and will map to the correct location. Note that you _must_ use the `https` url for the tails server endpoint. @@ -162,10 +154,9 @@ As part of its startup process, the agent will publish a revocation registry to Ledger - ## Accept the Invitation -When the Faber agent starts up it automatically creates an invitation and generates a QR code on the screen. On your mobile app, select "SCAN CODE" (or equivalent) and point your camera at the generated QR code. The mobile agent should automatically capture the code and ask you to confirm the connection. Confirm it. +When the Faber agent starts up it automatically creates an invitation and generates a QR code on the screen. On your mobile app, select "SCAN CODE" (or equivalent) and point your camera at the generated QR code. The mobile agent should automatically capture the code and ask you to confirm the connection. Confirm it.
Click here to view screenshot @@ -185,7 +176,7 @@ The mobile agent will give you feedback on the connection process, something lik Switch your browser back to Play with Docker. You should see that the connection has been established, and there is a prompt for what actions you want to take, e.g. "Issue Credential", "Send Proof Request" and so on. -Tip: If your screen is too small to display the QR code (this can happen in Play With Docker because the shell is only given a small portion of the browser) you can copy the invitation url to a site like https://www.the-qrcode-generator.com/ to convert the invitation url into a QR code that you can scan. Make sure you select the `URL` option, and copy the `invitation_url`, which will look something like: +Tip: If your screen is too small to display the QR code (this can happen in Play With Docker because the shell is only given a small portion of the browser) you can copy the invitation url to a site like https://www.the-qrcode-generator.com/ to convert the invitation url into a QR code that you can scan. Make sure you select the `URL` option, and copy the `invitation_url`, which will look something like: ```bash https://abfde260.ngrok.io?c_i=eyJAdHlwZSI6ICJkaWQ6c292OkJ6Q2JzTlloTXJqSGlxWkRUVUFTSGc7c3BlYy9jb25uZWN0aW9ucy8xLjAvaW52aXRhdGlvbiIsICJAaWQiOiAiZjI2ZjA2YTItNWU1Mi00YTA5LWEwMDctOTNkODBiZTYyNGJlIiwgInJlY2lwaWVudEtleXMiOiBbIjlQRFE2alNXMWZwZkM5UllRWGhCc3ZBaVJrQmVKRlVhVmI0QnRQSFdWbTFXIl0sICJsYWJlbCI6ICJGYWJlci5BZ2VudCIsICJzZXJ2aWNlRW5kcG9pbnQiOiAiaHR0cHM6Ly9hYmZkZTI2MC5uZ3Jvay5pbyJ9 @@ -199,7 +190,6 @@ http://ip10-0-121-4-bquqo816b480a4bfn3kg-8020.direct.play-with-von.vonx.io?c_i=e Note that this will use the ngrok endpoint if you are running locally, or your PWD endpoint if you are running on PWD. - ## Issue a Credential We will use the Faber console to issue a credential. This could be done using the Swagger API as we have done in the connection process. We'll leave that as an exercise to the user. diff --git a/demo/ngrok-wait.sh b/demo/ngrok-wait.sh index 027e9de9e8..cc9100daba 100755 --- a/demo/ngrok-wait.sh +++ b/demo/ngrok-wait.sh @@ -6,17 +6,21 @@ if ! [ -z "$TAILS_NGROK_NAME" ]; then echo "ngrok tails service name [$TAILS_NGROK_NAME]" NGROK_ENDPOINT=null - while [ -z "$NGROK_ENDPOINT" ] || [ "$NGROK_ENDPOINT" = "null" ] - do - echo "Fetching endpoint from ngrok service" - NGROK_ENDPOINT=$(curl --silent $TAILS_NGROK_NAME:4040/api/tunnels | ./jq -r '.tunnels[0].public_url') - - if [ -z "$NGROK_ENDPOINT" ] || [ "$NGROK_ENDPOINT" = "null" ]; then - echo "ngrok not ready, sleeping 5 seconds...." - sleep 5 - fi - done + JQ=${JQ:-`which jq`} + if [ -x "$JQ" ]; then + while [ -z "$NGROK_ENDPOINT" ] || [ "$NGROK_ENDPOINT" = "null" ] + do + echo "Fetching endpoint from ngrok service" + NGROK_ENDPOINT=$(curl --silent $TAILS_NGROK_NAME:4040/api/tunnels | $JQ -r '.tunnels[0].public_url') + if [ -z "$NGROK_ENDPOINT" ] || [ "$NGROK_ENDPOINT" = "null" ]; then + echo "ngrok not ready, sleeping 5 seconds...." + sleep 5 + fi + done + else + echo " not found" + fi export PUBLIC_TAILS_URL=$NGROK_ENDPOINT echo "Fetched ngrok tails server endpoint [$PUBLIC_TAILS_URL]" fi diff --git a/demo/runners/faber.py b/demo/runners/faber.py index 291ac5943e..9a3e1bd558 100644 --- a/demo/runners/faber.py +++ b/demo/runners/faber.py @@ -29,18 +29,24 @@ LOGGER = logging.getLogger(__name__) -TAILS_FILE_COUNT = int(os.getenv("TAILS_FILE_COUNT", 20)) +TAILS_FILE_COUNT = int(os.getenv("TAILS_FILE_COUNT", 100)) class FaberAgent(DemoAgent): def __init__( - self, http_port: int, admin_port: int, no_auto: bool = False, **kwargs + self, + http_port: int, + admin_port: int, + no_auto: bool = False, + tails_server_base_url: str = None, + **kwargs, ): super().__init__( "Faber.Agent", http_port, admin_port, prefix="Faber", + tails_server_base_url=tails_server_base_url, extra_args=[] if no_auto else ["--auto-accept-invites", "--auto-accept-requests"], @@ -138,6 +144,7 @@ async def main( start_port: int, no_auto: bool = False, revocation: bool = False, + tails_server_base_url: str = None, show_timing: bool = False, ): @@ -155,6 +162,7 @@ async def main( start_port + 1, genesis_data=genesis, no_auto=no_auto, + tails_server_base_url=tails_server_base_url, timing=show_timing, ) await agent.listen_webhooks(start_port + 2) @@ -184,17 +192,9 @@ async def main( version, ["name", "date", "degree", "age", "timestamp"], support_revocation=revocation, + revocation_registry_size=TAILS_FILE_COUNT, ) - if revocation: - with log_timer("Publish revocation registry duration:"): - log_status( - "#5/6 Create and publish the revocation registry on the ledger" - ) - await agent.create_and_publish_revocation_registry( - credential_definition_id, TAILS_FILE_COUNT - ) - # TODO add an additional credential for Student ID with log_timer("Generate invitation duration:"): @@ -227,11 +227,7 @@ async def main( " (3) Send Message\n" ) if revocation: - options += ( - " (4) Revoke Credential\n" - " (5) Publish Revocations\n" - " (6) Add Revocation Registry\n" - ) + options += " (4) Revoke Credential\n" " (5) Publish Revocations\n" options += " (T) Toggle tracing on credential/proof exchange\n" options += " (X) Exit?\n[1/2/3/{}T/X] ".format( "4/5/6/" if revocation else "" @@ -321,18 +317,7 @@ async def main( for req_pred in req_preds }, } - # test with an attribute group with attribute value restrictions - # indy_proof_request["requested_attributes"] = { - # "n_group_attrs": { - # "names": ["name", "degree", "timestamp", "date"], - # "restrictions": [ - # { - # "issuer_did": agent.did, - # "attr::name::value": "Alice Smith" - # } - # ] - # } - # } + if revocation: indy_proof_request["non_revoked"] = {"to": int(time.time())} proof_request_web_request = { @@ -379,11 +364,6 @@ async def main( ) except ClientError: pass - elif option == "6" and revocation: - log_status("#19 Add another revocation registry") - await agent.create_and_publish_revocation_registry( - credential_definition_id, TAILS_FILE_COUNT - ) if show_timing: timing = await agent.fetch_timing() @@ -422,6 +402,14 @@ async def main( parser.add_argument( "--revocation", action="store_true", help="Enable credential revocation" ) + + parser.add_argument( + "--tails-server-base-url", + type=str, + metavar=(""), + help="Tals server base url", + ) + parser.add_argument( "--timing", action="store_true", help="Enable timing information" ) @@ -457,9 +445,22 @@ async def main( require_indy() + tails_server_base_url = args.tails_server_base_url or os.getenv("PUBLIC_TAILS_URL") + + if args.revocation and not tails_server_base_url: + raise Exception( + "If revocation is enabled, --tails-server-base-url must be provided" + ) + try: asyncio.get_event_loop().run_until_complete( - main(args.port, args.no_auto, args.revocation, args.timing) + main( + args.port, + args.no_auto, + args.revocation, + tails_server_base_url, + args.timing, + ) ) except KeyboardInterrupt: os._exit(1) diff --git a/demo/runners/performance.py b/demo/runners/performance.py index 4b4a5a2cdb..131e3670c9 100644 --- a/demo/runners/performance.py +++ b/demo/runners/performance.py @@ -14,6 +14,9 @@ LOGGER = logging.getLogger(__name__) +TAILS_FILE_COUNT = int(os.getenv("TAILS_FILE_COUNT", 100)) + + class BaseAgent(DemoAgent): def __init__( self, @@ -190,6 +193,7 @@ async def publish_defs(self, support_revocation: bool = False): credential_definition_body = { "schema_id": self.schema_id, "support_revocation": support_revocation, + "revocation_registry_size": TAILS_FILE_COUNT, } credential_definition_response = await self.admin_POST( "/credential-definitions", credential_definition_body @@ -200,15 +204,15 @@ async def publish_defs(self, support_revocation: bool = False): self.log(f"Credential Definition ID: {self.credential_definition_id}") # create revocation registry - if support_revocation: - revoc_body = { - "credential_definition_id": self.credential_definition_id, - } - revoc_response = await self.admin_POST( - "/revocation/create-registry", revoc_body - ) - self.revocation_registry_id = revoc_response["result"]["revoc_reg_id"] - self.log(f"Revocation Registry ID: {self.revocation_registry_id}") + # if support_revocation: + # revoc_body = { + # "credential_definition_id": self.credential_definition_id, + # } + # revoc_response = await self.admin_POST( + # "/revocation/create-registry", revoc_body + # ) + # self.revocation_registry_id = revoc_response["result"]["revoc_reg_id"] + # self.log(f"Revocation Registry ID: {self.revocation_registry_id}") async def send_credential( self, cred_attrs: dict, comment: str = None, auto_remove: bool = True @@ -245,8 +249,9 @@ async def main( ping_only: bool = False, show_timing: bool = False, routing: bool = False, + revocation: bool = False, + tails_server_base_url: str = None, issue_count: int = 300, - revoc: bool = False, ): genesis = await default_genesis_txns() @@ -264,7 +269,13 @@ async def main( alice = AliceAgent(start_port, genesis_data=genesis, timing=show_timing) await alice.listen_webhooks(start_port + 2) - faber = FaberAgent(start_port + 3, genesis_data=genesis, timing=show_timing) + faber = FaberAgent( + start_port + 3, + genesis_data=genesis, + timing=show_timing, + tails_server_base_url=tails_server_base_url, + ) + await faber.listen_webhooks(start_port + 5) await faber.register_did() @@ -283,7 +294,7 @@ async def main( if not ping_only: with log_timer("Publish duration:"): - await faber.publish_defs(revoc) + await faber.publish_defs(revocation) # await alice.set_tag_policy(faber.credential_definition_id, ["name"]) with log_timer("Connect duration:"): @@ -328,7 +339,7 @@ async def send_credential(index: int): "age": "24", } asyncio.ensure_future( - faber.send_credential(attributes, comment, not revoc) + faber.send_credential(attributes, comment, not revocation) ).add_done_callback(done_send) async def check_received_creds(agent, issue_count, pb): @@ -440,7 +451,7 @@ async def check_received_pings(agent, issue_count, pb): for line in faber.format_postgres_stats(): faber.log(line) - if revoc and faber.revocations: + if revocation and faber.revocations: (rev_reg_id, cred_rev_id) = next(iter(faber.revocations)) print( "Revoking and publishing cred rev id {cred_rev_id} " @@ -518,6 +529,17 @@ async def check_received_pings(agent, issue_count, pb): default=False, help="Only send ping messages between the agents", ) + parser.add_argument( + "--revocation", action="store_true", help="Enable credential revocation" + ) + + parser.add_argument( + "--tails-server-base-url", + type=str, + metavar=(""), + help="Tals server base url", + ) + parser.add_argument( "--routing", action="store_true", help="Enable inbound routing demonstration" ) @@ -543,6 +565,8 @@ async def check_received_pings(agent, issue_count, pb): args.ping, args.timing, args.routing, + args.revocation, + args.tails_server_base_url, args.count, ) ) diff --git a/demo/runners/support/agent.py b/demo/runners/support/agent.py index 76f400ed50..41586fb6e0 100644 --- a/demo/runners/support/agent.py +++ b/demo/runners/support/agent.py @@ -6,8 +6,6 @@ import os import random import subprocess -import hashlib -import base58 from timeit import default_timer from aiohttp import ( @@ -111,9 +109,11 @@ def __init__( label: str = None, color: str = None, prefix: str = None, + tails_server_base_url: str = None, timing: bool = False, timing_log: str = None, postgres: bool = None, + revocation: bool = False, extra_args=None, **params, ): @@ -129,6 +129,7 @@ def __init__( self.timing = timing self.timing_log = timing_log self.postgres = DEFAULT_POSTGRES if postgres is None else postgres + self.tails_server_base_url = tails_server_base_url self.extra_args = extra_args self.trace_enabled = TRACE_ENABLED self.trace_target = TRACE_TARGET @@ -143,14 +144,7 @@ def __init__( ) else: self.endpoint = f"http://{self.external_host}:{http_port}" - if os.getenv("PUBLIC_TAILS_URL"): - self.public_tails_url = os.getenv("PUBLIC_TAILS_URL") - elif RUN_MODE == "pwd": - self.public_tails_url = f"http://{self.external_host}".replace( - "{PORT}", str(admin_port) - ) - else: - self.public_tails_url = self.admin_url + self.webhook_port = None self.webhook_url = None self.webhook_site = None @@ -174,7 +168,12 @@ def __init__( self.wallet_stats = [] async def register_schema_and_creddef( - self, schema_name, version, schema_attrs, support_revocation: bool = False + self, + schema_name, + version, + schema_attrs, + support_revocation: bool = False, + revocation_registry_size: int = None, ): # Create a schema schema_body = { @@ -191,6 +190,7 @@ async def register_schema_and_creddef( credential_definition_body = { "schema_id": schema_id, "support_revocation": support_revocation, + "revocation_registry_size": revocation_registry_size, } credential_definition_response = await self.admin_POST( "/credential-definitions", credential_definition_body @@ -201,71 +201,6 @@ async def register_schema_and_creddef( log_msg("Cred def ID:", credential_definition_id) return schema_id, credential_definition_id - async def create_and_publish_revocation_registry( - self, credential_def_id, max_cred_num - ): - revoc_response = await self.admin_POST( - "/revocation/create-registry", - { - "credential_definition_id": credential_def_id, - "max_cred_num": max_cred_num, - }, - ) - revocation_registry_id = revoc_response["result"]["revoc_reg_id"] - tails_hash = revoc_response["result"]["tails_hash"] - - # get the tails file from "GET /revocation/registry/{id}/tails-file" - tails_file = await self.admin_GET_FILE( - f"/revocation/registry/{revocation_registry_id}/tails-file" - ) - hasher = hashlib.sha256() - hasher.update(tails_file) - my_tails_hash = base58.b58encode(hasher.digest()).decode("utf-8") - log_msg(f"Revocation Registry ID: {revocation_registry_id}") - assert tails_hash == my_tails_hash - - tails_file_url = ( - f"{self.public_tails_url}/revocation/registry/" - f"{revocation_registry_id}/tails-file" - ) - if os.getenv("PUBLIC_TAILS_URL"): - tails_file_url = f"{self.public_tails_url}/{revocation_registry_id}" - tails_file_external_url = ( - f"{self.public_tails_url}/{revocation_registry_id}" - ) - elif RUN_MODE == "pwd": - tails_file_external_url = f"http://{self.external_host}".replace( - "{PORT}", str(self.admin_port) - ) - else: - tails_file_external_url = f"http://127.0.0.1:{self.admin_port}" - tails_file_external_url += ( - f"/revocation/registry/{revocation_registry_id}/tails-file" - ) - - revoc_updated_response = await self.admin_PATCH( - f"/revocation/registry/{revocation_registry_id}", - {"tails_public_uri": tails_file_url}, - ) - tails_public_uri = revoc_updated_response["result"]["tails_public_uri"] - assert tails_public_uri == tails_file_url - - revoc_publish_response = await self.admin_POST( - f"/revocation/registry/{revocation_registry_id}/publish" - ) - - # if PUBLIC_TAILS_URL is specified, upload tails file to tails server - if os.getenv("PUBLIC_TAILS_URL"): - tails_server_hash = await self.admin_PUT_FILE( - {"genesis": await default_genesis_txns(), "tails": tails_file}, - tails_file_url, - params=None, - ) - assert my_tails_hash == tails_server_hash.decode("utf-8") - log_msg(f"Public tails file URL: {tails_file_url}") - - return revoc_publish_response["result"]["revoc_reg_id"] - def get_agent_args(self): result = [ ("--endpoint", self.endpoint), @@ -310,6 +245,9 @@ def get_agent_args(self): ("--trace-label", self.label + ".trace"), ] ) + + if self.tails_server_base_url: + result.append(("--tails-server-base-url", self.tails_server_base_url)) else: # set the tracing parameters but don't enable tracing result.extend( @@ -482,6 +420,11 @@ async def handle_problem_report(self, message): f"Received problem report: {message['explain-ltxt']}\n", source="stderr" ) + async def handle_revocation_registry(self, message): + self.log( + f"Revocation registry: {message['record_id']} state: {message['state']}" + ) + async def admin_request( self, method, path, data=None, text=False, params=None ) -> ClientResponse: diff --git a/docs/GettingStartedAriesDev/CredentialRevocation.md b/docs/GettingStartedAriesDev/CredentialRevocation.md index 86d923f259..44cabbb04e 100644 --- a/docs/GettingStartedAriesDev/CredentialRevocation.md +++ b/docs/GettingStartedAriesDev/CredentialRevocation.md @@ -1,11 +1,23 @@ These are the ACA-py steps and APIs involved to support credential revocation. -0. Publish credential definition +Run ACA-Py with tails server support enabled. You will need to have the URL of an running instance of https://github.com/bcgov/indy-tails-server. + +Incude the command line parameter `--tails-server-base-url ` + +0. Publish credential definition + + Credential definition is created. All required revocation collateral is also created + and managed including revocation registry definition, entry, and tails file. + ``` POST /credential-definitions { "schema_id": schema_id, - "support_revocation": true + "support_revocation": true, + # Only needed if support_revocation is true. Defaults to 100 + "revocation_registry_size": size_int, + "tag": cred_def_tag # Optional + } Response: { @@ -13,44 +25,11 @@ These are the ACA-py steps and APIs involved to support credential revocation. } ``` -0. Create (but not publish yet) Revocation registry - ``` - POST /revocation/create-registry, - { - "credential_definition_id": "credential_definition_id", - "max_cred_num": size_of_revocation_registry - } - Response: - { - "revoc_reg_id": "revocation_registry_id", - "tails_hash": hash_of_tails_file, - "cred_def_id": "credential_definition_id", - ... - } - ``` +1. Issue credential -0. Get the tail file from agent - ``` - Get /revocation/registry/{revocation_registry_id}/tails-file - - Response: stream down a binary file: - content-type: application/octet-stream - ... - ``` -0. Upload the tails file to a publicly accessible location -0. Update the tails file public URI to agent - ``` - PATCH /revocation/registry/{revocation_registry_id} - { - "tails_public_uri": - } - ``` -0. Publish the revocation registry and first entry to the ledger - ``` - POST /revocation/registry/{revocation_registry_id}/publish - ``` + This endpoint manages revocation data. If new revocation registry data is required, + it is automatically managed in the background. -0. Issue credential ``` POST /issue-credential/send-offer { @@ -62,51 +41,55 @@ These are the ACA-py steps and APIs involved to support credential revocation. Response { "credential_exchange_id": credential_exchange_id - } - ``` -0. Revoking credential + } + ``` + +2. Revoking credential + ``` POST /issue-credential/revoke?rev_reg_id= &cred_rev_id=&publish= ``` -0. When asking for proof, specify the timespan when the credential is NOT revoked - ``` - POST /present-proof/send-request - { - "connection_id": ..., - "proof_request": { - "requested_attributes": [ - { - "name": ... - "restrictions": ..., - ... - "non_revoked": # Optional, override the global one when specified - { - "from": # Optional, default is 0 - "to": - } - }, - ... - ], - "requested_predicates": [ - { - "name": ... - ... - "non_revoked": # Optional, override the global one when specified - { - "from": # Optional, default is 0 - "to": - } - }, - ... - ], - "non_revoked": # Optional, only check revocation if specified - { - "from": # Optional, default is 0 - "to": - } - } - } + If publish=false, you must use `​/issue-credential​/publish-revocations` to publish + pending revocations in batches. Revocation are not written to ledger until this is called. + +3. When asking for proof, specify the timespan when the credential is NOT revoked + ``` + POST /present-proof/send-request + { + "connection_id": ..., + "proof_request": { + "requested_attributes": [ + { + "name": ... + "restrictions": ..., + ... + "non_revoked": # Optional, override the global one when specified + { + "from": # Optional, default is 0 + "to": + } + }, + ... + ], + "requested_predicates": [ + { + "name": ... + ... + "non_revoked": # Optional, override the global one when specified + { + "from": # Optional, default is 0 + "to": + } + }, + ... + ], + "non_revoked": # Optional, only check revocation if specified + { + "from": # Optional, default is 0 + "to": + } + } + } ``` - \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b5f5c6e37d..ea89d1c2fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ msgpack~=0.6.1 prompt_toolkit~=2.0.9 pynacl~=1.3.0 requests~=2.23.0 -pyld==2.0.1 \ No newline at end of file +pyld==2.0.1 +py_multicodec==0.2.1 \ No newline at end of file