Skip to content

Commit

Permalink
Merge pull request #2071 from rmnre/record-attribute-validation
Browse files Browse the repository at this point in the history
Improved validation of record attributes
  • Loading branch information
ianco authored Jan 13, 2023
2 parents 4c20dcd + 4236b69 commit aedcbd3
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 47 deletions.
18 changes: 3 additions & 15 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,11 +677,7 @@ class Meta:
required=False,
description="Routing state of connection",
validate=validate.OneOf(
[
getattr(ConnRecord, m)
for m in vars(ConnRecord)
if m.startswith("ROUTING_STATE_")
]
ConnRecord.get_attributes_by_prefix("ROUTING_STATE_", walk_mro=False)
),
example=ConnRecord.ROUTING_STATE_ACTIVE,
)
Expand All @@ -690,11 +686,7 @@ class Meta:
description="Connection acceptance: manual or auto",
example=ConnRecord.ACCEPT_AUTO,
validate=validate.OneOf(
[
getattr(ConnRecord, a)
for a in vars(ConnRecord)
if a.startswith("ACCEPT_")
]
ConnRecord.get_attributes_by_prefix("ACCEPT_", walk_mro=False)
),
)
error_msg = fields.Str(
Expand All @@ -707,11 +699,7 @@ class Meta:
description="Invitation mode",
example=ConnRecord.INVITATION_MODE_ONCE,
validate=validate.OneOf(
[
getattr(ConnRecord, i)
for i in vars(ConnRecord)
if i.startswith("INVITATION_MODE_")
]
ConnRecord.get_attributes_by_prefix("INVITATION_MODE_", walk_mro=False)
),
)
alias = fields.Str(
Expand Down
21 changes: 20 additions & 1 deletion aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Meta:
EVENT_NAMESPACE: str = "acapy::record"
LOG_STATE_FLAG = None
TAG_NAMES = {"state"}
STATE_DELETED = "deleted"

def __init__(
self,
Expand Down Expand Up @@ -420,7 +421,7 @@ async def delete_record(self, session: ProfileSession):
storage = session.inject(BaseStorage)
if self.state:
self._previous_state = self.state
self.state = "deleted"
self.state = BaseRecord.STATE_DELETED
await self.emit_event(session, self.serialize())
await storage.delete_record(self.storage_record)

Expand Down Expand Up @@ -497,6 +498,24 @@ def __eq__(self, other: Any) -> bool:
return self.value == other.value and self.tags == other.tags
return False

@classmethod
def get_attributes_by_prefix(cls, prefix: str, walk_mro: bool = True):
"""
List all values for attributes with common prefix.
Args:
prefix: Common prefix to look for
walk_mro: Walk MRO to find attributes inherited from superclasses
"""

bases = cls.__mro__ if walk_mro else [cls]
return [
vars(base)[name]
for base in bases
for name in vars(base)
if name.startswith(prefix)
]


class BaseExchangeRecord(BaseRecord):
"""Represents a base record with event tracing capability."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,35 +296,23 @@ class Meta:
description="Issue-credential exchange initiator: self or external",
example=V20CredExRecord.INITIATOR_SELF,
validate=validate.OneOf(
[
getattr(V20CredExRecord, m)
for m in vars(V20CredExRecord)
if m.startswith("INITIATOR_")
]
V20CredExRecord.get_attributes_by_prefix("INITIATOR_", walk_mro=False)
),
)
role = fields.Str(
required=False,
description="Issue-credential exchange role: holder or issuer",
example=V20CredExRecord.ROLE_ISSUER,
validate=validate.OneOf(
[
getattr(V20CredExRecord, m)
for m in vars(V20CredExRecord)
if m.startswith("ROLE_")
]
V20CredExRecord.get_attributes_by_prefix("ROLE_", walk_mro=False)
),
)
state = fields.Str(
required=False,
description="Issue-credential exchange state",
example=V20CredExRecord.STATE_DONE,
validate=validate.OneOf(
[
getattr(V20CredExRecord, m)
for m in vars(V20CredExRecord)
if m.startswith("STATE_")
]
V20CredExRecord.get_attributes_by_prefix("STATE_", walk_mro=True)
),
)
cred_preview = fields.Nested(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from typing import Any, Mapping, Optional, Union

from marshmallow import fields
from marshmallow import fields, validate

from .....connections.models.conn_record import ConnRecord
from .....core.profile import ProfileSession
Expand Down Expand Up @@ -248,6 +248,9 @@ class Meta:
required=True,
description="Out of band message exchange state",
example=OobRecord.STATE_AWAIT_RESPONSE,
validate=validate.OneOf(
OobRecord.get_attributes_by_prefix("STATE_", walk_mro=True)
),
)
invi_msg_id = fields.Str(
required=True,
Expand Down Expand Up @@ -287,4 +290,7 @@ class Meta:
description="OOB Role",
required=False,
example=OobRecord.ROLE_RECEIVER,
validate=validate.OneOf(
OobRecord.get_attributes_by_prefix("ROLE_", walk_mro=False)
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -244,34 +244,22 @@ class Meta:
description="Present-proof exchange initiator: self or external",
example=V20PresExRecord.INITIATOR_SELF,
validate=validate.OneOf(
[
getattr(V20PresExRecord, m)
for m in vars(V20PresExRecord)
if m.startswith("INITIATOR_")
]
V20PresExRecord.get_attributes_by_prefix("INITIATOR_", walk_mro=False)
),
)
role = fields.Str(
required=False,
description="Present-proof exchange role: prover or verifier",
example=V20PresExRecord.ROLE_PROVER,
validate=validate.OneOf(
[
getattr(V20PresExRecord, m)
for m in vars(V20PresExRecord)
if m.startswith("ROLE_")
]
V20PresExRecord.get_attributes_by_prefix("ROLE_", walk_mro=False)
),
)
state = fields.Str(
required=False,
description="Present-proof exchange state",
validate=validate.OneOf(
[
getattr(V20PresExRecord, m)
for m in vars(V20PresExRecord)
if m.startswith("STATE_")
]
V20PresExRecord.get_attributes_by_prefix("STATE_", walk_mro=True)
),
)
pres_proposal = fields.Nested(
Expand Down

0 comments on commit aedcbd3

Please sign in to comment.