Skip to content

Commit

Permalink
[7.x][ML] Only report complete writing_results progress after complet…
Browse files Browse the repository at this point in the history
…ion (#49551) (#49577)

We depend on the number of data frame rows in order to report progress
for the writing of results, the last phase of a job run. However, results
include other objects than just the data frame rows (e.g, progress, inference model, etc.).

The problem this commit fixes is that if we receive the last data frame row results
we'll report that progress is complete even though we still have more results to process
potentially. If the job gets stopped for any reason at this point, we will not be able
to restart the job properly as we'll think that the job was completed.

This commit addresses this by limiting the max progress we can report for the
writing_results phase before the results processor completes to 98.
At the end, when the process is done we set the progress to 100.

The commit also improves failure capturing and reporting in the results processor.

Backport of #49551
  • Loading branch information
dimitris-athanasiou authored Nov 26, 2019
1 parent 5d306ae commit c23a218
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
Expand All @@ -37,6 +38,18 @@ public class AnalyticsResultProcessor {

private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);

/**
* While we report progress as we read row results there are other things we need to account for
* to report completion. There are other types of results we can't predict the number of like
* progress objects and the inference model. Thus, we report a max progress until we know we have
* completed processing results.
*
* It is critical to ensure we do not report complete progress too soon as restarting a job
* uses the progress to determine which state to restart from. If we report full progress too soon
* we cannot restart a job as we will think the job was finished.
*/
private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;

private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final ProgressTracker progressTracker;
Expand Down Expand Up @@ -68,7 +81,7 @@ public void awaitForCompletion() {
completionLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", analytics.getId()), e);
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for results processor to complete", e));
}
}

Expand All @@ -91,27 +104,32 @@ public void process(AnalyticsProcess<AnalyticsResult> process) {
processResult(result, resultsJoiner);
if (result.getRowResults() != null) {
processedRows++;
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
}
}
if (isCancelled == false) {
// This means we completed successfully so we need to set the progress to 100.
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
progressTracker.writingResultsPercent.set(100);
}
} catch (Exception e) {
if (isCancelled) {
// No need to log error as it's due to stopping
} else {
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
failure = "error parsing data frame analytics output: [" + e.getMessage() + "]";
setAndReportFailure(e);
}
} finally {
if (isCancelled == false && failure == null) {
completeResultsProgress();
}
completionLatch.countDown();
process.consumeAndCloseOutputStream();
}
}

private void updateResultsProgress(int progress) {
progressTracker.writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
}

private void completeResultsProgress() {
progressTracker.writingResultsPercent.set(100);
}

private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) {
RowResults rowResults = result.getRowResults();
if (rowResults != null) {
Expand All @@ -137,7 +155,7 @@ private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferen
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for inference model to be stored", analytics.getId()), e);
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
}
}

Expand Down Expand Up @@ -168,19 +186,22 @@ private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig)
aBoolean -> {
if (aBoolean == false) {
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
} else {
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
}
},
e -> {
LOGGER.error(new ParameterizedMessage("[{}] Error storing trained model [{}]", analytics.getId(),
trainedModelConfig.getModelId()), e);
auditor.error(analytics.getId(), "Error storing trained model with id [" + trainedModelConfig.getModelId()
+ "]; error message [" + e.getMessage() + "]");
}
e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
trainedModelConfig.getModelId()))
);
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
return latch;
}

private void setAndReportFailure(Exception e) {
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
failure = "error processing results; " + e.getMessage();
auditor.error(analytics.getId(), "Error processing results; " + e.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand Down Expand Up @@ -117,6 +119,28 @@ public void testProcess_GivenRowResults() {
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
}

public void testProcess_GivenDataFrameRowsJoinerFails() {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null), new AnalyticsResult(rowResults2, 100, null)));

doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));

AnalyticsResultProcessor resultProcessor = createResultProcessor();

resultProcessor.process(process);
resultProcessor.awaitForCompletion();

assertThat(resultProcessor.getFailure(), equalTo("error processing results; some failure"));

ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure"));

assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
}

@SuppressWarnings("unchecked")
public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
givenDataFrameRows(0);
Expand Down Expand Up @@ -182,9 +206,11 @@ public void testProcess_GivenInferenceModelFailedToStore() {
// This test verifies the processor knows how to handle a failure on storing the model and completes normally
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Error storing trained model with id [" + JOB_ID));
assertThat(auditCaptor.getValue(), containsString("[some failure]"));
assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID));
Mockito.verifyNoMoreInteractions(auditor);

assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
}

private void givenProcessResults(List<AnalyticsResult> results) {
Expand Down

0 comments on commit c23a218

Please sign in to comment.