diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index b06abe5cf677c..b6ac92134723c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; @@ -37,6 +38,18 @@ public class AnalyticsResultProcessor { private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class); + /** + * While we report progress as we read row results there are other things we need to account for + * to report completion. There are other types of results we can't predict the number of like + * progress objects and the inference model. Thus, we report a max progress until we know we have + * completed processing results. + * + * It is critical to ensure we do not report complete progress too soon as restarting a job + * uses the progress to determine which state to restart from. If we report full progress too soon + * we cannot restart a job as we will think the job was finished. + */ + private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98; + private final DataFrameAnalyticsConfig analytics; private final DataFrameRowsJoiner dataFrameRowsJoiner; private final ProgressTracker progressTracker; @@ -68,7 +81,7 @@ public void awaitForCompletion() { completionLatch.await(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); - LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", analytics.getId()), e); + setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for results processor to complete", e)); } } @@ -91,27 +104,32 @@ public void process(AnalyticsProcess process) { processResult(result, resultsJoiner); if (result.getRowResults() != null) { processedRows++; - progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); + updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); } } - if (isCancelled == false) { - // This means we completed successfully so we need to set the progress to 100. - // This is because due to skipped rows, it is possible the processed rows will not reach the total rows. - progressTracker.writingResultsPercent.set(100); - } } catch (Exception e) { if (isCancelled) { // No need to log error as it's due to stopping } else { - LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e); - failure = "error parsing data frame analytics output: [" + e.getMessage() + "]"; + setAndReportFailure(e); } } finally { + if (isCancelled == false && failure == null) { + completeResultsProgress(); + } completionLatch.countDown(); process.consumeAndCloseOutputStream(); } } + private void updateResultsProgress(int progress) { + progressTracker.writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION)); + } + + private void completeResultsProgress() { + progressTracker.writingResultsPercent.set(100); + } + private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) { RowResults rowResults = result.getRowResults(); if (rowResults != null) { @@ -137,7 +155,7 @@ private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferen } } catch (InterruptedException e) { Thread.currentThread().interrupt(); - LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for inference model to be stored", analytics.getId()), e); + setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored")); } } @@ -168,19 +186,22 @@ private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig) aBoolean -> { if (aBoolean == false) { LOGGER.error("[{}] Storing trained model responded false", analytics.getId()); + setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false")); } else { LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]"); } }, - e -> { - LOGGER.error(new ParameterizedMessage("[{}] Error storing trained model [{}]", analytics.getId(), - trainedModelConfig.getModelId()), e); - auditor.error(analytics.getId(), "Error storing trained model with id [" + trainedModelConfig.getModelId() - + "]; error message [" + e.getMessage() + "]"); - } + e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e, + trainedModelConfig.getModelId())) ); trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); return latch; } + + private void setAndReportFailure(Exception e) { + LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e); + failure = "error processing results; " + e.getMessage(); + auditor.error(analytics.getId(), "Error processing results; " + e.getMessage()); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index aaa64c13a1c44..0d2b5aea364eb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -39,9 +39,11 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.startsWith; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; @@ -117,6 +119,28 @@ public void testProcess_GivenRowResults() { assertThat(progressTracker.writingResultsPercent.get(), equalTo(100)); } + public void testProcess_GivenDataFrameRowsJoinerFails() { + givenDataFrameRows(2); + RowResults rowResults1 = mock(RowResults.class); + RowResults rowResults2 = mock(RowResults.class); + givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null), new AnalyticsResult(rowResults2, 100, null))); + + doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class)); + + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + assertThat(resultProcessor.getFailure(), equalTo("error processing results; some failure")); + + ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); + verify(auditor).error(eq(JOB_ID), auditCaptor.capture()); + assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure")); + + assertThat(progressTracker.writingResultsPercent.get(), equalTo(0)); + } + @SuppressWarnings("unchecked") public void testProcess_GivenInferenceModelIsStoredSuccessfully() { givenDataFrameRows(0); @@ -182,9 +206,11 @@ public void testProcess_GivenInferenceModelFailedToStore() { // This test verifies the processor knows how to handle a failure on storing the model and completes normally ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); verify(auditor).error(eq(JOB_ID), auditCaptor.capture()); - assertThat(auditCaptor.getValue(), containsString("Error storing trained model with id [" + JOB_ID)); - assertThat(auditCaptor.getValue(), containsString("[some failure]")); + assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID)); Mockito.verifyNoMoreInteractions(auditor); + + assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID)); + assertThat(progressTracker.writingResultsPercent.get(), equalTo(0)); } private void givenProcessResults(List results) {