Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/finish run api #76

Merged
merged 26 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
91cf70e
Add cancel run endpoint
branislav-burdiliak Jul 15, 2024
ba9e33d
Add modify run endpoint
branislav-burdiliak Jul 15, 2024
9a986bd
Add list runs endpoint
branislav-burdiliak Jul 15, 2024
771a8d2
Add newly introduced methods to service wrapper
branislav-burdiliak Jul 15, 2024
584d59e
Finish createRun
branislav-burdiliak Aug 4, 2024
626850a
Separate assistant tool hierarchy to be used by create assistant and …
branislav-burdiliak Aug 4, 2024
c23d710
wip
branislav-burdiliak Aug 6, 2024
20b953c
Add order param to list runs endpoint
branislav-burdiliak Aug 9, 2024
1ec0e1d
Fix submit tool outputs endpoint
branislav-burdiliak Aug 9, 2024
eae9d68
Add retrieve run step endpoint
branislav-burdiliak Aug 9, 2024
5387e5b
Fix assistant tool codec
branislav-burdiliak Aug 9, 2024
c2a80dd
Reformat code
branislav-burdiliak Aug 9, 2024
810373b
Fix serialization of assistant's code interpreter tool's file IDs
branislav-burdiliak Aug 9, 2024
c55fc19
WIP - create thread and run
branislav-burdiliak Sep 5, 2024
a196332
added JsonFormats, compiling
bburdiliak Sep 9, 2024
8f9025a
fix formats, add Assistants example
bburdiliak Sep 10, 2024
b456777
scalafmt
bburdiliak Sep 10, 2024
95c1986
Merge remote-tracking branch 'origin/master' into feature/finish_run_API
bburdiliak Sep 11, 2024
84f94de
compilable after merging master
bburdiliak Sep 11, 2024
eecf951
scalafmt
bburdiliak Sep 11, 2024
bd4c94e
deleteThreadMessage + reorder methods to match the API reference
bburdiliak Sep 12, 2024
c319641
get rid of UploadFileSettings
bburdiliak Sep 12, 2024
a38d2ab
polling example for createThreadAndRun
bburdiliak Sep 12, 2024
6f44021
added methods: modifyVectorStore, retrieveVectorStore, retrieveVector…
bburdiliak Sep 13, 2024
b4b6bf9
scalafmt
bburdiliak Sep 13, 2024
e6ff47d
Merge branch 'master' into feature/finish_run_API
peterbanda Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object EndPoint {
case object fine_tunes extends EndPoint("fine_tuning/jobs")
case object moderations extends EndPoint
case object threads extends EndPoint
case object threads_and_runs extends EndPoint("threads/runs")
case object batches extends EndPoint
case object assistants extends EndPoint
case object vector_stores extends EndPoint
Expand Down Expand Up @@ -105,11 +106,15 @@ object Param {
case object chunking_strategy extends Param
case object filter extends Param
case object max_prompt_tokens extends Param
case object max_completion_tokens extends Param
case object `object` extends Param
case object assistant_id extends Param
case object thread_id extends Param
case object additional_instructions extends Param
case object additional_messages extends Param
case object truncation_strategy extends Param
case object parallel_tool_calls extends Param
case object thread extends Param
// empty string param to sneak in extra parameters
case object extra_params extends Param(" ")
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ private[service] trait OpenAICoreServiceImpl
with OpenAIChatCompletionServiceImpl
with HandleOpenAIErrorCodes
with CompletionBodyMaker
with RunBodyMaker {
with RunBodyMaker
with ThreadAndRunBodyMaker {

override def listModels: Future[Seq[ModelInfo]] =
execGET(EndPoint.models).map { response =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import io.cequence.openaiscala.service.{HandleOpenAIErrorCodes, OpenAIService}
import io.cequence.wsclient.JsonUtil.JsonOps
import io.cequence.wsclient.ResponseImplicits._
import io.cequence.wsclient.domain.RichResponse
import play.api.libs.json.{JsArray, JsObject, JsValue, Json, Reads}
import play.api.libs.json.{JsObject, JsValue, Json, Reads}

import java.io.File
import java.nio.charset.StandardCharsets
Expand Down Expand Up @@ -44,15 +44,15 @@ private[service] trait OpenAIServiceImpl

override def createChatFunCompletion(
messages: Seq[BaseMessage],
functions: Seq[FunctionSpec],
functions: Seq[ChatCompletionTool],
responseFunctionName: Option[String],
settings: CreateChatCompletionSettings
): Future[ChatFunCompletionResponse] = {
val coreParams =
createBodyParamsForChatCompletion(messages, settings, stream = false)

val extraParams = jsonBodyParams(
Param.functions -> Some(functions.map(Json.toJson(_)(functionSpecFormat))),
Param.functions -> Some(functions.map(Json.toJson(_)(chatCompletionToolWrites))),
Param.function_call -> responseFunctionName.map(name =>
Map("name" -> name)
) // otherwise "auto" is used by default (if functions are present)
Expand All @@ -68,11 +68,11 @@ private[service] trait OpenAIServiceImpl

override def createRun(
threadId: String,
assistantId: String,
assistantId: AssistantId,
instructions: Option[String],
additionalInstructions: Option[String],
additionalMessages: Seq[BaseMessage],
tools: Seq[ForcableTool],
tools: Seq[AssistantTool],
responseToolChoice: Option[ToolChoice] = None,
settings: CreateRunSettings = DefaultSettings.CreateRun,
stream: Boolean
Expand All @@ -84,7 +84,7 @@ private[service] trait OpenAIServiceImpl
val messageJsons = additionalMessages.map(Json.toJson(_)(messageWrites))

val runParams = jsonBodyParams(
Param.assistant_id -> Some(assistantId),
Param.assistant_id -> Some(assistantId.id),
Param.additional_instructions -> instructions,
Param.additional_messages ->
(if (messageJsons.nonEmpty) Some(messageJsons) else None)
Expand All @@ -99,27 +99,82 @@ private[service] trait OpenAIServiceImpl
)
}

def submitToolOutputs(
override def createThreadAndRun(
assistantId: AssistantId,
thread: Option[ThreadAndRun],
instructions: Option[String],
tools: Seq[AssistantTool],
toolResources: Option[ThreadAndRunToolResource],
toolChoice: Option[ToolChoice],
settings: CreateThreadAndRunSettings,
stream: Boolean
): Future[Run] = {
val coreParams = createBodyParamsForThreadAndRun(settings, stream)
val runParams = jsonBodyParams(
Param.assistant_id -> Some(assistantId.id),
Param.thread -> thread.map(Json.toJson(_)),
Param.instructions -> Some(instructions),
// Param.tools -> Some(Json.toJson(tools)),
Param.tool_resources -> toolResources.map(Json.toJson(_)),
Param.tool_choice -> toolChoice.map(Json.toJson(_))
)
execPOST(
EndPoint.threads_and_runs,
bodyParams = coreParams ++ runParams
).map(
_.asSafeJson[Run]
)
}

override def modifyRun(
threadId: String,
runId: String,
toolOutputs: Seq[AssistantToolOutput]
metadata: Map[String, String]
): Future[Run] =
execPOST(
EndPoint.threads,
Some(s"$threadId/runs/$runId/submit_tool_outputs"),
bodyParams = Seq(
Param.tool_outputs -> (
if (toolOutputs.nonEmpty)
Some(
JsArray(toolOutputs.map(Json.toJson(_)(assistantToolOutputFormat)))
)
Some(s"$threadId/runs/$runId"),
bodyParams = jsonBodyParams(
Param.metadata -> (
if (metadata.nonEmpty)
Some(metadata)
else None
)
)
).map(
_.asSafeJson[Run]
)

def submitToolOutputs(
threadId: String,
runId: String,
toolOutputs: Seq[AssistantToolOutput],
stream: Boolean
): Future[Run] =
execPOST(
EndPoint.threads,
Some(s"$threadId/runs/$runId/submit_tool_outputs"),
bodyParams = jsonBodyParams(
Param.tool_outputs -> Some(toolOutputs.map(Json.toJson(_)(assistantToolOutputFormat))),
Param.stream -> Some(stream)
)
).map(
_.asSafeJson[Run]
)

override def cancelRun(
threadId: String,
runId: String
): Future[Run] = {
execPOST(
EndPoint.threads,
Some(s"$threadId/runs/$runId/cancel")
).map(
_.asSafeJson[Run]
)

}

override def retrieveRun(
threadId: String,
runId: String
Expand All @@ -131,6 +186,31 @@ private[service] trait OpenAIServiceImpl
handleNotFoundAndError(response).map(_.asSafeJson[Run])
}

override def listRuns(
threadId: String,
pagination: Pagination,
order: Option[SortOrder] = None
): Future[Seq[Run]] =
execGET(
EndPoint.threads,
Some(s"$threadId/runs"),
params = paginationParams(pagination) :+ Param.order -> order
).map { response =>
readAttribute(response.json, "data").asSafeArray[Run]
}

override def retrieveRunStep(
threadID: String,
runId: String,
stepId: String
): Future[Option[RunStep]] =
execGETRich(
EndPoint.threads,
Some(s"$threadID/runs/$runId/steps/$stepId")
).map { response =>
handleNotFoundAndError(response).map(_.asSafeJson[RunStep])
}

override def listRunSteps(
threadId: String,
runId: String,
Expand All @@ -146,48 +226,29 @@ private[service] trait OpenAIServiceImpl
}

private def toolParams(
tools: Seq[ForcableTool],
tools: Seq[AssistantTool],
maybeResponseToolChoice: Option[ToolChoice]
): Seq[(Param, Option[JsValue])] = {
val toolJsons = tools.map {
case CodeInterpreterSpec => Map("type" -> "code_interpreter")
case FileSearchSpec => Map("type" -> "file_search")
case tool: FunctionSpec => Map("type" -> "function", "function" -> Json.toJson(tool))
}

val maybeToolChoiceParam = maybeResponseToolChoice.map {
case ToolChoice.None => "none"
case ToolChoice.Auto => "auto"
case ToolChoice.Required => "required"
case ToolChoice.EnforcedTool(FileSearchSpec) => Map("type" -> "file_search")
case ToolChoice.EnforcedTool(CodeInterpreterSpec) =>
Map("type" -> "code_interpreter")
case ToolChoice.EnforcedTool(FunctionSpec(name, _, _, _)) =>
Map("type" -> "function", "function" -> Map("name" -> name))
}

val extraParams = jsonBodyParams(
Param.tools -> Some(toolJsons),
Param.tool_choice -> maybeToolChoiceParam
Param.tools -> Some(tools.map(Json.toJson(_))),
Param.tool_choice -> maybeResponseToolChoice.map(Json.toJson(_))
)

extraParams
}

override def createChatToolCompletion(
messages: Seq[BaseMessage],
tools: Seq[ToolSpec],
tools: Seq[ChatCompletionTool],
responseToolChoice: Option[String] = None,
settings: CreateChatCompletionSettings = DefaultSettings.CreateChatFunCompletion
): Future[ChatToolCompletionResponse] = {
val coreParams =
createBodyParamsForChatCompletion(messages, settings, stream = false)

val toolJsons: Seq[Map[String, Object]] = tools.map { case tool: FunctionSpec =>
Map(
"type" -> "function",
"function" -> Json.toJson(tool)
)
val toolJsons: Seq[Map[String, Object]] = tools.map {
case tool: AssistantTool.FunctionTool =>
Map("type" -> "function", "function" -> Json.toJson(tool))
}

val extraParams = jsonBodyParams(
Expand Down Expand Up @@ -372,14 +433,12 @@ private[service] trait OpenAIServiceImpl
override def uploadFile(
file: File,
displayFileName: Option[String],
settings: UploadFileSettings
purpose: FileUploadPurpose
): Future[FileInfo] =
execPOSTMultipart(
EndPoint.files,
fileParams = Seq((Param.file, file, displayFileName)),
bodyParams = Seq(
Param.purpose -> Some(settings.purpose)
)
bodyParams = Seq(Param.purpose -> Some(purpose))
).map(
_.asSafeJson[FileInfo]
)
Expand All @@ -404,13 +463,14 @@ private[service] trait OpenAIServiceImpl
displayFileName: Option[String]
): Future[FileInfo] = {
readFile(file)
// TODO
// parse the fileContent as Seq[BatchRow] solely for the purpose of validating its structure, OpenAIScalaClientException is thrown if the parsing fails

// fileRows.map { row =>
// Json.parse(row).asSafeArray[BatchRow]
// }

uploadFile(file, displayFileName, DefaultSettings.UploadBatchFile)
uploadFile(file, displayFileName, FileUploadPurpose.batch)
}

override def buildAndUploadBatchFile(
Expand Down Expand Up @@ -484,6 +544,22 @@ private[service] trait OpenAIServiceImpl
_.asSafeJson[VectorStore]
)

override def modifyVectorStore(
vectorStoreId: String,
name: Option[String] = None,
metadata: Map[String, Any]
): Future[VectorStore] =
execPOST(
EndPoint.vector_stores,
endPointParam = Some(vectorStoreId),
bodyParams = jsonBodyParams(
Param.name -> name,
Param.metadata -> (if (metadata.nonEmpty) Some(metadata) else None)
)
).map(
_.asSafeJson[VectorStore]
)

override def listVectorStores(
pagination: Pagination,
order: Option[SortOrder]
Expand All @@ -495,6 +571,16 @@ private[service] trait OpenAIServiceImpl
readAttribute(response.json, "data").asSafeArray[VectorStore]
}

override def retrieveVectorStore(
vectorStoreId: String
): Future[Option[VectorStore]] =
execGETRich(
EndPoint.vector_stores,
endPointParam = Some(vectorStoreId)
).map { response =>
handleNotFoundAndError(response).map(_.asSafeJson[VectorStore])
}

override def deleteVectorStore(
vectorStoreId: String
): Future[DeleteResponse] =
Expand Down Expand Up @@ -535,6 +621,18 @@ private[service] trait OpenAIServiceImpl
readAttribute(response.json, "data").asSafeArray[VectorStoreFile]
}

def retrieveVectorStoreFile(
vectorStoreId: String,
fileId: FileId
): Future[VectorStoreFile] = {
execGET(
EndPoint.vector_stores,
endPointParam = Some(s"$vectorStoreId/files/${fileId.file_id}")
).map(
_.asSafeJson[VectorStoreFile]
)
}

override def deleteVectorStoreFile(
vectorStoreId: String,
fileId: String
Expand Down Expand Up @@ -798,6 +896,15 @@ private[service] trait OpenAIServiceImpl
readAttribute(response.json, "data").asSafeArray[ThreadFullMessage]
}

override def deleteThreadMessage(
threadId: String,
messageId: String
): Future[DeleteResponse] =
execDELETERich(
EndPoint.threads,
endPointParam = Some(s"$threadId/messages/$messageId")
).map(handleDeleteEndpointResponse)

override def retrieveThreadMessageFile(
threadId: String,
messageId: String,
Expand Down Expand Up @@ -830,13 +937,12 @@ private[service] trait OpenAIServiceImpl
description: Option[String],
instructions: Option[String],
tools: Seq[AssistantTool],
toolResources: Seq[AssistantToolResource] = Seq.empty[AssistantToolResource],
toolResources: Option[AssistantToolResource] = None,
metadata: Map[String, String]
): Future[Assistant] = {
val toolResourcesJson =
toolResources.map(Json.toJson(_).as[JsObject]).foldLeft(Json.obj()) { case (acc, json) =>
acc.deepMerge(json)
}
toolResources.map(Json.toJson(_).as[JsObject]).foldLeft(Json.obj()) { case (acc, json) =>
acc.deepMerge(json)
}

execPOST(
EndPoint.assistants,
Expand All @@ -846,8 +952,7 @@ private[service] trait OpenAIServiceImpl
Param.description -> Some(description),
Param.instructions -> Some(instructions),
Param.tools -> Some(Json.toJson(tools)),
Param.tool_resources -> (if (toolResources.nonEmpty) Some(toolResourcesJson)
else None),
Param.tool_resources -> toolResources.map(Json.toJson(_)),
Param.metadata -> (if (metadata.nonEmpty) Some(metadata) else None)
)
).map(
Expand Down
Loading
Loading