From 6f175e26c33e7e49b1d6fb23326f369f6fefa2f8 Mon Sep 17 00:00:00 2001 From: Ronald Mannak Date: Fri, 8 Mar 2024 11:09:11 -0800 Subject: [PATCH] Add support for new embedding models --- Sources/CleverBird/chat/ChatModel.swift | 19 +++++++++----- .../embeddings/EmbeddedDocumentStore.swift | 2 +- .../embeddings/EmbeddingModel.swift | 26 ++++++++++++++----- .../EmbeddingRequestParameters.swift | 3 +++ 4 files changed, 36 insertions(+), 14 deletions(-) diff --git a/Sources/CleverBird/chat/ChatModel.swift b/Sources/CleverBird/chat/ChatModel.swift index 8e77dae..7dc69df 100644 --- a/Sources/CleverBird/chat/ChatModel.swift +++ b/Sources/CleverBird/chat/ChatModel.swift @@ -19,17 +19,22 @@ public enum ChatModel: Codable { public func encode(to encoder: Encoder) throws { var container = encoder.singleValueContainer() - let modelString: String + let modelString = self.description + + try container.encode(modelString) + } +} +extension ChatModel: CustomStringConvertible { + public var description: String { switch self { case .gpt35Turbo: - modelString = "gpt-3.5-turbo" + return "gpt-3.5-turbo" case .gpt4: - modelString = "gpt-4" - case .specific(let specificString): - modelString = specificString + return "gpt-4" + case .specific(let string): + return string } - - try container.encode(modelString) + } } diff --git a/Sources/CleverBird/embeddings/EmbeddedDocumentStore.swift b/Sources/CleverBird/embeddings/EmbeddedDocumentStore.swift index fcc23e2..47b0f3b 100644 --- a/Sources/CleverBird/embeddings/EmbeddedDocumentStore.swift +++ b/Sources/CleverBird/embeddings/EmbeddedDocumentStore.swift @@ -24,7 +24,7 @@ public class EmbeddedDocumentStore { private var similarityMetric: SimilarityMetric public init(connection: OpenAIAPIConnection, - model: EmbeddingModel = .textEmbeddingAda002, + model: EmbeddingModel = .textEmbedding3Small, user: String? = nil, similarityMetric: SimilarityMetric = .cosine) { self.connection = connection diff --git a/Sources/CleverBird/embeddings/EmbeddingModel.swift b/Sources/CleverBird/embeddings/EmbeddingModel.swift index 0357408..fe2210e 100644 --- a/Sources/CleverBird/embeddings/EmbeddingModel.swift +++ b/Sources/CleverBird/embeddings/EmbeddingModel.swift @@ -1,6 +1,8 @@ // Created by B.T. Franklin on 7/25/23 public enum EmbeddingModel: Codable { + case textEmbedding3Small + case textEmbedding3Large case textEmbeddingAda002 case specific(String) @@ -9,6 +11,10 @@ public enum EmbeddingModel: Codable { let modelString = try container.decode(String.self) switch modelString { + case "text-embedding-3-small": + self = .textEmbedding3Small + case "text-embedding-3-large": + self = .textEmbedding3Large case "text-embedding-ada-002": self = .textEmbeddingAda002 default: @@ -18,15 +24,23 @@ public enum EmbeddingModel: Codable { public func encode(to encoder: Encoder) throws { var container = encoder.singleValueContainer() - let modelString: String + let modelString = self.description + try container.encode(modelString) + } +} + +extension EmbeddingModel: CustomStringConvertible { + public var description: String { switch self { + case .textEmbedding3Small: + return "text-embedding-3-small" + case .textEmbedding3Large: + return "text-embedding-3-large" case .textEmbeddingAda002: - modelString = "text-embedding-ada-002" - case .specific(let specificString): - modelString = specificString + return "text-embedding-ada-002" + case .specific(let string): + return string } - - try container.encode(modelString) } } diff --git a/Sources/CleverBird/embeddings/EmbeddingRequestParameters.swift b/Sources/CleverBird/embeddings/EmbeddingRequestParameters.swift index fc8b0df..981bdc3 100644 --- a/Sources/CleverBird/embeddings/EmbeddingRequestParameters.swift +++ b/Sources/CleverBird/embeddings/EmbeddingRequestParameters.swift @@ -3,13 +3,16 @@ public struct EmbeddingRequestParameters: Encodable { public let model: EmbeddingModel public let input: [String] + public let dimensions: Int? public let user: String? public init(model: EmbeddingModel, input: [String], + dimensions: Int? = nil, user: String? = nil) { self.model = model self.input = input + self.dimensions = dimensions self.user = user } }