Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: OOB - Handling of minor versions #1940

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions aries_cloudagent/core/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
import warnings

from typing import Callable, Coroutine, Union
from typing import Callable, Coroutine, Optional, Union, Tuple
import weakref

from aiohttp.web import HTTPException
Expand All @@ -36,6 +36,13 @@

from .error import ProtocolMinorVersionNotSupported
from .protocol_registry import ProtocolRegistry
from .util import (
get_version_from_message_type,
validate_get_response_version,
# WARNING_DEGRADED_FEATURES,
# WARNING_VERSION_MISMATCH,
# WARNING_VERSION_NOT_SUPPORTED,
)

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -133,16 +140,22 @@ async def handle_message(
inbound_message: The inbound message instance
send_outbound: Async function to send outbound messages

# Raises:
# MessageParseError: If the message type version is not supported

Returns:
The response from the handler

"""
r_time = get_timer()

error_result = None
version_warning = None
message = None
try:
message = await self.make_message(inbound_message.payload)
(message, warning) = await self.make_message(
profile, inbound_message.payload
)
except ProblemReportParseError:
pass # avoid problem report recursion
except MessageParseError as e:
Expand All @@ -155,6 +168,47 @@ async def handle_message(
)
if inbound_message.receipt.thread_id:
error_result.assign_thread_id(inbound_message.receipt.thread_id)
# if warning:
# warning_message_type = inbound_message.payload.get("@type")
# if warning == WARNING_DEGRADED_FEATURES:
# LOGGER.error(
# f"Sending {WARNING_DEGRADED_FEATURES} problem report, "
# "message type received with a minor version at or higher"
# " than protocol minimum supported and current minor version "
# f"for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version at or "
# "higher than protocol minimum supported and current"
# f" minor version for message_type {warning_message_type}"
# ),
# "code": WARNING_DEGRADED_FEATURES,
# }
# )
# elif warning == WARNING_VERSION_MISMATCH:
# LOGGER.error(
# f"Sending {WARNING_VERSION_MISMATCH} problem report, message "
# "type received with a minor version higher than current minor "
# f"version for message_type {warning_message_type}"
# )
# version_warning = ProblemReport(
# description={
# "en": (
# "message type received with a minor version higher"
# " than current minor version for message_type"
# f" {warning_message_type}"
# ),
# "code": WARNING_VERSION_MISMATCH,
# }
# )
# elif warning == WARNING_VERSION_NOT_SUPPORTED:
# raise MessageParseError(
# f"Message type version not supported for {warning_message_type}"
# )
# if version_warning and inbound_message.receipt.thread_id:
# version_warning.assign_thread_id(inbound_message.receipt.thread_id)

trace_event(
self.profile.settings,
Expand Down Expand Up @@ -199,6 +253,8 @@ async def handle_message(

if error_result:
await responder.send_reply(error_result)
elif version_warning:
await responder.send_reply(version_warning)
elif context.message:
context.injector.bind_instance(BaseResponder, responder)

Expand All @@ -215,7 +271,9 @@ async def handle_message(
perf_counter=r_time,
)

async def make_message(self, parsed_msg: dict) -> BaseMessage:
async def make_message(
self, profile: Profile, parsed_msg: dict
) -> Tuple[BaseMessage, Optional[str]]:
"""
Deserialize a message dict into the appropriate message instance.

Expand All @@ -224,6 +282,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:

Args:
parsed_msg: The parsed message
profile: Profile

Returns:
An instance of the corresponding message class for this message
Expand All @@ -237,6 +296,7 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:
if not isinstance(parsed_msg, dict):
raise MessageParseError("Expected a JSON object")
message_type = parsed_msg.get("@type")
message_type_rec_version = get_version_from_message_type(message_type)

if not message_type:
raise MessageParseError("Message does not contain '@type' parameter")
Expand All @@ -256,8 +316,10 @@ async def make_message(self, parsed_msg: dict) -> BaseMessage:
if "/problem-report" in message_type:
raise ProblemReportParseError("Error parsing problem report message")
raise MessageParseError(f"Error deserializing message: {e}") from e

return instance
_, warning = await validate_get_response_version(
profile, message_type_rec_version, message_cls
)
return (instance, warning)

async def complete(self, timeout: float = 0.1):
"""Wait for pending tasks to complete."""
Expand Down
94 changes: 82 additions & 12 deletions aries_cloudagent/core/protocol_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging

from string import Template
from typing import Mapping, Sequence

from ..config.injection_context import InjectionContext
Expand Down Expand Up @@ -74,6 +75,73 @@ def parse_type_string(self, message_type):
"minor_version": int(version_string_tokens[1]),
}

