Skip to content

Commit

Permalink
[7.x][ML] Apply source query on data frame analytics memory estimation (
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitris-athanasiou committed Nov 25, 2019
1 parent 777f6d5 commit 60c2746
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.integration;

import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;

import java.io.IOException;

import static org.hamcrest.Matchers.lessThanOrEqualTo;

public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase {

public void testSourceQueryIsApplied() throws IOException {
// To test the source query is applied when we extract data,
// we set up a job where we have a query which excludes all but one document.
// We then assert the memory estimation is low enough.

String sourceIndex = "test-source-query-is-applied";

client().admin().indices().prepareCreate(sourceIndex)
.addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical", "type=keyword")
.get();

BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

for (int i = 0; i < 30; i++) {
IndexRequest indexRequest = new IndexRequest(sourceIndex);

// We insert one odd value out of 5 for one feature
indexRequest.source("numeric_1", 1.0, "numeric_2", 2.0, "categorical", i == 0 ? "only-one" : "normal");
bulkRequestBuilder.add(indexRequest);
}
BulkResponse bulkResponse = bulkRequestBuilder.get();
if (bulkResponse.hasFailures()) {
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
}

String id = "test_source_query_is_applied";

DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
.setId(id)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("categorical", "only-one"))))
.setAnalysis(new Classification("categorical"))
.buildForExplain();

ExplainDataFrameAnalyticsAction.Response explainResponse = explainDataFrame(config);

assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk().getKb(), lessThanOrEqualTo(500L));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
Expand Down Expand Up @@ -145,6 +146,11 @@ protected GetDataFrameAnalyticsStatsAction.Response.Stats getAnalyticsStats(Stri
return stats.get(0);
}

protected ExplainDataFrameAnalyticsAction.Response explainDataFrame(DataFrameAnalyticsConfig config) {
PutDataFrameAnalyticsAction.Request request = new PutDataFrameAnalyticsAction.Request(config);
return client().execute(ExplainDataFrameAnalyticsAction.INSTANCE, request).actionGet();
}

protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evaluation evaluation) {
EvaluateDataFrameAction.Request request =
new EvaluateDataFrameAction.Request()
Expand All @@ -155,12 +161,12 @@ protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evalu

protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
@Nullable String resultsField, DataFrameAnalysis analysis) {
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
configBuilder.setId(id);
configBuilder.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null));
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
configBuilder.setAnalysis(analysis);
return configBuilder.build();
return new DataFrameAnalyticsConfig.Builder()
.setId(id)
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null))
.setDest(new DataFrameAnalyticsDest(destIndex, resultsField))
.setAnalysis(analysis)
.build();
}

protected void assertIsStopped(String id) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@ public class DataFrameDataExtractorFactory {
private final Client client;
private final String analyticsId;
private final List<String> indices;
private final QueryBuilder sourceQuery;
private final ExtractedFields extractedFields;
private final Map<String, String> headers;
private final boolean includeRowsWithMissingValues;

public DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, ExtractedFields extractedFields,
Map<String, String> headers, boolean includeRowsWithMissingValues) {
private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, QueryBuilder sourceQuery,
ExtractedFields extractedFields, Map<String, String> headers,
boolean includeRowsWithMissingValues) {
this.client = Objects.requireNonNull(client);
this.analyticsId = Objects.requireNonNull(analyticsId);
this.indices = Objects.requireNonNull(indices);
this.sourceQuery = Objects.requireNonNull(sourceQuery);
this.extractedFields = Objects.requireNonNull(extractedFields);
this.headers = headers;
this.includeRowsWithMissingValues = includeRowsWithMissingValues;
Expand All @@ -54,7 +57,12 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) {
}

private QueryBuilder createQuery() {
return includeRowsWithMissingValues ? QueryBuilders.matchAllQuery() : allExtractedFieldsExistQuery();
BoolQueryBuilder query = QueryBuilders.boolQuery();
query.filter(sourceQuery);
if (includeRowsWithMissingValues == false) {
query.filter(allExtractedFieldsExistQuery());
}
return query;
}

private QueryBuilder allExtractedFieldsExistQuery() {
Expand All @@ -77,8 +85,8 @@ private QueryBuilder allExtractedFieldsExistQuery() {
*/
public static DataFrameDataExtractorFactory createForSourceIndices(Client client, String taskId, DataFrameAnalyticsConfig config,
ExtractedFields extractedFields) {
return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()), extractedFields,
config.getHeaders(), config.getAnalysis().supportsMissingValues());
return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()),
config.getSource().getParsedQuery(), extractedFields, config.getHeaders(), config.getAnalysis().supportsMissingValues());
}

/**
Expand All @@ -100,8 +108,8 @@ public static void createForDestinationIndex(Client client,
extractedFieldsDetector -> {
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(),
Collections.singletonList(config.getDest().getIndex()), extractedFields, config.getHeaders(),
config.getAnalysis().supportsMissingValues());
Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields,
config.getHeaders(), config.getAnalysis().supportsMissingValues());
listener.onResponse(extractorFactory);
},
listener::onFailure
Expand Down

0 comments on commit 60c2746

Please sign in to comment.