Skip to content

Commit

Permalink
Improved the performance of ParallelStream
Browse files Browse the repository at this point in the history
  • Loading branch information
darkfrog26 committed Dec 19, 2024
1 parent 38fe6a5 commit 2d9e1f6
Show file tree
Hide file tree
Showing 9 changed files with 1,506 additions and 161 deletions.
1,408 changes: 1,408 additions & 0 deletions benchmark/results/benchmarks-0.3.0.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions core/jvm/src/test/scala/spec/ParallelStreamSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ class ParallelStreamSpec extends AnyWordSpec with Matchers {
Task(i * 2)
}
val task = stream.toList
task.sync().sorted should be(List(2, 4, 6, 8, 10))
task.sync() should be(List(2, 4, 6, 8, 10))
}
"correctly toList with random sleeps" in {
val stream = Stream.emits(List(1, 2, 3, 4, 5)).par() { i =>
Task.sleep((Math.random() * 1000).toInt.millis).map(_ => i * 2)
}
val task = stream.toList
task.sync().sorted should be(List(2, 4, 6, 8, 10))
task.sync() should be(List(2, 4, 6, 8, 10))
}
"correctly toList with random sleeps and overflowing maxBuffer" in {
val stream = Stream.emits(0 until 100_000).par(maxBuffer = 100) { i =>
Expand Down
30 changes: 30 additions & 0 deletions core/shared/src/main/scala/rapid/LockFreeQueue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package rapid

import java.util.concurrent.atomic.AtomicInteger

class LockFreeQueue[A](capacity: Int) {
private val buffer = new Array[AnyRef](capacity)
private val head = new AtomicInteger(0)
private val tail = new AtomicInteger(0)

def enqueue(value: A): Boolean = {
val currentTail = tail.get()
val nextTail = (currentTail + 1) % capacity
if (nextTail != head.get()) {
buffer(currentTail) = value.asInstanceOf[AnyRef]
tail.set(nextTail)
true
} else false // Queue full
}

def dequeue(): Opt[A] = {
val currentHead = head.get()
if (currentHead == tail.get()) Opt.Empty // Queue empty
else {
val value = buffer(currentHead).asInstanceOf[A]
buffer(currentHead) = null // Avoid memory leaks
head.set((currentHead + 1) % capacity)
Opt.Value(value)
}
}
}
3 changes: 3 additions & 0 deletions core/shared/src/main/scala/rapid/Opt.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package rapid
sealed trait Opt[+A] extends Any {
def isEmpty: Boolean
def isNonEmpty: Boolean = !isEmpty
def foreach(f: A => Unit): Unit = ()
}

object Opt {
case class Value[+A](value: A) extends AnyVal with Opt[A] {
override def isEmpty: Boolean = false

override def foreach(f: A => Unit): Unit = f(value)
}
case object Empty extends Opt[Nothing] {
override def isEmpty: Boolean = true
Expand Down
20 changes: 4 additions & 16 deletions core/shared/src/main/scala/rapid/ParallelStream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ import scala.collection.mutable.ListBuffer
case class ParallelStream[T, R](stream: Stream[T],
f: T => Task[R],
maxThreads: Int,
maxBuffer: Int,
ordered: Boolean) {
maxBuffer: Int) {
def drain: Task[Unit] = Task.unit.flatMap { _ =>
val completable = Task.completable[Unit]
compile(_ => (), _ => completable.success(()))
Expand All @@ -26,22 +25,11 @@ case class ParallelStream[T, R](stream: Stream[T],
completable
}

protected def compile(handle: R => Unit, complete: Int => Unit): Unit = if (ordered) {
ParallelStreamProcessor(
stream = this,
handle = handle,
complete = complete
)
} else {
ParallelUnorderedStreamProcessor(
stream = this,
handle = handle,
complete = complete
)
}
protected def compile(handle: R => Unit, complete: Int => Unit): Unit =
ParallelStreamProcessor(this, handle, complete)
}

object ParallelStream {
val DefaultMaxThreads: Int = Runtime.getRuntime.availableProcessors * 2
val DefaultMaxBuffer: Int = 1_000
val DefaultMaxBuffer: Int = 100_000
}
100 changes: 55 additions & 45 deletions core/shared/src/main/scala/rapid/ParallelStreamProcessor.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package rapid

import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec

Expand All @@ -9,60 +8,71 @@ case class ParallelStreamProcessor[T, R](stream: ParallelStream[T, R],
complete: Int => Unit) {
private val iteratorTask: Task[Iterator[T]] = Stream.task(stream.stream)

private val ready = new AtomicIndexedQueue[R](stream.maxBuffer)
private val processing = new ConcurrentLinkedQueue[(Int, Task[R])]
private val queue = new LockFreeQueue[(T, Int)](stream.maxBuffer)
private val ready = new LockFreeQueue[R](stream.maxBuffer)
@volatile private var _total = -1

// Start a fiber that consumes the stream and queues tasks
private val queuingFiber: Fiber[Unit] = iteratorTask.map { iterator =>
var total = 0
iterator.zipWithIndex.foreach {
case (t, index) =>
processing.add(index -> stream.f(t))
total = index + 1
// Feed the iterator into the queue until empty
iteratorTask.map { iterator =>
var counter = 0
iterator.zipWithIndex.foreach { tuple =>
while (!queue.enqueue(tuple)) {
Thread.`yield`()
}
counter += 1
}
_total = counter
}.start()

// Process the queue and feed into ready
{
val counter = new AtomicInteger(0)

@tailrec
def recurse(): Unit = {
val next = queue.dequeue()
next.foreach {
case (t, index) =>
val r = stream.f(t).sync()
while (counter.get() != index) {
Thread.`yield`()
}
while (!ready.enqueue(r)) {
Thread.`yield`()
}
counter.incrementAndGet()
}
if (next.isEmpty && _total != -1) {
// Finished
} else {
recurse()
}
}

val tasks = (0 until stream.maxThreads).toList.map { _ =>
Task(recurse())
}
_total = total
TaskSeqOps(tasks).tasks
}.start()

def total: Option[Int] = if (_total == -1) None else Some(_total)

// Spawn worker fibers to process tasks
(0 until stream.maxThreads).foreach { _ =>
Task(processRecursive()).start()
}
// Processes through the ready queue feeding to handle and finally complete
Task(handleNext(0)).start()

@tailrec
private def processRecursive(): Unit = {
val next = processing.poll()
if (next != null) {
val (index, task) = next
val result = task.sync()
ready.add(index, result)
} else {
// No next task
Thread.sleep(1) // Consider a better signaling mechanism
}
// If total known and no more tasks, stop recursion
if (_total != -1 && processing.isEmpty) {
// Done processing
private def handleNext(counter: Int): Unit = {
val next = ready.dequeue()
if (_total == counter) {
complete(counter)
} else {
processRecursive()
}
}

// Fiber to consume results from 'ready' and handle them
Task {
var count = 0
while (_total == -1 || count < _total) {
ready.blockingPoll() match {
case Some(r) =>
handle(r)
count += 1
case None =>
// No result ready yet
Thread.sleep(1) // Again, consider using proper synchronization
val c = next match {
case Opt.Value(value) =>
handle(value)
counter + 1
case Opt.Empty => counter
}
handleNext(c)
}
complete(_total)
}.start()
}
}

This file was deleted.

8 changes: 2 additions & 6 deletions core/shared/src/main/scala/rapid/Queue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ case class Queue[T](maxSize: Int) {

def isEmpty: Boolean = size == 0

def tryAdd(value: T): Boolean = {
def enqueue(value: T): Boolean = {
var incremented = false
s.updateAndGet((operand: Int) => {
if (operand < maxSize) {
Expand All @@ -28,11 +28,7 @@ case class Queue[T](maxSize: Int) {
incremented
}

def add(value: T): Unit = while (!tryAdd(value)) {
Thread.`yield`()
}

def poll(): Opt[T] = {
def dequeue(): Opt[T] = {
val o = Opt(q.poll())
if (o.isNonEmpty) {
s.decrementAndGet()
Expand Down
47 changes: 2 additions & 45 deletions core/shared/src/main/scala/rapid/Stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,58 +115,15 @@ class Stream[Return](private val task: Task[Iterator[Return]]) extends AnyVal {
def count: Task[Int] = task.map(_.size)

def par[R](maxThreads: Int = ParallelStream.DefaultMaxThreads,
maxBuffer: Int = ParallelStream.DefaultMaxBuffer,
ordered: Boolean = false)
maxBuffer: Int = ParallelStream.DefaultMaxBuffer)
(f: Return => Task[R]): ParallelStream[Return, R] = ParallelStream(
stream = this,
f = f,
maxThreads = maxThreads,
maxBuffer = maxBuffer,
ordered = ordered
maxBuffer = maxBuffer
)
}

/*trait Stream[Return] { stream =>
/**
* Produces the next value in the stream, if any.
*
* @return a `Pull` that produces an optional pair of the next value and the remaining stream
*/
def pull: Pull[Option[(Return, Stream[Return])]]
/**
* Transforms the values in the stream using the given function that returns a task, with a maximum concurrency.
*
* @param maxConcurrency the maximum number of concurrent tasks
* @param f the function to transform the values into tasks
* @tparam T the type of the values in the tasks
* @return a new stream with the transformed values
*/
def parEvalMap[T](maxConcurrency: Int)(f: Return => Task[T]): Stream[T] = new Stream[T] {
val semaphore = new Semaphore(maxConcurrency)
def pull: Pull[Option[(T, Stream[T])]] = Pull.suspend {
if (semaphore.tryAcquire()) {
stream.pull.flatMap {
case Some((head, tail)) =>
val task = f(head)
Pull.suspend {
task.map { result =>
semaphore.release()
Option(result -> tail.parEvalMap(maxConcurrency)(f))
}.toPull
}
case None =>
semaphore.release()
Pull.pure(None)
}
} else {
Pull.suspend(pull)
}
}
}
}*/

object Stream {
/**
* Creates a stream that emits a single value.
Expand Down

0 comments on commit 2d9e1f6

Please sign in to comment.