Skip to content

Commit

Permalink
Merge pull request openwallet-foundation#2515 from dbluhm/refactor/me…
Browse files Browse the repository at this point in the history
…diation-terms-remove

refactor: drop mediator_terms and recipient_terms
  • Loading branch information
swcurran authored Sep 28, 2023
2 parents 04304aa + 5af00ba commit bad688e
Show file tree
Hide file tree
Showing 15 changed files with 103 additions and 210 deletions.
8 changes: 4 additions & 4 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ class Meta:

def __init__(
self,
id: str = None,
state: str = None,
id: Optional[str] = None,
state: Optional[str] = None,
*,
created_at: Union[str, datetime] = None,
updated_at: Union[str, datetime] = None,
created_at: Union[str, datetime, None] = None,
updated_at: Union[str, datetime, None] = None,
new_with_id: bool = False,
):
"""Initialize a new BaseRecord."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ async def test_create_request_multitenant(self):

async def test_create_request_mediation_id(self):
mediation_record = MediationRecord(
mediation_id="test_medation_id",
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
Expand Down Expand Up @@ -866,6 +867,7 @@ async def test_create_response_multitenant(self):
)

mediation_record = MediationRecord(
mediation_id="test_medation_id",
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
Expand Down Expand Up @@ -936,6 +938,7 @@ async def test_create_response_bad_state(self):

async def test_create_response_mediation(self):
mediation_record = MediationRecord(
mediation_id="test_medation_id",
role=MediationRecord.ROLE_CLIENT,
state=MediationRecord.STATE_GRANTED,
connection_id=self.test_mediator_conn_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from ..mediation_deny_handler import MediationDenyHandler

TEST_CONN_ID = "conn-id"
TEST_MEDIATOR_TERMS = ["test", "mediator", "terms"]
TEST_RECIPIENT_TERMS = ["test", "recipient", "terms"]


class TestMediationDenyHandler(AsyncTestCase):
Expand All @@ -22,9 +20,7 @@ async def setUp(self):
"""Setup test dependencies."""
self.context = RequestContext.test_context()
self.session = await self.context.session()
self.context.message = MediationDeny(
mediator_terms=TEST_MEDIATOR_TERMS, recipient_terms=TEST_RECIPIENT_TERMS
)
self.context.message = MediationDeny()
self.context.connection_ready = True
self.context.connection_record = ConnRecord(connection_id=TEST_CONN_ID)

Expand All @@ -50,5 +46,3 @@ async def test_handler(self):
)
assert record
assert record.state == MediationRecord.STATE_DENIED
assert record.mediator_terms == TEST_MEDIATOR_TERMS
assert record.recipient_terms == TEST_RECIPIENT_TERMS
28 changes: 2 additions & 26 deletions aries_cloudagent/protocols/coordinate_mediation/v1_0/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,8 @@ async def receive_request(
"MediationRecord already exists for connection"
)

# TODO: Determine if terms are acceptable
record = MediationRecord(
connection_id=connection_id,
mediator_terms=request.mediator_terms,
recipient_terms=request.recipient_terms,
)
await record.save(session, reason="New mediation request received")
return record
Expand Down Expand Up @@ -186,19 +183,11 @@ async def grant_request(
async def deny_request(
self,
mediation_id: str,
*,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
) -> Tuple[MediationRecord, MediationDeny]:
"""Deny a mediation request and prepare a deny message.
Args:
mediation_id: mediation record ID to deny
mediator_terms (Sequence[str]): updated mediator terms to return to
requester.
recipient_terms (Sequence[str]): updated recipient terms to return to
requester.
Returns:
MediationDeny: message to return to denied client.
Expand All @@ -215,9 +204,7 @@ async def deny_request(
mediation_record.state = MediationRecord.STATE_DENIED
await mediation_record.save(session, reason="Mediation request denied")

deny = MediationDeny(
mediator_terms=mediator_terms, recipient_terms=recipient_terms
)
deny = MediationDeny()
return mediation_record, deny

async def _handle_keylist_update_add(
Expand Down Expand Up @@ -442,15 +429,11 @@ async def clear_default_mediator(self):
async def prepare_request(
self,
connection_id: str,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
) -> Tuple[MediationRecord, MediationRequest]:
"""Prepare a MediationRequest Message, saving a new mediation record.
Args:
connection_id (str): ID representing mediator
mediator_terms (Sequence[str]): mediator_terms
recipient_terms (Sequence[str]): recipient_terms
Returns:
MediationRequest: message to send to mediator
Expand All @@ -459,15 +442,11 @@ async def prepare_request(
record = MediationRecord(
role=MediationRecord.ROLE_CLIENT,
connection_id=connection_id,
mediator_terms=mediator_terms,
recipient_terms=recipient_terms,
)

async with self._profile.session() as session:
await record.save(session, reason="Creating new mediation request.")
request = MediationRequest(
mediator_terms=mediator_terms, recipient_terms=recipient_terms
)
request = MediationRequest()
return record, request

async def request_granted(self, record: MediationRecord, grant: MediationGrant):
Expand Down Expand Up @@ -495,9 +474,6 @@ async def request_denied(self, record: MediationRecord, deny: MediationDeny):
"""
record.state = MediationRecord.STATE_DENIED
# TODO Record terms elsewhere?
record.mediator_terms = deny.mediator_terms
record.recipient_terms = deny.recipient_terms
async with self._profile.session() as session:
await record.save(session, reason="Mediation request denied.")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""mediate-deny message used to notify mediation client of a denied mediation request."""

from typing import Sequence

from marshmallow import fields

from .....messaging.agent_message import AgentMessage, AgentMessageSchema
from ..message_types import MEDIATE_DENY, PROTOCOL_PACKAGE
Expand All @@ -24,20 +21,10 @@ class Meta:

def __init__(
self,
*,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
**kwargs,
):
"""Initialize mediation deny object.
Args:
mediator_terms: Terms that were agreed by the recipient
recipient_terms: Terms that recipient wants to mediator to agree to
"""
"""Initialize mediation deny object."""
super(MediationDeny, self).__init__(**kwargs)
self.mediator_terms = list(mediator_terms) if mediator_terms else []
self.recipient_terms = list(recipient_terms) if recipient_terms else []


class MediationDenySchema(AgentMessageSchema):
Expand All @@ -47,12 +34,3 @@ class Meta:
"""Mediation deny schema metadata."""

model_class = MediationDeny

mediator_terms = fields.List(
fields.Str(metadata={"description": "Terms for mediator to agree"}),
required=False,
)
recipient_terms = fields.List(
fields.Str(metadata={"description": "Terms for recipient to agree"}),
required=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Used to notify mediation client of a granted mediation request.
"""

from typing import Sequence
from typing import Optional, Sequence

from marshmallow import fields

Expand All @@ -29,8 +29,8 @@ class Meta:
def __init__(
self,
*,
endpoint: str = None,
routing_keys: Sequence[str] = None,
endpoint: Optional[str] = None,
routing_keys: Optional[Sequence[str]] = None,
**kwargs,
):
"""Initialize mediation grant object.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""mediate-request message used to request mediation from a mediator."""

from typing import Sequence

from marshmallow import fields

from .....messaging.agent_message import AgentMessage, AgentMessageSchema
from ..message_types import MEDIATE_REQUEST, PROTOCOL_PACKAGE
Expand All @@ -22,22 +19,9 @@ class Meta:
message_type = MEDIATE_REQUEST
schema_class = "MediationRequestSchema"

def __init__(
self,
*,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
**kwargs,
):
"""Initialize mediation request object.
Args:
mediator_terms: Mediator's terms for granting mediation.
recipient_terms: Recipient's proposed mediation terms.
"""
def __init__(self, **kwargs):
"""Initialize mediation request object."""
super(MediationRequest, self).__init__(**kwargs)
self.mediator_terms = list(mediator_terms) if mediator_terms else []
self.recipient_terms = list(recipient_terms) if recipient_terms else []


class MediationRequestSchema(AgentMessageSchema):
Expand All @@ -47,27 +31,3 @@ class Meta:
"""Mediation request schema metadata."""

model_class = MediationRequest

mediator_terms = fields.List(
fields.Str(
metadata={
"description": (
"Indicate terms that the mediator requires the recipient to"
" agree to"
)
}
),
required=False,
metadata={"description": "List of mediator rules for recipient"},
)
recipient_terms = fields.List(
fields.Str(
metadata={
"description": (
"Indicate terms that the recipient requires the mediator to"
" agree to"
)
}
),
required=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ class TestMediateDeny(MessageTest, TestCase):
TYPE = MEDIATE_DENY
CLASS = MediationDeny
SCHEMA = MediationDenySchema
VALUES = {"mediator_terms": ["test", "terms"], "recipient_terms": ["test", "terms"]}
VALUES = {}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ class TestMediateRequest(MessageTest, TestCase):
TYPE = MEDIATE_REQUEST
CLASS = MediationRequest
SCHEMA = MediationRequestSchema
VALUES = {"mediator_terms": ["test", "terms"], "recipient_terms": ["test", "terms"]}
VALUES = {}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Store state for Mediation requests."""

from typing import Sequence
from typing import Optional, Sequence

from marshmallow import EXCLUDE, fields

Expand Down Expand Up @@ -33,14 +33,15 @@ class Meta:
def __init__(
self,
*,
mediation_id: str = None,
state: str = None,
role: str = None,
connection_id: str = None,
mediator_terms: Sequence[str] = None,
recipient_terms: Sequence[str] = None,
routing_keys: Sequence[str] = None,
endpoint: str = None,
mediation_id: Optional[str] = None,
state: Optional[str] = None,
role: Optional[str] = None,
connection_id: Optional[str] = None,
routing_keys: Optional[Sequence[str]] = None,
endpoint: Optional[str] = None,
# Included for record backwards compat
mediator_terms: Optional[Sequence[str]] = None,
recipient_terms: Optional[Sequence[str]] = None,
**kwargs,
):
"""__init__.
Expand All @@ -50,8 +51,6 @@ def __init__(
state (str): state, defaults to 'request_received'
role (str): role in mediation, defaults to 'server'
connection_id (str): ID of connection requesting or managing mediation
mediator_terms (Sequence[str]): mediator_terms
recipient_terms (Sequence[str]): recipient_terms
routing_keys (Sequence[str]): keys in mediator control used to
receive incoming messages
endpoint (str): mediators endpoint
Expand All @@ -61,8 +60,6 @@ def __init__(
super().__init__(mediation_id, state or self.STATE_REQUEST, **kwargs)
self.role = role if role else self.ROLE_SERVER
self.connection_id = connection_id
self.mediator_terms = list(mediator_terms) if mediator_terms else []
self.recipient_terms = list(recipient_terms) if recipient_terms else []
self.routing_keys = list(routing_keys) if routing_keys else []
self.endpoint = endpoint

Expand All @@ -79,6 +76,8 @@ def __eq__(self, other: "MediationRecord"):
@property
def mediation_id(self) -> str:
"""Get Mediation ID."""
if not self._id:
raise ValueError("Record not yet stored")
return self._id

@property
Expand Down Expand Up @@ -109,8 +108,6 @@ def record_value(self) -> dict:
return {
prop: getattr(self, prop)
for prop in (
"mediator_terms",
"recipient_terms",
"routing_keys",
"endpoint",
)
Expand Down Expand Up @@ -170,10 +167,12 @@ class Meta:
mediation_id = fields.Str(required=False)
role = fields.Str(required=True)
connection_id = fields.Str(required=True)
mediator_terms = fields.List(fields.Str(), required=False)
recipient_terms = fields.List(fields.Str(), required=False)
routing_keys = fields.List(
fields.Str(validate=DID_KEY_VALIDATE, metadata={"example": DID_KEY_EXAMPLE}),
required=False,
)
endpoint = fields.Str(required=False)

# Included for backwards compat with old records
mediator_terms = fields.List(fields.Str(), required=False)
recipient_terms = fields.List(fields.Str(), required=False)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Tests for models."""
Loading

0 comments on commit bad688e

Please sign in to comment.