Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Only report complete writing_results progress after completion #49551

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
Expand All @@ -37,6 +38,18 @@ public class AnalyticsResultProcessor {

private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);

/**
* While we report progress as we read row results there are other things we need to account for
* to report completion. There are other types of results we can't predict the number of like
* progress objects and the inference model. Thus, we report a max progress until we know we have
* completed processing results.
*
* It is critical to ensure we do not report complete progress too soon as restarting a job
* uses the progress to determine which state to restart from. If we report full progress too soon
* we cannot restart a job as we will think the job was finished.
*/
private static final int MAX_PROGRESS_BEFORE_COMPLETION = 98;

private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final ProgressTracker progressTracker;
Expand Down Expand Up @@ -68,7 +81,7 @@ public void awaitForCompletion() {
completionLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for results processor to complete", analytics.getId()), e);
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for results processor to complete", e));
}
}

Expand All @@ -91,27 +104,32 @@ public void process(AnalyticsProcess<AnalyticsResult> process) {
processResult(result, resultsJoiner);
if (result.getRowResults() != null) {
processedRows++;
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
}
}
if (isCancelled == false) {
// This means we completed successfully so we need to set the progress to 100.
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
progressTracker.writingResultsPercent.set(100);
}
} catch (Exception e) {
if (isCancelled) {
// No need to log error as it's due to stopping
} else {
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
failure = "error parsing data frame analytics output: [" + e.getMessage() + "]";
setAndReportFailure(e);
}
} finally {
if (isCancelled == false && failure == null) {
completeResultsProgress();
}
completionLatch.countDown();
process.consumeAndCloseOutputStream();
}
}

private void updateResultsProgress(int progress) {
progressTracker.writingResultsPercent.set(Math.min(progress, MAX_PROGRESS_BEFORE_COMPLETION));
}

private void completeResultsProgress() {
progressTracker.writingResultsPercent.set(100);
}

private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) {
RowResults rowResults = result.getRowResults();
if (rowResults != null) {
Expand All @@ -137,7 +155,7 @@ private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferen
}
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
LOGGER.error(new ParameterizedMessage("[{}] Interrupted waiting for inference model to be stored", analytics.getId()), e);
setAndReportFailure(ExceptionsHelper.serverError("interrupted waiting for inference model to be stored"));
}
}

Expand Down Expand Up @@ -168,19 +186,22 @@ private CountDownLatch storeTrainedModel(TrainedModelConfig trainedModelConfig)
aBoolean -> {
if (aBoolean == false) {
LOGGER.error("[{}] Storing trained model responded false", analytics.getId());
setAndReportFailure(ExceptionsHelper.serverError("storing trained model responded false"));
} else {
LOGGER.info("[{}] Stored trained model with id [{}]", analytics.getId(), trainedModelConfig.getModelId());
auditor.info(analytics.getId(), "Stored trained model with id [" + trainedModelConfig.getModelId() + "]");
}
},
e -> {
LOGGER.error(new ParameterizedMessage("[{}] Error storing trained model [{}]", analytics.getId(),
trainedModelConfig.getModelId()), e);
auditor.error(analytics.getId(), "Error storing trained model with id [" + trainedModelConfig.getModelId()
+ "]; error message [" + e.getMessage() + "]");
}
e -> setAndReportFailure(ExceptionsHelper.serverError("error storing trained model with id [{}]", e,
trainedModelConfig.getModelId()))
);
trainedModelProvider.storeTrainedModel(trainedModelConfig, new LatchedActionListener<>(storeListener, latch));
return latch;
}

private void setAndReportFailure(Exception e) {
LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e);
failure = "error processing results; " + e.getMessage();
auditor.error(analytics.getId(), "Error processing results; " + e.getMessage());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasKey;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
Expand Down Expand Up @@ -117,6 +119,28 @@ public void testProcess_GivenRowResults() {
assertThat(progressTracker.writingResultsPercent.get(), equalTo(100));
}

public void testProcess_GivenDataFrameRowsJoinerFails() {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null), new AnalyticsResult(rowResults2, 100, null)));

doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));

AnalyticsResultProcessor resultProcessor = createResultProcessor();

resultProcessor.process(process);
resultProcessor.awaitForCompletion();

assertThat(resultProcessor.getFailure(), equalTo("error processing results; some failure"));

ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to replace auditCaptor.capture() with containsString("Error processing results; some failure")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to call capture() on the captor for it to capture the argument to the mocked method.

assertThat(auditCaptor.getValue(), containsString("Error processing results; some failure"));

assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
}

@SuppressWarnings("unchecked")
public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
givenDataFrameRows(0);
Expand Down Expand Up @@ -182,9 +206,11 @@ public void testProcess_GivenInferenceModelFailedToStore() {
// This test verifies the processor knows how to handle a failure on storing the model and completes normally
ArgumentCaptor<String> auditCaptor = ArgumentCaptor.forClass(String.class);
verify(auditor).error(eq(JOB_ID), auditCaptor.capture());
assertThat(auditCaptor.getValue(), containsString("Error storing trained model with id [" + JOB_ID));
assertThat(auditCaptor.getValue(), containsString("[some failure]"));
assertThat(auditCaptor.getValue(), containsString("Error processing results; error storing trained model with id [" + JOB_ID));
Mockito.verifyNoMoreInteractions(auditor);

assertThat(resultProcessor.getFailure(), startsWith("error processing results; error storing trained model with id [" + JOB_ID));
assertThat(progressTracker.writingResultsPercent.get(), equalTo(0));
}

private void givenProcessResults(List<AnalyticsResult> results) {
Expand Down