diff --git a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java index 6d96d798..244f90b4 100644 --- a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java +++ b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/AutoConfig.java @@ -46,12 +46,14 @@ AzureOpenAiChatModel openAiChatModel(Properties properties) { .presencePenalty(chatModelProperties.presencePenalty()) .frequencyPenalty(chatModelProperties.frequencyPenalty()) .seed(chatModelProperties.seed()) + .strictJsonSchema(chatModelProperties.strictJsonSchema()) .timeout(Duration.ofSeconds(chatModelProperties.timeout() == null ? 0 : chatModelProperties.timeout())) .maxRetries(chatModelProperties.maxRetries()) .proxyOptions(ProxyOptions.fromConfiguration(Configuration.getGlobalConfiguration())) .logRequestsAndResponses(chatModelProperties.logRequestsAndResponses() != null && chatModelProperties.logRequestsAndResponses()) .userAgentSuffix(chatModelProperties.userAgentSuffix()) - .customHeaders(chatModelProperties.customHeaders()); + .customHeaders(chatModelProperties.customHeaders()) + .supportedCapabilities(chatModelProperties.supportedCapabilities()); if (chatModelProperties.nonAzureApiKey() != null) { builder.nonAzureApiKey(chatModelProperties.nonAzureApiKey()); } diff --git a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java index 5835406e..2c957428 100644 --- a/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java +++ b/langchain4j-azure-open-ai-spring-boot-starter/src/main/java/dev/langchain4j/azure/openai/spring/ChatModelProperties.java @@ -1,7 +1,10 @@ package dev.langchain4j.azure.openai.spring; +import dev.langchain4j.model.chat.Capability; + import java.util.List; import java.util.Map; +import java.util.Set; record ChatModelProperties( @@ -18,12 +21,13 @@ record ChatModelProperties( Double presencePenalty, Double frequencyPenalty, Long seed, - String responseFormat, + Boolean strictJsonSchema, Integer timeout, // TODO use Duration instead Integer maxRetries, Boolean logRequestsAndResponses, String userAgentSuffix, Map customHeaders, - String nonAzureApiKey + String nonAzureApiKey, + Set supportedCapabilities ) { } \ No newline at end of file diff --git a/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java b/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java index bd06a23c..5547a8fd 100644 --- a/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java +++ b/langchain4j-azure-open-ai-spring-boot-starter/src/test/java/dev/langchain4j/azure/openai/spring/AutoConfigIT.java @@ -8,6 +8,12 @@ import dev.langchain4j.model.azure.AzureOpenAiStreamingChatModel; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.chat.request.ChatRequest; +import dev.langchain4j.model.chat.request.ResponseFormat; +import dev.langchain4j.model.chat.request.json.JsonArraySchema; +import dev.langchain4j.model.chat.request.json.JsonObjectSchema; +import dev.langchain4j.model.chat.request.json.JsonSchema; +import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.embedding.EmbeddingModel; import dev.langchain4j.model.image.ImageModel; import dev.langchain4j.model.output.Response; @@ -17,8 +23,12 @@ import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import java.util.List; import java.util.concurrent.CompletableFuture; +import static dev.langchain4j.data.message.UserMessage.userMessage; +import static dev.langchain4j.model.chat.request.ResponseFormatType.JSON; +import static java.util.Collections.singletonList; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; @@ -53,6 +63,52 @@ void should_provide_chat_model(String deploymentName) { }); } + class Person { + + String name; + List favouriteColors; + } + + @ParameterizedTest(name = "Deployment name: {0}") + @CsvSource({ + "gpt-4o-mini" + }) + void should_provide_chat_model_with_json_schema(String deploymentName) { + contextRunner + .withPropertyValues( + "langchain4j.azure-open-ai.chat-model.api-key=" + AZURE_OPENAI_KEY, + "langchain4j.azure-open-ai.chat-model.endpoint=" + AZURE_OPENAI_ENDPOINT, + "langchain4j.azure-open-ai.chat-model.deployment-name=" + deploymentName, + "langchain4j.azure-open-ai.chat-model.strict-json-schema=true" + ) + .run(context -> { + + ChatLanguageModel chatLanguageModel = context.getBean(ChatLanguageModel.class); + + ChatRequest chatRequest = ChatRequest.builder() + .messages(singletonList(userMessage("Julien likes blue, white and red"))) + .responseFormat(ResponseFormat.builder() + .type(JSON) + .jsonSchema(JsonSchema.builder() + .name("Person") + .rootElement(JsonObjectSchema.builder() + .addStringProperty("name") + .addProperty("favouriteColors", JsonArraySchema.builder() + .items(new JsonStringSchema()) + .build()) + .required("name", "favouriteColors") + .build()) + .build()) + .build()) + .build(); + + assertThat(chatLanguageModel).isInstanceOf(AzureOpenAiChatModel.class); + AiMessage aiMessage = chatLanguageModel.chat(chatRequest).aiMessage(); + assertThat(aiMessage.text()).contains("{\"name\":\"Julien\",\"favouriteColors\":[\"blue\",\"white\",\"red\"]}"); + assertThat(context.getBean(AzureOpenAiChatModel.class)).isSameAs(chatLanguageModel); + }); + } + @ParameterizedTest(name = "Deployment name: {0}") @CsvSource({ "gpt-3.5-turbo"