diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java index 107c132fe1..9f4c3836cb 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java @@ -24,6 +24,9 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { + private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); + private final String COHERE_KEY = System.getenv("COHERE_KEY"); + private final String completionModelConnectorEntity = "{\n" + "\"name\": \"OpenAI Connector\",\n" + "\"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" @@ -39,7 +42,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase { + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") + + OPENAI_KEY + "\"\n" + " },\n" + " \"actions\": [\n" @@ -133,6 +136,10 @@ public void testDeployRemoteModel() throws IOException, InterruptedException { @Ignore public void testPredictRemoteModel() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } Response response = createConnector(completionModelConnectorEntity); Map responseMap = parseResponseToMap(response); String connectorId = (String) responseMap.get("connector_id"); @@ -186,6 +193,10 @@ public void testUndeployRemoteModel() throws IOException, InterruptedException { @Ignore public void testOpenAIChatCompletionModel() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } String entity = "{\n" + " \"name\": \"OpenAI chat model Connector\",\n" + " \"description\": \"The connector to public OpenAI model service for GPT 3.5\",\n" @@ -201,7 +212,7 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") + + OPENAI_KEY + "\"\n" + " },\n" + " \"actions\": [\n" @@ -243,6 +254,10 @@ public void testOpenAIChatCompletionModel() throws IOException, InterruptedExcep @Ignore public void testOpenAIEditsModel() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } String entity = "{\n" + " \"name\": \"OpenAI Edit model Connector\",\n" + " \"description\": \"The connector to public OpenAI edit model service\",\n" @@ -256,7 +271,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") + + OPENAI_KEY + "\"\n" + " },\n" + " \"actions\": [\n" @@ -309,6 +324,10 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException { @Ignore public void testOpenAIModerationsModel() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } String entity = "{\n" + " \"name\": \"OpenAI moderations model Connector\",\n" + " \"description\": \"The connector to public OpenAI moderations model service\",\n" @@ -322,7 +341,7 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") + + OPENAI_KEY + "\"\n" + " },\n" + " \"actions\": [\n" @@ -372,6 +391,10 @@ public void testOpenAIModerationsModel() throws IOException, InterruptedExceptio @Ignore public void testOpenAITextEmbeddingModel() throws IOException, InterruptedException { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } String entity = "{\n" + " \"name\": \"OpenAI text embedding model Connector\",\n" + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" @@ -385,7 +408,7 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept + " },\n" + " \"credential\": {\n" + " \"openAI_key\": \"" - + System.getenv("OPENAI_KEY") + + OPENAI_KEY + "\"\n" + " },\n" + " \"actions\": [\n" @@ -430,6 +453,10 @@ public void testOpenAITextEmbeddingModel() throws IOException, InterruptedExcept } public void testCohereGenerateTextModel() throws IOException, InterruptedException { + // Skip test if key is null + if (COHERE_KEY == null) { + return; + } String entity = "{\n" + " \"name\": \"Cohere generate text model Connector\",\n" + " \"description\": \"The connector to public Cohere generate text model service\",\n" @@ -443,7 +470,7 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti + " },\n" + " \"credential\": {\n" + " \"cohere_key\": \"" - + System.getenv("COHERE_KEY") + + COHERE_KEY + "\"\n" + " },\n" + " \"actions\": [\n" @@ -491,6 +518,10 @@ public void testCohereGenerateTextModel() throws IOException, InterruptedExcepti } public void testCohereClassifyModel() throws IOException, InterruptedException { + // Skip test if key is null + if (COHERE_KEY == null) { + return; + } String entity = "{\n" + " \"name\": \"Cohere classify model Connector\",\n" + " \"description\": \"The connector to public Cohere classify model service\",\n" @@ -504,7 +535,7 @@ public void testCohereClassifyModel() throws IOException, InterruptedException { + " },\n" + " \"credential\": {\n" + " \"cohere_key\": \"" - + System.getenv("COHERE_KEY") + + COHERE_KEY + "\"\n" + " },\n" + " \"actions\": [\n"