-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support OpenAI in Spring Cloud Azure (#35551)
* add openai
- Loading branch information
Showing
19 changed files
with
927 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
.../azure/spring/cloud/autoconfigure/implementation/openai/AzureOpenAIAutoConfiguration.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package com.azure.spring.cloud.autoconfigure.implementation.openai; | ||
|
||
import com.azure.ai.openai.OpenAIAsyncClient; | ||
import com.azure.ai.openai.OpenAIClient; | ||
import com.azure.ai.openai.OpenAIClientBuilder; | ||
import com.azure.spring.cloud.autoconfigure.AzureServiceConfigurationBase; | ||
import com.azure.spring.cloud.autoconfigure.condition.ConditionalOnAnyProperty; | ||
import com.azure.spring.cloud.autoconfigure.context.AzureGlobalProperties; | ||
import com.azure.spring.cloud.autoconfigure.implementation.openai.properties.AzureOpenAIProperties; | ||
import com.azure.spring.cloud.core.customizer.AzureServiceClientBuilderCustomizer; | ||
import com.azure.spring.cloud.core.implementation.util.AzureSpringIdentifier; | ||
import com.azure.spring.cloud.service.implementation.openai.OpenAIClientBuilderFactory; | ||
import org.springframework.beans.factory.ObjectProvider; | ||
import org.springframework.boot.autoconfigure.EnableAutoConfiguration; | ||
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; | ||
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; | ||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; | ||
import org.springframework.boot.context.properties.ConfigurationProperties; | ||
import org.springframework.context.annotation.Bean; | ||
|
||
/** | ||
* {@link EnableAutoConfiguration Auto-configuration} for Azure Open AI support. | ||
* | ||
* @since 4.9.0-beta.1 | ||
*/ | ||
@ConditionalOnClass(OpenAIClientBuilder.class) | ||
@ConditionalOnProperty(value = "spring.cloud.azure.openai.enabled", havingValue = "true", matchIfMissing = true) | ||
@ConditionalOnAnyProperty(prefix = "spring.cloud.azure.openai", name = "endpoint") | ||
public class AzureOpenAIAutoConfiguration extends AzureServiceConfigurationBase { | ||
|
||
AzureOpenAIAutoConfiguration(AzureGlobalProperties azureGlobalProperties) { | ||
super(azureGlobalProperties); | ||
} | ||
|
||
@Bean | ||
@ConfigurationProperties(AzureOpenAIProperties.PREFIX) | ||
AzureOpenAIProperties azureOpenAIProperties() { | ||
return loadProperties(getAzureGlobalProperties(), new AzureOpenAIProperties()); | ||
} | ||
|
||
@Bean | ||
@ConditionalOnMissingBean | ||
OpenAIClient openAIClient(OpenAIClientBuilder builder) { | ||
return builder.buildClient(); | ||
} | ||
|
||
@Bean | ||
@ConditionalOnMissingBean | ||
OpenAIAsyncClient openAIAsyncClient(OpenAIClientBuilder builder) { | ||
return builder.buildAsyncClient(); | ||
} | ||
|
||
@Bean | ||
@ConditionalOnMissingBean | ||
OpenAIClientBuilder openAIClientBuilder(OpenAIClientBuilderFactory factory) { | ||
return factory.build(); | ||
} | ||
|
||
@Bean | ||
@ConditionalOnMissingBean | ||
OpenAIClientBuilderFactory openAIClientBuilderFactory(AzureOpenAIProperties properties, | ||
ObjectProvider<AzureServiceClientBuilderCustomizer<OpenAIClientBuilder>> customizers) { | ||
OpenAIClientBuilderFactory factory = new OpenAIClientBuilderFactory(properties); | ||
factory.setSpringIdentifier(AzureSpringIdentifier.AZURE_SPRING_OPENAI); | ||
customizers.orderedStream().forEach(factory::addBuilderCustomizer); | ||
return factory; | ||
} | ||
} |
70 changes: 70 additions & 0 deletions
70
...re/spring/cloud/autoconfigure/implementation/openai/properties/AzureOpenAIProperties.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package com.azure.spring.cloud.autoconfigure.implementation.openai.properties; | ||
|
||
import com.azure.ai.openai.OpenAIServiceVersion; | ||
import com.azure.spring.cloud.autoconfigure.implementation.properties.core.AbstractAzureHttpConfigurationProperties; | ||
import com.azure.spring.cloud.service.implementation.openai.OpenAIClientProperties; | ||
|
||
/** | ||
* Configuration properties for Azure OpenAI. | ||
*/ | ||
public class AzureOpenAIProperties extends AbstractAzureHttpConfigurationProperties implements OpenAIClientProperties { | ||
|
||
public static final String PREFIX = "spring.cloud.azure.openai"; | ||
|
||
/** | ||
* Endpoint of the Azure OpenAI. For instance, 'https://{azure-openai-name}.openai.azure.com/'. | ||
*/ | ||
private String endpoint; | ||
|
||
/** | ||
* Azure OpenAI service version used when making API requests. | ||
*/ | ||
private OpenAIServiceVersion serviceVersion; | ||
|
||
/** | ||
* The API key to authenticate the non-Azure OpenAI service (https://platform.openai.com/docs/api-reference/authentication). | ||
*/ | ||
private String nonAzureOpenAIKey; | ||
|
||
/** | ||
* Key to authenticate for accessing the Azure OpenAI. | ||
*/ | ||
private String key; | ||
|
||
public String getEndpoint() { | ||
return endpoint; | ||
} | ||
|
||
public void setEndpoint(String endpoint) { | ||
this.endpoint = endpoint; | ||
} | ||
|
||
public OpenAIServiceVersion getServiceVersion() { | ||
return serviceVersion; | ||
} | ||
|
||
public void setServiceVersion(OpenAIServiceVersion serviceVersion) { | ||
this.serviceVersion = serviceVersion; | ||
} | ||
|
||
@Override | ||
public String getKey() { | ||
return key; | ||
} | ||
|
||
public void setKey(String key) { | ||
this.key = key; | ||
} | ||
|
||
@Override | ||
public String getNonAzureOpenAIKey() { | ||
return nonAzureOpenAIKey; | ||
} | ||
|
||
public void setNonAzureOpenAIKey(String nonAzureOpenAIKey) { | ||
this.nonAzureOpenAIKey = nonAzureOpenAIKey; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
152 changes: 152 additions & 0 deletions
152
...e/spring/cloud/autoconfigure/implementation/openai/AzureOpenAIAutoConfigurationTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
package com.azure.spring.cloud.autoconfigure.implementation.openai; | ||
|
||
import com.azure.ai.openai.OpenAIServiceVersion; | ||
import com.azure.ai.openai.OpenAIAsyncClient; | ||
import com.azure.ai.openai.OpenAIClient; | ||
import com.azure.ai.openai.OpenAIClientBuilder; | ||
import com.azure.cosmos.CosmosClientBuilder; | ||
import com.azure.spring.cloud.autoconfigure.AbstractAzureServiceConfigurationTests; | ||
import com.azure.spring.cloud.autoconfigure.TestBuilderCustomizer; | ||
import com.azure.spring.cloud.autoconfigure.context.AzureGlobalProperties; | ||
import com.azure.spring.cloud.autoconfigure.implementation.openai.properties.AzureOpenAIProperties; | ||
import com.azure.spring.cloud.service.implementation.openai.OpenAIClientBuilderFactory; | ||
import org.junit.jupiter.api.Test; | ||
import org.springframework.boot.autoconfigure.AutoConfigurations; | ||
import org.springframework.boot.test.context.FilteredClassLoader; | ||
import org.springframework.boot.test.context.runner.ApplicationContextRunner; | ||
|
||
import static org.assertj.core.api.Assertions.assertThat; | ||
import static org.junit.Assert.assertEquals; | ||
import static org.mockito.Mockito.mock; | ||
|
||
class AzureOpenAIAutoConfigurationTests extends AbstractAzureServiceConfigurationTests< | ||
OpenAIClientBuilderFactory, AzureOpenAIProperties> { | ||
|
||
static final String TEST_ENDPOINT_HTTPS = "https://test.openai.azure.com/"; | ||
private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() | ||
.withConfiguration(AutoConfigurations.of(AzureOpenAIAutoConfiguration.class)); | ||
|
||
@Override | ||
protected ApplicationContextRunner getMinimalContextRunner() { | ||
return this.contextRunner | ||
.withPropertyValues("spring.cloud.azure.openai.endpoint=" + TEST_ENDPOINT_HTTPS); | ||
} | ||
|
||
@Override | ||
protected String getPropertyPrefix() { | ||
return AzureOpenAIProperties.PREFIX; | ||
} | ||
|
||
@Override | ||
protected Class<OpenAIClientBuilderFactory> getBuilderFactoryType() { | ||
return OpenAIClientBuilderFactory.class; | ||
} | ||
|
||
@Override | ||
protected Class<AzureOpenAIProperties> getConfigurationPropertiesType() { | ||
return AzureOpenAIProperties.class; | ||
} | ||
|
||
@Test | ||
void configureWithoutOpenAIClientBuilder() { | ||
this.contextRunner | ||
.withPropertyValues("spring.cloud.azure.openai.endpoint=" + TEST_ENDPOINT_HTTPS) | ||
.withClassLoader(new FilteredClassLoader(OpenAIClientBuilder.class)) | ||
.run(context -> assertThat(context).doesNotHaveBean(AzureOpenAIAutoConfiguration.class)); | ||
} | ||
|
||
@Test | ||
void configureWithOpenAIDisabled() { | ||
this.contextRunner | ||
.withPropertyValues( | ||
"spring.cloud.azure.openai.enabled=false", | ||
"spring.cloud.azure.openai.endpoint=" + TEST_ENDPOINT_HTTPS) | ||
.run(context -> assertThat(context).doesNotHaveBean(AzureOpenAIAutoConfiguration.class)); | ||
} | ||
|
||
@Test | ||
void configureWithoutEndpoint() { | ||
this.contextRunner | ||
.run(context -> assertThat(context).doesNotHaveBean(AzureOpenAIAutoConfiguration.class)); | ||
} | ||
|
||
@Test | ||
void configureWithEndpoint() { | ||
this.contextRunner | ||
.withPropertyValues("spring.cloud.azure.openai.endpoint=" + TEST_ENDPOINT_HTTPS) | ||
.withBean(AzureGlobalProperties.class, AzureGlobalProperties::new) | ||
.withBean(OpenAIClientBuilder.class, () -> mock(OpenAIClientBuilder.class)) | ||
.run(context -> { | ||
assertThat(context).hasSingleBean(AzureOpenAIAutoConfiguration.class); | ||
assertThat(context).hasSingleBean(AzureOpenAIProperties.class); | ||
assertThat(context).hasSingleBean(OpenAIClientBuilderFactory.class); | ||
assertThat(context).hasSingleBean(OpenAIClientBuilder.class); | ||
assertThat(context).hasSingleBean(OpenAIClient.class); | ||
assertThat(context).hasSingleBean(OpenAIAsyncClient.class); | ||
}); | ||
} | ||
|
||
@Test | ||
void customizerShouldBeCalled() { | ||
OpenAIBuilderCustomizer customizer = new OpenAIBuilderCustomizer(); | ||
this.contextRunner | ||
.withPropertyValues("spring.cloud.azure.openai.endpoint=" + TEST_ENDPOINT_HTTPS) | ||
.withBean(AzureGlobalProperties.class, AzureGlobalProperties::new) | ||
.withBean("customizer1", OpenAIBuilderCustomizer.class, () -> customizer) | ||
.withBean("customizer2", OpenAIBuilderCustomizer.class, () -> customizer) | ||
.run(context -> assertThat(customizer.getCustomizedTimes()).isEqualTo(2) | ||
); | ||
} | ||
|
||
@Test | ||
void otherCustomizerShouldNotBeCalled() { | ||
OpenAIBuilderCustomizer customizer = new OpenAIBuilderCustomizer(); | ||
OtherBuilderCustomizer otherCustomizer = new OtherBuilderCustomizer(); | ||
this.contextRunner | ||
.withPropertyValues("spring.cloud.azure.openai.endpoint=" + TEST_ENDPOINT_HTTPS) | ||
.withBean(AzureGlobalProperties.class, AzureGlobalProperties::new) | ||
.withBean("customizer1", OpenAIBuilderCustomizer.class, () -> customizer) | ||
.withBean("customizer2", OpenAIBuilderCustomizer.class, () -> customizer) | ||
.withBean("customizer3", OtherBuilderCustomizer.class, () -> otherCustomizer) | ||
.run(context -> { | ||
assertThat(customizer.getCustomizedTimes()).isEqualTo(2); | ||
assertThat(otherCustomizer.getCustomizedTimes()).isEqualTo(0); | ||
}); | ||
} | ||
|
||
@Test | ||
void configurationPropertiesShouldBind() { | ||
String azureKeyCredential = "azure-key-credential"; | ||
String nonAzureOpenAIKeyCredential = "non-azure-key-credential"; | ||
this.contextRunner | ||
.withPropertyValues( | ||
"spring.cloud.azure.openai.endpoint=" + TEST_ENDPOINT_HTTPS, | ||
"spring.cloud.azure.openai.key=" + azureKeyCredential, | ||
"spring.cloud.azure.openai.non-azure-openai-key=" + nonAzureOpenAIKeyCredential, | ||
"spring.cloud.azure.openai.service-version=v2022_12_01", | ||
"spring.cloud.azure.credential.client-id=openai-client-id" | ||
) | ||
.withBean(AzureGlobalProperties.class, AzureGlobalProperties::new) | ||
.withBean(OpenAIClientBuilder.class, () -> mock(OpenAIClientBuilder.class)) | ||
.run(context -> { | ||
assertThat(context).hasSingleBean(AzureOpenAIProperties.class); | ||
AzureOpenAIProperties properties = context.getBean(AzureOpenAIProperties.class); | ||
assertEquals(TEST_ENDPOINT_HTTPS, properties.getEndpoint()); | ||
assertEquals(azureKeyCredential, properties.getKey()); | ||
assertEquals(nonAzureOpenAIKeyCredential, properties.getNonAzureOpenAIKey()); | ||
assertEquals(OpenAIServiceVersion.V2022_12_01, properties.getServiceVersion()); | ||
assertEquals("openai-client-id", properties.getCredential().getClientId()); | ||
}); | ||
} | ||
|
||
private static class OpenAIBuilderCustomizer extends TestBuilderCustomizer<OpenAIClientBuilder> { | ||
|
||
} | ||
|
||
private static class OtherBuilderCustomizer extends TestBuilderCustomizer<CosmosClientBuilder> { | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.