Skip to content

Commit

Permalink
Added embedding endpoint and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
metjuperry committed Feb 1, 2023
1 parent 2fd8b09 commit 864263d
Show file tree
Hide file tree
Showing 5 changed files with 293 additions and 27 deletions.
77 changes: 77 additions & 0 deletions OpenAI_API/Embedding/EmbeddingEndpoint.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Security.Authentication;
using System.Text;
using System.Threading.Tasks;

namespace OpenAI_API.Embedding
{
/// <summary>
/// OpenAI’s text embeddings measure the relatedness of text strings by generating an embedding, which is a vector (list) of floating point numbers. The distance between two vectors measures their relatedness. Small distances suggest high relatedness and large distances suggest low relatedness.
/// </summary>
public class EmbeddingEndpoint
{
OpenAIAPI Api;
/// <summary>
/// This allows you to send request to the recommended model without needing to specify. Every request uses the <see cref="Model.AdaTextEmbedding"/> model
/// </summary>
public EmbeddingRequest DefaultEmbeddingRequestArgs { get; set; } = new EmbeddingRequest() { Model = Model.AdaTextEmbedding };

/// <summary>
/// Constructor of the api endpoint. Rather than instantiating this yourself, access it through an instance of <see cref="OpenAIAPI"/> as <see cref="OpenAIAPI.Embeddings"/>.
/// </summary>
/// <param name="api"></param>
internal EmbeddingEndpoint(OpenAIAPI api)
{
this.Api = api;
}

/// <summary>
/// Ask the API to embedd text using the default embedding model <see cref="Model.AdaTextEmbedding"/>
/// </summary>
/// <param name="input">Text to be embedded</param>
/// <returns>Asynchronously returns the embedding result. Look in its <see cref="Data.Embedding"/> property of <see cref="EmbeddingResult.Data"/> to find the vector of floating point numbers</returns>
public async Task<EmbeddingResult> CreateEmbeddingAsync(string input)
{
DefaultEmbeddingRequestArgs.Input = input;
return await CreateEmbeddingAsync(DefaultEmbeddingRequestArgs);
}

/// <summary>
/// Ask the API to embedd text using a custom request
/// </summary>
/// <param name="request">Request to be send</param>
/// <returns>Asynchronously returns the embedding result. Look in its <see cref="Data.Embedding"/> property of <see cref="EmbeddingResult.Data"/> to find the vector of floating point numbers</returns>
public async Task<EmbeddingResult> CreateEmbeddingAsync(EmbeddingRequest request)
{
if (Api.Auth?.ApiKey is null)
{
throw new AuthenticationException("You must provide API authentication. Please refer to https://github.com/OkGoDoIt/OpenAI-API-dotnet#authentication for details.");
}

HttpClient client = new HttpClient();
client.DefaultRequestHeaders.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", Api.Auth.ApiKey);
client.DefaultRequestHeaders.Add("User-Agent", "okgodoit/dotnet_openai_api");

string jsonContent = JsonConvert.SerializeObject(request, new JsonSerializerSettings() { NullValueHandling = NullValueHandling.Ignore });
var stringContent = new StringContent(jsonContent, Encoding.UTF8, "application/json");

var response = await client.PostAsync($"https://api.openai.com/v1/embeddings", stringContent);
if (response.IsSuccessStatusCode)
{
string resultAsString = await response.Content.ReadAsStringAsync();

var res = JsonConvert.DeserializeObject<EmbeddingResult>(resultAsString);

return res;
}
else
{
throw new HttpRequestException("Error calling OpenAi API to get completion. HTTP status code: " + response.StatusCode.ToString() + ". Request body: " + jsonContent);
}
}

}
}
50 changes: 50 additions & 0 deletions OpenAI_API/Embedding/EmbeddingRequest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace OpenAI_API.Embedding
{
/// <summary>
/// Represents a request to the Completions API. Matches with the docs at <see href="https://platform.openai.com/docs/api-reference/embeddings">the OpenAI docs</see>
/// </summary>
public class EmbeddingRequest
{
/// <summary>
/// ID of the model to use. You can use <see cref="ModelsEndpoint.GetModelsAsync()"/> to see all of your available models, or use a standard model like <see cref="Model.AdaTextEmbedding"/>.
/// </summary>
[JsonIgnore]
public Model Model { get; set; }

/// <summary>
/// The id/name of the model
/// </summary>
[JsonProperty("model")]
public string ModelName => Model.ModelID;

/// <summary>
/// Main text to be embedded
/// </summary>
[JsonProperty("input")]
public string Input { get; set; }

/// <summary>
/// Cretes a new, empty <see cref="EmbeddingRequest"/>
/// </summary>
public EmbeddingRequest()
{

}

/// <summary>
/// Creates a new <see cref="EmbeddingRequest"/> with the specified parameters
/// </summary>
/// <param name="model">The model to use. You can use <see cref="ModelsEndpoint.GetModelsAsync()"/> to see all of your available models, or use a standard model like <see cref="Model.AdaTextEmbedding"/>.</param>
/// <param name="input">The prompt to transform</param>
public EmbeddingRequest(Model model, string input)
{
Model = model;
this.Input = input;
}
}
}
84 changes: 84 additions & 0 deletions OpenAI_API/Embedding/EmbeddingResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace OpenAI_API.Embedding
{
/// <summary>
/// Represents an embedding result returned by the Embedding API.
/// </summary>
public class EmbeddingResult
{
/// <summary>
/// Type of the response. In case of embeddings, this will be "list"
/// </summary>
[JsonProperty("object")]

public string Object { get; set; }

/// <summary>
/// List of results of the embedding
/// </summary>
[JsonProperty("data")]
public Data[] Data { get; set; }

/// <summary>
/// Name of the model used to generate this embedding
/// </summary>
[JsonProperty("model")]
public string Model { get; set; }

/// <summary>
/// Usage statistics of how many tokens have been used for this request
/// </summary>
[JsonProperty("usage")]
public Usage Usage { get; set; }
}

/// <summary>
/// Data returned from the Embedding API.
/// </summary>
public class Data
{
/// <summary>
/// Type of the response. In case of Data, this will be "embedding"
/// </summary>
[JsonProperty("object")]

public string Object { get; set; }

/// <summary>
/// The input text represented as a vector (list) of floating point numbers
/// </summary>
[JsonProperty("embedding")]
public float[] Embedding { get; set; }

/// <summary>
/// Index
/// </summary>
[JsonProperty("index")]
public int Index { get; set; }

}

/// <summary>
/// Usage statistics of how many tokens have been used for this request.
/// </summary>
public class Usage
{
/// <summary>
/// How many tokens did the prompt consist of
/// </summary>
[JsonProperty("prompt_tokens")]
public int PromptTokens { get; set; }

/// <summary>
/// How many tokens did the request consume total
/// </summary>
[JsonProperty("total_tokens")]
public int TotalTokens { get; set; }

}

}
61 changes: 34 additions & 27 deletions OpenAI_API/OpenAIAPI.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Newtonsoft.Json;
using OpenAI_API.Embedding;
using System;
using System.Collections.Generic;
using System.IO;
Expand All @@ -9,32 +10,38 @@

