Skip to content

Commit

Permalink
[Transform] Make transform _preview request cancellable (#91313)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemekwitek authored Nov 8, 2022
1 parent dcdf587 commit a8a684e
Show file tree
Hide file tree
Showing 10 changed files with 82 additions and 13 deletions.
6 changes: 6 additions & 0 deletions docs/changelog/91313.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 91313
summary: Make transform `_preview` request cancellable
area: Transform
type: bug
issues:
- 91286
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -37,6 +40,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;

public class PreviewTransformAction extends ActionType<PreviewTransformAction.Response> {
Expand Down Expand Up @@ -135,6 +139,11 @@ public boolean equals(Object obj) {
Request other = (Request) obj;
return Objects.equals(config, other.config);
}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new CancellableTask(id, type, action, format("preview_transform[%s]", config.getId()), parentTaskId, headers);
}
}

public static class Response extends ActionResponse implements ToXContentObject {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.DeprecationHandler;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.json.JsonXContent;
Expand All @@ -22,9 +25,11 @@
import org.elasticsearch.xpack.core.transform.transforms.pivot.PivotConfigTests;

import java.io.IOException;
import java.util.Map;

import static org.elasticsearch.xpack.core.transform.transforms.SourceConfigTests.randomSourceConfig;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;

public class PreviewTransformActionRequestTests extends AbstractSerializingTransformTestCase<Request> {
Expand Down Expand Up @@ -132,4 +137,11 @@ private void testParsingOverwrites(
assertThat(request.getConfig().getDestination().getPipeline(), is(equalTo(expectedDestPipeline)));
}
}

public void testCreateTask() {
Request request = createTestInstance();
Task task = request.createTask(123, "type", "action", TaskId.EMPTY_TASK_ID, Map.of());
assertThat(task, is(instanceOf(CancellableTask.class)));
assertThat(task.getDescription(), is(equalTo("preview_transform[transform-preview]")));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ setup:
- match: { generated_dest_index.mappings.properties.by-hour.type: "date" }
- match: { generated_dest_index.mappings.properties.avg_response.type: "double" }

---
"Test preview transform with timeout":
- do:
transform.preview_transform:
timeout: "10s"
body: >
{
"source": { "index": "airline-data" },
"pivot": {
"group_by": {
"airline": {"terms": {"field": "airline"}},
"by-hour": {"date_histogram": {"fixed_interval": "1h", "field": "time"}}},
"aggs": {
"avg_response": {"avg": {"field": "responsetime"}},
"time.max": {"max": {"field": "time"}},
"time.min": {"min": {"field": "time"}}
}
}
}
---
"Test preview transform with disabled mapping deduction":
- do:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ParentTaskAssigningClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.node.DiscoveryNode;
Expand All @@ -25,10 +26,12 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.license.License;
import org.elasticsearch.license.RemoteClusterLicenseChecker;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xcontent.ToXContent;
Expand Down Expand Up @@ -112,6 +115,7 @@ public TransportPreviewTransformAction(

@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
TaskId parentTaskId = new TaskId(clusterService.localNode().getId(), task.getId());
final ClusterState clusterState = clusterService.state();
TransformNodes.throwIfNoTransformNodes(clusterState);

Expand All @@ -137,6 +141,8 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
validateConfigResponse -> useSecondaryAuthIfAvailable(
securityContext,
() -> getPreview(
parentTaskId,
request.timeout(),
config.getId(), // note: @link{PreviewTransformAction} sets an id, so this is never null
function,
config.getSource(),
Expand Down Expand Up @@ -175,7 +181,7 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
securityContext,
indexNameExpressionResolver,
clusterState,
client,
new ParentTaskAssigningClient(client, parentTaskId),
config,
// We don't want to check privileges for a dummy (placeholder) index and the placeholder is inserted as config.dest.index
// early in the REST action so the only possibility we have here is string comparison.
Expand All @@ -189,6 +195,8 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li

@SuppressWarnings("unchecked")
private void getPreview(
TaskId parentTaskId,
TimeValue timeout,
String transformId,
Function function,
SourceConfig source,
Expand All @@ -197,6 +205,8 @@ private void getPreview(
SyncConfig syncConfig,
ActionListener<Response> listener
) {
Client parentTaskAssigningClient = new ParentTaskAssigningClient(client, parentTaskId);

final SetOnce<Map<String, String>> mappings = new SetOnce<>();

ActionListener<SimulatePipelineResponse> pipelineResponseActionListener = ActionListener.wrap(simulatePipelineResponse -> {
Expand Down Expand Up @@ -256,15 +266,16 @@ private void getPreview(
builder.endObject();
var pipelineRequest = new SimulatePipelineRequest(BytesReference.bytes(builder), XContentType.JSON);
pipelineRequest.setId(pipeline);
client.execute(SimulatePipelineAction.INSTANCE, pipelineRequest, pipelineResponseActionListener);
parentTaskAssigningClient.execute(SimulatePipelineAction.INSTANCE, pipelineRequest, pipelineResponseActionListener);
}
}
}, listener::onFailure);

ActionListener<Map<String, String>> deduceMappingsListener = ActionListener.wrap(deducedMappings -> {
mappings.set(deducedMappings);
function.preview(
client,
parentTaskAssigningClient,
timeout,
ClientHelper.getPersistableSafeSecurityHeaders(threadPool.getThreadContext(), clusterService.state()),
source,
deducedMappings,
Expand All @@ -273,6 +284,6 @@ private void getPreview(
);
}, listener::onFailure);

function.deduceMappings(client, source, deduceMappingsListener);
function.deduceMappings(parentTaskAssigningClient, source, deduceMappingsListener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
if (request.isDeferValidation()) {
validateQueryListener.onResponse(true);
} else {
function.validateQuery(client, config.getSource(), validateQueryListener);
function.validateQuery(client, config.getSource(), request.timeout(), validateQueryListener);
}
}, listener::onFailure);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.BaseRestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.action.RestCancellableNodeClient;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.transform.TransformField;
Expand Down Expand Up @@ -47,7 +49,7 @@ public String getName() {
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient) throws IOException {
String transformId = restRequest.param(TransformField.ID.getPreferredName());

if (Strings.isNullOrEmpty(transformId) && restRequest.hasContentOrSourceParam() == false) {
Expand All @@ -72,6 +74,7 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
previewRequestHolder.set(PreviewTransformAction.Request.fromXContent(restRequest.contentOrSourceParamParser(), timeout));
}

Client client = new RestCancellableNodeClient(nodeClient, restRequest.getHttpChannel());
return channel -> {
RestToXContentListener<PreviewTransformAction.Response> listener = new RestToXContentListener<>(channel);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -124,6 +126,7 @@ interface ChangeCollector {
* Create a preview of the function.
*
* @param client a client instance for querying
* @param timeout search query timeout
* @param headers headers to be used to query only for what the caller is allowed to
* @param sourceConfig the source configuration
* @param fieldTypeMap mapping of field types
Expand All @@ -132,6 +135,7 @@ interface ChangeCollector {
*/
void preview(
Client client,
@Nullable TimeValue timeout,
Map<String, String> headers,
SourceConfig sourceConfig,
Map<String, String> fieldTypeMap,
Expand Down Expand Up @@ -175,9 +179,10 @@ void preview(
*
* @param client a client instance for querying the source
* @param sourceConfig the source configuration
* @param timeout search query timeout
* @param listener the result listener
*/
void validateQuery(Client client, SourceConfig sourceConfig, ActionListener<Boolean> listener);
void validateQuery(Client client, SourceConfig sourceConfig, @Nullable TimeValue timeout, ActionListener<Boolean> listener);

/**
* Create a change collector instance and return it
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.Aggregations;
Expand Down Expand Up @@ -63,6 +64,7 @@ public SearchSourceBuilder buildSearchQuery(SearchSourceBuilder builder, Map<Str
@Override
public void preview(
Client client,
TimeValue timeout,
Map<String, String> headers,
SourceConfig sourceConfig,
Map<String, String> fieldTypeMap,
Expand All @@ -75,7 +77,7 @@ public void preview(
ClientHelper.TRANSFORM_ORIGIN,
client,
SearchAction.INSTANCE,
buildSearchRequest(sourceConfig, null, numberOfBuckets),
buildSearchRequest(sourceConfig, timeout, numberOfBuckets),
ActionListener.wrap(r -> {
try {
final Aggregations aggregations = r.getAggregations();
Expand All @@ -102,8 +104,8 @@ public void preview(
}

@Override
public void validateQuery(Client client, SourceConfig sourceConfig, ActionListener<Boolean> listener) {
SearchRequest searchRequest = buildSearchRequest(sourceConfig, null, TEST_QUERY_PAGE_SIZE);
public void validateQuery(Client client, SourceConfig sourceConfig, TimeValue timeout, ActionListener<Boolean> listener) {
SearchRequest searchRequest = buildSearchRequest(sourceConfig, timeout, TEST_QUERY_PAGE_SIZE);
client.execute(SearchAction.INSTANCE, searchRequest, ActionListener.wrap(response -> {
if (response == null) {
listener.onFailure(new ValidationException().addValidationError("Unexpected null response from test query"));
Expand Down Expand Up @@ -173,9 +175,10 @@ protected abstract Stream<Map<String, Object>> extractResults(
TransformProgress progress
);

private SearchRequest buildSearchRequest(SourceConfig sourceConfig, Map<String, Object> position, int pageSize) {
private SearchRequest buildSearchRequest(SourceConfig sourceConfig, TimeValue timeout, int pageSize) {
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(sourceConfig.getQueryConfig().getQuery())
.runtimeMappings(sourceConfig.getRuntimeMappings());
.runtimeMappings(sourceConfig.getRuntimeMappings())
.timeout(timeout);
buildSearchQuery(sourceBuilder, null, pageSize);
return new SearchRequest(sourceConfig.getIndex()).source(sourceBuilder).indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ private static void assertInvalidTransform(Client client, SourceConfig source, F
private static void validate(Client client, SourceConfig source, Function pivot, boolean expectValid) throws Exception {
CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
pivot.validateQuery(client, source, ActionListener.wrap(validity -> {
pivot.validateQuery(client, source, null, ActionListener.wrap(validity -> {
assertEquals(expectValid, validity);
latch.countDown();
}, e -> {
Expand Down

0 comments on commit a8a684e

Please sign in to comment.