Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved validation of record attributes #2071

Merged
merged 5 commits into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions aries_cloudagent/connections/models/conn_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,11 +677,7 @@ class Meta:
required=False,
description="Routing state of connection",
validate=validate.OneOf(
[
getattr(ConnRecord, m)
for m in vars(ConnRecord)
if m.startswith("ROUTING_STATE_")
]
ConnRecord.get_attributes_by_prefix("ROUTING_STATE_", walk_mro=False)
),
example=ConnRecord.ROUTING_STATE_ACTIVE,
)
Expand All @@ -690,11 +686,7 @@ class Meta:
description="Connection acceptance: manual or auto",
example=ConnRecord.ACCEPT_AUTO,
validate=validate.OneOf(
[
getattr(ConnRecord, a)
for a in vars(ConnRecord)
if a.startswith("ACCEPT_")
]
ConnRecord.get_attributes_by_prefix("ACCEPT_", walk_mro=False)
),
)
error_msg = fields.Str(
Expand All @@ -707,11 +699,7 @@ class Meta:
description="Invitation mode",
example=ConnRecord.INVITATION_MODE_ONCE,
validate=validate.OneOf(
[
getattr(ConnRecord, i)
for i in vars(ConnRecord)
if i.startswith("INVITATION_MODE_")
]
ConnRecord.get_attributes_by_prefix("INVITATION_MODE_", walk_mro=False)
),
)
alias = fields.Str(
Expand Down
21 changes: 20 additions & 1 deletion aries_cloudagent/messaging/models/base_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Meta:
EVENT_NAMESPACE: str = "acapy::record"
LOG_STATE_FLAG = None
TAG_NAMES = {"state"}
STATE_DELETED = "deleted"

def __init__(
self,
Expand Down Expand Up @@ -420,7 +421,7 @@ async def delete_record(self, session: ProfileSession):
storage = session.inject(BaseStorage)
if self.state:
self._previous_state = self.state
self.state = "deleted"
self.state = BaseRecord.STATE_DELETED
await self.emit_event(session, self.serialize())
await storage.delete_record(self.storage_record)

Expand Down Expand Up @@ -497,6 +498,24 @@ def __eq__(self, other: Any) -> bool:
return self.value == other.value and self.tags == other.tags
return False

@classmethod
def get_attributes_by_prefix(cls, prefix: str, walk_mro: bool = True):
"""
List all values for attributes with common prefix.

Args:
prefix: Common prefix to look for
walk_mro: Walk MRO to find attributes inherited from superclasses
"""

bases = cls.__mro__ if walk_mro else [cls]
return [
vars(base)[name]
for base in bases
for name in vars(base)
if name.startswith(prefix)
]


class BaseExchangeRecord(BaseRecord):
"""Represents a base record with event tracing capability."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,35 +296,23 @@ class Meta:
description="Issue-credential exchange initiator: self or external",
example=V20CredExRecord.INITIATOR_SELF,
validate=validate.OneOf(
[
getattr(V20CredExRecord, m)
for m in vars(V20CredExRecord)
if m.startswith("INITIATOR_")
]
V20CredExRecord.get_attributes_by_prefix("INITIATOR_", walk_mro=False)
),
)
role = fields.Str(
required=False,
description="Issue-credential exchange role: holder or issuer",
example=V20CredExRecord.ROLE_ISSUER,
validate=validate.OneOf(
[
getattr(V20CredExRecord, m)
for m in vars(V20CredExRecord)
if m.startswith("ROLE_")
]
V20CredExRecord.get_attributes_by_prefix("ROLE_", walk_mro=False)
),
)
state = fields.Str(
required=False,
description="Issue-credential exchange state",
example=V20CredExRecord.STATE_DONE,
validate=validate.OneOf(
[
getattr(V20CredExRecord, m)
for m in vars(V20CredExRecord)
if m.startswith("STATE_")
]
V20CredExRecord.get_attributes_by_prefix("STATE_", walk_mro=True)
),
)
cred_preview = fields.Nested(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from typing import Any, Mapping, Optional, Union

from marshmallow import fields
from marshmallow import fields, validate

from .....connections.models.conn_record import ConnRecord
from .....core.profile import ProfileSession
Expand Down Expand Up @@ -248,6 +248,9 @@ class Meta:
required=True,
description="Out of band message exchange state",
example=OobRecord.STATE_AWAIT_RESPONSE,
validate=validate.OneOf(
OobRecord.get_attributes_by_prefix("STATE_", walk_mro=True)
),
)
invi_msg_id = fields.Str(
required=True,
Expand Down Expand Up @@ -287,4 +290,7 @@ class Meta:
description="OOB Role",
required=False,
example=OobRecord.ROLE_RECEIVER,
validate=validate.OneOf(
OobRecord.get_attributes_by_prefix("ROLE_", walk_mro=False)
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -244,34 +244,22 @@ class Meta:
description="Present-proof exchange initiator: self or external",
example=V20PresExRecord.INITIATOR_SELF,
validate=validate.OneOf(
[
getattr(V20PresExRecord, m)
for m in vars(V20PresExRecord)
if m.startswith("INITIATOR_")
]
V20PresExRecord.get_attributes_by_prefix("INITIATOR_", walk_mro=False)
),
)
role = fields.Str(
required=False,
description="Present-proof exchange role: prover or verifier",
example=V20PresExRecord.ROLE_PROVER,
validate=validate.OneOf(
[
getattr(V20PresExRecord, m)
for m in vars(V20PresExRecord)
if m.startswith("ROLE_")
]
V20PresExRecord.get_attributes_by_prefix("ROLE_", walk_mro=False)
),
)
state = fields.Str(
required=False,
description="Present-proof exchange state",
validate=validate.OneOf(
[
getattr(V20PresExRecord, m)
for m in vars(V20PresExRecord)
if m.startswith("STATE_")
]
V20PresExRecord.get_attributes_by_prefix("STATE_", walk_mro=True)
),
)
pres_proposal = fields.Nested(
Expand Down