diff --git a/Documentation~/README.md b/Documentation~/README.md index 865b58a7..3ac7f923 100644 --- a/Documentation~/README.md +++ b/Documentation~/README.md @@ -257,7 +257,39 @@ Debug.Log(result.FirstChoice); ##### [Chat Streaming](https://platform.openai.com/docs/api-reference/chat/create#chat/create-stream) ```csharp -TODO +var api = new OpenAIClient(); +var chatPrompts = new List +{ + new ChatPrompt("system", "You are a helpful assistant."), + new ChatPrompt("user", "Who won the world series in 2020?"), + new ChatPrompt("assistant", "The Los Angeles Dodgers won the World Series in 2020."), + new ChatPrompt("user", "Where was it played?"), +}; +var chatRequest = new ChatRequest(chatPrompts, Model.GPT3_5_Turbo); + +await api.ChatEndpoint.StreamCompletionAsync(chatRequest, result => +{ + Debug.Log(result.FirstChoice); +}); +``` + +Or if using [`IAsyncEnumerable{T}`](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.iasyncenumerable-1?view=net-5.0) ([C# 8.0+](https://docs.microsoft.com/en-us/archive/msdn-magazine/2019/november/csharp-iterating-with-async-enumerables-in-csharp-8)) + +```csharp +var api = new OpenAIClient(); +var chatPrompts = new List +{ + new ChatPrompt("system", "You are a helpful assistant."), + new ChatPrompt("user", "Who won the world series in 2020?"), + new ChatPrompt("assistant", "The Los Angeles Dodgers won the World Series in 2020."), + new ChatPrompt("user", "Where was it played?"), +}; +var chatRequest = new ChatRequest(chatPrompts, Model.GPT3_5_Turbo); + +await foreach (var result in api.ChatEndpoint.StreamCompletionEnumerableAsync(chatRequest)) +{ + Debug.Log(result.FirstChoice); +} ``` ### [Edits](https://beta.openai.com/docs/api-reference/edits) diff --git a/Runtime/Audio/AudioTranscriptionRequest.cs b/Runtime/Audio/AudioTranscriptionRequest.cs index fb3d2076..3a5d3e74 100644 --- a/Runtime/Audio/AudioTranscriptionRequest.cs +++ b/Runtime/Audio/AudioTranscriptionRequest.cs @@ -149,11 +149,11 @@ public AudioTranscriptionRequest( AudioName = audioName; - Model = model ?? new Model("whisper-1"); + Model = model ?? Models.Model.Whisper1; if (!Model.Contains("whisper")) { - throw new ArgumentException(nameof(model), $"{Model} is not supported."); + throw new ArgumentException($"{Model} is not supported", nameof(model)); } Prompt = prompt; diff --git a/Runtime/Audio/AudioTranslationRequest.cs b/Runtime/Audio/AudioTranslationRequest.cs index e68168b0..41b2b55e 100644 --- a/Runtime/Audio/AudioTranslationRequest.cs +++ b/Runtime/Audio/AudioTranslationRequest.cs @@ -119,11 +119,11 @@ public AudioTranslationRequest( AudioName = audioName; - Model = model ?? new Model("whisper-1"); + Model = model ?? Models.Model.Whisper1; if (!Model.Contains("whisper")) { - throw new ArgumentException(nameof(model), $"{Model} is not supported."); + throw new ArgumentException($"{Model} is not supported", nameof(model)); } Prompt = prompt; diff --git a/Runtime/Chat/ChatEndpoint.cs b/Runtime/Chat/ChatEndpoint.cs index d5bd1d5e..aca74fdd 100644 --- a/Runtime/Chat/ChatEndpoint.cs +++ b/Runtime/Chat/ChatEndpoint.cs @@ -1,6 +1,11 @@ // Licensed under the MIT License. See LICENSE in the project root for license information. using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.IO; +using System.Net.Http; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -16,19 +21,100 @@ protected override string GetEndpoint() => $"{Api.BaseUrl}chat"; /// - /// Creates a completion for the chat message + /// Creates a completion for the chat message. /// /// The chat request which contains the message content. /// Optional, . /// . + /// Raised when the HTTP request fails public async Task GetCompletionAsync(ChatRequest chatRequest, CancellationToken cancellationToken = default) { var payload = JsonConvert.SerializeObject(chatRequest, Api.JsonSerializationOptions).ToJsonStringContent(); - var result = await Api.Client.PostAsync($"{GetEndpoint()}/completions", payload, cancellationToken); - var resultAsString = await result.ReadAsStringAsync(); - return JsonConvert.DeserializeObject(resultAsString, Api.JsonSerializationOptions); + var response = await Api.Client.PostAsync($"{GetEndpoint()}/completions", payload, cancellationToken); + var responseAsString = await response.ReadAsStringAsync(); + return response.DeserializeResponse(responseAsString, Api.JsonSerializationOptions); } - // TODO Streaming endpoints + /// + /// Created a completion for the chat message and stream the results to the as they come in. + /// + /// The chat request which contains the message content. + /// An action to be called as each new result arrives. + /// Optional, . + /// . + /// Raised when the HTTP request fails + public async Task StreamCompletionAsync(ChatRequest chatRequest, Action resultHandler, CancellationToken cancellationToken = default) + { + chatRequest.Stream = true; + var jsonContent = JsonConvert.SerializeObject(chatRequest, Api.JsonSerializationOptions); + using var request = new HttpRequestMessage(HttpMethod.Post, $"{GetEndpoint()}/completions") + { + Content = jsonContent.ToJsonStringContent() + }; + var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + await response.CheckResponseAsync(cancellationToken); + await using var stream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(stream); + + while (await reader.ReadLineAsync() is { } line) + { + if (line.StartsWith("data: ")) + { + line = line["data: ".Length..]; + } + + if (line == "[DONE]") + { + return; + } + + if (!string.IsNullOrWhiteSpace(line)) + { + resultHandler(response.DeserializeResponse(line.Trim(), Api.JsonSerializationOptions)); + } + } + } + + /// + /// Created a completion for the chat message and stream the results as they come in.
+ /// If you are not using C# 8 supporting IAsyncEnumerable{T} or if you are using the .NET Framework, + /// you may need to use instead. + ///
+ /// The chat request which contains the message content. + /// Optional, . + /// . + /// Raised when the HTTP request fails + public async IAsyncEnumerable StreamCompletionEnumerableAsync(ChatRequest chatRequest, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + chatRequest.Stream = true; + var jsonContent = JsonConvert.SerializeObject(chatRequest, Api.JsonSerializationOptions); + using var request = new HttpRequestMessage(HttpMethod.Post, $"{GetEndpoint()}/completions") + { + Content = jsonContent.ToJsonStringContent() + }; + var response = await Api.Client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + await response.CheckResponseAsync(cancellationToken); + await using var stream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(stream); + + while (await reader.ReadLineAsync() is { } line && + !cancellationToken.IsCancellationRequested) + { + if (line.StartsWith("data: ")) + { + line = line["data: ".Length..]; + } + + if (line == "[DONE]") + { + yield break; + } + + if (!string.IsNullOrWhiteSpace(line)) + { + yield return response.DeserializeResponse(line.Trim(), Api.JsonSerializationOptions); + } + } + } } } diff --git a/Runtime/Chat/ChatRequest.cs b/Runtime/Chat/ChatRequest.cs index 8fd92c79..a680d6dd 100644 --- a/Runtime/Chat/ChatRequest.cs +++ b/Runtime/Chat/ChatRequest.cs @@ -10,6 +10,62 @@ namespace OpenAI.Chat { public sealed class ChatRequest { + /// + /// Constructor. + /// + /// + /// + /// ID of the model to use. Currently, only gpt-3.5-turbo and gpt-3.5-turbo-0301 are supported. + /// + /// + /// What sampling temperature to use, between 0 and 2. + /// Higher values like 0.8 will make the output more random, while lower values like 0.2 will + /// make it more focused and deterministic. + /// We generally recommend altering this or top_p but not both.
+ /// Defaults to 1 + /// + /// + /// An alternative to sampling with temperature, called nucleus sampling, + /// where the model considers the results of the tokens with top_p probability mass. + /// So 0.1 means only the tokens comprising the top 10% probability mass are considered. + /// We generally recommend altering this or temperature but not both.
+ /// Defaults to 1 + /// + /// + /// How many chat completion choices to generate for each input message.
+ /// Defaults to 1 + /// + /// + /// Up to 4 sequences where the API will stop generating further tokens. + /// + /// + /// The maximum number of tokens allowed for the generated answer. + /// By default, the number of tokens the model can return will be (4096 - prompt tokens). + /// + /// + /// Number between -2.0 and 2.0. + /// Positive values penalize new tokens based on whether they appear in the text so far, + /// increasing the model's likelihood to talk about new topics.
+ /// Defaults to 0 + /// + /// + /// Number between -2.0 and 2.0. + /// Positive values penalize new tokens based on their existing frequency in the text so far, + /// decreasing the model's likelihood to repeat the same line verbatim.
+ /// Defaults to 0 + /// + /// + /// Modify the likelihood of specified tokens appearing in the completion. + /// Accepts a json object that maps tokens(specified by their token ID in the tokenizer) + /// to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits + /// generated by the model prior to sampling.The exact effect will vary per model, but values between + /// -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result + /// in a ban or exclusive selection of the relevant token.
+ /// Defaults to null + /// + /// + /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + /// public ChatRequest( IEnumerable messages, Model model = null, @@ -23,12 +79,11 @@ public ChatRequest( Dictionary logitBias = null, string user = null) { - const string defaultModel = "gpt-3.5-turbo"; Model = model ?? Models.Model.GPT3_5_Turbo; - if (!Model.Contains(defaultModel)) + if (!Model.Contains("turbo")) { - throw new ArgumentException(nameof(model), $"{Model} not supported"); + throw new ArgumentException($"{Model} is not supported", nameof(model)); } Messages = messages?.ToList(); @@ -127,7 +182,8 @@ public ChatRequest( [JsonProperty("frequency_penalty")] public double? FrequencyPenalty { get; } - /// Modify the likelihood of specified tokens appearing in the completion. + /// + /// Modify the likelihood of specified tokens appearing in the completion. /// Accepts a json object that maps tokens(specified by their token ID in the tokenizer) /// to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits /// generated by the model prior to sampling.The exact effect will vary per model, but values between @@ -136,7 +192,7 @@ public ChatRequest( /// Defaults to null /// [JsonProperty("logit_bias")] - public IReadOnlyDictionary LogitBias { get; set; } + public IReadOnlyDictionary LogitBias { get; } /// /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. @@ -144,6 +200,7 @@ public ChatRequest( [JsonProperty("user")] public string User { get; } + /// public override string ToString() => JsonConvert.SerializeObject(this); } } diff --git a/Runtime/Chat/ChatResponse.cs b/Runtime/Chat/ChatResponse.cs index 3b266022..e76eccc1 100644 --- a/Runtime/Chat/ChatResponse.cs +++ b/Runtime/Chat/ChatResponse.cs @@ -6,7 +6,7 @@ namespace OpenAI.Chat { - public sealed class ChatResponse + public sealed class ChatResponse : BaseResponse { [JsonConstructor] public ChatResponse( @@ -15,8 +15,7 @@ public ChatResponse( [JsonProperty("created")] int created, [JsonProperty("model")] string model, [JsonProperty("usage")] Usage usage, - [JsonProperty("choices")] List choices - ) + [JsonProperty("choices")] List choices) { Id = id; Object = @object; diff --git a/Runtime/Chat/Choice.cs b/Runtime/Chat/Choice.cs index d84638b8..b2a15c3a 100644 --- a/Runtime/Chat/Choice.cs +++ b/Runtime/Chat/Choice.cs @@ -9,10 +9,12 @@ public sealed class Choice [JsonConstructor] public Choice( [JsonProperty("message")] Message message, + [JsonProperty("delta")] Delta delta, [JsonProperty("finish_reason")] string finishReason, [JsonProperty("index")] int index) { Message = message; + Delta = delta; FinishReason = finishReason; Index = index; } @@ -20,13 +22,16 @@ public Choice( [JsonProperty("message")] public Message Message { get; } + [JsonProperty("delta")] + public Delta Delta { get; } + [JsonProperty("finish_reason")] public string FinishReason { get; } [JsonProperty("index")] public int Index { get; } - public override string ToString() => Message.ToString(); + public override string ToString() => Message?.ToString() ?? Delta.Content; public static implicit operator string(Choice choice) => choice.ToString(); } diff --git a/Runtime/Chat/Delta.cs b/Runtime/Chat/Delta.cs new file mode 100644 index 00000000..934101b0 --- /dev/null +++ b/Runtime/Chat/Delta.cs @@ -0,0 +1,24 @@ +// Licensed under the MIT License. See LICENSE in the project root for license information. + +using Newtonsoft.Json; + +namespace OpenAI.Chat +{ + public sealed class Delta + { + [JsonConstructor] + public Delta( + [JsonProperty("role")] string role, + [JsonProperty("content")] string content) + { + Role = role; + Content = content; + } + + [JsonProperty("role")] + public string Role { get; } + + [JsonProperty("content")] + public string Content { get; } + } +} diff --git a/Runtime/Chat/Delta.cs.meta b/Runtime/Chat/Delta.cs.meta new file mode 100644 index 00000000..f277ecc8 --- /dev/null +++ b/Runtime/Chat/Delta.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: dc31f29bf97d67442ab164d5fd5df7fa +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/Runtime/Completions/CompletionRequest.cs b/Runtime/Completions/CompletionRequest.cs index 4f988ace..22da9698 100644 --- a/Runtime/Completions/CompletionRequest.cs +++ b/Runtime/Completions/CompletionRequest.cs @@ -202,7 +202,7 @@ public CompletionRequest(CompletionRequest basedOn) return; } - Model = basedOn.Model; + Model = basedOn.Model ?? DefaultCompletionRequestArgs?.Model ?? Models.Model.Davinci; Prompts = basedOn.Prompts; Suffix = basedOn.Suffix ?? DefaultCompletionRequestArgs?.Suffix; MaxTokens = basedOn.MaxTokens ?? DefaultCompletionRequestArgs?.MaxTokens; @@ -280,7 +280,17 @@ public CompletionRequest( throw new ArgumentNullException($"Missing required {nameof(prompt)}(s)"); } - Model = model ?? DefaultCompletionRequestArgs?.Model; + Model = model ?? DefaultCompletionRequestArgs?.Model ?? Models.Model.Davinci; + + if (!Model.Contains("davinci") && + !Model.Contains("curie") && + !Model.Contains("cushman") && + !Model.Contains("babbage") && + !Model.Contains("ada")) + { + throw new ArgumentException($"{Model} is not supported", nameof(model)); + } + Suffix = suffix ?? DefaultCompletionRequestArgs?.Suffix; MaxTokens = maxTokens ?? DefaultCompletionRequestArgs?.MaxTokens; Temperature = temperature ?? DefaultCompletionRequestArgs?.Temperature; diff --git a/Runtime/Completions/CompletionsEndpoint.cs b/Runtime/Completions/CompletionsEndpoint.cs index eebf76a0..b25266bb 100644 --- a/Runtime/Completions/CompletionsEndpoint.cs +++ b/Runtime/Completions/CompletionsEndpoint.cs @@ -59,7 +59,7 @@ protected override string GetEndpoint() /// One or more sequences where the API will stop generating further tokens. /// The returned text will not contain the stop sequence. /// Optional, to use when calling the API. - /// Defaults to . + /// Defaults to . /// Optional, . /// Asynchronously returns the completion result. /// Look in its property for the completions. @@ -80,7 +80,7 @@ public async Task CreateCompletionAsync( CancellationToken cancellationToken = default) { var request = new CompletionRequest( - model ?? Api.DefaultModel, + model ?? Model.Davinci, prompt, prompts, suffix, @@ -112,7 +112,7 @@ public async Task CreateCompletionAsync(CompletionRequest comp var jsonContent = JsonConvert.SerializeObject(completionRequest, Api.JsonSerializationOptions); var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent(), cancellationToken); var responseAsString = await response.ReadAsStringAsync(); - return DeserializeResult(response, responseAsString); + return response.DeserializeResponse(responseAsString, Api.JsonSerializationOptions); } #endregion Non-Streaming @@ -149,7 +149,7 @@ public async Task CreateCompletionAsync(CompletionRequest comp /// One or more sequences where the API will stop generating further tokens. /// The returned text will not contain the stop sequence. /// Optional, to use when calling the API. - /// Defaults to . + /// Defaults to . /// Optional, . /// An async enumerable with each of the results as they come in. /// See the C# docs @@ -172,7 +172,7 @@ public async Task StreamCompletionAsync( CancellationToken cancellationToken = default) { var request = new CompletionRequest( - model ?? Api.DefaultModel, + model ?? Model.Davinci, prompt, prompts, suffix, @@ -224,13 +224,13 @@ public async Task StreamCompletionAsync(CompletionRequest completionRequest, Act if (!string.IsNullOrWhiteSpace(line)) { - resultHandler(DeserializeResult(response, line.Trim())); + resultHandler(response.DeserializeResponse(line.Trim(), Api.JsonSerializationOptions)); } } } /// - /// Ask the API to complete the prompt(s) using the specified parameters. + /// Ask the API to complete the prompt(s) using the specified parameters.
/// If you are not using C# 8 supporting IAsyncEnumerable{T} or if you are using the .NET Framework, /// you may need to use instead. ///
@@ -259,7 +259,7 @@ public async Task StreamCompletionAsync(CompletionRequest completionRequest, Act /// One or more sequences where the API will stop generating further tokens. /// The returned text will not contain the stop sequence. /// Optional, to use when calling the API. - /// Defaults to . + /// Defaults to . /// Optional, . /// An async enumerable with each of the results as they come in. /// See the C# docs @@ -281,7 +281,7 @@ public IAsyncEnumerable StreamCompletionEnumerableAsync( CancellationToken cancellationToken = default) { var request = new CompletionRequest( - model ?? Api.DefaultModel, + model ?? Model.Davinci, prompt, prompts, suffix, @@ -299,7 +299,7 @@ public IAsyncEnumerable StreamCompletionEnumerableAsync( } /// - /// Ask the API to complete the prompt(s) using the specified request, and stream the results as they come in. + /// Ask the API to complete the prompt(s) using the specified request, and stream the results as they come in.
/// If you are not using C# 8 supporting IAsyncEnumerable{T} or if you are using the .NET Framework, /// you may need to use instead. ///
@@ -337,25 +337,11 @@ public async IAsyncEnumerable StreamCompletionEnumerableAsync( if (!string.IsNullOrWhiteSpace(line)) { - yield return DeserializeResult(response, line.Trim()); + yield return response.DeserializeResponse(line.Trim(), Api.JsonSerializationOptions); } } } #endregion Streaming - - private CompletionResult DeserializeResult(HttpResponseMessage response, string json) - { - var result = JsonConvert.DeserializeObject(json, Api.JsonSerializationOptions); - - if (result?.Completions == null || result.Completions.Count == 0) - { - throw new HttpRequestException($"{nameof(DeserializeResult)} no completions! HTTP status code: {response.StatusCode}. Response body: {json}"); - } - - result.SetResponseData(response.Headers); - - return result; - } } } diff --git a/Runtime/Edits/EditRequest.cs b/Runtime/Edits/EditRequest.cs index 63fbdf5b..90a1d8e1 100644 --- a/Runtime/Edits/EditRequest.cs +++ b/Runtime/Edits/EditRequest.cs @@ -38,7 +38,7 @@ public EditRequest( if (!Model.Contains("-edit-")) { - throw new ArgumentException(nameof(model), $"{Model} does not support editing"); + throw new ArgumentException($"{Model} is not supported", nameof(model)); } Input = input; diff --git a/Runtime/Edits/EditsEndpoint.cs b/Runtime/Edits/EditsEndpoint.cs index 88c3ba80..3a4532f9 100644 --- a/Runtime/Edits/EditsEndpoint.cs +++ b/Runtime/Edits/EditsEndpoint.cs @@ -1,9 +1,8 @@ // Licensed under the MIT License. See LICENSE in the project root for license information. using Newtonsoft.Json; -using System.Net.Http; -using System.Threading.Tasks; using OpenAI.Models; +using System.Threading.Tasks; namespace OpenAI.Edits { @@ -53,24 +52,16 @@ public async Task CreateEditAsync( } /// - /// Creates a new edit for the provided input, instruction, and parameters + /// Creates a new edit for the provided input, instruction, and parameters. /// - /// - /// + /// . + /// . public async Task CreateEditAsync(EditRequest request) { var jsonContent = JsonConvert.SerializeObject(request, Api.JsonSerializationOptions); var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent()); var resultAsString = await response.ReadAsStringAsync(); - var editResponse = JsonConvert.DeserializeObject(resultAsString, Api.JsonSerializationOptions); - editResponse.SetResponseData(response.Headers); - - if (editResponse == null) - { - throw new HttpRequestException($"{nameof(CreateEditAsync)} returned no results! HTTP status code: {response.StatusCode}. Response body: {resultAsString}"); - } - - return editResponse; + return response.DeserializeResponse(resultAsString, Api.JsonSerializationOptions); } } } diff --git a/Runtime/Embeddings/EmbeddingsEndpoint.cs b/Runtime/Embeddings/EmbeddingsEndpoint.cs index 45deaa27..d925f090 100644 --- a/Runtime/Embeddings/EmbeddingsEndpoint.cs +++ b/Runtime/Embeddings/EmbeddingsEndpoint.cs @@ -3,7 +3,6 @@ using Newtonsoft.Json; using OpenAI.Models; using System.Collections.Generic; -using System.Net.Http; using System.Threading.Tasks; namespace OpenAI.Embeddings @@ -67,16 +66,8 @@ public async Task CreateEmbeddingAsync(EmbeddingsRequest req { var jsonContent = JsonConvert.SerializeObject(request, Api.JsonSerializationOptions); var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent()); - var resultAsString = await response.ReadAsStringAsync(); - var embeddingsResponse = JsonConvert.DeserializeObject(resultAsString, Api.JsonSerializationOptions); - embeddingsResponse.SetResponseData(response.Headers); - - if (embeddingsResponse == null) - { - throw new HttpRequestException($"{nameof(CreateEmbeddingAsync)} returned no results! HTTP status code: {response.StatusCode}. Response body: {resultAsString}"); - } - - return embeddingsResponse; + var responseAsString = await response.ReadAsStringAsync(); + return response.DeserializeResponse(responseAsString, Api.JsonSerializationOptions); } } } diff --git a/Runtime/Embeddings/EmbeddingsRequest.cs b/Runtime/Embeddings/EmbeddingsRequest.cs index 0530515b..71c1d70b 100644 --- a/Runtime/Embeddings/EmbeddingsRequest.cs +++ b/Runtime/Embeddings/EmbeddingsRequest.cs @@ -44,8 +44,8 @@ public EmbeddingsRequest(string input, Model model = null, string user = null) /// Each input must not exceed 8192 tokens in length. /// /// - /// The to use. - /// Defaults to: text-embedding-ada-002 + /// The to use.
+ /// Defaults to: /// /// /// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. @@ -60,11 +60,13 @@ public EmbeddingsRequest(IEnumerable input, Model model = null, string u throw new ArgumentNullException(nameof(input), $"Missing required {nameof(input)} parameter"); } - Model = model ?? new Model("text-embedding-ada-002"); + Model = model ?? Models.Model.Embedding_Ada_002; - if (!Model.Contains("text-embedding")) + if (!Model.Contains("embedding") && + !Model.Contains("search") && + !Model.Contains("similarity")) { - throw new ArgumentException(nameof(model), $"{Model} is not supported for embedding."); + throw new ArgumentException($"{Model} is not supported", nameof(model)); } User = user; diff --git a/Runtime/FineTuning/FineTuningEndpoint.cs b/Runtime/FineTuning/FineTuningEndpoint.cs index 26cd0122..f202646a 100644 --- a/Runtime/FineTuning/FineTuningEndpoint.cs +++ b/Runtime/FineTuning/FineTuningEndpoint.cs @@ -53,9 +53,7 @@ public async Task CreateFineTuneJobAsync(CreateFineTuneJobRequest j var jsonContent = JsonConvert.SerializeObject(jobRequest, Api.JsonSerializationOptions); var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent()); var responseAsString = await response.ReadAsStringAsync(); - var result = JsonConvert.DeserializeObject(responseAsString, Api.JsonSerializationOptions); - result.SetResponseData(response.Headers); - return result; + return response.DeserializeResponse(responseAsString, Api.JsonSerializationOptions); } /// @@ -80,9 +78,7 @@ public async Task RetrieveFineTuneJobInfoAsync(string jobId) { var response = await Api.Client.GetAsync($"{GetEndpoint()}/{jobId}"); var responseAsString = await response.ReadAsStringAsync(); - var result = JsonConvert.DeserializeObject(responseAsString, Api.JsonSerializationOptions); - result.SetResponseData(response.Headers); - return result; + return response.DeserializeResponse(responseAsString, Api.JsonSerializationOptions); } /// @@ -95,8 +91,7 @@ public async Task CancelFineTuneJobAsync(string jobId) { var response = await Api.Client.PostAsync($"{GetEndpoint()}/{jobId}/cancel", null!); var responseAsString = await response.ReadAsStringAsync(); - var result = JsonConvert.DeserializeObject(responseAsString, Api.JsonSerializationOptions); - result.SetResponseData(response.Headers); + var result = response.DeserializeResponse(responseAsString, Api.JsonSerializationOptions); return result.Status == "cancelled"; } diff --git a/Runtime/Models/Model.cs b/Runtime/Models/Model.cs index d6face40..be945ab1 100644 --- a/Runtime/Models/Model.cs +++ b/Runtime/Models/Model.cs @@ -1,12 +1,14 @@ // Licensed under the MIT License. See LICENSE in the project root for license information. using Newtonsoft.Json; +using System; using System.Collections.Generic; namespace OpenAI.Models { /// - /// Represents a language model. + /// Represents a language model.
+ /// ///
public sealed class Model { @@ -69,6 +71,7 @@ public Model( /// /// The default Model to use in the case no other is specified. Defaults to /// + [Obsolete("Will be removed in next major release.")] public static Model Default => Davinci; /// @@ -100,5 +103,20 @@ public Model( /// Good at: Parsing text, simple classification, address correction, keywords /// public static Model Ada { get; } = new Model("text-ada-001") { OwnedBy = "openai" }; + + /// + /// The default model for . + /// + public static Model Embedding_Ada_002 { get; } = new Model("text-embedding-ada-002") { OwnedBy = "openai" }; + + /// + /// The default model for . + /// + public static Model Whisper1 { get; } = new Model("whisper-1") { OwnedBy = "openai" }; + + /// + /// The default model for . + /// + public static Model Moderation_Latest { get; } = new Model("text-moderation-latest") { OwnedBy = "openai" }; } } diff --git a/Runtime/Moderations/ModerationsEndpoint.cs b/Runtime/Moderations/ModerationsEndpoint.cs index a60fb286..1c2f273e 100644 --- a/Runtime/Moderations/ModerationsEndpoint.cs +++ b/Runtime/Moderations/ModerationsEndpoint.cs @@ -56,18 +56,10 @@ public async Task GetModerationAsync(string input, Model model = null) /// Raised when the HTTP request fails public async Task CreateModerationAsync(ModerationsRequest request) { - var jsonContent = JsonConvert.SerializeObject(request, Api.JsonSerializationOptions); - var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent.ToJsonStringContent()); + var jsonContent = JsonConvert.SerializeObject(request, Api.JsonSerializationOptions).ToJsonStringContent(); + var response = await Api.Client.PostAsync(GetEndpoint(), jsonContent); var resultAsString = await response.ReadAsStringAsync(); - var moderationResponse = JsonConvert.DeserializeObject(resultAsString, Api.JsonSerializationOptions); - - if (moderationResponse == null) - { - throw new HttpRequestException($"{nameof(CreateModerationAsync)} returned no results! HTTP status code: {response.StatusCode}. Response body: {resultAsString}"); - } - - moderationResponse.SetResponseData(response.Headers); - return moderationResponse; + return response.DeserializeResponse(resultAsString, Api.JsonSerializationOptions); } } } diff --git a/Runtime/Moderations/ModerationsRequest.cs b/Runtime/Moderations/ModerationsRequest.cs index 5f2ac03c..198e1e7b 100644 --- a/Runtime/Moderations/ModerationsRequest.cs +++ b/Runtime/Moderations/ModerationsRequest.cs @@ -26,11 +26,11 @@ public sealed class ModerationsRequest public ModerationsRequest(string input, Model model = null) { Input = input; - Model = model ?? new Model("text-moderation-latest"); + Model = model ?? Models.Model.Moderation_Latest; if (!Model.Contains("text-moderation")) { - throw new ArgumentException(nameof(model), $"{Model} is not supported."); + throw new ArgumentException($"{Model} is not supported", nameof(model)); } } diff --git a/Runtime/OpenAIClient.cs b/Runtime/OpenAIClient.cs index 83066274..12b32cd0 100644 --- a/Runtime/OpenAIClient.cs +++ b/Runtime/OpenAIClient.cs @@ -11,6 +11,7 @@ using OpenAI.Images; using OpenAI.Models; using OpenAI.Moderations; +using System; using System.Net.Http; using System.Net.Http.Headers; using System.Security.Authentication; @@ -28,18 +29,27 @@ public sealed class OpenAIClient /// The /model to use for API calls, /// defaulting to if not specified. /// Raised when authentication details are missing or invalid. + [Obsolete("Will be removed in next major release.")] public OpenAIClient(Model model) : this(null, model) { } + /// + /// Creates a new entry point to the OpenAPI API, handling auth and allowing access to the various API endpoints + /// + /// + /// The /model to use for API calls, + /// defaulting to if not specified. + /// Raised when authentication details are missing or invalid. + [Obsolete("Will be removed in next major release.")] + public OpenAIClient(OpenAIAuthentication openAIAuthentication, Model model) : this(openAIAuthentication) { } + /// /// Creates a new entry point to the OpenAPI API, handling auth and allowing access to the various API endpoints /// /// The API authentication information to use for API calls, /// or to attempt to use the , /// potentially loading from environment vars or from a config file. - /// The to use for API calls, - /// defaulting to if not specified. /// Raised when authentication details are missing or invalid. - public OpenAIClient(OpenAIAuthentication openAIAuthentication = null, Model model = null) + public OpenAIClient(OpenAIAuthentication openAIAuthentication = null) { OpenAIAuthentication = openAIAuthentication ?? OpenAIAuthentication.Default; @@ -62,7 +72,6 @@ public OpenAIClient(OpenAIAuthentication openAIAuthentication = null, Model mode { DefaultValueHandling = DefaultValueHandling.Ignore }; - DefaultModel = model ?? Model.Default; ModelsEndpoint = new ModelsEndpoint(this); CompletionsEndpoint = new CompletionsEndpoint(this); ChatEndpoint = new ChatEndpoint(this); @@ -93,6 +102,7 @@ public OpenAIClient(OpenAIAuthentication openAIAuthentication = null, Model mode /// /// Specifies which to use for API calls /// + [Obsolete("Will be removed in next major release.")] public Model DefaultModel { get; set; } private int version; diff --git a/Runtime/ResponseExtensions.cs b/Runtime/ResponseExtensions.cs index 3fae1f88..3a2fd01b 100644 --- a/Runtime/ResponseExtensions.cs +++ b/Runtime/ResponseExtensions.cs @@ -1,5 +1,6 @@ // Licensed under the MIT License. See LICENSE in the project root for license information. +using Newtonsoft.Json; using System; using System.Linq; using System.Net.Http; @@ -49,5 +50,12 @@ internal static async Task CheckResponseAsync(this HttpResponseMessage response, throw new HttpRequestException($"{methodName} Failed! HTTP status code: {response.StatusCode} | Response body: {responseAsString}"); } } + + internal static T DeserializeResponse(this HttpResponseMessage response, string json, JsonSerializerSettings settings) where T : BaseResponse + { + var result = JsonConvert.DeserializeObject(json, settings); + result.SetResponseData(response.Headers); + return result; + } } } diff --git a/Tests/TestFixture_00_Authentication.cs b/Tests/TestFixture_00_Authentication.cs index 66bd3122..bb7a286c 100644 --- a/Tests/TestFixture_00_Authentication.cs +++ b/Tests/TestFixture_00_Authentication.cs @@ -59,11 +59,12 @@ public void Test_04_GetAuthFromConfiguration() var config = ScriptableObject.CreateInstance(); config.apiKey = "sk-test12"; config.organizationId = "org-testOrg"; - Debug.Log(Application.dataPath); + var cleanup = false; if (!Directory.Exists($"{Application.dataPath}/Resources")) { Directory.CreateDirectory($"{Application.dataPath}/Resources"); + cleanup = true; } AssetDatabase.CreateAsset(config, "Assets/Resources/OpenAIConfigurationSettings-Test.asset"); @@ -79,6 +80,11 @@ public void Test_04_GetAuthFromConfiguration() Assert.IsNotEmpty(auth.OrganizationId); Assert.AreEqual(auth.OrganizationId, config.OrganizationId); AssetDatabase.DeleteAsset(configPath); + + if (cleanup) + { + AssetDatabase.DeleteAsset("Assets/Resources"); + } } [Test] diff --git a/Tests/TestFixture_02_Completions.cs b/Tests/TestFixture_02_Completions.cs index f079611a..24429ee6 100644 --- a/Tests/TestFixture_02_Completions.cs +++ b/Tests/TestFixture_02_Completions.cs @@ -13,7 +13,7 @@ namespace OpenAI.Tests { internal class TestFixture_02_Completions { - private readonly string completionPrompts = "One Two Three Four Five Six Seven Eight Nine One Two Three Four Five Six Seven Eight"; + private const string CompletionPrompts = "One Two Three Four Five Six Seven Eight Nine One Two Three Four Five Six Seven Eight"; [UnityTest] public IEnumerator Test_1_GetBasicCompletion() @@ -23,7 +23,7 @@ public IEnumerator Test_1_GetBasicCompletion() var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv()); Assert.IsNotNull(api.CompletionsEndpoint); var result = await api.CompletionsEndpoint.CreateCompletionAsync( - completionPrompts, + CompletionPrompts, temperature: 0.1, maxTokens: 5, numOutputs: 5, @@ -51,14 +51,37 @@ await api.CompletionsEndpoint.StreamCompletionAsync(result => Assert.NotNull(result.Completions); Assert.NotZero(result.Completions.Count); allCompletions.AddRange(result.Completions); + }, CompletionPrompts, temperature: 0.1, maxTokens: 5, numOutputs: 5); - foreach (var choice in result.Completions) - { - Debug.Log(choice); - } - }, completionPrompts, temperature: 0.1, maxTokens: 5, numOutputs: 5); + Assert.That(allCompletions.Any(c => c.Text.Trim().ToLower().StartsWith("nine"))); + Debug.Log(allCompletions.FirstOrDefault()); + }); + } + + [UnityTest] + public IEnumerator Test_3_GetStreamingCompletionEnumerableAsync() + { + yield return AwaitTestUtilities.Await(async () => + { + var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv()); + Assert.IsNotNull(api.CompletionsEndpoint); + var allCompletions = new List(); + + await foreach (var result in api.CompletionsEndpoint.StreamCompletionEnumerableAsync( + CompletionPrompts, + temperature: 0.1, + maxTokens: 5, + numOutputs: 5, + model: Model.Davinci)) + { + Assert.IsNotNull(result); + Assert.NotNull(result.Completions); + Assert.NotZero(result.Completions.Count); + allCompletions.AddRange(result.Completions); + } Assert.That(allCompletions.Any(c => c.Text.Trim().ToLower().StartsWith("nine"))); + Debug.Log(allCompletions.FirstOrDefault()); }); } } diff --git a/Tests/TestFixture_03_Chat.cs b/Tests/TestFixture_03_Chat.cs index 7e02b37c..d722fca9 100644 --- a/Tests/TestFixture_03_Chat.cs +++ b/Tests/TestFixture_03_Chat.cs @@ -34,5 +34,63 @@ public IEnumerator Test_1_GetChatCompletion() Debug.Log(result.FirstChoice); }); } + + [UnityTest] + public IEnumerator Test_2_GetChatStreamingCompletion() + { + yield return AwaitTestUtilities.Await(async () => + { + var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv()); + Assert.IsNotNull(api.ChatEndpoint); + var chatPrompts = new List + { + new ChatPrompt("system", "You are a helpful assistant."), + new ChatPrompt("user", "Who won the world series in 2020?"), + new ChatPrompt("assistant", "The Los Angeles Dodgers won the World Series in 2020."), + new ChatPrompt("user", "Where was it played?"), + }; + var chatRequest = new ChatRequest(chatPrompts, Model.GPT3_5_Turbo); + var allContent = new List(); + + await api.ChatEndpoint.StreamCompletionAsync(chatRequest, result => + { + Assert.IsNotNull(result); + Assert.NotNull(result.Choices); + Assert.NotZero(result.Choices.Count); + allContent.Add(result.FirstChoice); + }); + + Debug.Log(string.Join("", allContent)); + }); + } + + [UnityTest] + public IEnumerator Test_3_GetChatStreamingCompletionEnumerableAsync() + { + yield return AwaitTestUtilities.Await(async () => + { + var api = new OpenAIClient(OpenAIAuthentication.LoadFromEnv()); + Assert.IsNotNull(api.ChatEndpoint); + var chatPrompts = new List + { + new ChatPrompt("system", "You are a helpful assistant."), + new ChatPrompt("user", "Who won the world series in 2020?"), + new ChatPrompt("assistant", "The Los Angeles Dodgers won the World Series in 2020."), + new ChatPrompt("user", "Where was it played?"), + }; + var chatRequest = new ChatRequest(chatPrompts, Model.GPT3_5_Turbo); + var allContent = new List(); + + await foreach (var result in api.ChatEndpoint.StreamCompletionEnumerableAsync(chatRequest)) + { + Assert.IsNotNull(result); + Assert.NotNull(result.Choices); + Assert.NotZero(result.Choices.Count); + allContent.Add(result.FirstChoice); + } + + Debug.Log(string.Join("", allContent)); + }); + } } } diff --git a/package.json b/package.json index 593fbcee..75281c48 100644 --- a/package.json +++ b/package.json @@ -3,7 +3,7 @@ "displayName": "OpenAI", "description": "A OpenAI package for the Unity Game Engine to use GPT-3 and Dall-E though their RESTful API (currently in beta).\n\nIndependently developed, this is not an official library and I am not affiliated with OpenAI.\n\nAn OpenAI API account is required.", "keywords": [], - "version": "3.1.0", + "version": "3.1.1", "unity": "2021.3", "documentationUrl": "https://github.com/RageAgainstThePixel/com.openai.unity#documentation", "changelogUrl": "https://github.com/RageAgainstThePixel/com.openai.unity/releases",