Skip to content

Commit

Permalink
[AI] [Inference] update readme with scenario samples, add tests (Azur…
Browse files Browse the repository at this point in the history
…e#42008)

* [AI] [Inference] add scenario sample to readme

* add non-streaming tool call sample

* add links to function call samples

* add stubs and links for image chat samples

* add image chat code samples to readme

* make getModelInfo public, add sample

* add readme sample for getModelInfo

* add streaming function call test

* update test recordings

* linter fixes

* add image tests to chat suite

* update test recordings

* add getChoice() (singular) API for ChatCompletions and StreamingChatCompletionsUpdate, update samples and tests accordingly

* review feedback
  • Loading branch information
glharper authored and mssfang committed Oct 21, 2024
1 parent ccd98e0 commit e9ed40f
Show file tree
Hide file tree
Showing 21 changed files with 734 additions and 43 deletions.
77 changes: 74 additions & 3 deletions sdk/ai/azure-ai-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ client.completeStream(new ChatCompletionsOptions(chatMessages))
if (CoreUtils.isNullOrEmpty(chatCompletions.getChoices())) {
return;
}
StreamingChatResponseMessageUpdate delta = chatCompletions.getChoices().get(0).getDelta();
StreamingChatResponseMessageUpdate delta = chatCompletions.getChoice().getDelta();
if (delta.getRole() != null) {
System.out.println("Role = " + delta.getRole());
}
Expand All @@ -145,14 +145,80 @@ client.completeStream(new ChatCompletionsOptions(chatMessages))

To compute tokens in streaming chat completions, see sample [Streaming Chat Completions][sample_get_chat_completions_streaming].

<!--
### Chat with image URL

```java readme-sample-chatWithImageUrl
List<ChatMessageContentItem> contentItems = new ArrayList<>();
contentItems.add(new ChatMessageTextContentItem("Describe the image."));
contentItems.add(new ChatMessageImageContentItem(
new ChatMessageImageUrl("<URL>")));

List<ChatRequestMessage> chatMessages = new ArrayList<>();
chatMessages.add(new ChatRequestSystemMessage("You are a helpful assistant."));
chatMessages.add(ChatRequestUserMessage.fromContentItems(contentItems));

ChatCompletions completions = client.complete(new ChatCompletionsOptions(chatMessages));
System.out.printf("%s.%n", completions.getChoice().getMessage().getContent());
```
For a complete sample example, see sample [Image URL][sample_chat_with_image_url].

### Chat with image file

```java readme-sample-chatWithImageFile
Path testFilePath = Paths.get("<path-to-image-file>");
List<ChatMessageContentItem> contentItems = new ArrayList<>();
contentItems.add(new ChatMessageTextContentItem("Describe the image."));
contentItems.add(new ChatMessageImageContentItem(testFilePath, "<image-format>"));

List<ChatRequestMessage> chatMessages = new ArrayList<>();
chatMessages.add(new ChatRequestSystemMessage("You are a helpful assistant."));
chatMessages.add(ChatRequestUserMessage.fromContentItems(contentItems));

ChatCompletions completions = client.complete(new ChatCompletionsOptions(chatMessages));

System.out.printf("%s.%n", completions.getChoice().getMessage().getContent());
```
For a complete sample example, see sample [Image File][sample_chat_with_image_file].

### Text embeddings

```java readme-sample-getEmbedding
EmbeddingsClient client = new EmbeddingsClientBuilder()
.endpoint("{endpoint}")
.credential(new AzureKeyCredential("{key}"))
.buildClient();

List<String> promptList = new ArrayList<>();
String prompt = "Tell me 3 jokes about trains";
promptList.add(prompt);

EmbeddingsResult embeddings = client.embed(promptList);

for (EmbeddingItem item : embeddings.getData()) {
System.out.printf("Index: %d.%n", item.getIndex());
for (Float embedding : item.getEmbeddingList()) {
System.out.printf("%f;", embedding);
}
}
```
For a complete sample example, see sample [Embedding][sample_get_embedding].

-->
### Function calls

For a complete sample example, see sample [Function Calls][sample_function_calls].

### Streaming function calls

For a complete sample example, see sample [Streaming Function Calls][sample_streaming_function_calls].

### Get Model information

```java readme-sample-getModelInfo
ModelInfo modelInfo = client.getModelInfo();

System.out.printf("modelName: %s, modelNameProvider: %s, modelType: %s%n",
modelInfo.getModelName(), modelInfo.getModelProviderName(), modelInfo.getModelType().toString());
```

### Service API versions

Expand Down Expand Up @@ -209,6 +275,11 @@ For details on contributing to this repository, see the [contributing guide](htt
[azure_identity]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/identity/azure-identity
[sample_get_chat_completions]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/samples/java/com/azure/ai/inference/usage/BasicChatSample.java
[sample_get_chat_completions_streaming]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/samples/java/com/azure/ai/inference/usage/StreamingChatSample.java
[sample_get_embedding]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/samples/java/com/azure/ai/inference/usage/TextEmbeddingsSample.java
[sample_chat_with_image_url]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/samples/java/com/azure/ai/inference/usage/ImageUrlChatSample.java
[sample_chat_with_image_file]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/samples/java/com/azure/ai/inference/usage/ImageFileChatSample.java
[sample_function_calls]: https://aka.ms/azsdk/azure-ai-inference/java/toolCallSample
[sample_streaming_function_calls]: https://aka.ms/azsdk/azure-ai-inference/java/streamingToolCallSample
[chat_completions_client_async]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/main/java/com/azure/ai/inference/ChatCompletionsAsyncClient.java
[chat_completions_client_builder]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/main/java/com/azure/ai/inference/ChatCompletionsClientBuilder.java
[chat_completions_client_sync]: https://github.com/Azure/azure-sdk-for-java/blob/main/sdk/ai/azure-ai-inference/src/main/java/com/azure/ai/inference/ChatCompletionsClient.java
Expand Down
2 changes: 1 addition & 1 deletion sdk/ai/azure-ai-inference/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo" : "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath" : "java",
"TagPrefix" : "java/ai/azure-ai-inference",
"Tag" : "java/ai/azure-ai-inference_9ec2ae5fbb"
"Tag" : "java/ai/azure-ai-inference_5c740d7f95"
}
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,8 @@ public Mono<ChatCompletions> complete(ChatCompletionsOptions options) {
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
* @return represents some basic information about the AI model on successful completion of {@link Mono}.
*/
@Generated
@ServiceMethod(returns = ReturnType.SINGLE)
Mono<ModelInfo> getModelInfo() {
public Mono<ModelInfo> getModelInfo() {
// Generated convenience method for getModelInfoWithResponse
RequestOptions requestOptions = new RequestOptions();
return getModelInfoWithResponse(requestOptions).flatMap(FluxUtil::toMono)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,8 @@ public IterableStream<StreamingChatCompletionsUpdate> completeStream(ChatComplet
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
* @return represents some basic information about the AI model.
*/
@Generated
@ServiceMethod(returns = ReturnType.SINGLE)
ModelInfo getModelInfo() {
public ModelInfo getModelInfo() {
// Generated convenience method for getModelInfoWithResponse
RequestOptions requestOptions = new RequestOptions();
return getModelInfoWithResponse(requestOptions).getValue().toObject(ModelInfo.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,8 @@ public Mono<EmbeddingsResult> embed(List<String> input) {
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
* @return represents some basic information about the AI model on successful completion of {@link Mono}.
*/
@Generated
@ServiceMethod(returns = ReturnType.SINGLE)
Mono<ModelInfo> getModelInfo() {
public Mono<ModelInfo> getModelInfo() {
// Generated convenience method for getModelInfoWithResponse
RequestOptions requestOptions = new RequestOptions();
return getModelInfoWithResponse(requestOptions).flatMap(FluxUtil::toMono)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,8 @@ public Response<EmbeddingsResult> embedWithResponse(List<String> input) {
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
* @return represents some basic information about the AI model.
*/
@Generated
@ServiceMethod(returns = ReturnType.SINGLE)
ModelInfo getModelInfo() {
public ModelInfo getModelInfo() {
// Generated convenience method for getModelInfoWithResponse
RequestOptions requestOptions = new RequestOptions();
return getModelInfoWithResponse(requestOptions).getValue().toObject(ModelInfo.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ public CompletionsUsage getUsage() {
return this.usage;
}

/**
* Get the choice property: The chat choice associated with this completion response.
*
* @return the choice value.
*/
public ChatChoice getChoice() {
return this.choices.get(0);
}

/**
* Get the choices property: The collection of completions choices associated with this completions response.
* Generally, `n` choices are generated per provided prompt with a default value of 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,15 @@ public List<StreamingChatChoiceUpdate> getChoices() {
return this.choices;
}

/**
* Get the choice property: The chat choice associated with this completion response.
*
* @return the choice value.
*/
public StreamingChatChoiceUpdate getChoice() {
return this.choices.get(0);
}

/**
* {@inheritDoc}
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
import com.azure.ai.inference.models.ChatChoice;
import com.azure.ai.inference.models.ChatCompletions;
import com.azure.ai.inference.models.ChatCompletionsOptions;
import com.azure.ai.inference.models.ChatMessageContentItem;
import com.azure.ai.inference.models.ChatMessageImageContentItem;
import com.azure.ai.inference.models.ChatMessageImageUrl;
import com.azure.ai.inference.models.ChatMessageTextContentItem;
import com.azure.ai.inference.models.ChatRequestMessage;
import com.azure.ai.inference.models.ChatRequestAssistantMessage;
import com.azure.ai.inference.models.ChatRequestSystemMessage;
import com.azure.ai.inference.models.ChatRequestUserMessage;
import com.azure.ai.inference.models.ChatResponseMessage;
import com.azure.ai.inference.models.EmbeddingItem;
import com.azure.ai.inference.models.EmbeddingsResult;
import com.azure.ai.inference.models.ModelInfo;
import com.azure.ai.inference.models.StreamingChatResponseMessageUpdate;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.TokenCredential;
Expand All @@ -23,6 +30,8 @@
import com.azure.identity.DefaultAzureCredential;
import com.azure.identity.DefaultAzureCredentialBuilder;

import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -90,7 +99,7 @@ public void getChatCompletionsStream() {
if (CoreUtils.isNullOrEmpty(chatCompletions.getChoices())) {
return;
}
StreamingChatResponseMessageUpdate delta = chatCompletions.getChoices().get(0).getDelta();
StreamingChatResponseMessageUpdate delta = chatCompletions.getChoice().getDelta();
if (delta.getRole() != null) {
System.out.println("Role = " + delta.getRole());
}
Expand All @@ -104,9 +113,68 @@ public void getChatCompletionsStream() {

public void getEmbedding() {
// BEGIN: readme-sample-getEmbedding
EmbeddingsClient client = new EmbeddingsClientBuilder()
.endpoint("{endpoint}")
.credential(new AzureKeyCredential("{key}"))
.buildClient();

List<String> promptList = new ArrayList<>();
String prompt = "Tell me 3 jokes about trains";
promptList.add(prompt);

EmbeddingsResult embeddings = client.embed(promptList);

for (EmbeddingItem item : embeddings.getData()) {
System.out.printf("Index: %d.%n", item.getIndex());
for (Float embedding : item.getEmbeddingList()) {
System.out.printf("%f;", embedding);
}
}
// END: readme-sample-getEmbedding
}

public void getModelInfo() {
// BEGIN: readme-sample-getModelInfo
ModelInfo modelInfo = client.getModelInfo();

System.out.printf("modelName: %s, modelNameProvider: %s, modelType: %s%n",
modelInfo.getModelName(), modelInfo.getModelProviderName(), modelInfo.getModelType().toString());
// END: readme-sample-getModelInfo
}

public void chatWithImageFile() {
// BEGIN: readme-sample-chatWithImageFile
Path testFilePath = Paths.get("<path-to-image-file>");
List<ChatMessageContentItem> contentItems = new ArrayList<>();
contentItems.add(new ChatMessageTextContentItem("Describe the image."));
contentItems.add(new ChatMessageImageContentItem(testFilePath, "<image-format>"));

List<ChatRequestMessage> chatMessages = new ArrayList<>();
chatMessages.add(new ChatRequestSystemMessage("You are a helpful assistant."));
chatMessages.add(ChatRequestUserMessage.fromContentItems(contentItems));

ChatCompletions completions = client.complete(new ChatCompletionsOptions(chatMessages));

System.out.printf("%s.%n", completions.getChoice().getMessage().getContent());
// END: readme-sample-chatWithImageFile
}

public void chatWithImageUrl() {
// BEGIN: readme-sample-chatWithImageUrl
List<ChatMessageContentItem> contentItems = new ArrayList<>();
contentItems.add(new ChatMessageTextContentItem("Describe the image."));
contentItems.add(new ChatMessageImageContentItem(
new ChatMessageImageUrl("<URL>")));

List<ChatRequestMessage> chatMessages = new ArrayList<>();
chatMessages.add(new ChatRequestSystemMessage("You are a helpful assistant."));
chatMessages.add(ChatRequestUserMessage.fromContentItems(contentItems));

ChatCompletions completions = client.complete(new ChatCompletionsOptions(chatMessages));
System.out.printf("%s.%n", completions.getChoice().getMessage().getContent());
// END: readme-sample-chatWithImageUrl
}

public void enableHttpLogging() {
// BEGIN: readme-sample-enablehttplogging
ChatCompletionsClient chatCompletionsClient = new ChatCompletionsClientBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ public static void main(String[] args) {

ChatCompletions completions = client.complete(prompt);

for (ChatChoice choice : completions.getChoices()) {
System.out.printf("%s.%n", choice.getMessage().getContent());
}
ChatChoice choice = completions.getChoice();
System.out.printf("%s.%n", choice.getMessage().getContent());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import com.azure.ai.inference.ChatCompletionsClient;
import com.azure.ai.inference.ChatCompletionsClientBuilder;
import com.azure.ai.inference.models.ChatChoice;
import com.azure.ai.inference.models.ChatCompletions;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.util.Configuration;
Expand All @@ -27,8 +26,6 @@ public static void main(String[] args) {

ChatCompletions completions = client.complete(prompt);

for (ChatChoice choice : completions.getChoices()) {
System.out.printf("%s.%n", choice.getMessage().getContent());
}
System.out.printf("%s.%n", completions.getChoice().getMessage().getContent());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import com.azure.ai.inference.ChatCompletionsAsyncClient;
import com.azure.ai.inference.ChatCompletionsClientBuilder;
import com.azure.ai.inference.models.ChatChoice;
import com.azure.ai.inference.models.ChatResponseMessage;
import com.azure.ai.inference.models.CompletionsUsage;
import com.azure.core.credential.AzureKeyCredential;
Expand All @@ -28,12 +27,10 @@ public static void main(String[] args) throws InterruptedException {
client.complete("Tell me about Euler's Identity").subscribe(
chatCompletions -> {
System.out.printf("Model ID=%s.%n", chatCompletions.getId());
for (ChatChoice choice : chatCompletions.getChoices()) {
ChatResponseMessage message = choice.getMessage();
System.out.printf("Index: %d, Chat Role: %s.%n", choice.getIndex(), message.getRole());
System.out.println("Message:");
System.out.println(message.getContent());
}
ChatResponseMessage message = chatCompletions.getChoice().getMessage();
System.out.printf("Chat Role: %s.%n", message.getRole());
System.out.println("Message:");
System.out.println(message.getContent());

System.out.println();
CompletionsUsage usage = chatCompletions.getUsage();
Expand Down
Loading

0 comments on commit e9ed40f

Please sign in to comment.