Skip to content

Commit

Permalink
perf(connect): Update to Mercury 0.17.0 (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioPinheiro authored Feb 7, 2023
1 parent ebf583c commit 8823325
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@ import io.iohk.atala.connect.core.model.error.ConnectionServiceError
import io.iohk.atala.connect.core.model.error.ConnectionServiceError._
import io.iohk.atala.connect.core.model.ConnectionRecord
import io.iohk.atala.connect.core.model.ConnectionRecord._
import io.iohk.atala.mercury.protocol.connection.ConnectionRequest
import java.util.UUID
import io.iohk.atala.mercury._
import io.iohk.atala.mercury.model.DidId
import io.iohk.atala.mercury.protocol.invitation.v2.Invitation
import io.iohk.atala.mercury.protocol.connection._
import java.time.Instant
import java.rmi.UnexpectedException
import io.iohk.atala.mercury.protocol.invitation.v2.Invitation
import io.iohk.atala.mercury.protocol.connection.ConnectionResponse
import io.iohk.atala.shared.utils.Base64Utils

private class ConnectionServiceImpl(
Expand All @@ -26,7 +25,7 @@ private class ConnectionServiceImpl(
pairwiseDID: DidId
): IO[ConnectionServiceError, ConnectionRecord] =
for {
invitation <- ZIO.succeed(Invitation.invitation2Connect(pairwiseDID))
invitation <- ZIO.succeed(ConnectionInvitation.makeConnectionInvitation(pairwiseDID))
record <- ZIO.succeed(
ConnectionRecord(
id = UUID.fromString(invitation.id),
Expand Down Expand Up @@ -115,7 +114,9 @@ private class ConnectionServiceImpl(
): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
record <- getRecordWithState(recordId, ProtocolState.InvitationReceived)
request = createDidCommConnectionRequest(record, pairwiseDid)
request = ConnectionRequest
.makeFromInvitation(record.invitation, pairwiseDid)
.copy(thid = Some(record.invitation.id)) // This logic shound be move to the SQL when fetching the record
count <- connectionRepository
.updateWithConnectionRequest(recordId, request, ProtocolState.ConnectionRequestPending, maxRetries)
.mapError(RepositoryError.apply)
Expand All @@ -138,7 +139,10 @@ private class ConnectionServiceImpl(
request: ConnectionRequest
): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
record <- getRecordFromThreadIdAndState(request.thid, ProtocolState.InvitationGenerated)
record <- getRecordFromThreadIdAndState(
Some(request.thid.orElse(request.pthid).getOrElse(request.id)),
ProtocolState.InvitationGenerated
)
_ <- connectionRepository
.updateWithConnectionRequest(record.id, request, ProtocolState.ConnectionRequestReceived, maxRetries)
.flatMap {
Expand All @@ -154,7 +158,13 @@ private class ConnectionServiceImpl(
override def acceptConnectionRequest(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
record <- getRecordWithState(recordId, ProtocolState.ConnectionRequestReceived)
response = createDidCommConnectionResponse(record)
response <- {
record.connectionRequest.map(_.makeMessage).map(ConnectionResponse.makeResponseFromRequest(_)) match
case None => ZIO.fail(RepositoryError.apply(new RuntimeException("Unable to make Message")))
case Some(Left(value)) => ZIO.fail(RepositoryError.apply(new RuntimeException(value)))
case Some(Right(response)) => ZIO.succeed(response)
}
// response = createDidCommConnectionResponse(record)
count <- connectionRepository
.updateWithConnectionResponse(recordId, response, ProtocolState.ConnectionResponsePending, maxRetries)
.mapError(RepositoryError.apply)
Expand All @@ -177,7 +187,10 @@ private class ConnectionServiceImpl(
response: ConnectionResponse
): IO[ConnectionServiceError, Option[ConnectionRecord]] =
for {
record <- getRecordFromThreadIdAndState(response.thid, ProtocolState.ConnectionRequestSent)
record <- getRecordFromThreadIdAndState(
response.thid.orElse(response.pthid),
ProtocolState.ConnectionRequestSent
)
_ <- connectionRepository
.updateWithConnectionResponse(record.id, response, ProtocolState.ConnectionResponseReceived, maxRetries)
.flatMap {
Expand Down Expand Up @@ -208,18 +221,6 @@ private class ConnectionServiceImpl(
} yield record
}

private[this] def createDidCommConnectionRequest(record: ConnectionRecord, pairwiseDid: DidId): ConnectionRequest = {
ConnectionRequest(
from = pairwiseDid,
to = record.invitation.from,
thid = record.thid.map(_.toString),
body = ConnectionRequest.Body(goal_code = Some("Connect"))
)
}

private[this] def createDidCommConnectionResponse(record: ConnectionRecord): ConnectionResponse =
ConnectionResponse.makeResponseFromRequest(record.connectionRequest.get.makeMessage) // TODO: get

private[this] def updateConnectionProtocolState(
recordId: UUID,
from: ProtocolState,
Expand All @@ -240,7 +241,7 @@ private class ConnectionServiceImpl(
}

private[this] def getRecordFromThreadIdAndState(
thid: Option[String],
thid: Option[String], // TODO this should not be optional
state: ProtocolState
): IO[ConnectionServiceError, ConnectionRecord] = {
for {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ object ConnectionRepositorySpecSuite {
Invitation(
id = UUID.randomUUID().toString,
from = DidId("did:prism:aaa"),
body = Invitation.Body(goal_code = "connect", goal = "Establish a trust connection between two peers", Nil)
body = Invitation
.Body(goal_code = "io.atalaprism.connect", goal = "Establish a trust connection between two peers", Nil)
),
None,
None,
Expand All @@ -43,8 +44,9 @@ object ConnectionRepositorySpecSuite {
private def connectionRequest = ConnectionRequest(
from = DidId("did:prism:aaa"),
to = DidId("did:prism:bbb"),
thid = Some(UUID.randomUUID().toString),
body = ConnectionRequest.Body(goal_code = Some("Connect"))
thid = None,
pthid = Some(UUID.randomUUID().toString),
body = ConnectionRequest.Body(goal_code = Some("io.atalaprism.connect"))
)

val testSuite = suite("CRUD operations")(
Expand Down Expand Up @@ -220,7 +222,7 @@ object ConnectionRepositorySpecSuite {
aRecord = connectionRecord
_ <- repo.createConnectionRecord(aRecord)
record <- repo.getConnectionRecord(aRecord.id)
response = ConnectionResponse.makeResponseFromRequest(connectionRequest.makeMessage)
response = ConnectionResponse.makeResponseFromRequest(connectionRequest.makeMessage).toOption.get
count <- repo.updateWithConnectionResponse(
aRecord.id,
response,
Expand All @@ -243,7 +245,7 @@ object ConnectionRepositorySpecSuite {
record <- repo.getConnectionRecord(aRecord.id)
count <- repo.updateAfterFail(aRecord.id, Some("Just to test")) // TEST
updatedRecord1 <- repo.getConnectionRecord(aRecord.id)
response = ConnectionResponse.makeResponseFromRequest(connectionRequest.makeMessage)
response = ConnectionResponse.makeResponseFromRequest(connectionRequest.makeMessage).toOption.get
count <- repo.updateWithConnectionResponse(
aRecord.id,
response,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
assertTrue(record.updatedAt.isEmpty) &&
assertTrue(record.invitation.from == did) &&
assertTrue(record.invitation.attachments.isEmpty) &&
assertTrue(record.invitation.body.goal_code == "connect") &&
assertTrue(record.invitation.body.goal_code == "io.atalaprism.connect") &&
assertTrue(record.invitation.body.accept.isEmpty)
}
}, {
Expand Down Expand Up @@ -232,7 +232,7 @@ object ConnectionServiceImplSpec extends ZIOSpecDefault {
)
_ <- inviterSvc.markConnectionResponseSent(inviterRecord.id)
maybeReceivedResponseConnectionRecord <- inviteeSvc.receiveConnectionResponse(
ConnectionResponse.readFromMessage(connectionResponseMessage)
ConnectionResponse.fromMessage(connectionResponseMessage).toOption.get
)
allInviteeRecords <- inviteeSvc.getConnectionRecords()
} yield {
Expand Down
2 changes: 1 addition & 1 deletion connect/lib/project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object Dependencies {
val doobie = "1.0.0-RC2"
val zioCatsInterop = "3.3.0"
val iris = "0.1.0"
val mercury = "0.16.0"
val mercury = "0.17.0"
val flyway = "9.8.3"
val shared = "0.2.0"
val testContainersScalaPostgresql = "0.40.11"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepositor
| meta_last_failure
| FROM public.connection_records
| WHERE thid = $thid
""".stripMargin
""".stripMargin // | WHERE thid = $thid OR id = $thid
.query[ConnectionRecord]
.option

Expand Down

0 comments on commit 8823325

Please sign in to comment.