def create_msg_types_for_minor_version(self, typesets, version_definition):
"""
Return mapping of message type to module path for minor versions.

Args:
typesets: Mappings of message types to register
version_definition: Optional version definition dict

Returns:
Typesets mapping

"""
updated_typeset = {}
curr_minor_version = version_definition["current_minor_version"]
min_minor_version = version_definition["minimum_minor_version"]
major_version = version_definition["major_version"]
if curr_minor_version >= min_minor_version and curr_minor_version >= 1:
for version_index in range(min_minor_version, curr_minor_version + 1):
to_check = f"{str(major_version)}.{str(version_index)}"
updated_typeset.update(
self._get_updated_tyoeset_dict(typesets, to_check, updated_typeset)
)
return (updated_typeset,)

def _get_updated_tyoeset_dict(self, typesets, to_check, updated_typeset) -> dict:
for typeset in typesets:
for msg_type_string, module_path in typeset.items():
updated_msg_type_string = Template(msg_type_string).substitute(
version=to_check
)
updated_typeset[updated_msg_type_string] = module_path
return updated_typeset

def _template_message_type_check(self, typeset) -> bool:
for msg_type_string, _ in typeset.items():
if "$version" in msg_type_string:
return True
return False

def _create_and_register_updated_typesets(self, typesets, version_definition):
updated_typesets = self.create_msg_types_for_minor_version(
typesets, version_definition
)
update_flag = False
for typeset in updated_typesets:
if typeset:
self._typemap.update(typeset)
update_flag = True
if update_flag:
return updated_typesets
else:
return None

def _update_version_map(self, message_type_string, module_path, version_definition):
parsed_type_string = self.parse_type_string(message_type_string)

if version_definition["major_version"] not in self._versionmap:
self._versionmap[version_definition["major_version"]] = []

self._versionmap[version_definition["major_version"]].append(
{
"parsed_type_string": parsed_type_string,
"version_definition": version_definition,
"message_module": module_path,
}
)

def register_message_types(self, *typesets, version_definition=None):
"""
Add new supported message types.
Expand All @@ -85,24 +153,26 @@ def register_message_types(self, *typesets, version_definition=None):
"""

# Maintain support for versionless protocol modules
template_msg_type_version = True
updated_typesets = None
for typeset in typesets:
self._typemap.update(typeset)
if not self._template_message_type_check(typeset):
self._typemap.update(typeset)
template_msg_type_version = False

# Track versioned modules for version routing
if version_definition:
# create updated typesets for minor versions and register them
if template_msg_type_version:
updated_typesets = self._create_and_register_updated_typesets(
typesets, version_definition
)
if updated_typesets:
typesets = updated_typesets
for typeset in typesets:
for message_type_string, module_path in typeset.items():
parsed_type_string = self.parse_type_string(message_type_string)

if version_definition["major_version"] not in self._versionmap:
self._versionmap[version_definition["major_version"]] = []

self._versionmap[version_definition["major_version"]].append(
{
"parsed_type_string": parsed_type_string,
"version_definition": version_definition,
"message_module": module_path,
}
self._update_version_map(
message_type_string, module_path, version_definition
)

def register_controllers(self, *controller_sets, version_definition=None):
Expand Down
Loading