diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a2c0552e..3920ca210 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Bug Fixes - Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635)) - Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639)) +- Hide user and credential field from search response ([#680](https://github.com/opensearch-project/flow-framework/pull/680)) ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java index fa552c2f3..9fbc0f6d5 100644 --- a/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/AbstractSearchWorkflowAction.java @@ -23,19 +23,14 @@ import org.opensearch.rest.RestRequest; import org.opensearch.rest.RestResponse; import org.opensearch.rest.action.RestResponseListener; -import org.opensearch.script.Script; -import org.opensearch.script.ScriptType; import org.opensearch.search.builder.SearchSourceBuilder; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED; -import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; /** * Abstract class to handle search request. @@ -89,23 +84,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); - searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout()); - // Apply credential filter when searching templates - if (index.equals(GLOBAL_CONTEXT_INDEX)) { - searchSourceBuilder.scriptField( - "filter", - new Script( - ScriptType.INLINE, - "painless", - "def filteredSource = new HashMap(params._source); def workflows = filteredSource.get(\"workflows\"); if (workflows != null) { def provision = workflows.get(\"provision\"); if (provision != null) { def nodes = provision.get(\"nodes\"); if (nodes != null) { for (node in nodes) { def userInputs = node.get(\"user_inputs\"); if (userInputs != null) { userInputs.remove(\"credential\"); } } } } } return filteredSource;", - Collections.emptyMap() - ) - ); - } - SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index); return channel -> client.execute(actionType, searchRequest, search(channel)); } diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java index 11c135210..d713e8f48 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowTransportAction.java @@ -17,12 +17,14 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -79,8 +81,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java index 3ce1ceffa..afe3b85d4 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportAction.java @@ -17,10 +17,15 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; + /** * Transport Action to search workflow states */ @@ -45,8 +50,10 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { // AccessController should take care of letting the user with right permission to view the workflow + User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - logger.info("Searching workflow states in global context"); + SearchSourceBuilder searchSourceBuilder = request.source(); + searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); client.search(request, ActionListener.runBefore(actionListener, context::restore)); } catch (Exception e) { logger.error("Failed to search workflow states in global context", e); diff --git a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java index 28dad8c78..41a8b23f9 100644 --- a/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/SearchWorkflowTransportAction.java @@ -17,10 +17,15 @@ import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext; + /** * Transport Action to search workflows created */ @@ -45,8 +50,11 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi @Override protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { // AccessController should take care of letting the user with right permission to view the workflow + User user = ParseUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { logger.info("Searching workflows in global context"); + SearchSourceBuilder searchSourceBuilder = request.source(); + searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder)); client.search(request, ActionListener.runBefore(actionListener, context::restore)); } catch (Exception e) { logger.error("Failed to search workflows in global context", e); diff --git a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java index df7e66e07..946c68908 100644 --- a/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/EncryptorUtils.java @@ -17,6 +17,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -122,7 +123,7 @@ public Template decryptTemplateCredentials(Template template) { /** * Applies the given cipher function on template credentials * @param template the template to process - * @param cipher the encryption/decryption function to apply on credential values + * @param cipherFunction the encryption/decryption function to apply on credential values * @return template with encrypted credentials */ private Template processTemplateCredentials(Template template, Function cipherFunction) { @@ -201,11 +202,13 @@ String decrypt(final String encryptedCredential) { // TODO : Improve redactTemplateCredentials to redact different fields /** * Removes the credential fields from a template + * @param user User * @param template the template * @return the redacted template */ - public Template redactTemplateCredentials(Template template) { + public Template redactTemplateSecuredFields(User user, Template template) { Map processedWorkflows = new HashMap<>(); + for (Map.Entry entry : template.workflows().entrySet()) { List processedNodes = new ArrayList<>(); @@ -227,7 +230,11 @@ public Template redactTemplateCredentials(Template template) { processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges())); } - return new Template.Builder(template).workflows(processedWorkflows).build(); + if (ParseUtils.isAdmin(user)) { + return new Template.Builder(template).workflows(processedWorkflows).build(); + } + + return new Template.Builder(template).user(null).workflows(processedWorkflows).build(); } /** diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index f7a1da0d4..ccf9ab686 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -25,7 +25,6 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.workflow.WorkflowData; -import org.opensearch.ml.common.agent.LLMSpec; import java.io.FileNotFoundException; import java.io.IOException; @@ -47,8 +46,6 @@ import jakarta.json.bind.JsonbBuilder; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; -import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; -import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; /** * Utility methods for Template parsing @@ -113,6 +110,18 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map parameters = llm.getParameters(); - xContentBuilder.field(MODEL_ID, modelId); - xContentBuilder.field(PARAMETERS_FIELD); - buildStringToStringMap(xContentBuilder, parameters); - } - /** * Parses an XContent object representing a map of String keys to String values. * diff --git a/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java b/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java index f300c90f6..a119d6809 100644 --- a/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/RestHandlerUtils.java @@ -9,8 +9,9 @@ package org.opensearch.flowframework.util; import org.apache.commons.lang3.ArrayUtils; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.common.Strings; import org.opensearch.flowframework.common.CommonValue; -import org.opensearch.rest.RestRequest; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -19,26 +20,36 @@ */ public class RestHandlerUtils { + /** Path to credential field **/ + private static final String PATH_TO_CREDENTIAL_FIELD = "workflows.provision.nodes.user_inputs.credential"; + /** Fields that need to be excluded from the Search Response*/ - public static final String[] USER_EXCLUDE = new String[] { CommonValue.USER_FIELD, CommonValue.UI_METADATA_FIELD }; + private static final String[] DASHBOARD_EXCLUDES = new String[] { + CommonValue.USER_FIELD, + CommonValue.UI_METADATA_FIELD, + PATH_TO_CREDENTIAL_FIELD }; + + private static final String[] EXCLUDES = new String[] { CommonValue.USER_FIELD, PATH_TO_CREDENTIAL_FIELD }; private RestHandlerUtils() {} /** * Creates a source context and include/exclude information to be shared based on the user * - * @param request the REST request + * @param user User * @param searchSourceBuilder the search request source builder * @return modified sources */ - public static FetchSourceContext getSourceContext(RestRequest request, SearchSourceBuilder searchSourceBuilder) { - // TODO - // 1. check if the request came from dashboard and exclude UI_METADATA + public static FetchSourceContext getSourceContext(User user, SearchSourceBuilder searchSourceBuilder) { if (searchSourceBuilder.fetchSource() != null) { - String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), USER_EXCLUDE); + String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), DASHBOARD_EXCLUDES); return new FetchSourceContext(true, searchSourceBuilder.fetchSource().includes(), newArray); } else { - return null; + // When user does not set the _source field in search api request, searchSourceBuilder.fetchSource becomes null + if (ParseUtils.isAdmin(user)) { + return new FetchSourceContext(true, Strings.EMPTY_ARRAY, new String[] { PATH_TO_CREDENTIAL_FIELD }); + } + return new FetchSourceContext(true, Strings.EMPTY_ARRAY, EXCLUDES); } } } diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java index 69d0e6bc3..f3f55c052 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowStateTransportActionTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -71,6 +72,8 @@ public void testSearchWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener); verify(client, times(1)).search(any(SearchRequest.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java index d60316085..763ae73b5 100644 --- a/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/SearchWorkflowTransportActionTests.java @@ -15,6 +15,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -34,12 +35,17 @@ public class SearchWorkflowTransportActionTests extends OpenSearchTestCase { private SearchWorkflowTransportAction searchWorkflowTransportAction; private Client client; private ThreadPool threadPool; + ThreadContext threadContext; @Override public void setUp() throws Exception { super.setUp(); this.client = mock(Client.class); - + this.threadPool = mock(ThreadPool.class); + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); this.searchWorkflowTransportAction = new SearchWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), @@ -73,6 +79,8 @@ public void testSearchWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchRequest.source(searchSourceBuilder); searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener); verify(client, times(1)).search(any(SearchRequest.class), any()); diff --git a/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java index cae595430..167dd634f 100644 --- a/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/EncryptorUtilsTests.java @@ -17,6 +17,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -26,6 +27,7 @@ import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; @@ -199,11 +201,39 @@ public void testRedactTemplateCredential() { WorkflowNode node = testTemplate.workflows().get("provision").nodes().get(0); assertNotNull(node.userInputs().get(CREDENTIAL_FIELD)); + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + // Redact template with credential field - Template redactedTemplate = encryptorUtils.redactTemplateCredentials(testTemplate); + Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate); // Validate the credential field has been removed WorkflowNode redactedNode = redactedTemplate.workflows().get("provision").nodes().get(0); assertNull(redactedNode.userInputs().get(CREDENTIAL_FIELD)); } + + public void testRedactTemplateUserField() { + // Confirm user is present in the non-redacted template + assertNotNull(testTemplate.getUser()); + + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + // Redact template with user field + Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate); + + // Validate the user field has been removed + assertNull(redactedTemplate.getUser()); + } + + public void testAdminUserTemplate() { + // Confirm user is present in the non-redacted template + assertNotNull(testTemplate.getUser()); + + List roles = new ArrayList<>(); + roles.add("all_access"); + + User user = new User("admin", roles, roles, Collections.emptyList()); + + // Redact template with user field + Template redactedTemplate = encryptorUtils.redactTemplateSecuredFields(user, testTemplate); + assertNotNull(redactedTemplate.getUser()); + } } diff --git a/src/test/java/org/opensearch/flowframework/util/RestHandlerUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/RestHandlerUtilsTests.java new file mode 100644 index 000000000..76d80feea --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/util/RestHandlerUtilsTests.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.util; + +import org.opensearch.commons.authuser.User; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class RestHandlerUtilsTests extends OpenSearchTestCase { + + public void testGetSourceContextFromClientWithDashboardExcludes() { + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + testSearchSourceBuilder.fetchSource(new String[] { "a" }, new String[] { "b" }); + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 4); + } + + public void testGetSourceContextFromClientWithExcludes() { + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + User user = new User("user", Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 2); + } + + public void testGetSourceContextAdminUser() { + SearchSourceBuilder testSearchSourceBuilder = new SearchSourceBuilder(); + List roles = new ArrayList<>(); + roles.add("all_access"); + + User user = new User("admin", roles, roles, Collections.emptyList()); + FetchSourceContext sourceContext = RestHandlerUtils.getSourceContext(user, testSearchSourceBuilder); + assertEquals(sourceContext.excludes().length, 1); + } + +}