Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Set instance from the Runloop #804

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ package zio.kafka.consumer

import org.apache.kafka.clients.consumer.ConsumerRebalanceListener
import org.apache.kafka.common.TopicPartition
import zio.{ Runtime, Task, Unsafe, ZIO }
import scala.jdk.CollectionConverters._
import zio.{ Chunk, Runtime, Task, Unsafe, ZIO }

/**
* ZIO wrapper around Kafka's `ConsumerRebalanceListener` to work with Scala collection types and ZIO effects.
Expand All @@ -12,9 +11,9 @@ import scala.jdk.CollectionConverters._
* when this is not desired.
*/
final case class RebalanceListener(
onAssigned: (Set[TopicPartition], RebalanceConsumer) => Task[Unit],
onRevoked: (Set[TopicPartition], RebalanceConsumer) => Task[Unit],
onLost: (Set[TopicPartition], RebalanceConsumer) => Task[Unit]
onAssigned: (Chunk[TopicPartition], RebalanceConsumer) => Task[Unit],
onRevoked: (Chunk[TopicPartition], RebalanceConsumer) => Task[Unit],
onLost: (Chunk[TopicPartition], RebalanceConsumer) => Task[Unit]
) {

/**
Expand All @@ -36,7 +35,7 @@ final case class RebalanceListener(
partitions: java.util.Collection[TopicPartition]
): Unit = Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(onRevoked(partitions.asScala.toSet, consumer))
.run(onRevoked(Chunk.fromJavaIterable(partitions), consumer))
.getOrThrowFiberFailure()
()
}
Expand All @@ -45,7 +44,7 @@ final case class RebalanceListener(
partitions: java.util.Collection[TopicPartition]
): Unit = Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(onAssigned(partitions.asScala.toSet, consumer))
.run(onAssigned(Chunk.fromJavaIterable(partitions), consumer))
.getOrThrowFiberFailure()
()
}
Expand All @@ -54,7 +53,7 @@ final case class RebalanceListener(
partitions: java.util.Collection[TopicPartition]
): Unit = Unsafe.unsafe { implicit u =>
runtime.unsafe
.run(onLost(partitions.asScala.toSet, consumer))
.run(onLost(Chunk.fromJavaIterable(partitions), consumer))
.getOrThrowFiberFailure()
()
}
Expand All @@ -64,8 +63,8 @@ final case class RebalanceListener(

object RebalanceListener {
def apply(
onAssigned: (Set[TopicPartition], RebalanceConsumer) => Task[Unit],
onRevoked: (Set[TopicPartition], RebalanceConsumer) => Task[Unit]
onAssigned: (Chunk[TopicPartition], RebalanceConsumer) => Task[Unit],
onRevoked: (Chunk[TopicPartition], RebalanceConsumer) => Task[Unit]
): RebalanceListener =
RebalanceListener(onAssigned, onRevoked, onRevoked)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ package zio.kafka.consumer.diagnostics

import org.apache.kafka.clients.consumer.OffsetAndMetadata
import org.apache.kafka.common.TopicPartition
import zio.Chunk

sealed trait DiagnosticEvent
object DiagnosticEvent {

final case class Poll(
tpRequested: Set[TopicPartition],
tpWithData: Set[TopicPartition],
tpWithoutData: Set[TopicPartition]
tpRequested: Chunk[TopicPartition],
tpWithData: Chunk[TopicPartition],
tpWithoutData: Chunk[TopicPartition]
) extends DiagnosticEvent
final case class Request(partition: TopicPartition) extends DiagnosticEvent

Expand All @@ -22,9 +23,9 @@ object DiagnosticEvent {

sealed trait Rebalance extends DiagnosticEvent
object Rebalance {
final case class Revoked(partitions: Set[TopicPartition]) extends Rebalance
final case class Assigned(partitions: Set[TopicPartition]) extends Rebalance
final case class Lost(partitions: Set[TopicPartition]) extends Rebalance
final case class Revoked(partitions: Chunk[TopicPartition]) extends Rebalance
final case class Assigned(partitions: Chunk[TopicPartition]) extends Rebalance
final case class Lost(partitions: Chunk[TopicPartition]) extends Rebalance
}

}
45 changes: 24 additions & 21 deletions zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,13 @@ private[consumer] final class Runloop private (
private def offerRecordsToStreams(
partitionStreams: Chunk[PartitionStreamControl],
pendingRequests: Chunk[Request],
ignoreRecordsForTps: Set[TopicPartition],
ignoreRecordsForTps: Chunk[TopicPartition],
polledRecords: ConsumerRecords[Array[Byte], Array[Byte]]
): UIO[Runloop.FulfillResult] = {
// The most efficient way to get the records from [[ConsumerRecords]] per
// topic-partition, is by first getting the set of topic-partitions, and
// then requesting the records per topic-partition.
val tps = polledRecords.partitions().asScala.toSet -- ignoreRecordsForTps
val tps = Chunk.fromJavaIterable(polledRecords.partitions()) diff ignoreRecordsForTps
val fulfillResult = Runloop.FulfillResult(pendingRequests = pendingRequests.filter(req => !tps.contains(req.tp)))
val streams =
if (tps.isEmpty) Chunk.empty else partitionStreams.filter(streamControl => tps.contains(streamControl.tp))
Expand Down Expand Up @@ -313,26 +313,29 @@ private[consumer] final class Runloop private (
if (hasGroupId) consumer.withConsumer(_.groupMetadata()).fold(_ => None, Some(_))
else ZIO.none

private def doSeekForNewPartitions(c: ByteArrayKafkaConsumer, tps: Set[TopicPartition]): Task[Set[TopicPartition]] =
private def doSeekForNewPartitions(
c: ByteArrayKafkaConsumer,
tps: Chunk[TopicPartition]
): Task[Chunk[TopicPartition]] =
offsetRetrieval match {
case OffsetRetrieval.Manual(getOffsets) =>
getOffsets(tps)
getOffsets(tps.toSet)
.tap(offsets => ZIO.foreachDiscard(offsets) { case (tp, offset) => ZIO.attempt(c.seek(tp, offset)) })
.when(tps.nonEmpty)
.as(tps)

case OffsetRetrieval.Auto(_) =>
ZIO.succeed(Set.empty)
ZIO.succeed(Chunk.empty)
}

// Pause partitions for which there is no demand and resume those for which there is now demand
private def resumeAndPausePartitions(
c: ByteArrayKafkaConsumer,
assignment: Set[TopicPartition],
requestedPartitions: Set[TopicPartition]
assignment: Chunk[TopicPartition],
requestedPartitions: Chunk[TopicPartition]
): Unit = {
val toResume = assignment intersect requestedPartitions
val toPause = assignment -- requestedPartitions
val toPause = assignment diff requestedPartitions

if (toResume.nonEmpty) c.resume(toResume.asJava)
if (toPause.nonEmpty) c.pause(toPause.asJava)
Expand All @@ -354,8 +357,8 @@ private[consumer] final class Runloop private (
_ <- rebalanceListenerEvent.set(RebalanceEvent.None)
pollResult <-
consumer.withConsumerZIO { c =>
val prevAssigned = c.assignment().asScala.toSet
val requestedPartitions = state.pendingRequests.map(_.tp).toSet
val prevAssigned = Chunk.fromJavaIterable(c.assignment())
val requestedPartitions = state.pendingRequests.map(_.tp)

resumeAndPausePartitions(c, prevAssigned, requestedPartitions)

Expand All @@ -376,18 +379,18 @@ private[consumer] final class Runloop private (
// either because they are restarting, or because they
// are new.
val startingTps =
if (restartStreamsOnRebalancing) c.assignment().asScala.toSet
if (restartStreamsOnRebalancing) Chunk.fromJavaIterable(c.assignment())
else newlyAssigned

for {
ignoreRecordsForTps <- doSeekForNewPartitions(c, newlyAssigned)

_ <- diagnostics.emitIfEnabled {
val providedTps = records.partitions().asScala.toSet
val providedTps = Chunk.fromJavaIterable(records.partitions())
DiagnosticEvent.Poll(
tpRequested = requestedPartitions,
tpWithData = providedTps,
tpWithoutData = requestedPartitions -- providedTps
tpWithoutData = requestedPartitions diff providedTps
)
}

Expand All @@ -408,7 +411,7 @@ private[consumer] final class Runloop private (
runningStreams <- ZIO.filter(state.assignedStreams)(_.acceptsData)
updatedStreams = runningStreams ++ startingStreams
updatedPendingRequests = {
val streamTps = updatedStreams.map(_.tp).toSet
val streamTps = updatedStreams.map(_.tp)
state.pendingRequests.filter(req => streamTps.contains(req.tp))
}
fulfillResult <- offerRecordsToStreams(
Expand Down Expand Up @@ -563,17 +566,17 @@ private[consumer] object Runloop {

private final case class PollResult(
newCommits: Chunk[Commit],
startingTps: Set[TopicPartition],
startingTps: Chunk[TopicPartition],
records: ConsumerRecords[Array[Byte], Array[Byte]],
ignoreRecordsForTps: Set[TopicPartition]
ignoreRecordsForTps: Chunk[TopicPartition]
)
private object PollResult {
def apply(records: ConsumerRecords[Array[Byte], Array[Byte]]): PollResult =
PollResult(
newCommits = Chunk.empty,
startingTps = Set.empty,
startingTps = Chunk.empty,
records = records,
ignoreRecordsForTps = Set.empty
ignoreRecordsForTps = Chunk.empty
)
}

Expand All @@ -583,10 +586,10 @@ private[consumer] object Runloop {

private final case class RebalanceEvent(
wasInvoked: Boolean,
newlyAssigned: Set[TopicPartition],
newlyAssigned: Chunk[TopicPartition],
pendingCommits: Chunk[Commit]
) {
def onAssigned(assigned: Set[TopicPartition], commits: Chunk[Commit]): RebalanceEvent =
def onAssigned(assigned: Chunk[TopicPartition], commits: Chunk[Commit]): RebalanceEvent =
RebalanceEvent(
wasInvoked = true,
newlyAssigned = newlyAssigned ++ assigned,
Expand All @@ -600,7 +603,7 @@ private[consumer] object Runloop {
}

private object RebalanceEvent {
val None: RebalanceEvent = RebalanceEvent(wasInvoked = false, Set.empty, Chunk.empty)
val None: RebalanceEvent = RebalanceEvent(wasInvoked = false, Chunk.empty, Chunk.empty)
}

sealed trait Command
Expand Down