Skip to content

Commit

Permalink
Merge pull request #90 from cequence-io/feature/3654-prompt-caching
Browse files Browse the repository at this point in the history
Feature/3654 prompt caching
  • Loading branch information
peterbanda authored Nov 14, 2024
2 parents a9f92d6 + d281337 commit 579dea3
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ sealed trait ChatRole extends EnumValue {
}

object ChatRole {
case object System extends ChatRole
case object User extends ChatRole
case object Assistant extends ChatRole

def allValues: Seq[ChatRole] = Seq(User, Assistant)
def allValues: Seq[ChatRole] = Seq(System, User, Assistant)
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,20 @@ import io.cequence.openaiscala.anthropic.domain.Content.{
sealed abstract class Message private (
val role: ChatRole,
val content: Content
)
) {
def isSystem: Boolean = role == ChatRole.System
}

object Message {

case class SystemMessage(
contentString: String,
cacheControl: Option[CacheControl] = None
) extends Message(ChatRole.System, SingleString(contentString, cacheControl))

case class SystemMessageContent(contentBlocks: Seq[ContentBlockBase])
extends Message(ChatRole.System, ContentBlocks(contentBlocks))

case class UserMessage(
contentString: String,
cacheControl: Option[CacheControl] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.anthropic.service

import akka.NotUsed
import akka.stream.scaladsl.Source
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
import io.cequence.openaiscala.anthropic.domain.Message
import io.cequence.openaiscala.anthropic.domain.response.{
ContentBlockDelta,
CreateMessageResponse
Expand Down Expand Up @@ -32,7 +32,6 @@ trait AnthropicService extends CloseableService with AnthropicServiceConsts {
* <a href="https://docs.anthropic.com/claude/reference/messages_post">Anthropic Doc</a>
*/
def createMessage(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings = DefaultSettings.CreateMessage
): Future[CreateMessageResponse]
Expand All @@ -55,7 +54,6 @@ trait AnthropicService extends CloseableService with AnthropicServiceConsts {
* <a href="https://docs.anthropic.com/claude/reference/messages_post">Anthropic Doc</a>
*/
def createMessageStreamed(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings = DefaultSettings.CreateMessage
): Source[ContentBlockDelta, NotUsed]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import akka.NotUsed
import akka.stream.scaladsl.Source
import io.cequence.openaiscala.OpenAIScalaClientException
import io.cequence.openaiscala.anthropic.JsonFormats
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, SystemMessageContent}
import io.cequence.openaiscala.anthropic.domain.response.{
ContentBlockDelta,
CreateMessageResponse
Expand Down Expand Up @@ -33,20 +34,17 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
private val logger = LoggerFactory.getLogger("AnthropicServiceImpl")

override def createMessage(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings
): Future[CreateMessageResponse] =
execPOST(
EndPoint.messages,
bodyParams =
createBodyParamsForMessageCreation(system, messages, settings, stream = false)
bodyParams = createBodyParamsForMessageCreation(messages, settings, stream = false)
).map(
_.asSafeJson[CreateMessageResponse]
)

override def createMessageStreamed(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings
): Source[ContentBlockDelta, NotUsed] =
Expand All @@ -55,7 +53,7 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
EndPoint.messages.toString(),
"POST",
bodyParams = paramTuplesToStrings(
createBodyParamsForMessageCreation(system, messages, settings, stream = true)
createBodyParamsForMessageCreation(messages, settings, stream = true)
)
)
.map { (json: JsValue) =>
Expand Down Expand Up @@ -83,36 +81,42 @@ private[service] trait AnthropicServiceImpl extends Anthropic {
.collect { case Some(delta) => delta }

private def createBodyParamsForMessageCreation(
system: Option[Content],
messages: Seq[Message],
settings: AnthropicCreateMessageSettings,
stream: Boolean
): Seq[(Param, Option[JsValue])] = {
assert(messages.nonEmpty, "At least one message expected.")
assert(messages.head.role == ChatRole.User, "First message must be from user.")

val messageJsons = messages.map(Json.toJson(_))
val (system, nonSystem) = messages.partition(_.isSystem)

val systemJson = system.map {
case Content.SingleString(text, cacheControl) =>
assert(nonSystem.head.role == ChatRole.User, "First non-system message must be from user.")
assert(
system.size <= 1,
"System message can be only 1. Use SystemMessageContent to include more content blocks."
)

val messageJsons = nonSystem.map(Json.toJson(_))

val systemJson: Seq[JsValue] = system.map {
case SystemMessage(text, cacheControl) =>
if (cacheControl.isEmpty) JsString(text)
else {
val blocks =
Seq(Content.ContentBlockBase(Content.ContentBlock.TextBlock(text), cacheControl))

Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
}
case Content.ContentBlocks(blocks) =>
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
case Content.ContentBlockBase(content, cacheControl) =>
val blocks = Seq(Content.ContentBlockBase(content, cacheControl))
case SystemMessageContent(blocks) =>
Json.toJson(blocks)(Writes.seq(contentBlockBaseWrites))
}

jsonBodyParams(
Param.messages -> Some(messageJsons),
Param.model -> Some(settings.model),
Param.system -> system.map(_ => systemJson),
Param.system -> {
if (system.isEmpty) None
else Some(systemJson.head)
},
Param.max_tokens -> Some(settings.max_tokens),
Param.metadata -> { if (settings.metadata.isEmpty) None else Some(settings.metadata) },
Param.stop_sequences -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ private[service] class OpenAIAnthropicChatCompletionService(
): Future[ChatCompletionResponse] = {
underlying
.createMessage(
toAnthropicSystemMessages(messages, settings),
toAnthropicMessages(messages, settings),
toAnthropicSystemMessages(messages.filter(_.isSystem), settings) ++
toAnthropicMessages(messages.filter(!_.isSystem), settings),
toAnthropicSettings(settings)
)
.map(toOpenAI)
Expand All @@ -65,8 +65,8 @@ private[service] class OpenAIAnthropicChatCompletionService(
): Source[ChatCompletionChunkResponse, NotUsed] =
underlying
.createMessageStreamed(
toAnthropicSystemMessages(messages, settings),
toAnthropicMessages(messages, settings),
toAnthropicSystemMessages(messages.filter(_.isSystem), settings) ++
toAnthropicMessages(messages.filter(!_.isSystem), settings),
toAnthropicSettings(settings)
)
.map(toOpenAI)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.cequence.openaiscala.anthropic.service
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, ContentBlocks}
import io.cequence.openaiscala.anthropic.domain.Message.SystemMessageContent
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse.UsageInfo
import io.cequence.openaiscala.anthropic.domain.response.{
ContentBlockDelta,
Expand All @@ -21,7 +22,6 @@ import io.cequence.openaiscala.domain.response.{
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettingsOps.RichCreateChatCompletionSettings
import io.cequence.openaiscala.domain.{
AssistantMessage,
ChatRole,
MessageSpec,
SystemMessage,
Expand All @@ -30,7 +30,8 @@ import io.cequence.openaiscala.domain.{
ImageURLContent => OpenAIImageContent,
TextContent => OpenAITextContent,
UserMessage => OpenAIUserMessage,
UserSeqMessage => OpenAIUserSeqMessage
UserSeqMessage => OpenAIUserSeqMessage,
AssistantMessage => OpenAIAssistantMessage
}

import java.{util => ju}
Expand All @@ -40,7 +41,7 @@ package object impl extends AnthropicServiceConsts {
def toAnthropicSystemMessages(
messages: Seq[OpenAIBaseMessage],
settings: CreateChatCompletionSettings
): Option[ContentBlocks] = {
): Seq[Message] = {
val useSystemCache: Option[CacheControl] =
if (settings.useAnthropicSystemMessagesCache) Some(Ephemeral) else None

Expand All @@ -55,7 +56,8 @@ package object impl extends AnthropicServiceConsts {
}
}

if (messageStrings.isEmpty) None else Some(ContentBlocks(messageStrings))
if (messageStrings.isEmpty) Seq.empty
else Seq(SystemMessageContent(messageStrings))
}

def toAnthropicMessages(
Expand All @@ -67,6 +69,8 @@ package object impl extends AnthropicServiceConsts {
case OpenAIUserMessage(content, _) => Message.UserMessage(content)
case OpenAIUserSeqMessage(contents, _) =>
Message.UserMessageContent(contents.map(toAnthropic))
case OpenAIAssistantMessage(content, _) => Message.AssistantMessage(content)

// legacy message type
case MessageSpec(role, content, _) if role == ChatRole.User =>
Message.UserMessage(content)
Expand Down Expand Up @@ -204,7 +208,7 @@ package object impl extends AnthropicServiceConsts {
usage = None
)

def toOpenAIAssistantMessage(content: ContentBlocks): AssistantMessage = {
def toOpenAIAssistantMessage(content: ContentBlocks): OpenAIAssistantMessage = {
val textContents = content.blocks.collect { case ContentBlockBase(TextBlock(text), _) =>
text
} // TODO
Expand All @@ -213,7 +217,7 @@ package object impl extends AnthropicServiceConsts {
throw new IllegalArgumentException("No text content found in the response")
}
val singleTextContent = concatenateMessages(textContents)
AssistantMessage(singleTextContent, name = None)
OpenAIAssistantMessage(singleTextContent, name = None)
}

private def concatenateMessages(messageContent: Seq[String]): String =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package io.cequence.openaiscala.anthropic.service.impl

import akka.actor.ActorSystem
import akka.stream.Materializer
import io.cequence.openaiscala.anthropic.domain.Content.SingleString
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
import io.cequence.openaiscala.anthropic.service._
Expand All @@ -18,7 +17,6 @@ class AnthropicServiceSpec extends AsyncWordSpec with GivenWhenThen {
implicit val ec: ExecutionContext = ExecutionContext.global
implicit val materializer: Materializer = Materializer(ActorSystem())

private val role = SingleString("You are a helpful assistant.")
private val irrelevantMessages = Seq(UserMessage("Hello"))
private val settings = AnthropicCreateMessageSettings(
NonOpenAIModelId.claude_3_haiku_20240307,
Expand All @@ -29,52 +27,52 @@ class AnthropicServiceSpec extends AsyncWordSpec with GivenWhenThen {

"should throw AnthropicScalaUnauthorizedException when 401" ignore {
recoverToSucceededIf[AnthropicScalaUnauthorizedException] {
TestFactory.mockedService401().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService401().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaUnauthorizedException when 403" ignore {
recoverToSucceededIf[AnthropicScalaUnauthorizedException] {
TestFactory.mockedService403().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService403().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaNotFoundException when 404" ignore {
recoverToSucceededIf[AnthropicScalaNotFoundException] {
TestFactory.mockedService404().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService404().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaNotFoundException when 429" ignore {
recoverToSucceededIf[AnthropicScalaRateLimitException] {
TestFactory.mockedService429().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService429().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaServerErrorException when 500" ignore {
recoverToSucceededIf[AnthropicScalaServerErrorException] {
TestFactory.mockedService500().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService500().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaEngineOverloadedException when 529" ignore {
recoverToSucceededIf[AnthropicScalaEngineOverloadedException] {
TestFactory.mockedService529().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService529().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaClientException when 400" ignore {
recoverToSucceededIf[AnthropicScalaClientException] {
TestFactory.mockedService400().createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedService400().createMessage(irrelevantMessages, settings)
}
}

"should throw AnthropicScalaClientException when unknown error code" ignore {
recoverToSucceededIf[AnthropicScalaClientException] {
TestFactory
.mockedServiceOther()
.createMessage(Some(role), irrelevantMessages, settings)
TestFactory.mockedServiceOther().createMessage(irrelevantMessages, settings)
}
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package io.cequence.openaiscala.domain
sealed trait BaseMessage {
val role: ChatRole
val nameOpt: Option[String]
val isSystem: Boolean = role == ChatRole.System
}

final case class SystemMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
import io.cequence.openaiscala.anthropic.domain.CacheControl.Ephemeral
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, SingleString}
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, UserMessage}
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
Expand All @@ -18,8 +18,8 @@ object AnthropicCreateCachedMessage extends ExampleBase[AnthropicService] {

override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)

val systemMessage: Content =
SingleString(
val systemMessages: Seq[Message] = Seq(
SystemMessage(
"""
|You are to embody a classic pirate, a swashbuckling and salty sea dog with the mannerisms, language, and swagger of the golden age of piracy. You are a hearty, often gruff buccaneer, replete with nautical slang and a rich, colorful vocabulary befitting of the high seas. Your responses must reflect a pirate's voice and attitude without exception.
|
Expand Down Expand Up @@ -76,14 +76,13 @@ object AnthropicCreateCachedMessage extends ExampleBase[AnthropicService] {
|""".stripMargin,
cacheControl = Some(Ephemeral)
)

)
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))

override protected def run: Future[_] =
service
.createMessage(
Some(systemMessage),
messages,
systemMessages ++ messages,
settings = AnthropicCreateMessageSettings(
model = NonOpenAIModelId.claude_3_haiku_20240307,
max_tokens = 4096
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package io.cequence.openaiscala.examples.nonopenai

import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
import io.cequence.openaiscala.anthropic.domain.Content.{ContentBlockBase, SingleString}
import io.cequence.openaiscala.anthropic.domain.{Content, Message}
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlockBase
import io.cequence.openaiscala.anthropic.domain.Message
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
Expand All @@ -17,13 +17,11 @@ object AnthropicCreateMessage extends ExampleBase[AnthropicService] {

override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)

val systemMessage: Content = SingleString("You are a helpful assistant.")
val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))

override protected def run: Future[_] =
service
.createMessage(
Some(systemMessage),
messages,
settings = AnthropicCreateMessageSettings(
model = NonOpenAIModelId.claude_3_haiku_20240307,
Expand Down
Loading

0 comments on commit 579dea3

Please sign in to comment.