Skip to content

Commit

Permalink
deploymentId moved to be part of GetCompletions signature (Azure#…
Browse files Browse the repository at this point in the history
…33834)

* Removed custom constructors and reinstated deploymentId as part of GetCompletions's signature

* Small corrections to the documentation

* CodeCheck, Export-API, Update-Snippets run
  • Loading branch information
jpalvarezl authored and richardcho-msft committed Feb 6, 2023
1 parent 1e8866c commit e264784
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 58 deletions.
8 changes: 4 additions & 4 deletions sdk/openai/Azure.AI.OpenAI/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ dotnet add package Azure.Identity

```C# Snippet:CreateOpenAIClientTokenCredential
string endpoint = "https://myaccount.openai.azure.com/";
OpenAIClient client = new OpenAIClient(new Uri(endpoint), "myDeploymentId", new DefaultAzureCredential());
OpenAIClient client = new OpenAIClient(new Uri(endpoint), new DefaultAzureCredential());
```

## Key concepts
Expand Down Expand Up @@ -102,12 +102,12 @@ The `GenerateChatbotResponse` method authenticates using a DefaultAzureCredentia

```C# Snippet:GenerateChatbotResponse
string endpoint = "https://myaccount.openai.azure.com/";
OpenAIClient client = new OpenAIClient(new Uri(endpoint), "myDeploymentId", new DefaultAzureCredential());
OpenAIClient client = new OpenAIClient(new Uri(endpoint), new DefaultAzureCredential());

string prompt = "What is Azure OpenAI?";
Console.Write($"Input: {prompt}");

Response<Completions> completionsResponse = client.GetCompletions(prompt);
Response<Completions> completionsResponse = client.GetCompletions("myDeploymentId", prompt);
string completion = completionsResponse.Value.Choices[0].Text;
Console.WriteLine($"Chatbot: {completion}");
```
Expand Down Expand Up @@ -148,7 +148,7 @@ The `SummarizeText` method generates a summarization of the given input prompt.

