From be67209c13bf69dbe2afe05a62702f03b2397ae3 Mon Sep 17 00:00:00 2001 From: Joshua Palis Date: Fri, 22 Sep 2023 00:50:32 +0000 Subject: [PATCH] Adds unit tests for create ingest pipeline step, fixes pipeline request body generator Signed-off-by: Joshua Palis --- .../workflow/CreateIngestPipelineStep.java | 16 ++- .../CreateIngestPipelineStepTests.java | 111 ++++++++++++++++++ 2 files changed, 121 insertions(+), 6 deletions(-) create mode 100644 src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index a184282bc..7d7f02bbd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -79,7 +79,7 @@ public CompletableFuture execute(List data) { Map parameters = workflowData.getParams(); Map content = workflowData.getContent(); - logger.debug("Previous step sent params: {}, content: {}", parameters, content); + logger.info("Previous step sent params: {}, content: {}", parameters, content); for (Entry entry : content.entrySet()) { switch (entry.getKey()) { @@ -105,13 +105,13 @@ public CompletableFuture execute(List data) { } // Determmine if fields have been populated, else iterate over remaining workflow data - if (Stream.of(pipelineId, description, modelId, inputFieldName, outputFieldName).allMatch(x -> x != null)) { + if (Stream.of(pipelineId, description, modelId, type, inputFieldName, outputFieldName).allMatch(x -> x != null)) { try { configuration = BytesReference.bytes( - buildIngestPipelineRequestContent(description, modelId, inputFieldName, outputFieldName) + buildIngestPipelineRequestContent(description, modelId, type, inputFieldName, outputFieldName) ); } catch (IOException e) { - logger.error("Failed to create ingest pipeline : " + e.getMessage()); + logger.error("Failed to create ingest pipeline configuration: " + e.getMessage()); createIngestPipelineFuture.completeExceptionally(e); } break; @@ -157,7 +157,7 @@ public String getName() { * "description" : "", * "processors" : [ * { - * "text_embedding" : { + * "" : { * "model_id" : "", * "field_map" : { * "" : "" @@ -168,6 +168,7 @@ public String getName() { * * @param description The description of the ingest pipeline configuration * @param modelId The ID of the model that will be used in the embedding interface + * @param type The processor type * @param inputFieldName The field name used to cache text for text embeddings * @param outputFieldName The field name in which output text is stored * @throws IOException if the request content fails to be generated @@ -176,6 +177,7 @@ public String getName() { private XContentBuilder buildIngestPipelineRequestContent( String description, String modelId, + String type, String inputFieldName, String outputFieldName ) throws IOException { @@ -183,12 +185,14 @@ private XContentBuilder buildIngestPipelineRequestContent( .startObject() .field(DESCRIPTION_FIELD, description) .startArray(PROCESSORS_FIELD) - .startObject(TYPE_FIELD) + .startObject() + .startObject(type) .field(MODEL_ID_FIELD, modelId) .startObject(FIELD_MAP) .field(inputFieldName, outputFieldName) .endObject() .endObject() + .endObject() .endArray() .endObject(); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java new file mode 100644 index 000000000..1288a0b8f --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipeline/CreateIngestPipelineStepTests.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow.CreateIngestPipeline; + +import org.opensearch.action.ingest.PutPipelineRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.ClusterAdminClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.workflow.CreateIngestPipelineStep; +import org.opensearch.flowframework.workflow.WorkflowData; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class CreateIngestPipelineStepTests extends OpenSearchTestCase { + + private WorkflowData inputData; + private WorkflowData outpuData; + private Client client; + private AdminClient adminClient; + private ClusterAdminClient clusterAdminClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + inputData = new WorkflowData() { + + @Override + public Map getContent() { + return Map.of( + "id", + "pipelineId", + "description", + "some description", + "type", + "text_embedding", + "model_id", + "model_id", + "input_field_name", + "inputField", + "output_field_name", + "outputField" + ); + } + + @Override + public Map getParams() { + return Map.of(); + } + }; + + // Set output data to returned pipelineId + outpuData = new WorkflowData() { + + @Override + public Map getContent() { + return Map.of("pipelineId", "pipelineId"); + } + + @Override + public Map getParams() { + return Map.of(); + } + }; + + client = mock(Client.class); + adminClient = mock(AdminClient.class); + clusterAdminClient = mock(ClusterAdminClient.class); + + when(client.admin()).thenReturn(adminClient); + when(adminClient.cluster()).thenReturn(clusterAdminClient); + } + + public void testCreateIngestPipelineStep() throws InterruptedException, ExecutionException { + + CreateIngestPipelineStep createIngestPipelineStep = new CreateIngestPipelineStep(client); + + ArgumentCaptor actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + CompletableFuture future = createIngestPipelineStep.execute(List.of(inputData)); + + assertFalse(future.isDone()); + + // Mock put pipeline request execution and return true + verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture()); + actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true)); + + assertTrue(future.isDone()); + assertEquals(outpuData.getContent(), future.get().getContent()); + } + +}