namespace OpenAI_API
{
/// <summary>
/// Entry point to the OpenAPI API, handling auth and allowing access to the various API endpoints
/// </summary>
public class OpenAIAPI
{
/// <summary>
/// The API authentication information to use for API calls
/// </summary>
public APIAuthentication Auth { get; set; }

/// <summary>
/// Creates a new entry point to the OpenAPI API, handling auth and allowing access to the various API endpoints
/// </summary>
/// <param name="apiKeys">The API authentication information to use for API calls, or <see langword="null"/> to attempt to use the <see cref="APIAuthentication.Default"/>, potentially loading from environment vars or from a config file.</param>
public OpenAIAPI(APIAuthentication apiKeys = null)
{
this.Auth = apiKeys.ThisOrDefault();
Completions = new CompletionEndpoint(this);
Models = new ModelsEndpoint(this);
Search = new SearchEndpoint(this);
}

/// <summary>
/// Text generation is the core function of the API. You give the API a prompt, and it generates a completion. The way you “program” the API to do a task is by simply describing the task in plain english or providing a few written examples. This simple approach works for a wide range of use cases, including summarization, translation, grammar correction, question answering, chatbots, composing emails, and much more (see the prompt library for inspiration).
/// </summary>
public CompletionEndpoint Completions { get; }
/// <summary>
/// Entry point to the OpenAPI API, handling auth and allowing access to the various API endpoints
/// </summary>
public class OpenAIAPI
{
/// <summary>
/// The API authentication information to use for API calls
/// </summary>
public APIAuthentication Auth { get; set; }

/// <summary>
/// Creates a new entry point to the OpenAPI API, handling auth and allowing access to the various API endpoints
/// </summary>
/// <param name="apiKeys">The API authentication information to use for API calls, or <see langword="null"/> to attempt to use the <see cref="APIAuthentication.Default"/>, potentially loading from environment vars or from a config file.</param>
public OpenAIAPI(APIAuthentication apiKeys = null)
{
this.Auth = apiKeys.ThisOrDefault();
Completions = new CompletionEndpoint(this);
Models = new ModelsEndpoint(this);
Search = new SearchEndpoint(this);
Embeddings = new EmbeddingEndpoint(this);
}

/// <summary>
/// Text generation is the core function of the API. You give the API a prompt, and it generates a completion. The way you “program” the API to do a task is by simply describing the task in plain english or providing a few written examples. This simple approach works for a wide range of use cases, including summarization, translation, grammar correction, question answering, chatbots, composing emails, and much more (see the prompt library for inspiration).
/// </summary>
public CompletionEndpoint Completions { get; }

/// <summary>
/// The API lets you transform text into a vector (list) of floating point numbers. The distance between two vectors measures their relatedness. Small distances suggest high relatedness and large distances suggest low relatedness.
/// </summary>
public EmbeddingEndpoint Embeddings { get; }

/// <summary>
/// The API endpoint for querying available Engines/models
Expand All @@ -51,5 +58,5 @@ public OpenAIAPI(APIAuthentication apiKeys = null)



}
}
}
48 changes: 48 additions & 0 deletions OpenAI_Tests/EmbeddingEndpointTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using NUnit.Framework;
using OpenAI_API;
using OpenAI_API.Embedding;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace OpenAI_Tests
{
public class EmbeddingEndpointTests
{
[SetUp]
public void Setup()
{
OpenAI_API.APIAuthentication.Default = new OpenAI_API.APIAuthentication(Environment.GetEnvironmentVariable("TEST_OPENAI_SECRET_KEY"));
}

[Test]
public void GetBasicEmbedding()
{
var api = new OpenAI_API.OpenAIAPI();

Assert.IsNotNull(api.Embeddings);

var results = api.Embeddings.CreateEmbeddingAsync(new EmbeddingRequest(Model.AdaTextEmbedding, "A test text for embedding")).Result;
Assert.IsNotNull(results);
Assert.NotNull(results.Object);
Assert.NotZero(results.Data.Length);
Assert.That(results.Data.First().Embedding.Length == 1536);
}


[Test]
public void GetSimpleEmbedding()
{
var api = new OpenAI_API.OpenAIAPI();

Assert.IsNotNull(api.Embeddings);

var results = api.Embeddings.CreateEmbeddingAsync("A test text for embedding").Result;
Assert.IsNotNull(results);
Assert.NotNull(results.Object);
Assert.NotZero(results.Data.Length);
Assert.That(results.Data.First().Embedding.Length == 1536);
}
}
}

0 comments on commit 864263d

Please sign in to comment.