```C# Snippet:SummarizeText
string endpoint = "https://myaccount.openai.azure.com/";
OpenAIClient client = new OpenAIClient(new Uri(endpoint), "myDeploymentId", new DefaultAzureCredential());
OpenAIClient client = new OpenAIClient(new Uri(endpoint), new DefaultAzureCredential());

string textToSummarize = @"
Two independent experiments reported their results this morning at CERN, Europe's high-energy physics laboratory near Geneva in Switzerland. Both show convincing evidence of a new boson particle weighing around 125 gigaelectronvolts, which so far fits predictions of the Higgs previously made by theoretical physicists.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ public OpenAIClient(System.Uri endpoint, Azure.AzureKeyCredential credential) {
public OpenAIClient(System.Uri endpoint, Azure.AzureKeyCredential credential, Azure.AI.OpenAI.OpenAIClientOptions options) { }
public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential credential) { }
public OpenAIClient(System.Uri endpoint, Azure.Core.TokenCredential credential, Azure.AI.OpenAI.OpenAIClientOptions options) { }
public OpenAIClient(System.Uri endpoint, string deploymentId, Azure.Core.TokenCredential credential) { }
public OpenAIClient(System.Uri endpoint, string deploymentId, Azure.Core.TokenCredential credential, Azure.AI.OpenAI.OpenAIClientOptions options) { }
public virtual Azure.Core.Pipeline.HttpPipeline Pipeline { get { throw null; } }
public virtual Azure.Response<Azure.AI.OpenAI.Models.Completions> GetCompletions(string deploymentId, Azure.AI.OpenAI.Models.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetCompletions(string deploymentId, Azure.Core.RequestContent content, Azure.RequestContext context = null) { throw null; }
public virtual Azure.Response<Azure.AI.OpenAI.Models.Completions> GetCompletions(string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response<Azure.AI.OpenAI.Models.Completions> GetCompletions(string deploymentId, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Response<Azure.AI.OpenAI.Models.Completions>> GetCompletionsAsync(string deploymentId, Azure.AI.OpenAI.Models.CompletionsOptions completionsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Response> GetCompletionsAsync(string deploymentId, Azure.Core.RequestContent content, Azure.RequestContext context = null) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Response<Azure.AI.OpenAI.Models.Completions>> GetCompletionsAsync(string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Response<Azure.AI.OpenAI.Models.Completions>> GetCompletionsAsync(string deploymentId, string prompt, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response<Azure.AI.OpenAI.Models.Embeddings> GetEmbeddings(string deploymentId, Azure.AI.OpenAI.Models.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public virtual Azure.Response GetEmbeddings(string deploymentId, Azure.Core.RequestContent content, Azure.RequestContext context = null) { throw null; }
public virtual System.Threading.Tasks.Task<Azure.Response<Azure.AI.OpenAI.Models.Embeddings>> GetEmbeddingsAsync(string deploymentId, Azure.AI.OpenAI.Models.EmbeddingsOptions embeddingsOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
Expand Down
53 changes: 8 additions & 45 deletions sdk/openai/Azure.AI.OpenAI/src/Custom/OpenAIClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,69 +16,32 @@ namespace Azure.AI.OpenAI
/// <summary> Azure OpenAI APIs for completions and search. </summary>
public partial class OpenAIClient
{
private readonly string _completionsDeploymentId;
private readonly string _embeddingsDeploymentId;

/// <summary> Initializes a new instance of OpenAIClient. </summary>
/// <param name="endpoint">
/// Supported Cognitive Services endpoints (protocol and hostname, for example:
/// https://westus.api.cognitive.microsoft.com).
/// </param>
/// <param name="deploymentId"> default deployment id to use for operations </param>
/// <param name="credential"> A credential used to authenticate to an Azure Service. </param>
/// <exception cref="ArgumentNullException"> <paramref name="endpoint"/> or <paramref name="credential"/> is null. </exception>
public OpenAIClient(Uri endpoint, string deploymentId, TokenCredential credential) : this(endpoint, deploymentId, credential, new OpenAIClientOptions())
{
}

/// <summary> Initializes a new instance of OpenAIClient. </summary>
/// <param name="endpoint">
/// Supported Cognitive Services endpoints (protocol and hostname, for example:
/// https://westus.api.cognitive.microsoft.com).
/// </param>
/// <param name="deploymentId"> default deployment id to use for operations </param>
/// <param name="credential"> A credential used to authenticate to an Azure Service. </param>
/// <param name="options"> The options for configuring the client. </param>
/// <exception cref="ArgumentNullException"> <paramref name="endpoint"/> or <paramref name="credential"/> is null. </exception>
public OpenAIClient(Uri endpoint, string deploymentId, TokenCredential credential, OpenAIClientOptions options)
{
Argument.AssertNotNull(endpoint, nameof(endpoint));
Argument.AssertNotNull(credential, nameof(credential));
options ??= new OpenAIClientOptions();

ClientDiagnostics = new ClientDiagnostics(options, true);
_tokenCredential = credential;
_pipeline = HttpPipelineBuilder.Build(options, Array.Empty<HttpPipelinePolicy>(), new HttpPipelinePolicy[] { new BearerTokenAuthenticationPolicy(_tokenCredential, AuthorizationScopes) }, new ResponseClassifier());
_endpoint = endpoint;
_apiVersion = options.Version;
_completionsDeploymentId ??= deploymentId;
_embeddingsDeploymentId ??= deploymentId;
}

/// <summary> Return the completion for a given prompt. </summary>
/// <param name="deploymentId"> Deployment id (also known as model name) to use for operations </param>
/// <param name="prompt"> Input string prompt to create a prompt completion from a deployment. </param>
/// <param name="cancellationToken"> The cancellation token to use. </param>
public virtual async Task<Response<Completions>> GetCompletionsAsync(string prompt, CancellationToken cancellationToken = default)
public virtual async Task<Response<Completions>> GetCompletionsAsync(string deploymentId, string prompt, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(_completionsDeploymentId, nameof(_completionsDeploymentId));
Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
Argument.AssertNotNullOrEmpty(prompt, nameof(prompt));

CompletionsOptions completionsOptions = new CompletionsOptions();
completionsOptions.Prompt.Add(prompt);
return await GetCompletionsAsync(_completionsDeploymentId, completionsOptions, cancellationToken).ConfigureAwait(false);
return await GetCompletionsAsync(deploymentId, completionsOptions, cancellationToken).ConfigureAwait(false);
}

/// <summary> Return the completions for a given prompt. </summary>
/// <param name="deploymentId"> Deployment id (also known as model name) to use for operations </param>
/// <param name="prompt"> Input string prompt to create a prompt completion from a deployment. </param>
/// <param name="cancellationToken"> The cancellation token to use. </param>
public virtual Response<Completions> GetCompletions(string prompt, CancellationToken cancellationToken = default)
public virtual Response<Completions> GetCompletions(string deploymentId, string prompt, CancellationToken cancellationToken = default)
{
Argument.AssertNotNullOrEmpty(_completionsDeploymentId, nameof(_completionsDeploymentId));
Argument.AssertNotNullOrEmpty(deploymentId, nameof(deploymentId));
Argument.AssertNotNullOrEmpty(prompt, nameof(prompt));

CompletionsOptions completionsOptions = new CompletionsOptions();
completionsOptions.Prompt.Add(prompt);
return GetCompletions(_completionsDeploymentId, completionsOptions, cancellationToken);
return GetCompletions(deploymentId, completionsOptions, cancellationToken);
}
}
}
2 changes: 1 addition & 1 deletion sdk/openai/Azure.AI.OpenAI/tests/OpenAIInferenceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public async Task CompletionTest()
public async Task SimpleCompletionTest()
{
var client = GetClientWithCompletionsDeploymentId();
var response = await client.GetCompletionsAsync("Hello World!");
var response = await client.GetCompletionsAsync(DeploymentId, "Hello World!");
Assert.That(response, Is.InstanceOf<Response<Completions>>());
}

Expand Down
1 change: 0 additions & 1 deletion sdk/openai/Azure.AI.OpenAI/tests/OpenAITestBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ protected OpenAIClient GetClientWithCredential() => InstrumentClient(
protected OpenAIClient GetClientWithCompletionsDeploymentId() => InstrumentClient(
new OpenAIClient(
new Uri(_endpoint),
DeploymentId,
TestEnvironment.Credential,
InstrumentClientOptions(new OpenAIClientOptions(OpenAIClientOptions.ServiceVersion.V2022_12_01))));
}
Expand Down
4 changes: 2 additions & 2 deletions sdk/openai/Azure.AI.OpenAI/tests/Samples/Sample01_Chatbot.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ public void GetChatbotResponse()
#region Snippet:GenerateChatbotResponse
#region Snippet:CreateOpenAIClientTokenCredential
string endpoint = "https://myaccount.openai.azure.com/";
OpenAIClient client = new OpenAIClient(new Uri(endpoint), "myDeploymentId", new DefaultAzureCredential());
OpenAIClient client = new OpenAIClient(new Uri(endpoint), new DefaultAzureCredential());
#endregion

string prompt = "What is Azure OpenAI?";
Console.Write($"Input: {prompt}");

Response<Completions> completionsResponse = client.GetCompletions(prompt);
Response<Completions> completionsResponse = client.GetCompletions("myDeploymentId", prompt);
string completion = completionsResponse.Value.Choices[0].Text;
Console.WriteLine($"Chatbot: {completion}");
#endregion
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public void SummarizeText()
{
#region Snippet:SummarizeText
string endpoint = "https://myaccount.openai.azure.com/";
OpenAIClient client = new OpenAIClient(new Uri(endpoint), "myDeploymentId", new DefaultAzureCredential());
OpenAIClient client = new OpenAIClient(new Uri(endpoint), new DefaultAzureCredential());

string textToSummarize = @"
Two independent experiments reported their results this morning at CERN, Europe's high-energy physics laboratory near Geneva in Switzerland. Both show convincing evidence of a new boson particle weighing around 125 gigaelectronvolts, which so far fits predictions of the Higgs previously made by theoretical physicists.
Expand Down

0 comments on commit e264784

Please sign in to comment.