Skip to content

Commit

Permalink
Anthropic - system messages fix + anthropic examples fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbanda committed Nov 26, 2024
1 parent 980cc5c commit a8dfef8
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ package object impl extends AnthropicServiceConsts {
messages: Seq[OpenAIBaseMessage],
settings: CreateChatCompletionSettings
): Seq[Message] = {
assert(
messages.forall(_.isSystem),
"All messages must be system messages"
)

val useSystemCache: Option[CacheControl] =
if (settings.useAnthropicSystemMessagesCache) Some(Ephemeral) else None

Expand All @@ -52,12 +57,15 @@ package object impl extends AnthropicServiceConsts {
if (index == messages.size - 1)
ContentBlockBase(TextBlock(content), Some(cacheControl))
else ContentBlockBase(TextBlock(content), None)

case None => ContentBlockBase(TextBlock(content))
}
}

if (messageStrings.isEmpty) Seq.empty
else Seq(SystemMessageContent(messageStrings))
if (messageStrings.isEmpty)
Seq.empty
else
Seq(SystemMessageContent(messageStrings))
}

def toAnthropicMessages(
Expand All @@ -67,8 +75,10 @@ package object impl extends AnthropicServiceConsts {

val anthropicMessages: Seq[Message] = messages.collect {
case OpenAIUserMessage(content, _) => Message.UserMessage(content)

case OpenAIUserSeqMessage(contents, _) =>
Message.UserMessageContent(contents.map(toAnthropic))

case OpenAIAssistantMessage(content, _) => Message.AssistantMessage(content)

// legacy message type
Expand All @@ -82,27 +92,30 @@ package object impl extends AnthropicServiceConsts {

val anthropicMessagesWithCache: Seq[Message] = anthropicMessages
.foldLeft((List.empty[Message], countUserMessagesToCache)) {
case ((acc, userMessagesToCache), message) =>
case ((acc, userMessagesToCacheCount), message) =>
message match {
case Message.UserMessage(contentString, _) =>
val newCacheControl = if (userMessagesToCache > 0) Some(Ephemeral) else None
val newCacheControl = if (userMessagesToCacheCount > 0) Some(Ephemeral) else None
(
acc :+ Message.UserMessage(contentString, newCacheControl),
userMessagesToCache - newCacheControl.map(_ => 1).getOrElse(0)
userMessagesToCacheCount - newCacheControl.map(_ => 1).getOrElse(0)
)

case Message.UserMessageContent(contentBlocks) =>
val (newContentBlocks, remainingCache) =
contentBlocks.foldLeft((Seq.empty[ContentBlockBase], userMessagesToCache)) {
contentBlocks.foldLeft((Seq.empty[ContentBlockBase], userMessagesToCacheCount)) {
case ((acc, cacheLeft), content) =>
val (block, newCacheLeft) =
toAnthropic(cacheLeft)(content.asInstanceOf[OpenAIContent])
(acc :+ block, newCacheLeft)
}
(acc :+ Message.UserMessageContent(newContentBlocks), remainingCache)

case assistant: Message.AssistantMessage =>
(acc :+ assistant, userMessagesToCache)
(acc :+ assistant, userMessagesToCacheCount)

case assistants: Message.AssistantMessageContent =>
(acc :+ assistants, userMessagesToCache)
(acc :+ assistants, userMessagesToCacheCount)
}
}
._1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.cequence.openaiscala.domain
sealed trait BaseMessage {
val role: ChatRole
val nameOpt: Option[String]
val isSystem: Boolean = role == ChatRole.System
def isSystem: Boolean = role == ChatRole.System
}

final case class SystemMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import io.cequence.openaiscala.domain.{NonOpenAIModelId, SystemMessage, UserMessage}
import io.cequence.openaiscala.examples.ExampleBase
import io.cequence.openaiscala.service.OpenAIChatCompletionService
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettingsOps._

import scala.concurrent.Future

Expand All @@ -15,15 +16,17 @@ object AnthropicCreateChatCompletionCachedWithOpenAIAdapter
ChatCompletionProvider.anthropic(withCache = true)

private val messages = Seq(
SystemMessage("You are a helpful assistant."),
SystemMessage("You are a helpful assistant who knows elfs personally."),
UserMessage("What is the weather like in Norway?")
)

override protected def run: Future[_] =
service
.createChatCompletion(
messages = messages,
settings = CreateChatCompletionSettings(NonOpenAIModelId.claude_3_5_sonnet_20241022)
settings = CreateChatCompletionSettings(
NonOpenAIModelId.claude_3_5_sonnet_20241022
).setUseAnthropicSystemMessagesCache(true), // this is how we pass it through the adapter
)
.map { content =>
println(content.choices.headOption.map(_.message.content).getOrElse("N/A"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.cequence.openaiscala.examples.nonopenai
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlock.TextBlock
import io.cequence.openaiscala.anthropic.domain.Content.ContentBlockBase
import io.cequence.openaiscala.anthropic.domain.Message
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, UserMessage}
import io.cequence.openaiscala.anthropic.domain.response.CreateMessageResponse
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
import io.cequence.openaiscala.anthropic.service.{AnthropicService, AnthropicServiceFactory}
Expand All @@ -15,16 +15,19 @@ import scala.concurrent.Future
// requires `openai-scala-anthropic-client` as a dependency and `ANTHROPIC_API_KEY` environment variable to be set
object AnthropicCreateMessage extends ExampleBase[AnthropicService] {

override protected val service: AnthropicService = AnthropicServiceFactory(withCache = true)
override protected val service: AnthropicService = AnthropicServiceFactory()

val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))
val messages: Seq[Message] = Seq(
SystemMessage("You are a helpful assistant who knows elfs personally."),
UserMessage("What is the weather like in Norway?")
)

override protected def run: Future[_] =
service
.createMessage(
messages,
settings = AnthropicCreateMessageSettings(
model = NonOpenAIModelId.claude_3_haiku_20240307,
model = NonOpenAIModelId.claude_3_5_haiku_20241022,
max_tokens = 4096
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package io.cequence.openaiscala.examples.nonopenai

import akka.stream.scaladsl.Sink
import io.cequence.openaiscala.anthropic.domain.Message
import io.cequence.openaiscala.anthropic.domain.Message.UserMessage
import io.cequence.openaiscala.anthropic.domain.Message.{SystemMessage, UserMessage}
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
import io.cequence.openaiscala.anthropic.service.{AnthropicService, AnthropicServiceFactory}
import io.cequence.openaiscala.domain.NonOpenAIModelId
Expand All @@ -15,14 +15,19 @@ object AnthropicCreateMessageStreamed extends ExampleBase[AnthropicService] {

override protected val service: AnthropicService = AnthropicServiceFactory()

val messages: Seq[Message] = Seq(UserMessage("What is the weather like in Norway?"))
val messages: Seq[Message] = Seq(
SystemMessage("You are a helpful assistant who knows elfs personally."),
UserMessage("What is the weather like in Norway?")
)

private val modelId = NonOpenAIModelId.claude_3_5_haiku_20241022

override protected def run: Future[_] =
service
.createMessageStreamed(
messages,
settings = AnthropicCreateMessageSettings(
model = NonOpenAIModelId.claude_3_haiku_20240307,
model = modelId,
max_tokens = 4096
)
)
Expand Down

0 comments on commit a8dfef8

Please sign in to comment.