-
Notifications
You must be signed in to change notification settings - Fork 433
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #56 from megalon/feature/chat
Add Chat endpoint, thanks @megalon!
- Loading branch information
Showing
5 changed files
with
480 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
using OpenAI_API.Models; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Net.Http; | ||
using System.Text; | ||
using System.Threading.Tasks; | ||
|
||
namespace OpenAI_API.Chat | ||
{ | ||
/// <summary> | ||
/// ChatGPT API endpoint. Use this endpoint to send multiple messages and carry on a conversation. | ||
/// </summary> | ||
public class ChatEndpoint : EndpointBase | ||
{ | ||
/// <summary> | ||
/// This allows you to set default parameters for every request, for example to set a default temperature or max tokens. For every request, if you do not have a parameter set on the request but do have it set here as a default, the request will automatically pick up the default value. | ||
/// </summary> | ||
public ChatRequest DefaultChatRequestArgs { get; set; } = new ChatRequest() { Model = Model.ChatGPTTurbo }; | ||
|
||
/// <summary> | ||
/// The name of the endpoint, which is the final path segment in the API URL. For example, "completions". | ||
/// </summary> | ||
protected override string Endpoint { get { return "chat/completions"; } } | ||
|
||
/// <summary> | ||
/// Constructor of the api endpoint. Rather than instantiating this yourself, access it through an instance of <see cref="OpenAIAPI"/> as <see cref="OpenAIAPI.Completions"/>. | ||
/// </summary> | ||
/// <param name="api"></param> | ||
internal ChatEndpoint(OpenAIAPI api) : base(api) { } | ||
|
||
#region Non-streaming | ||
|
||
/// <summary> | ||
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present. | ||
/// </summary> | ||
/// <param name="request">The request to send to the API.</param> | ||
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns> | ||
public async Task<ChatResult> CreateChatAsync(ChatRequest request) | ||
{ | ||
return await HttpPost<ChatResult>(postData: request); | ||
} | ||
|
||
/// <summary> | ||
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present. | ||
/// </summary> | ||
/// <param name="request">The request to send to the API.</param> | ||
/// <param name="numOutputs">Overrides <see cref="ChatRequest.NumChoicesPerMessage"/> as a convenience.</param> | ||
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns> | ||
public Task<ChatResult> CreateChatAsync(ChatRequest request, int numOutputs = 5) | ||
{ | ||
request.NumChoicesPerMessage = numOutputs; | ||
return CreateChatAsync(request); | ||
} | ||
|
||
/// <summary> | ||
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present. | ||
/// </summary> | ||
/// <param name="messages">The array of messages to send to the API</param> | ||
/// <param name="model">The model to use. See the ChatGPT models available from <see cref="ModelsEndpoint.GetModelsAsync()"/></param> | ||
/// <param name="temperature">What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. It is generally recommend to use this or <paramref name="top_p"/> but not both.</param> | ||
/// <param name="top_p">An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. It is generally recommend to use this or <paramref name="temperature"/> but not both.</param> | ||
/// <param name="numOutputs">How many different choices to request for each prompt.</param> | ||
/// <param name="max_tokens">How many tokens to complete to. Can return fewer if a stop sequence is hit.</param> | ||
/// <param name="frequencyPenalty">The scale of the penalty for how often a token is used. Should generally be between 0 and 1, although negative numbers are allowed to encourage token reuse.</param> | ||
/// <param name="presencePenalty">The scale of the penalty applied if a token is already present at all. Should generally be between 0 and 1, although negative numbers are allowed to encourage token reuse.</param> | ||
/// <param name="logitBias">Maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.</param> | ||
/// <param name="stopSequences">One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.</param> | ||
/// <returns>Asynchronously returns the completion result. Look in its <see cref="ChatResult.Choices"/> property for the results.</returns> | ||
public Task<ChatResult> CreateChatAsync(IEnumerable<ChatMessage> messages, | ||
Model model = null, | ||
double? temperature = null, | ||
double? top_p = null, | ||
int? numOutputs = null, | ||
int? max_tokens = null, | ||
double? frequencyPenalty = null, | ||
double? presencePenalty = null, | ||
IReadOnlyDictionary<string, float> logitBias = null, | ||
params string[] stopSequences) | ||
{ | ||
ChatRequest request = new ChatRequest(DefaultChatRequestArgs) | ||
{ | ||
Messages = messages, | ||
Model = model ?? DefaultChatRequestArgs.Model, | ||
Temperature = temperature ?? DefaultChatRequestArgs.Temperature, | ||
TopP = top_p ?? DefaultChatRequestArgs.TopP, | ||
NumChoicesPerMessage = numOutputs ?? DefaultChatRequestArgs.NumChoicesPerMessage, | ||
MultipleStopSequences = stopSequences ?? DefaultChatRequestArgs.MultipleStopSequences, | ||
MaxTokens = max_tokens ?? DefaultChatRequestArgs.MaxTokens, | ||
FrequencyPenalty = frequencyPenalty ?? DefaultChatRequestArgs.FrequencyPenalty, | ||
PresencePenalty = presencePenalty ?? DefaultChatRequestArgs.PresencePenalty, | ||
LogitBias = logitBias ?? DefaultChatRequestArgs.LogitBias | ||
}; | ||
return CreateChatAsync(request); | ||
} | ||
|
||
/// <summary> | ||
/// Ask the API to complete the request using the specified parameters. This is non-streaming, so it will wait until the API returns the full result. Any non-specified parameters will fall back to default values specified in <see cref="DefaultChatRequestArgs"/> if present. | ||
/// </summary> | ||
/// <param name="messages">The messages to use in the generation.</param> | ||
/// <returns></returns> | ||
public Task<ChatResult> CreateChatAsync(IEnumerable<ChatMessage> messages) | ||
{ | ||
ChatRequest request = new ChatRequest(DefaultChatRequestArgs) | ||
{ | ||
Messages = messages | ||
}; | ||
return CreateChatAsync(request); | ||
} | ||
|
||
#endregion | ||
|
||
#region Streaming | ||
|
||
/// <summary> | ||
/// Ask the API to complete the message(s) using the specified request, and stream the results to the <paramref name="resultHandler"/> as they come in. | ||
/// If you are on the latest C# supporting async enumerables, you may prefer the cleaner syntax of <see cref="StreamChatEnumerableAsync(ChatRequest)"/> instead. | ||
/// </summary> | ||
/// <param name="request">The request to send to the API. This does not fall back to default values specified in <see cref="DefaultChatRequestArgs"/>.</param> | ||
/// <param name="resultHandler">An action to be called as each new result arrives, which includes the index of the result in the overall result set.</param> | ||
public async Task StreamCompletionAsync(ChatRequest request, Action<int, ChatResult> resultHandler) | ||
{ | ||
int index = 0; | ||
|
||
await foreach (var res in StreamChatEnumerableAsync(request)) | ||
{ | ||
resultHandler(index++, res); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Ask the API to complete the message(s) using the specified request, and stream the results to the <paramref name="resultHandler"/> as they come in. | ||
/// If you are on the latest C# supporting async enumerables, you may prefer the cleaner syntax of <see cref="StreamChatEnumerableAsync(ChatRequest)"/> instead. | ||
/// </summary> | ||
/// <param name="request">The request to send to the API. This does not fall back to default values specified in <see cref="DefaultChatRequestArgs"/>.</param> | ||
/// <param name="resultHandler">An action to be called as each new result arrives.</param> | ||
public async Task StreamChatAsync(ChatRequest request, Action<ChatResult> resultHandler) | ||
{ | ||
await foreach (var res in StreamChatEnumerableAsync(request)) | ||
{ | ||
resultHandler(res); | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Ask the API to complete the message(s) using the specified request, and stream the results as they come in. | ||
/// If you are not using C# 8 supporting async enumerables or if you are using the .NET Framework, you may need to use <see cref="StreamChatAsync(ChatRequest, Action{ChatResult})"/> instead. | ||
/// </summary> | ||
/// <param name="request">The request to send to the API. This does not fall back to default values specified in <see cref="DefaultChatRequestArgs"/>.</param> | ||
/// <returns>An async enumerable with each of the results as they come in. See <see href="https://docs.microsoft.com/en-us/dotnet/csharp/whats-new/csharp-8#asynchronous-streams"/> for more details on how to consume an async enumerable.</returns> | ||
public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(ChatRequest request) | ||
{ | ||
request = new ChatRequest(request) { Stream = true }; | ||
return HttpStreamingRequest<ChatResult>(Url, HttpMethod.Post, request); | ||
} | ||
|
||
/// <summary> | ||
/// Ask the API to complete the message(s) using the specified request, and stream the results as they come in. | ||
/// If you are not using C# 8 supporting async enumerables or if you are using the .NET Framework, you may need to use <see cref="StreamChatAsync(ChatRequest, Action{ChatResult})"/> instead. | ||
/// </summary> | ||
/// <param name="messages">The array of messages to send to the API</param> | ||
/// <param name="model">The model to use. See the ChatGPT models available from <see cref="ModelsEndpoint.GetModelsAsync()"/></param> | ||
/// <param name="temperature">What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. It is generally recommend to use this or <paramref name="top_p"/> but not both.</param> | ||
/// <param name="top_p">An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. It is generally recommend to use this or <paramref name="temperature"/> but not both.</param> | ||
/// <param name="numOutputs">How many different choices to request for each prompt.</param> | ||
/// <param name="max_tokens">How many tokens to complete to. Can return fewer if a stop sequence is hit.</param> | ||
/// <param name="frequencyPenalty">The scale of the penalty for how often a token is used. Should generally be between 0 and 1, although negative numbers are allowed to encourage token reuse.</param> | ||
/// <param name="presencePenalty">The scale of the penalty applied if a token is already present at all. Should generally be between 0 and 1, although negative numbers are allowed to encourage token reuse.</param> | ||
/// <param name="logitBias">Maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.</param> | ||
/// <param name="stopSequences">One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.</param> | ||
/// <returns>An async enumerable with each of the results as they come in. See <see href="https://docs.microsoft.com/en-us/dotnet/csharp/whats-new/csharp-8#asynchronous-streams">the C# docs</see> for more details on how to consume an async enumerable.</returns> | ||
public IAsyncEnumerable<ChatResult> StreamChatEnumerableAsync(IEnumerable<ChatMessage> messages, | ||
Model model = null, | ||
double? temperature = null, | ||
double? top_p = null, | ||
int? numOutputs = null, | ||
int? max_tokens = null, | ||
double? frequencyPenalty = null, | ||
double? presencePenalty = null, | ||
IReadOnlyDictionary<string, float> logitBias = null, | ||
params string[] stopSequences) | ||
{ | ||
ChatRequest request = new ChatRequest(DefaultChatRequestArgs) | ||
{ | ||
Messages = messages, | ||
Model = model ?? DefaultChatRequestArgs.Model, | ||
Temperature = temperature ?? DefaultChatRequestArgs.Temperature, | ||
TopP = top_p ?? DefaultChatRequestArgs.TopP, | ||
NumChoicesPerMessage = numOutputs ?? DefaultChatRequestArgs.NumChoicesPerMessage, | ||
MultipleStopSequences = stopSequences ?? DefaultChatRequestArgs.MultipleStopSequences, | ||
MaxTokens = max_tokens ?? DefaultChatRequestArgs.MaxTokens, | ||
FrequencyPenalty = frequencyPenalty ?? DefaultChatRequestArgs.FrequencyPenalty, | ||
PresencePenalty = presencePenalty ?? DefaultChatRequestArgs.PresencePenalty, | ||
LogitBias = logitBias ?? DefaultChatRequestArgs.LogitBias | ||
}; | ||
return StreamChatEnumerableAsync(request); | ||
} | ||
#endregion | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
using Newtonsoft.Json; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
using System.Text; | ||
|
||
namespace OpenAI_API.Chat | ||
{ | ||
/// <summary> | ||
/// A request to the Chat API. This is similar, but not exactly the same as the <see cref="Completions.CompletionRequest"/> | ||
/// Based on the <see href="https://platform.openai.com/docs/api-reference/chat">OpenAI API docs</see> | ||
/// </summary> | ||
public class ChatRequest | ||
{ | ||
/// <summary> | ||
/// The model to use for this request | ||
/// </summary> | ||
[JsonProperty("model")] | ||
public string Model { get; set; } = OpenAI_API.Models.Model.ChatGPTTurbo; | ||
|
||
/// <summary> | ||
/// The messages to send with this Chat Request | ||
/// </summary> | ||
[JsonProperty("messages")] | ||
public IEnumerable<ChatMessage> Messages { get; set; } | ||
|
||
/// <summary> | ||
/// What sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. It is generally recommend to use this or <see cref="TopP"/> but not both. | ||
/// </summary> | ||
[JsonProperty("temperature")] | ||
public double? Temperature { get; set; } | ||
|
||
/// <summary> | ||
/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. It is generally recommend to use this or <see cref="Temperature"/> but not both. | ||
/// </summary> | ||
[JsonProperty("top_p")] | ||
public double? TopP { get; set; } | ||
|
||
/// <summary> | ||
/// How many different choices to request for each message. Defaults to 1. | ||
/// </summary> | ||
[JsonProperty("n")] | ||
public int? NumChoicesPerMessage { get; set; } | ||
|
||
/// <summary> | ||
/// Specifies where the results should stream and be returned at one time. Do not set this yourself, use the appropriate methods on <see cref="CompletionEndpoint"/> instead. | ||
/// </summary> | ||
[JsonProperty("stream")] | ||
public bool Stream { get; internal set; } = false; | ||
|
||
/// <summary> | ||
/// This is only used for serializing the request into JSON, do not use it directly. | ||
/// </summary> | ||
[JsonProperty("stop")] | ||
public object CompiledStop | ||
{ | ||
get | ||
{ | ||
if (MultipleStopSequences?.Length == 1) | ||
return StopSequence; | ||
else if (MultipleStopSequences?.Length > 0) | ||
return MultipleStopSequences; | ||
else | ||
return null; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// One or more sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. | ||
/// </summary> | ||
[JsonIgnore] | ||
public string[] MultipleStopSequences { get; set; } | ||
|
||
/// <summary> | ||
/// The stop sequence where the API will stop generating further tokens. The returned text will not contain the stop sequence. For convenience, if you are only requesting a single stop sequence, set it here | ||
/// </summary> | ||
[JsonIgnore] | ||
public string StopSequence | ||
{ | ||
get => MultipleStopSequences?.FirstOrDefault() ?? null; | ||
set | ||
{ | ||
if (value != null) | ||
MultipleStopSequences = new string[] { value }; | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// How many tokens to complete to. Can return fewer if a stop sequence is hit. Defaults to 16. | ||
/// </summary> | ||
[JsonProperty("max_tokens")] | ||
public int? MaxTokens { get; set; } | ||
|
||
/// <summary> | ||
/// The scale of the penalty for how often a token is used. Should generally be between 0 and 1, although negative numbers are allowed to encourage token reuse. Defaults to 0. | ||
/// </summary> | ||
[JsonProperty("frequency_penalty")] | ||
public double? FrequencyPenalty { get; set; } | ||
|
||
|
||
/// <summary> | ||
/// The scale of the penalty applied if a token is already present at all. Should generally be between 0 and 1, although negative numbers are allowed to encourage token reuse. Defaults to 0. | ||
/// </summary> | ||
[JsonProperty("presence_penalty")] | ||
public double? PresencePenalty { get; set; } | ||
|
||
/// <summary> | ||
/// Modify the likelihood of specified tokens appearing in the completion. | ||
/// Accepts a json object that maps tokens(specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. | ||
/// Mathematically, the bias is added to the logits generated by the model prior to sampling. | ||
/// The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. | ||
/// </summary> | ||
[JsonProperty("logit_bias")] | ||
public IReadOnlyDictionary<string, float> LogitBias { get; set; } | ||
|
||
/// <summary> | ||
/// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. | ||
/// </summary> | ||
[JsonProperty("user")] | ||
public string user { get; set; } | ||
|
||
/// <summary> | ||
/// Creates a new, empty <see cref="ChatRequest"/> | ||
/// </summary> | ||
public ChatRequest() | ||
{ | ||
this.Model = OpenAI_API.Models.Model.ChatGPTTurbo; | ||
} | ||
|
||
/// <summary> | ||
/// Create a new chat request using the data from the input chat request. | ||
/// </summary> | ||
/// <param name="basedOn"></param> | ||
public ChatRequest(ChatRequest basedOn) | ||
{ | ||
this.Model = basedOn.Model; | ||
this.Messages = basedOn.Messages; | ||
this.Temperature = basedOn.Temperature; | ||
this.TopP = basedOn.TopP; | ||
this.NumChoicesPerMessage = basedOn.NumChoicesPerMessage; | ||
this.MultipleStopSequences = basedOn.MultipleStopSequences; | ||
this.MaxTokens = basedOn.MaxTokens; | ||
this.FrequencyPenalty = basedOn.FrequencyPenalty; | ||
this.PresencePenalty = basedOn.PresencePenalty; | ||
this.LogitBias = basedOn.LogitBias; | ||
} | ||
} | ||
} |
Oops, something went wrong.