Skip to content

Commit

Permalink
New examples: CreateChatToolCompletion, CreateChatToolCompletionWithF…
Browse files Browse the repository at this point in the history
…eedback, ListFiles, and ListModels
  • Loading branch information
peterbanda committed Nov 27, 2023
1 parent 983b6be commit 6779db1
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.cequence.openaiscala.examples

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings

import scala.concurrent.Future

object CreateChatToolCompletion extends Example {

val messages = Seq(
SystemMessage("You are a helpful assistant."),
UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")
)

val tools = Seq(
FunctionSpec(
name = "get_current_weather",
description = Some("Get the current weather in a given location"),
parameters = Map(
"type" -> "object",
"properties" -> Map(
"location" -> Map(
"type" -> "string",
"description" -> "The city and state, e.g. San Francisco, CA",
),
"unit" -> Map(
"type" -> "string",
"enum" -> Seq("celsius", "fahrenheit")
)
),
"required" -> Seq("location"),
)
)
)

override protected def run: Future[_] =
service
.createChatToolCompletion(
messages = messages,
tools = tools,
responseToolChoice = None, // means "auto"
settings = CreateChatCompletionSettings(ModelId.gpt_3_5_turbo_1106)
)
.map { response =>
val chatFunCompletionMessage = response.choices.head.message
val toolCalls = chatFunCompletionMessage.tool_calls.collect { case (id, x: FunctionCallSpec) => (id, x) }

println(
"tool call ids : " + toolCalls.map(_._1).mkString(", ")
)
println(
"function/tool call names : " + toolCalls.map(_._2.name).mkString(", ")
)
println(
"function/tool call arguments : " + toolCalls.map(_._2.arguments).mkString(", ")
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package io.cequence.openaiscala.examples

import io.cequence.openaiscala.domain._
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
import play.api.libs.json.Json

import scala.concurrent.Future

// based on: https://platform.openai.com/docs/guides/function-calling
object CreateChatToolCompletionWithFeedback extends Example {

private val modelId = ModelId.gpt_4_turbo_preview

val introMessages = Seq(
SystemMessage("You are a helpful assistant."),
UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"),
)

// as a param type we can use "number", "string", "boolean", "object", "array", and "null"
val tools = Seq(
FunctionSpec(
name = "get_current_weather",
description = Some("Get the current weather in a given location"),
parameters = Map(
"type" -> "object",
"properties" -> Map(
"location" -> Map(
"type" -> "string",
"description" -> "The city and state, e.g. San Francisco, CA",
),
"unit" -> Map(
"type" -> "string",
"enum" -> Seq("celsius", "fahrenheit")
)
),
"required" -> Seq("location"),
)
)
)

override protected def run: Future[_] =
for {
assistantToolResponse <- service.createChatToolCompletion(
messages = introMessages,
tools = tools,
responseToolChoice = None, // means "auto"
settings = CreateChatCompletionSettings(modelId)
)

assistantToolMessage = assistantToolResponse.choices.head.message

toolCalls = assistantToolMessage.tool_calls

// we can handle only function calls (that will change in future)
functionCalls = toolCalls.collect { case (toolCallId, x: FunctionCallSpec) => (toolCallId, x) }

available_functions = Map("get_current_weather" -> getCurrentWeather _)

toolMessages = functionCalls.map { case (toolCallId, functionCallSpec) =>
val functionName = functionCallSpec.name
val functionArgsJson = Json.parse(functionCallSpec.arguments)

// this is not very generic, but it's ok for a demo
val functionResponse = available_functions.get(functionName) match {
case Some(functionToCall) => functionToCall(
(functionArgsJson \ "location").as[String],
(functionArgsJson \ "unit").asOpt[String]
)

case _ => throw new IllegalArgumentException(s"Unknown function: $functionName")
}

ToolMessage(
tool_call_id = toolCallId,
content = Some(functionResponse.toString),
name = functionName
)
}

messages = introMessages ++ Seq(assistantToolMessage) ++ toolMessages

finalAssistantResponse <- service.createChatCompletion(
messages = messages,
settings = CreateChatCompletionSettings(modelId)
)
} yield {
println(finalAssistantResponse.choices.head.message.content)
}

// unit is ignored here
private def getCurrentWeather(location: String, unit: Option[String]) =
location.toLowerCase() match {
case loc if loc.contains("tokyo") =>
Json.obj("location" -> "Tokyo", "temperature" -> "10", "unit" -> "celsius")

case loc if loc.contains("san francisco") =>
Json.obj("location" -> "San Francisco", "temperature" -> "72", "unit" -> "fahrenheit")

case loc if loc.contains("paris") =>
Json.obj("location" -> "Paris", "temperature" -> "22", "unit" -> "celsius")

case _ =>
Json.obj("location" -> location, "temperature" -> "unknown")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.cequence.openaiscala.examples

object ListFiles extends Example {

override protected def run =
service.listFiles
.map(_.foreach(println))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package io.cequence.openaiscala.examples

object ListModels extends Example {

override protected def run =
service.listModels.map(
_.sortBy(_.created).reverse.foreach(fileInfo => println(fileInfo.id))
)
}

0 comments on commit 6779db1

Please sign in to comment.