Skip to content

Commit

Permalink
Merge pull request #350 from sklump/record-eq
Browse files Browse the repository at this point in the history
Hush LGTM with equality comparison for exhange records, tiny bug fix …
  • Loading branch information
andrewwhitehead authored Jan 30, 2020
2 parents 59ac609 + 10de47f commit 9e30014
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 2 deletions.
3 changes: 1 addition & 2 deletions aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def record_tags(self) -> dict:
return {
tag: getattr(self, prop)
for (prop, tag) in self.get_tag_map().items()
if getattr(self, prop)
if not None
if getattr(self, prop) is not None
}

@property
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Aries#0036 v1.0 credential exchange information with non-secrets storage."""

from typing import Any

from marshmallow import fields
from marshmallow.validate import OneOf

Expand Down Expand Up @@ -129,6 +131,10 @@ async def retrieve_by_connection_and_thread(
await cls.set_cached_key(context, cache_key, record.credential_exchange_id)
return record

def __eq__(self, other: Any) -> bool:
"""Comparison between records."""
return super().__eq__(other)


class V10CredentialExchangeSchema(BaseRecordSchema):
"""Schema to allow serialization/deserialization of credential exchange records."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,44 @@ async def setUp(self):
self.context.injector.bind_instance(BaseLedger, self.ledger)
self.manager = CredentialManager(self.context)

async def test_record_eq(self):
cred_def_id = "LjgpST2rjsoxYegQDRm7EL:3:CL:18:tag"
same = [
V10CredentialExchange(
credential_exchange_id="dummy-0",
thread_id="thread-0",
credential_definition_id=cred_def_id,
role=V10CredentialExchange.ROLE_ISSUER
)
] * 2
diff = [
V10CredentialExchange(
credential_exchange_id="dummy-1",
credential_definition_id=cred_def_id,
role=V10CredentialExchange.ROLE_ISSUER
),
V10CredentialExchange(
credential_exchange_id="dummy-0",
thread_id="thread-1",
credential_definition_id=cred_def_id,
role=V10CredentialExchange.ROLE_ISSUER
),
V10CredentialExchange(
credential_exchange_id="dummy-1",
thread_id="thread-0",
credential_definition_id=f"{cred_def_id}_distinct_tag",
role=V10CredentialExchange.ROLE_ISSUER
)
]

for i in range(len(same) - 1):
for j in range(i, len(same)):
assert same[i] == same[j]

for i in range(len(diff) - 1):
for j in range(i, len(diff)):
assert diff[i] == diff[j] if i == j else diff[i] != diff[j]

async def test_prepare_send(self):
schema_id = "LjgpST2rjsoxYegQDRm7EL:2:bc-reg:1.0"
cred_def_id = "LjgpST2rjsoxYegQDRm7EL:3:CL:18:tag"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Aries#0037 v1.0 presentation exchange information with non-secrets storage."""

from typing import Any

from marshmallow import fields
from marshmallow.validate import OneOf

Expand Down Expand Up @@ -90,6 +92,10 @@ def record_value(self) -> dict:
)
}

def __eq__(self, other: Any) -> bool:
"""Comparison between records."""
return super().__eq__(other)


class V10PresentationExchangeSchema(BaseRecordSchema):
"""Schema for de/serialization of v1.0 presentation exchange records."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,39 @@ async def setUp(self):

self.manager = PresentationManager(self.context)

async def test_record_eq(self):
same = [
V10PresentationExchange(
presentation_exchange_id="dummy-0",
thread_id="thread-0",
role=V10PresentationExchange.ROLE_PROVER
)
] * 2
diff = [
V10PresentationExchange(
presentation_exchange_id="dummy-1",
role=V10PresentationExchange.ROLE_PROVER
),
V10PresentationExchange(
presentation_exchange_id="dummy-0",
thread_id="thread-1",
role=V10PresentationExchange.ROLE_PROVER
),
V10PresentationExchange(
presentation_exchange_id="dummy-1",
thread_id="thread-0",
role=V10PresentationExchange.ROLE_VERIFIER
)
]

for i in range(len(same) - 1):
for j in range(i, len(same)):
assert same[i] == same[j]

for i in range(len(diff) - 1):
for j in range(i, len(diff)):
assert diff[i] == diff[j] if i == j else diff[i] != diff[j]

async def test_create_exchange_for_proposal(self):
self.context.connection_record = async_mock.MagicMock()
self.context.connection_record.connection_id = CONN_ID
Expand Down

0 comments on commit 9e30014

Please sign in to comment.