Skip to content

Commit

Permalink
[ML] Prepare to hold additinal stats in DF Analytics task (elastic#52134
Browse files Browse the repository at this point in the history
)

Refactors `DataFrameAnalyticsTask` to hold a `StatsHolder` object.
That just has a `ProgressTracker` for now but this is paving the
way to add additional stats like memory usage, analysis stats, etc.
  • Loading branch information
dimitris-athanasiou authored Feb 11, 2020
1 parent 2d6e59c commit 72b84ad
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.StoredProgress;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;

import java.io.IOException;
import java.io.InputStream;
Expand Down Expand Up @@ -106,9 +107,7 @@ protected void taskOperation(GetDataFrameAnalyticsStatsAction.Request request, D
);

ActionListener<Void> reindexingProgressListener = ActionListener.wrap(
aVoid -> {
progressListener.onResponse(task.getProgressTracker().report());
},
aVoid -> progressListener.onResponse(task.getStatsHolder().getProgressTracker().report()),
listener::onFailure
);

Expand Down Expand Up @@ -201,7 +200,7 @@ private void searchStoredProgresses(List<String> configIds, ActionListener<List<
} else {
SearchHit[] hits = itemResponse.getResponse().getHits().getHits();
if (hits.length == 0) {
progresses.add(new StoredProgress(new DataFrameAnalyticsTask.ProgressTracker().report()));
progresses.add(new StoredProgress(new ProgressTracker().report()));
} else {
progresses.add(parseStoredProgress(hits[0]));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.watcher.watch.Payload;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
Expand All @@ -68,7 +68,7 @@ public class DataFrameAnalyticsTask extends AllocatedPersistentTask implements S
private volatile boolean isReindexingFinished;
private volatile boolean isStopping;
private volatile boolean isMarkAsCompletedCalled;
private final ProgressTracker progressTracker = new ProgressTracker();
private final StatsHolder statsHolder = new StatsHolder();

public DataFrameAnalyticsTask(long id, String type, String action, TaskId parentTask, Map<String, String> headers,
Client client, ClusterService clusterService, DataFrameAnalyticsManager analyticsManager,
Expand Down Expand Up @@ -98,8 +98,8 @@ public boolean isStopping() {
return isStopping;
}

public ProgressTracker getProgressTracker() {
return progressTracker;
public StatsHolder getStatsHolder() {
return statsHolder;
}

@Override
Expand Down Expand Up @@ -197,7 +197,7 @@ public void updateReindexTaskProgress(ActionListener<Void> listener) {
// We set reindexing progress at least to 1 for a running process to be able to
// distinguish a job that is running for the first time against a job that is restarting.
reindexTaskProgress -> {
progressTracker.reindexingPercent.set(Math.max(1, reindexTaskProgress));
statsHolder.getProgressTracker().reindexingPercent.set(Math.max(1, reindexTaskProgress));
listener.onResponse(null);
},
listener::onFailure
Expand Down Expand Up @@ -353,25 +353,4 @@ public static StartingState determineStartingState(String jobId, List<PhaseProgr
}
}

public static class ProgressTracker {

public static final String REINDEXING = "reindexing";
public static final String LOADING_DATA = "loading_data";
public static final String ANALYZING = "analyzing";
public static final String WRITING_RESULTS = "writing_results";

public final AtomicInteger reindexingPercent = new AtomicInteger(0);
public final AtomicInteger loadingDataPercent = new AtomicInteger(0);
public final AtomicInteger analyzingPercent = new AtomicInteger(0);
public final AtomicInteger writingResultsPercent = new AtomicInteger(0);

public List<PhaseProgress> report() {
return Arrays.asList(
new PhaseProgress(REINDEXING, reindexingPercent.get()),
new PhaseProgress(LOADING_DATA, loadingDataPercent.get()),
new PhaseProgress(ANALYZING, analyzingPercent.get()),
new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get())
);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
Expand Down Expand Up @@ -152,7 +153,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
try {
writeHeaderRecord(dataExtractor, process);
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getProgressTracker());
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker());
process.writeEndOfDataMessage();
process.flushStream();

Expand Down Expand Up @@ -199,7 +200,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
}

private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
DataFrameAnalysis analysis, DataFrameAnalyticsTask.ProgressTracker progressTracker) throws IOException {
DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException {

CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis);

Expand Down Expand Up @@ -427,7 +428,7 @@ private AnalyticsResultProcessor createResultProcessor(DataFrameAnalyticsTask ta
DataFrameRowsJoiner dataFrameRowsJoiner =
new DataFrameRowsJoiner(config.getId(), dataExtractorFactory.newExtractor(true), resultsPersisterService);
return new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
config, dataFrameRowsJoiner, task.getStatsHolder(), trainedModelProvider, auditor, dataExtractor.get().getFieldNames());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;

Expand Down Expand Up @@ -57,7 +57,7 @@ public class AnalyticsResultProcessor {

private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final ProgressTracker progressTracker;
private final StatsHolder statsHolder;
private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor;
private final List<String> fieldNames;
Expand All @@ -66,11 +66,11 @@ public class AnalyticsResultProcessor {
private volatile boolean isCancelled;

public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider,
StatsHolder statsHolder, TrainedModelProvider trainedModelProvider,
DataFrameAnalyticsAuditor auditor, List<String> fieldNames) {
this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.progressTracker = Objects.requireNonNull(progressTracker);
this.statsHolder = Objects.requireNonNull(statsHolder);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor);
this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames));
Expand Down Expand Up @@ -128,11 +128,11 @@ public void process(AnalyticsProcess<AnalyticsResult> process) {
}

private void updateResultsProgress(int progress) {
progressTracker.writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
statsHolder.getProgressTracker().writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
}

private void completeResultsProgress() {
progressTracker.writingResultsPercent.set(100);
statsHolder.getProgressTracker().writingResultsPercent.set(100);
}

private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) {
Expand All @@ -142,7 +142,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
}
Integer progressPercent = result.getProgressPercent();
if (progressPercent != null) {
progressTracker.analyzingPercent.set(progressPercent);
statsHolder.getProgressTracker().analyzingPercent.set(progressPercent);
}
TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
if (inferenceModelBuilder != null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.stats;

import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class ProgressTracker {

public static final String REINDEXING = "reindexing";
public static final String LOADING_DATA = "loading_data";
public static final String ANALYZING = "analyzing";
public static final String WRITING_RESULTS = "writing_results";

public final AtomicInteger reindexingPercent = new AtomicInteger(0);
public final AtomicInteger loadingDataPercent = new AtomicInteger(0);
public final AtomicInteger analyzingPercent = new AtomicInteger(0);
public final AtomicInteger writingResultsPercent = new AtomicInteger(0);

public List<PhaseProgress> report() {
return Arrays.asList(
new PhaseProgress(REINDEXING, reindexingPercent.get()),
new PhaseProgress(LOADING_DATA, loadingDataPercent.get()),
new PhaseProgress(ANALYZING, analyzingPercent.get()),
new PhaseProgress(WRITING_RESULTS, writingResultsPercent.get())
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.stats;

/**
* Holds data frame analytics stats in memory so that they may be retrieved
* from the get stats api for started jobs efficiently.
*/
public class StatsHolder {

private final ProgressTracker progressTracker = new ProgressTracker();

public ProgressTracker getProgressTracker() {
return progressTracker;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
Expand Down Expand Up @@ -89,7 +90,7 @@ public void setUpMocks() {

task = mock(DataFrameAnalyticsTask.class);
when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID);
when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class));
when(task.getStatsHolder()).thenReturn(new StatsHolder());
dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandomBuilder(CONFIG_ID,
false,
OutlierDetectionTests.createRandom()).build();
Expand Down Expand Up @@ -127,7 +128,7 @@ public void testRunJob_ProcessContextAlreadyExists() {
inOrder.verify(task).isStopping();
inOrder.verify(task).getAllocationId();
inOrder.verify(task).isStopping();
inOrder.verify(task).getProgressTracker();
inOrder.verify(task).getStatsHolder();
inOrder.verify(task).isStopping();
inOrder.verify(task).getAllocationId();
inOrder.verify(task).updateState(DataFrameAnalyticsState.FAILED, "[config-id] Could not create process as one already exists");
Expand Down Expand Up @@ -162,7 +163,7 @@ public void testRunJob_Ok() {
inOrder.verify(dataExtractor).collectDataSummary();
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
inOrder.verify(process).isProcessAlive();
inOrder.verify(task).getProgressTracker();
inOrder.verify(task).getStatsHolder();
inOrder.verify(dataExtractor).getFieldNames();
inOrder.verify(executorServiceForProcess, times(2)).execute(any()); // 'processData' and 'processResults' threads
verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, task);
Expand Down Expand Up @@ -220,7 +221,7 @@ public void testProcessContext_StartAndStop() throws Exception {
inOrder.verify(dataExtractor).collectDataSummary();
inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis());
inOrder.verify(process).isProcessAlive();
inOrder.verify(task).getProgressTracker();
inOrder.verify(task).getStatsHolder();
inOrder.verify(dataExtractor).getFieldNames();
// stop
inOrder.verify(dataExtractor).cancel();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
Expand Down Expand Up @@ -58,7 +58,7 @@ public class AnalyticsResultProcessorTests extends ESTestCase {

private AnalyticsProcess<AnalyticsResult> process;
private DataFrameRowsJoiner dataFrameRowsJoiner;
private ProgressTracker progressTracker = new ProgressTracker();
private StatsHolder statsHolder = new StatsHolder();
private TrainedModelProvider trainedModelProvider;
private DataFrameAnalyticsAuditor auditor;
private DataFrameAnalyticsConfig analyticsConfig;
Expand Down Expand Up @@ -101,7 +101,7 @@ public void testProcess_GivenEmptyResults() {

verify(dataFrameRowsJoiner).close();
Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner);
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100));
}

public void testProcess_GivenRowResults() {
Expand All @@ -118,7 +118,7 @@ public void testProcess_GivenRowResults() {
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults1);
inOrder.verify(dataFrameRowsJoiner).processRowResults(rowResults2);

assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(100));
}

public void testProcess_GivenDataFrameRowsJoinerFails() {
Expand All @@ -140,7 +140,7 @@ public void testProcess_GivenDataFrameRowsJoinerFails() {
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure"));

assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0));
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -212,7 +212,7 @@ public void testProcess_GivenInferenceModelFailedToStore() {
Mockito.verifyNoMoreInteractions(auditor);

assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
assertThat(statsHolder.getProgressTracker().writingResultsPercent.get(), equalTo(0));
}

private void givenProcessResults(List<AnalyticsResult> results) {
Expand All @@ -232,6 +232,6 @@ private AnalyticsResultProcessor createResultProcessor() {

private AnalyticsResultProcessor createResultProcessor(List<String> fieldNames) {
return new AnalyticsResultProcessor(
analyticsConfig, dataFrameRowsJoiner, progressTracker, trainedModelProvider, auditor, fieldNames);
analyticsConfig, dataFrameRowsJoiner, statsHolder, trainedModelProvider, auditor, fieldNames);
}
}

0 comments on commit 72b84ad

Please sign in to comment.