Skip to content

Commit

Permalink
Various changes, including refactoring that breaks out the Connection…
Browse files Browse the repository at this point in the history
… concept to not be part of Thread instantiation.
  • Loading branch information
btfranklin committed Jan 16, 2024
1 parent 6ebaaad commit 41172ee
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 64 deletions.
2 changes: 1 addition & 1 deletion Sources/CleverBird/chat/ChatMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public struct ChatMessage: Codable, Identifiable {
self.content = content
self.name = functionCall?.name
if role == .function {
// Attention: if the role is function I need to set the functionCall to nil, otherwise this will
// If the role is "function" I need to set functionCall to nil, otherwise this will
// be encoded into the message which leads to an error.
self.functionCall = nil
} else {
Expand Down
23 changes: 12 additions & 11 deletions Sources/CleverBird/chat/ChatThread+complete.swift
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
// Created by B.T. Franklin on 5/5/23

extension ChatThread {
public func complete(model: ChatModel? = nil,
temperature: Percentage? = nil,
public func complete(using connection: OpenAIAPIConnection,
model: ChatModel = .gpt4,
temperature: Percentage = 0.7,
topP: Percentage? = nil,
stop: [String]? = nil,
maxTokens: Int? = nil,
Expand All @@ -12,13 +13,13 @@ extension ChatThread {
functionCallMode: FunctionCallMode? = nil) async throws -> ChatMessage {

let requestBody = ChatCompletionRequestParameters(
model: model ?? self.model,
temperature: temperature ?? self.temperature,
topP: topP ?? self.topP,
stop: stop ?? self.stop,
maxTokens: maxTokens ?? self.maxTokens,
presencePenalty: presencePenalty ?? self.presencePenalty,
frequencyPenalty: frequencyPenalty ?? self.frequencyPenalty,
model: model,
temperature: temperature,
topP: topP,
stop: stop,
maxTokens: maxTokens,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
user: self.user,
messages: self.messages,
functions: functions ?? self.functions,
Expand All @@ -36,8 +37,8 @@ extension ChatThread {
FunctionRegistry.shared.clearFunctions()
}

let request = try await self.connection.createChatCompletionRequest(for: requestBody)
let response = try await self.connection.client.send(request)
let request = try await connection.createChatCompletionRequest(for: requestBody)
let response = try await connection.client.send(request)
let completion = response.value
guard let firstChoiceMessage = completion.choices.first?.message else {
throw CleverBirdError.responseParsingFailed(message: "No message choice was available in completion response.")
Expand Down
4 changes: 2 additions & 2 deletions Sources/CleverBird/chat/ChatThread+tokenCount.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import Foundation

extension ChatThread {
public func tokenCount() throws -> Int {
public func tokenCount(using model: ChatModel = .gpt4) throws -> Int {

let tokenEncoder: TokenEncoder
do {
Expand All @@ -14,7 +14,7 @@ extension ChatThread {

var tokensPerMessage: Int

switch self.model {
switch model {
case .gpt35Turbo:
tokensPerMessage = 4
case .gpt4:
Expand Down
30 changes: 3 additions & 27 deletions Sources/CleverBird/chat/ChatThread.swift
Original file line number Diff line number Diff line change
@@ -1,38 +1,14 @@
// Created by B.T. Franklin on 5/5/23

public class ChatThread {
public class ChatThread: Codable {

let connection: OpenAIAPIConnection
let model: ChatModel
let temperature: Percentage
let topP: Percentage?
let stop: [String]?
let maxTokens: Int?
let presencePenalty: Penalty?
let frequencyPenalty: Penalty?
let user: String?

var messages: [ChatMessage] = []
var functions: [Function]?

public init(connection: OpenAIAPIConnection,
model: ChatModel = .gpt4,
temperature: Percentage = 0.7,
topP: Percentage? = nil,
stop: [String]? = nil,
maxTokens: Int? = nil,
presencePenalty: Penalty? = nil,
frequencyPenalty: Penalty? = nil,
user: String? = nil,
public init(user: String? = nil,
functions: [Function]? = nil) {
self.connection = connection
self.model = model
self.temperature = temperature
self.topP = topP
self.stop = stop
self.maxTokens = maxTokens
self.presencePenalty = presencePenalty
self.frequencyPenalty = frequencyPenalty
self.user = user
self.functions = functions
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@ import Foundation

extension StreamableChatThread {

public func complete(model: ChatModel? = nil,
temperature: Percentage? = nil,
public func complete(using connection: OpenAIAPIConnection,
model: ChatModel = .gpt4,
temperature: Percentage = 0.7,
topP: Percentage? = nil,
stop: [String]? = nil,
maxTokens: Int? = nil,
presencePenalty: Penalty? = nil,
frequencyPenalty: Penalty? = nil) async throws -> AsyncThrowingStream<String, Swift.Error> {

let requestBody = ChatCompletionRequestParameters(
model: model ?? self.chatThread.model,
temperature: temperature ?? self.chatThread.temperature,
topP: topP ?? self.chatThread.topP,
model: model,
temperature: temperature,
topP: topP,
stream: true,
stop: stop ?? self.chatThread.stop,
maxTokens: maxTokens ?? self.chatThread.maxTokens,
presencePenalty: presencePenalty ?? self.chatThread.presencePenalty,
frequencyPenalty: frequencyPenalty ?? self.chatThread.frequencyPenalty,
stop: stop,
maxTokens: maxTokens,
presencePenalty: presencePenalty,
frequencyPenalty: frequencyPenalty,
user: self.chatThread.user,
messages: self.chatThread.messages
)
Expand All @@ -30,7 +31,7 @@ extension StreamableChatThread {
self.addMessage(message)
}

let asyncByteStream = try await self.chatThread.connection.createChatCompletionAsyncByteStream(for: requestBody)
let asyncByteStream = try await connection.createChatCompletionAsyncByteStream(for: requestBody)

return AsyncThrowingStream { [weak self] continuation in
guard let strongSelf = self else {
Expand Down Expand Up @@ -64,6 +65,7 @@ extension StreamableChatThread {
do {
for try await line in asyncByteStream.lines {
guard let responseChunk = ChatStreamedResponseChunk.decode(from: line) else {
print(line)
break
}

Expand Down
16 changes: 3 additions & 13 deletions Tests/CleverBirdTests/OpenAIChatThreadTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ class OpenAIChatThreadTests: XCTestCase {

func testThreadLength() async {
let userMessageContent = "Who won the world series in 2020?"
let openAIAPIConnection = OpenAIAPIConnection(apiKey: "fake_api_key")
let chatThread = ChatThread(connection: openAIAPIConnection)
let chatThread = ChatThread()
.addSystemMessage("You are a helpful assistant.")
.addUserMessage(userMessageContent)

Expand All @@ -18,8 +17,7 @@ class OpenAIChatThreadTests: XCTestCase {
}

func testTokenCount() throws {
let openAIAPIConnection = OpenAIAPIConnection(apiKey: "fake_api_key")
let chatThread = ChatThread(connection: openAIAPIConnection)
let chatThread = ChatThread()
.addSystemMessage("You are a helpful assistant.")
.addUserMessage("Who won the world series in 2020?")
let tokenCount = try chatThread.tokenCount()
Expand Down Expand Up @@ -59,24 +57,16 @@ class OpenAIChatThreadTests: XCTestCase {
description: "Get an N-day weather forecast",
parameters: getNDayWeatherForecastParameters)

let openAIAPIConnection = OpenAIAPIConnection(apiKey: "fake_api_key")
let functionCall = FunctionCall(name: "testFunc", arguments: ["arg1": .string("value1")])
let chatThread = ChatThread(connection: openAIAPIConnection)
_ = ChatThread()
.addSystemMessage("You are a helpful assistant.")
.setFunctions([getCurrentWeather, getNDayWeatherForecast])
.addMessage(try! ChatMessage(role: .assistant, functionCall: functionCall))
let tokenCount = try chatThread.tokenCount()

XCTAssertEqual(tokenCount, 241, "Unexpected token count")
}

func testInvalidMessageCreation() {
let functionCall = FunctionCall(name: "testFunc", arguments: ["arg1": .string("value1")])
XCTAssertThrowsError(try ChatMessage(role: .assistant)) { error in
XCTAssertEqual(error as? CleverBirdError, CleverBirdError.invalidMessageContent)
}
XCTAssertThrowsError(try ChatMessage(role: .function, functionCall: functionCall)) { error in
XCTAssertEqual(error as? CleverBirdError, CleverBirdError.invalidFunctionMessage)
}
}
}

0 comments on commit 41172ee

Please sign in to comment.