Skip to content

Commit

Permalink
Hides user and credential field from search response (#680)
Browse files Browse the repository at this point in the history
* Hide user and credential field from search response

Signed-off-by: Owais Kazi <[email protected]>

* Hid the user field for Get API as well

Signed-off-by: Owais Kazi <[email protected]>

* Addressed PR Comments

Signed-off-by: Owais Kazi <[email protected]>

* Added user permission and new tests

Signed-off-by: Owais Kazi <[email protected]>

---------

Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 authored Apr 25, 2024
1 parent bae3aa2 commit 757eaaf
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Bug Fixes
- Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635))
- Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639))
- Hide user and credential field from search response ([#680](https://github.com/opensearch-project/flow-framework/pull/680))

### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,14 @@
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.action.RestResponseListener;
import org.opensearch.script.Script;
import org.opensearch.script.ScriptType;
import org.opensearch.search.builder.SearchSourceBuilder;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS;
import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.FlowFrameworkSettings.FLOW_FRAMEWORK_ENABLED;
import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Abstract class to handle search request.
Expand Down Expand Up @@ -89,23 +84,9 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
}
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.parseXContent(request.contentOrSourceParamParser());
searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder));
searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true);
searchSourceBuilder.timeout(flowFrameworkSettings.getRequestTimeout());

// Apply credential filter when searching templates
if (index.equals(GLOBAL_CONTEXT_INDEX)) {
searchSourceBuilder.scriptField(
"filter",
new Script(
ScriptType.INLINE,
"painless",
"def filteredSource = new HashMap(params._source); def workflows = filteredSource.get(\"workflows\"); if (workflows != null) { def provision = workflows.get(\"provision\"); if (provision != null) { def nodes = provision.get(\"nodes\"); if (nodes != null) { for (node in nodes) { def userInputs = node.get(\"user_inputs\"); if (userInputs != null) { userInputs.remove(\"credential\"); } } } } } return filteredSource;",
Collections.emptyMap()
)
);
}

SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(index);
return channel -> client.execute(actionType, searchRequest, search(channel));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -79,8 +81,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<GetW
logger.error(errorMessage);
listener.onFailure(new FlowFrameworkException(errorMessage, RestStatus.NOT_FOUND));
} else {
// Remove any credential from response
Template template = encryptorUtils.redactTemplateCredentials(Template.parse(response.getSourceAsString()));
// Remove any secured field from response
User user = ParseUtils.getUserContext(client);
Template template = encryptorUtils.redactTemplateSecuredFields(user, Template.parse(response.getSourceAsString()));
listener.onResponse(new GetWorkflowResponse(template));
}
}, exception -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Transport Action to search workflow states
*/
Expand All @@ -45,8 +50,10 @@ public SearchWorkflowStateTransportAction(TransportService transportService, Act
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
// AccessController should take care of letting the user with right permission to view the workflow
User user = ParseUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Searching workflow states in global context");
SearchSourceBuilder searchSourceBuilder = request.source();
searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder));
client.search(request, ActionListener.runBefore(actionListener, context::restore));
} catch (Exception e) {
logger.error("Failed to search workflow states in global context", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
import org.opensearch.client.Client;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.flowframework.util.RestHandlerUtils.getSourceContext;

/**
* Transport Action to search workflows created
*/
Expand All @@ -45,8 +50,11 @@ public SearchWorkflowTransportAction(TransportService transportService, ActionFi
@Override
protected void doExecute(Task task, SearchRequest request, ActionListener<SearchResponse> actionListener) {
// AccessController should take care of letting the user with right permission to view the workflow
User user = ParseUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
logger.info("Searching workflows in global context");
SearchSourceBuilder searchSourceBuilder = request.source();
searchSourceBuilder.fetchSource(getSourceContext(user, searchSourceBuilder));
client.search(request, ActionListener.runBefore(actionListener, context::restore));
} catch (Exception e) {
logger.error("Failed to search workflows in global context", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
Expand Down Expand Up @@ -122,7 +123,7 @@ public Template decryptTemplateCredentials(Template template) {
/**
* Applies the given cipher function on template credentials
* @param template the template to process
* @param cipher the encryption/decryption function to apply on credential values
* @param cipherFunction the encryption/decryption function to apply on credential values
* @return template with encrypted credentials
*/
private Template processTemplateCredentials(Template template, Function<String, String> cipherFunction) {
Expand Down Expand Up @@ -201,11 +202,13 @@ String decrypt(final String encryptedCredential) {
// TODO : Improve redactTemplateCredentials to redact different fields
/**
* Removes the credential fields from a template
* @param user User
* @param template the template
* @return the redacted template
*/
public Template redactTemplateCredentials(Template template) {
public Template redactTemplateSecuredFields(User user, Template template) {
Map<String, Workflow> processedWorkflows = new HashMap<>();

for (Map.Entry<String, Workflow> entry : template.workflows().entrySet()) {

List<WorkflowNode> processedNodes = new ArrayList<>();
Expand All @@ -227,7 +230,11 @@ public Template redactTemplateCredentials(Template template) {
processedWorkflows.put(entry.getKey(), new Workflow(entry.getValue().userParams(), processedNodes, entry.getValue().edges()));
}

return new Template.Builder(template).workflows(processedWorkflows).build();
if (ParseUtils.isAdmin(user)) {
return new Template.Builder(template).workflows(processedWorkflows).build();
}

return new Template.Builder(template).user(null).workflows(processedWorkflows).build();
}

/**
Expand Down
30 changes: 12 additions & 18 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.FileNotFoundException;
import java.io.IOException;
Expand All @@ -47,8 +46,6 @@
import jakarta.json.bind.JsonbBuilder;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD;
import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID;

/**
* Utility methods for Template parsing
Expand Down Expand Up @@ -113,6 +110,18 @@ public static void buildStringToStringMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* 'all_access' role users are treated as admins.
* @param user of the current role
* @return boolean if the role is admin
*/
public static boolean isAdmin(User user) {
if (user == null) {
return false;
}
return user.getRoles().contains("all_access");
}

/**
* Builds an XContent object representing a map of String keys to Object values.
*
Expand All @@ -132,21 +141,6 @@ public static void buildStringToObjectMap(XContentBuilder xContentBuilder, Map<?
xContentBuilder.endObject();
}

/**
* Builds an XContent object representing a LLMSpec.
*
* @param xContentBuilder An XContent builder whose position is at the start of the map object to build
* @param llm LLMSpec
* @throws IOException on a build failure
*/
public static void buildLLMMap(XContentBuilder xContentBuilder, LLMSpec llm) throws IOException {
String modelId = llm.getModelId();
Map<String, String> parameters = llm.getParameters();
xContentBuilder.field(MODEL_ID, modelId);
xContentBuilder.field(PARAMETERS_FIELD);
buildStringToStringMap(xContentBuilder, parameters);
}

/**
* Parses an XContent object representing a map of String keys to String values.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
package org.opensearch.flowframework.util;

import org.apache.commons.lang3.ArrayUtils;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.Strings;
import org.opensearch.flowframework.common.CommonValue;
import org.opensearch.rest.RestRequest;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;

Expand All @@ -19,26 +20,36 @@
*/
public class RestHandlerUtils {

/** Path to credential field **/
private static final String PATH_TO_CREDENTIAL_FIELD = "workflows.provision.nodes.user_inputs.credential";

/** Fields that need to be excluded from the Search Response*/
public static final String[] USER_EXCLUDE = new String[] { CommonValue.USER_FIELD, CommonValue.UI_METADATA_FIELD };
private static final String[] DASHBOARD_EXCLUDES = new String[] {
CommonValue.USER_FIELD,
CommonValue.UI_METADATA_FIELD,
PATH_TO_CREDENTIAL_FIELD };

private static final String[] EXCLUDES = new String[] { CommonValue.USER_FIELD, PATH_TO_CREDENTIAL_FIELD };

private RestHandlerUtils() {}

/**
* Creates a source context and include/exclude information to be shared based on the user
*
* @param request the REST request
* @param user User
* @param searchSourceBuilder the search request source builder
* @return modified sources
*/
public static FetchSourceContext getSourceContext(RestRequest request, SearchSourceBuilder searchSourceBuilder) {
// TODO
// 1. check if the request came from dashboard and exclude UI_METADATA
public static FetchSourceContext getSourceContext(User user, SearchSourceBuilder searchSourceBuilder) {
if (searchSourceBuilder.fetchSource() != null) {
String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), USER_EXCLUDE);
String[] newArray = (String[]) ArrayUtils.addAll(searchSourceBuilder.fetchSource().excludes(), DASHBOARD_EXCLUDES);
return new FetchSourceContext(true, searchSourceBuilder.fetchSource().includes(), newArray);
} else {
return null;
// When user does not set the _source field in search api request, searchSourceBuilder.fetchSource becomes null
if (ParseUtils.isAdmin(user)) {
return new FetchSourceContext(true, Strings.EMPTY_ARRAY, new String[] { PATH_TO_CREDENTIAL_FIELD });
}
return new FetchSourceContext(true, Strings.EMPTY_ARRAY, EXCLUDES);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -71,6 +72,8 @@ public void testSearchWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener = mock(ActionListener.class);
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchRequest.source(searchSourceBuilder);

searchWorkflowStateTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(client, times(1)).search(any(SearchRequest.class), any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -34,12 +35,17 @@ public class SearchWorkflowTransportActionTests extends OpenSearchTestCase {
private SearchWorkflowTransportAction searchWorkflowTransportAction;
private Client client;
private ThreadPool threadPool;
ThreadContext threadContext;

@Override
public void setUp() throws Exception {
super.setUp();
this.client = mock(Client.class);

this.threadPool = mock(ThreadPool.class);
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
this.searchWorkflowTransportAction = new SearchWorkflowTransportAction(
mock(TransportService.class),
mock(ActionFilters.class),
Expand Down Expand Up @@ -73,6 +79,8 @@ public void testSearchWorkflow() {
@SuppressWarnings("unchecked")
ActionListener<SearchResponse> listener = mock(ActionListener.class);
SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
searchRequest.source(searchSourceBuilder);

searchWorkflowTransportAction.doExecute(mock(Task.class), searchRequest, listener);
verify(client, times(1)).search(any(SearchRequest.class), any());
Expand Down
Loading

0 comments on commit 757eaaf

Please sign in to comment.