Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/3654 prompt caching #88

Merged
merged 13 commits into from
Nov 13, 2024
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.cequence.openaiscala.anthropic

import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.{ImageBlock, TextBlock}
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.{MediaBlock, TextBlock}
import io.cequence.openaiscala.anthropic.domain.Content.{
ContentBlock,
ContentBlockBase,
ContentBlocks,
SingleString
}
Expand All @@ -19,9 +19,10 @@ import io.cequence.openaiscala.anthropic.domain.response.{
CreateMessageResponse,
DeltaText
}
import io.cequence.openaiscala.anthropic.domain.{ChatRole, Content, Message}
import io.cequence.openaiscala.anthropic.domain.{CacheControl, ChatRole, Content, Message}
import io.cequence.wsclient.JsonUtil
import play.api.libs.functional.syntax._
import play.api.libs.json.JsonNaming.SnakeCase
import play.api.libs.json._

object JsonFormats extends JsonFormats
Expand All @@ -32,6 +33,84 @@ trait JsonFormats {
JsonUtil.enumFormat[ChatRole](ChatRole.allValues: _*)
implicit lazy val usageInfoFormat: Format[UsageInfo] = Json.format[UsageInfo]

def writeJsObject(cacheControl: CacheControl): JsObject = cacheControl match {
case CacheControl.Ephemeral =>
Json.obj("cache_control" -> Json.obj("type" -> "ephemeral"))
}

implicit lazy val cacheControlFormat: Format[CacheControl] = new Format[CacheControl] {
def reads(json: JsValue): JsResult[CacheControl] = json match {
case JsObject(map) =>
if (map == Map("type" -> JsString("ephemeral"))) JsSuccess(CacheControl.Ephemeral)
else JsError(s"Invalid cache control $map")
case x => {
JsError(s"Invalid cache control ${x}")
}
}

def writes(cacheControl: CacheControl): JsValue = writeJsObject(cacheControl)
}

implicit lazy val cacheControlOptionFormat: Format[Option[CacheControl]] =
new Format[Option[CacheControl]] {
def reads(json: JsValue): JsResult[Option[CacheControl]] = json match {
case JsNull => JsSuccess(None)
case _ => cacheControlFormat.reads(json).map(Some(_))
}

def writes(option: Option[CacheControl]): JsValue = option match {
case None => JsNull
case Some(cacheControl) => cacheControlFormat.writes(cacheControl)
}
}

implicit lazy val contentBlockBaseWrites: Writes[ContentBlockBase] = {
case ContentBlockBase(textBlock @ TextBlock(_), cacheControl) =>
Json.obj("type" -> "text") ++
Json.toJson(textBlock)(textBlockWrites).as[JsObject] ++
cacheControlToJsObject(cacheControl)
case ContentBlockBase(media @ MediaBlock(_, _, _, _), maybeCacheControl) =>
Json.toJson(media)(mediaBlockWrites).as[JsObject] ++
cacheControlToJsObject(maybeCacheControl)

}

implicit lazy val contentBlockBaseReads: Reads[ContentBlockBase] =
(json: JsValue) => {
(json \ "type").validate[String].flatMap {
case "text" =>
((json \ "text").validate[String] and
(json \ "cache_control").validateOpt[CacheControl]).tupled.flatMap {
case (text, cacheControl) =>
JsSuccess(ContentBlockBase(TextBlock(text), cacheControl))
case _ => JsError("Invalid text block")
}

case imageOrDocument @ ("image" | "document") =>
for {
source <- (json \ "source").validate[JsObject]
`type` <- (source \ "type").validate[String]
mediaType <- (source \ "media_type").validate[String]
data <- (source \ "data").validate[String]
cacheControl <- (json \ "cache_control").validateOpt[CacheControl]
} yield ContentBlockBase(
MediaBlock(imageOrDocument, `type`, mediaType, data),
cacheControl
)

case _ => JsError("Unsupported or invalid content block")
}
}

implicit lazy val contentBlockBaseFormat: Format[ContentBlockBase] = Format(
contentBlockBaseReads,
contentBlockBaseWrites
)
implicit lazy val contentBlockBaseSeqFormat: Format[Seq[ContentBlockBase]] = Format(
Reads.seq(contentBlockBaseReads),
Writes.seq(contentBlockBaseWrites)
)

implicit lazy val userMessageFormat: Format[UserMessage] = Json.format[UserMessage]
implicit lazy val userMessageContentFormat: Format[UserMessageContent] =
Json.format[UserMessageContent]
Expand All @@ -44,92 +123,114 @@ trait JsonFormats {

implicit lazy val contentBlocksFormat: Format[ContentBlocks] = Json.format[ContentBlocks]

// implicit val textBlockWrites: Writes[TextBlock] = Json.writes[TextBlock]
implicit val textBlockReads: Reads[TextBlock] = Json.reads[TextBlock]
implicit lazy val textBlockReads: Reads[TextBlock] = {
implicit val config: JsonConfiguration = JsonConfiguration(SnakeCase)
Json.reads[TextBlock]
}

implicit lazy val textBlockWrites: Writes[TextBlock] = {
implicit val config: JsonConfiguration = JsonConfiguration(SnakeCase)
Json.writes[TextBlock]
}

implicit val textBlockWrites: Writes[TextBlock] = Json.writes[TextBlock]
implicit val imageBlockWrites: Writes[ImageBlock] =
(block: ImageBlock) =>
implicit lazy val mediaBlockWrites: Writes[MediaBlock] =
(block: MediaBlock) =>
Json.obj(
"type" -> "image",
"type" -> block.`type`,
"source" -> Json.obj(
"type" -> block.`type`,
"type" -> block.encoding,
"media_type" -> block.mediaType,
"data" -> block.data
)
)

implicit val contentBlockWrites: Writes[ContentBlock] = {
case tb: TextBlock =>
Json.obj("type" -> "text") ++ Json.toJson(tb)(textBlockWrites).as[JsObject]
case ib: ImageBlock => Json.toJson(ib)(imageBlockWrites)
}

implicit val contentBlockReads: Reads[ContentBlock] =
(json: JsValue) => {
(json \ "type").validate[String].flatMap {
case "text" => (json \ "text").validate[String].map(TextBlock.apply)
case "image" =>
for {
source <- (json \ "source").validate[JsObject]
`type` <- (source \ "type").validate[String]
mediaType <- (source \ "media_type").validate[String]
data <- (source \ "data").validate[String]
} yield ImageBlock(`type`, mediaType, data)
case _ => JsError("Unsupported or invalid content block")
}
}
private def cacheControlToJsObject(maybeCacheControl: Option[CacheControl]): JsObject =
maybeCacheControl.fold(Json.obj())(cc => writeJsObject(cc))

implicit val contentReads: Reads[Content] = new Reads[Content] {
implicit lazy val contentReads: Reads[Content] = new Reads[Content] {
def reads(json: JsValue): JsResult[Content] = json match {
case JsString(str) => JsSuccess(SingleString(str))
case JsArray(_) => Json.fromJson[Seq[ContentBlock]](json).map(ContentBlocks(_))
case JsArray(_) => Json.fromJson[Seq[ContentBlockBase]](json).map(ContentBlocks(_))
case _ => JsError("Invalid content format")
}
}

implicit val baseMessageWrites: Writes[Message] = new Writes[Message] {
implicit lazy val contentWrites: Writes[Content] = new Writes[Content] {
def writes(content: Content): JsValue = content match {
case SingleString(text, cacheControl) =>
Json.obj("content" -> text) ++ cacheControlToJsObject(cacheControl)
case ContentBlocks(blocks) =>
Json.obj("content" -> Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites)))
}
}

implicit lazy val baseMessageWrites: Writes[Message] = new Writes[Message] {
def writes(message: Message): JsValue = message match {
case UserMessage(content) => Json.obj("role" -> "user", "content" -> content)
case UserMessage(content, cacheControl) =>
val baseObj = Json.obj("role" -> "user", "content" -> content)
baseObj ++ cacheControlToJsObject(cacheControl)

case UserMessageContent(content) =>
Json.obj(
"role" -> "user",
"content" -> content.map(Json.toJson(_)(contentBlockWrites))
"content" -> content.map(Json.toJson(_)(contentBlockBaseWrites))
)
case AssistantMessage(content) => Json.obj("role" -> "assistant", "content" -> content)

case AssistantMessage(content, cacheControl) =>
val baseObj = Json.obj("role" -> "assistant", "content" -> content)
baseObj ++ cacheControlToJsObject(cacheControl)

case AssistantMessageContent(content) =>
Json.obj(
"role" -> "assistant",
"content" -> content.map(Json.toJson(_)(contentBlockWrites))
"content" -> content.map(Json.toJson(_)(contentBlockBaseWrites))
)
// Add cases for other subclasses if necessary
}
}

implicit val baseMessageReads: Reads[Message] = (
implicit lazy val baseMessageReads: Reads[Message] = (
(__ \ "role").read[String] and
(__ \ "content").lazyRead(contentReads)
(__ \ "content").read[JsValue] and
(__ \ "cache_control").readNullable[CacheControl]
).tupled.flatMap {
case ("user", SingleString(text)) => Reads.pure(UserMessage(text))
case ("user", ContentBlocks(blocks)) => Reads.pure(UserMessageContent(blocks))
case ("assistant", SingleString(text)) => Reads.pure(AssistantMessage(text))
case ("assistant", ContentBlocks(blocks)) => Reads.pure(AssistantMessageContent(blocks))
case ("user", JsString(str), cacheControl) => Reads.pure(UserMessage(str, cacheControl))
case ("user", json @ JsArray(_), _) => {
Json.fromJson[Seq[ContentBlockBase]](json) match {
case JsSuccess(contentBlocks, _) =>
Reads.pure(UserMessageContent(contentBlocks))
case JsError(errors) =>
Reads(_ => JsError(errors))
}
}
case ("assistant", JsString(str), cacheControl) =>
Reads.pure(AssistantMessage(str, cacheControl))

case ("assistant", json @ JsArray(_), _) => {
Json.fromJson[Seq[ContentBlockBase]](json) match {
case JsSuccess(contentBlocks, _) =>
Reads.pure(AssistantMessageContent(contentBlocks))
case JsError(errors) =>
Reads(_ => JsError(errors))
}
}
case _ => Reads(_ => JsError("Unsupported role or content type"))
}

implicit val createMessageResponseReads: Reads[CreateMessageResponse] = (
implicit lazy val createMessageResponseReads: Reads[CreateMessageResponse] = (
(__ \ "id").read[String] and
(__ \ "role").read[ChatRole] and
(__ \ "content").read[Seq[ContentBlock]].map(ContentBlocks(_)) and
(__ \ "content").read[Seq[ContentBlockBase]].map(ContentBlocks(_)) and
(__ \ "model").read[String] and
(__ \ "stop_reason").readNullable[String] and
(__ \ "stop_sequence").readNullable[String] and
(__ \ "usage").read[UsageInfo]
)(CreateMessageResponse.apply _)

implicit val createMessageChunkResponseReads: Reads[CreateMessageChunkResponse] =
implicit lazy val createMessageChunkResponseReads: Reads[CreateMessageChunkResponse] =
Json.reads[CreateMessageChunkResponse]

implicit val deltaTextReads: Reads[DeltaText] = Json.reads[DeltaText]
implicit val contentBlockDeltaReads: Reads[ContentBlockDelta] = Json.reads[ContentBlockDelta]
implicit lazy val deltaTextReads: Reads[DeltaText] = Json.reads[DeltaText]
implicit lazy val contentBlockDeltaReads: Reads[ContentBlockDelta] =
Json.reads[ContentBlockDelta]
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,80 @@ package io.cequence.openaiscala.anthropic.domain

sealed trait Content

sealed trait CacheControl
object CacheControl {
case object Ephemeral extends CacheControl
}

trait Cacheable {
def cacheControl: Option[CacheControl]
}

object Content {
case class SingleString(text: String) extends Content
case class SingleString(
text: String,
override val cacheControl: Option[CacheControl] = None
) extends Content
with Cacheable

case class ContentBlocks(blocks: Seq[ContentBlock]) extends Content
case class ContentBlocks(blocks: Seq[ContentBlockBase]) extends Content

case class ContentBlockBase(
content: ContentBlock,
override val cacheControl: Option[CacheControl] = None
) extends Content
with Cacheable

sealed trait ContentBlock

object ContentBlock {
case class TextBlock(text: String) extends ContentBlock
case class ImageBlock(

case class MediaBlock(
`type`: String,
encoding: String,
mediaType: String,
data: String
) extends ContentBlock

object MediaBlock {
def pdf(
data: String,
cacheControl: Option[CacheControl] = None
): ContentBlockBase =
ContentBlockBase(
MediaBlock("document", "base64", "application/pdf", data),
cacheControl
)

def image(
mediaType: String
)(
data: String,
cacheControl: Option[CacheControl] = None
): ContentBlockBase =
ContentBlockBase(MediaBlock("image", "base64", mediaType, data), cacheControl)

def jpeg(
data: String,
cacheControl: Option[CacheControl] = None
): ContentBlockBase = image("image/jpeg")(data, cacheControl)

def png(
data: String,
cacheControl: Option[CacheControl] = None
): ContentBlockBase = image("image/png")(data, cacheControl)

def gif(
data: String,
cacheControl: Option[CacheControl] = None
): ContentBlockBase = image("image/gif")(data, cacheControl)

def webp(
data: String,
cacheControl: Option[CacheControl] = None
): ContentBlockBase = image("image/webp")(data, cacheControl)
}

}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package io.cequence.openaiscala.anthropic.domain

import io.cequence.openaiscala.anthropic.domain.Content.{
ContentBlock,
ContentBlockBase,
ContentBlocks,
SingleString
}
Expand All @@ -13,12 +13,19 @@ sealed abstract class Message private (

object Message {

case class UserMessage(contentString: String)
extends Message(ChatRole.User, SingleString(contentString))
case class UserMessageContent(contentBlocks: Seq[ContentBlock])
case class UserMessage(
contentString: String,
cacheControl: Option[CacheControl] = None
) extends Message(ChatRole.User, SingleString(contentString, cacheControl))

case class UserMessageContent(contentBlocks: Seq[ContentBlockBase])
extends Message(ChatRole.User, ContentBlocks(contentBlocks))
case class AssistantMessage(contentString: String)
extends Message(ChatRole.Assistant, SingleString(contentString))
case class AssistantMessageContent(contentBlocks: Seq[ContentBlock])

case class AssistantMessage(
contentString: String,
cacheControl: Option[CacheControl] = None
) extends Message(ChatRole.Assistant, SingleString(contentString, cacheControl))

case class AssistantMessageContent(contentBlocks: Seq[ContentBlockBase])
extends Message(ChatRole.Assistant, ContentBlocks(contentBlocks))
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ final case class AnthropicCreateMessageSettings(
// See [[models|https://docs.anthropic.com/claude/docs/models-overview]] for additional details and options.
model: String,

// System prompt.
// A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role. See our [[guide to system prompts|https://docs.anthropic.com/claude/docs/system-prompts]].
system: Option[String] = None,
// // System prompt.
// // A system prompt is a way of providing context and instructions to Claude, such as specifying a particular goal or role. See our [[guide to system prompts|https://docs.anthropic.com/claude/docs/system-prompts]].
// system: Option[String] = None,

// The maximum number of tokens to generate before stopping.
// Note that our models may stop before reaching this maximum. This parameter only specifies the absolute maximum number of tokens to generate.
Expand Down
Loading
Loading