diff --git a/aries_cloudagent/config/argparse.py b/aries_cloudagent/config/argparse.py index 8132ad51e2..1f676f9dd2 100644 --- a/aries_cloudagent/config/argparse.py +++ b/aries_cloudagent/config/argparse.py @@ -7,6 +7,7 @@ from typing import Type from .error import ArgsParseError +from .util import ByteSize CAT_PROVISION = "general" CAT_START = "start" @@ -574,7 +575,6 @@ def add_arguments(self, parser: ArgumentParser): be specified multiple times to create multiple interfaces.\ Supported inbound transport types are 'http' and 'ws'.", ) - parser.add_argument( "-ot", "--outbound-transport", @@ -588,7 +588,6 @@ def add_arguments(self, parser: ArgumentParser): multiple times to supoort multiple transport types. Supported outbound\ transport types are 'http' and 'ws'.", ) - parser.add_argument( "-e", "--endpoint", @@ -605,7 +604,6 @@ def add_arguments(self, parser: ArgumentParser): The endpoints are used in the formation of a connection \ with another agent.", ) - parser.add_argument( "-l", "--label", @@ -614,6 +612,13 @@ def add_arguments(self, parser: ArgumentParser): help="Specifies the label for this agent. This label is publicized\ (self-attested) to other agents as part of forming a connection.", ) + parser.add_argument( + "--max-message-size", + default=2097152, + type=ByteSize(min_size=1024), + metavar="", + help="Set the maximum size in bytes for inbound agent messages.", + ) parser.add_argument( "--enable-undelivered-queue", @@ -635,6 +640,9 @@ def get_settings(self, args: Namespace): settings["additional_endpoints"] = args.endpoint[1:] if args.label: settings["default_label"] = args.label + if args.max_message_size: + settings["transport.max_message_size"] = args.max_message_size + return settings diff --git a/aries_cloudagent/config/tests/test_argparse.py b/aries_cloudagent/config/tests/test_argparse.py index bb64b915f2..ae6dd85251 100644 --- a/aries_cloudagent/config/tests/test_argparse.py +++ b/aries_cloudagent/config/tests/test_argparse.py @@ -1,9 +1,10 @@ import itertools -from argparse import ArgumentParser +from argparse import ArgumentParser, ArgumentTypeError from asynctest import TestCase as AsyncTestCase, mock as async_mock from .. import argparse +from ..util import ByteSize class TestArgParse(AsyncTestCase): @@ -41,13 +42,16 @@ async def test_transport_settings(self): "http", "-e", "http://default.endpoint/", - "ws://alternate.endpoint/" + "ws://alternate.endpoint/", ] ) assert result.inbound_transports == [["http", "0.0.0.0", "80"]] assert result.outbound_transports == ["http"] - assert result.endpoint == ["http://default.endpoint/", "ws://alternate.endpoint/"] + assert result.endpoint == [ + "http://default.endpoint/", + "ws://alternate.endpoint/", + ] settings = group.get_settings(result) @@ -55,3 +59,34 @@ async def test_transport_settings(self): assert settings.get("transport.outbound_configs") == ["http"] assert settings.get("default_endpoint") == "http://default.endpoint/" assert settings.get("additional_endpoints") == ["ws://alternate.endpoint/"] + + def test_bytesize(self): + bs = ByteSize() + with self.assertRaises(ArgumentTypeError): + bs(None) + with self.assertRaises(ArgumentTypeError): + bs("") + with self.assertRaises(ArgumentTypeError): + bs("a") + with self.assertRaises(ArgumentTypeError): + bs("1.5") + with self.assertRaises(ArgumentTypeError): + bs("-1") + assert bs("101") == 101 + assert bs("101b") == 101 + assert bs("101KB") == 103424 + assert bs("2M") == 2097152 + assert bs("1G") == 1073741824 + assert bs("1t") == 1099511627776 + + bs = ByteSize(min_size=10) + with self.assertRaises(ArgumentTypeError): + bs("5") + assert bs("12") == 12 + + bs = ByteSize(max_size=10) + with self.assertRaises(ArgumentTypeError): + bs("15") + assert bs("10") == 10 + + assert repr(bs) == "ByteSize" diff --git a/aries_cloudagent/config/util.py b/aries_cloudagent/config/util.py index 5fdfaebe86..a8ad9f9933 100644 --- a/aries_cloudagent/config/util.py +++ b/aries_cloudagent/config/util.py @@ -1,6 +1,9 @@ """Entrypoint.""" import os +import re + +from argparse import ArgumentTypeError from typing import Any, Mapping from .logging import LoggingConfigurator @@ -21,3 +24,43 @@ def common_config(settings: Mapping[str, Any]): and settings.get("wallet.storage_type") == "postgres_storage" ): load_postgres_plugin() + + +class ByteSize: + """Argument value parser for byte sizes.""" + + def __init__(self, min_size: int = 0, max_size: int = 0): + """Initialize the ByteSize parser.""" + self.min_size = min_size + self.max_size = max_size + + def __call__(self, arg: str) -> int: + """Interpret the argument value.""" + if not arg: + raise ArgumentTypeError("Expected value") + parts = re.match(r"^(\d+)([kKmMgGtT]?)[bB]?$", arg) + if not parts: + raise ArgumentTypeError("Invalid format") + size = int(parts[1]) + suffix = parts[2].upper() + if suffix == "K": + size = size << 10 + elif suffix == "M": + size = size << 20 + elif suffix == "G": + size = size << 30 + elif suffix == "T": + size = size << 40 + if size < self.min_size: + raise ArgumentTypeError( + f"Size must be greater than or equal to {self.min_size}" + ) + if self.max_size and size > self.max_size: + raise ArgumentTypeError( + f"Size must be less than or equal to {self.max_size}" + ) + return size + + def __repr__(self): + """Format for in error reporting.""" + return self.__class__.__name__ diff --git a/aries_cloudagent/transport/inbound/base.py b/aries_cloudagent/transport/inbound/base.py index a2c5ea52ec..b3c87a4196 100644 --- a/aries_cloudagent/transport/inbound/base.py +++ b/aries_cloudagent/transport/inbound/base.py @@ -14,12 +14,30 @@ class BaseInboundTransport(ABC): """Base inbound transport class.""" def __init__( - self, scheme: str, create_session: Callable, + self, + scheme: str, + create_session: Callable, + *, + max_message_size: int = 0, + wire_format: BaseWireFormat = None, ): - """Initialize the inbound transport instance.""" + """ + Initialize the inbound transport instance. + + Args: + scheme: The transport scheme identifier + create_session: Method to create a new inbound session + """ + self._create_session = create_session + self._max_message_size = max_message_size self._scheme = scheme - self.wire_format: BaseWireFormat = None + self.wire_format: BaseWireFormat = wire_format + + @property + def max_message_size(self): + """Accessor for this transport's max message size.""" + return self._max_message_size @property def scheme(self): diff --git a/aries_cloudagent/transport/inbound/http.py b/aries_cloudagent/transport/inbound/http.py index 73730aca5f..823dc70f5c 100644 --- a/aries_cloudagent/transport/inbound/http.py +++ b/aries_cloudagent/transport/inbound/http.py @@ -14,9 +14,9 @@ class HttpTransport(BaseInboundTransport): """Http Transport class.""" - def __init__(self, host: str, port: int, create_session) -> None: + def __init__(self, host: str, port: int, create_session, **kwargs) -> None: """ - Initialize a Transport instance. + Initialize an inbound HTTP transport instance. Args: host: Host to listen on @@ -24,14 +24,17 @@ def __init__(self, host: str, port: int, create_session) -> None: create_session: Method to create a new inbound session """ - super().__init__("http", create_session) + super().__init__("http", create_session, **kwargs) self.host = host self.port = port self.site: web.BaseSite = None async def make_application(self) -> web.Application: """Construct the aiohttp application.""" - app = web.Application() + app_args = {} + if self.max_message_size: + app_args["client_max_size"] = self.max_message_size + app = web.Application(**app_args) app.add_routes([web.get("/", self.invite_message_handler)]) app.add_routes([web.post("/", self.inbound_message_handler)]) return app diff --git a/aries_cloudagent/transport/inbound/manager.py b/aries_cloudagent/transport/inbound/manager.py index 4a48e8609d..ce32bc8968 100644 --- a/aries_cloudagent/transport/inbound/manager.py +++ b/aries_cloudagent/transport/inbound/manager.py @@ -37,6 +37,7 @@ def __init__( ): """Initialize an `InboundTransportManager` instance.""" self.context = context + self.max_message_size = 0 self.receive_inbound = receive_inbound self.return_inbound = return_inbound self.registered_transports = {} @@ -48,6 +49,10 @@ def __init__( async def setup(self): """Perform setup operations.""" + # Load config settings + if self.context.settings.get("transport.max_message_size"): + self.max_message_size = self.context.settings["transport.max_message_size"] + inbound_transports = ( self.context.settings.get("transport.inbound_configs") or [] ) @@ -81,7 +86,12 @@ def register(self, config: InboundTransportConfiguration) -> str: ) from e return self.register_transport( - imported_class(config.host, config.port, self.create_session), + imported_class( + config.host, + config.port, + self.create_session, + max_message_size=self.max_message_size, + ), imported_class.__qualname__, ) diff --git a/aries_cloudagent/transport/inbound/ws.py b/aries_cloudagent/transport/inbound/ws.py index d720ce733e..13de94f6c0 100644 --- a/aries_cloudagent/transport/inbound/ws.py +++ b/aries_cloudagent/transport/inbound/ws.py @@ -15,9 +15,9 @@ class WsTransport(BaseInboundTransport): """Websockets Transport class.""" - def __init__(self, host: str, port: int, create_session) -> None: + def __init__(self, host: str, port: int, create_session, **kwargs) -> None: """ - Initialize a Transport instance. + Initialize an inbound WebSocket transport instance. Args: host: Host to listen on @@ -25,7 +25,7 @@ def __init__(self, host: str, port: int, create_session) -> None: create_session: Method to create a new inbound session """ - super().__init__("ws", create_session) + super().__init__("ws", create_session, **kwargs) self.host = host self.port = port self.site: web.BaseSite = None