Skip to content

Commit

Permalink
AWS stream bytes decoder, event parser, and frame decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbanda committed Dec 18, 2024
1 parent b67242b commit 38cbced
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.cequence.openaiscala.anthropic.service.impl

import akka.NotUsed
import akka.stream.scaladsl.Flow

import java.util.Base64
import play.api.libs.json.{JsString, JsValue, Json}

object AwsEventStreamBytesDecoder {
def flow: Flow[JsValue, JsValue, NotUsed] = Flow[JsValue].map { eventJson =>
// eventJson might look like:
// { ":message-type":"event", ":event-type":"...", "bytes":"base64string" }

val base64Str = (eventJson \ "bytes").asOpt[String]
base64Str match {
case Some(encoded) =>
val decoded = Base64.getDecoder.decode(encoded)
Json.parse(decoded)
case None =>
// If there's no "bytes" field, return the original JSON (or handle differently)
eventJson
}
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.cequence.openaiscala.anthropic.service.impl

import akka.NotUsed
import play.api.libs.json.{JsValue, Json}
import akka.stream._
import akka.stream.scaladsl.Flow
import akka.util.ByteString

object AwsEventStreamEventParser {
def flow: Flow[ByteString, Option[JsValue], NotUsed] = Flow[ByteString].map { frame =>
val rawString = new String(frame.toArray)

if (rawString.contains("message-type")) {
val jsonString = rawString.dropWhile(_ != '{').takeWhile(_ != '}') + "}"
Some(Json.parse(jsonString))
} else
None
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package io.cequence.openaiscala.anthropic.service.impl

import akka.stream._
import akka.stream.stage._
import akka.util.ByteString

class AwsEventStreamFrameDecoder extends GraphStage[FlowShape[ByteString, ByteString]] {
val in = Inlet[ByteString]("AwsEventStreamFrameDecoder.in")
val out = Outlet[ByteString]("AwsEventStreamFrameDecoder.out")
override val shape = FlowShape(in, out)

private implicit val order = java.nio.ByteOrder.BIG_ENDIAN

override def createLogic(attrs: Attributes): GraphStageLogic = new GraphStageLogic(shape) {
var buffer = ByteString.empty

setHandler(in, new InHandler {
override def onPush(): Unit = {
buffer ++= grab(in)
emitFrames()
}
override def onUpstreamFinish(): Unit = {
emitFrames()
if (buffer.isEmpty) completeStage()
else failStage(new RuntimeException("Truncated frame at stream end"))
}
})

setHandler(out, new OutHandler {
override def onPull(): Unit = {
if (!hasBeenPulled(in)) pull(in)
}
})

def emitFrames(): Unit = {
while (buffer.size >= 4) {
val totalLength = buffer.iterator.getInt
println("buffer size: " + buffer.size)
println("total length: " + totalLength)
println("buffer: " + buffer.utf8String)

if (buffer.size < 4 + totalLength) {
// not enough data yet
return
}
val frame = buffer.slice(4, 4 + totalLength)
buffer = buffer.drop(4 + totalLength)
emit(out, frame)
}

if (!hasBeenPulled(in) && !isClosed(in)) {
pull(in)
}
}
}
}

0 comments on commit 38cbced

Please sign in to comment.