Skip to content

Commit

Permalink
feat(connect): support connect records retrieval by states (#349)
Browse files Browse the repository at this point in the history
* feat(connect): support connect records retrieval by states

* feat(connect): merge latest changes from main

* chore(connect): run scalafmt
  • Loading branch information
bvoiturier authored Feb 7, 2023
1 parent f810ee3 commit 7673278
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ trait ConnectionRepository[F[_]] {

def getConnectionRecords(): F[Seq[ConnectionRecord]]

def getConnectionRecordsByStates(states: ConnectionRecord.ProtocolState*): F[Seq[ConnectionRecord]]

def getConnectionRecord(recordId: UUID): F[Option[ConnectionRecord]]

def deleteConnectionRecord(recordId: UUID): F[Int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@ class ConnectionRepositoryInMemory(storeRef: Ref[Map[UUID, ConnectionRecord]]) e
} yield store.values.toSeq
}

override def getConnectionRecordsByStates(states: ConnectionRecord.ProtocolState*): Task[Seq[ConnectionRecord]] = {
for {
store <- storeRef.get
} yield store.values.filter(rec => states.contains(rec.protocolState)).toSeq
}

override def createConnectionRecord(record: ConnectionRecord): Task[Int] = {
for {
_ <- record.thid match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ trait ConnectionService {

def getConnectionRecords(): IO[ConnectionServiceError, Seq[ConnectionRecord]]

def getConnectionRecordsByStates(
states: ConnectionRecord.ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]]

def getConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]]

def deleteConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Int]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ private class ConnectionServiceImpl(
} yield records
}

override def getConnectionRecordsByStates(
states: ProtocolState*
): IO[ConnectionServiceError, Seq[ConnectionRecord]] = {
for {
records <- connectionRepository
.getConnectionRecordsByStates(states: _*)
.mapError(RepositoryError.apply)
} yield records
}

override def getConnectionRecord(recordId: UUID): IO[ConnectionServiceError, Option[ConnectionRecord]] = {
for {
record <- connectionRepository
Expand Down Expand Up @@ -189,6 +199,7 @@ private class ConnectionServiceImpl(
for {
record <- getRecordFromThreadIdAndState(
response.thid.orElse(response.pthid),
ProtocolState.ConnectionRequestPending,
ProtocolState.ConnectionRequestSent
)
_ <- connectionRepository
Expand Down Expand Up @@ -241,8 +252,8 @@ private class ConnectionServiceImpl(
}

private[this] def getRecordFromThreadIdAndState(
thid: Option[String], // TODO this should not be optional
state: ProtocolState
thid: Option[String],
states: ProtocolState*
): IO[ConnectionServiceError, ConnectionRecord] = {
for {
thid <- ZIO
Expand All @@ -256,8 +267,8 @@ private class ConnectionServiceImpl(
.fromOption(maybeRecord)
.mapError(_ => ThreadIdNotFound(thid))
_ <- record.protocolState match {
case s if s == state => ZIO.unit
case state => ZIO.fail(InvalidFlowStateError(s"Invalid protocol state for operation: $state"))
case s if states.contains(s) => ZIO.unit
case state => ZIO.fail(InvalidFlowStateError(s"Invalid protocol state for operation: $state"))
}
} yield record
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,55 @@ object ConnectionRepositorySpecSuite {
assertTrue(records.contains(bRecord))
}
},
test("deleteRecord deletes an exsiting record") {
test("getConnectionRecordsByStates returns correct records") {
for {
repo <- ZIO.service[ConnectionRepository[Task]]
aRecord = connectionRecord
bRecord = connectionRecord
cRecord = connectionRecord
_ <- repo.createConnectionRecord(aRecord)
_ <- repo.createConnectionRecord(bRecord)
_ <- repo.createConnectionRecord(cRecord)
_ <- repo.updateConnectionProtocolState(
aRecord.id,
ProtocolState.InvitationGenerated,
ProtocolState.ConnectionRequestReceived,
1
)
_ <- repo.updateConnectionProtocolState(
cRecord.id,
ProtocolState.InvitationGenerated,
ProtocolState.ConnectionResponsePending,
1
)
invitationGeneratedRecords <- repo.getConnectionRecordsByStates(ProtocolState.InvitationGenerated)
otherRecords <- repo.getConnectionRecordsByStates(
ProtocolState.ConnectionRequestReceived,
ProtocolState.ConnectionResponsePending
)
} yield {
assertTrue(invitationGeneratedRecords.size == 1) &&
assertTrue(invitationGeneratedRecords.contains(bRecord)) &&
assertTrue(otherRecords.size == 2) &&
assertTrue(otherRecords.exists(_.id == aRecord.id)) &&
assertTrue(otherRecords.exists(_.id == cRecord.id))
}
},
test("getConnectionRecordsByStates returns an empty list if 'states' parameter is empty") {
for {
repo <- ZIO.service[ConnectionRepository[Task]]
aRecord = connectionRecord
bRecord = connectionRecord
cRecord = connectionRecord
_ <- repo.createConnectionRecord(aRecord)
_ <- repo.createConnectionRecord(bRecord)
_ <- repo.createConnectionRecord(cRecord)
records <- repo.getConnectionRecordsByStates()
} yield {
assertTrue(records.isEmpty)
}
},
test("deleteRecord deletes an existing record") {
for {
repo <- ZIO.service[ConnectionRepository[Task]]
aRecord = connectionRecord
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
CREATE INDEX protocol_state_idx
ON public.connection_records (protocol_state);
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import zio.interop.catz.*

import java.time.Instant
import java.util.UUID
import cats.data.NonEmptyList

class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepository[Task] {

Expand Down Expand Up @@ -102,6 +103,38 @@ class JdbcConnectionRepository(xa: Transactor[Task]) extends ConnectionRepositor
.transact(xa)
}

override def getConnectionRecordsByStates(states: ConnectionRecord.ProtocolState*): Task[Seq[ConnectionRecord]] = {
states match
case Nil =>
ZIO.succeed(Nil)
case head +: tail =>
val nel = NonEmptyList.of(head, tail: _*)
val inClauseFragment = Fragments.in(fr"protocol_state", nel)
val cxnIO = sql"""
| SELECT
| id,
| created_at,
| updated_at,
| thid,
| label,
| role,
| protocol_state,
| invitation,
| connection_request,
| connection_response,
| meta_retries,
| meta_last_failure
| FROM public.connection_records
| WHERE $inClauseFragment
| LIMIT 50
""".stripMargin
.query[ConnectionRecord]
.to[Seq]

cxnIO
.transact(xa)
}

override def getConnectionRecord(recordId: UUID): Task[Option[ConnectionRecord]] = {
val cxnIO = sql"""
| SELECT
Expand Down

0 comments on commit 7673278

Please sign in to comment.