Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make AnalyticsProcessManager class more robust #49282

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
Expand Down Expand Up @@ -239,7 +240,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
"Finished analysis");
}

@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/49095")
public void testStopAndRestart() throws Exception {
initialize("regression_stop_and_restart");

Expand Down Expand Up @@ -270,8 +270,12 @@ public void testStopAndRestart() throws Exception {
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
assertBusy(() -> {
DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState();
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
equalTo(DataFrameAnalyticsState.STOPPED))));
assertThat(
state,
is(anyOf(
equalTo(DataFrameAnalyticsState.REINDEXING),
equalTo(DataFrameAnalyticsState.ANALYZING),
equalTo(DataFrameAnalyticsState.STOPPED))));
});
stopAnalytics(jobId);
waitUntilAnalyticsIsStopped(jobId);
Expand All @@ -287,7 +291,7 @@ public void testStopAndRestart() throws Exception {
}
}

waitUntilAnalyticsIsStopped(jobId);
waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1));

SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
for (SearchHit hit : sourceData.getHits()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.Nullable;
Expand Down Expand Up @@ -54,7 +53,8 @@ public class AnalyticsProcessManager {
private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class);

private final Client client;
private final ThreadPool threadPool;
private final ExecutorService executorServiceForJob;
private final ExecutorService executorServiceForProcess;
private final AnalyticsProcessFactory<AnalyticsResult> processFactory;
private final ConcurrentMap<Long, ProcessContext> processContextByAllocation = new ConcurrentHashMap<>();
private final DataFrameAnalyticsAuditor auditor;
Expand All @@ -65,40 +65,59 @@ public AnalyticsProcessManager(Client client,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
this(
client,
threadPool.generic(),
threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME),
analyticsProcessFactory,
auditor,
trainedModelProvider);
}

// Visible for testing
public AnalyticsProcessManager(Client client,
ExecutorService executorServiceForJob,
ExecutorService executorServiceForProcess,
AnalyticsProcessFactory<AnalyticsResult> analyticsProcessFactory,
DataFrameAnalyticsAuditor auditor,
TrainedModelProvider trainedModelProvider) {
this.client = Objects.requireNonNull(client);
this.threadPool = Objects.requireNonNull(threadPool);
this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob);
this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess);
this.processFactory = Objects.requireNonNull(analyticsProcessFactory);
this.auditor = Objects.requireNonNull(auditor);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
}

public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory,
Consumer<Exception> finishHandler) {
threadPool.generic().execute(() -> {
if (task.isStopping()) {
// The task was requested to stop before we created the process context
finishHandler.accept(null);
return;
executorServiceForJob.execute(() -> {
ProcessContext processContext = new ProcessContext(config.getId());
synchronized (this) {
if (task.isStopping()) {
// The task was requested to stop before we created the process context
finishHandler.accept(null);
return;
}
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
finishHandler.accept(
ExceptionsHelper.serverError("[" + config.getId() + "] Could not create process as one already exists"));
return;
}
}

// First we refresh the dest index to ensure data is searchable
// Refresh the dest index to ensure data is searchable
refreshDest(config);

ProcessContext processContext = new ProcessContext(config.getId());
if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) {
finishHandler.accept(ExceptionsHelper.serverError("[" + processContext.id
+ "] Could not create process as one already exists"));
return;
}

// Fetch existing model state (if any)
BytesReference state = getModelState(config);

if (processContext.startProcess(dataExtractorFactory, config, task, state)) {
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
executorService.execute(() -> processResults(processContext));
executorService.execute(() -> processData(task, config, processContext.dataExtractor,
executorServiceForProcess.execute(() -> processResults(processContext));
executorServiceForProcess.execute(() -> processData(task, config, processContext.dataExtractor,
processContext.process, processContext.resultProcessor, finishHandler, state));
} else {
processContextByAllocation.remove(task.getAllocationId());
finishHandler.accept(null);
}
});
Expand All @@ -111,8 +130,6 @@ private BytesReference getModelState(DataFrameAnalyticsConfig config) {
}

try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern());
searchRequest.source().size(1).query(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())));
SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
.setSize(1)
.setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
Expand Down Expand Up @@ -246,9 +263,8 @@ private void restoreState(DataFrameAnalyticsConfig config, @Nullable BytesRefere

private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME);
AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state,
executorService, onProcessCrash(task));
AnalyticsProcess<AnalyticsResult> process =
processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task));
if (process.isProcessAlive() == false) {
throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
}
Expand Down Expand Up @@ -285,17 +301,22 @@ private void closeProcess(DataFrameAnalyticsTask task) {
}
}

public void stop(DataFrameAnalyticsTask task) {
public synchronized void stop(DataFrameAnalyticsTask task) {
ProcessContext processContext = processContextByAllocation.get(task.getAllocationId());
if (processContext != null) {
LOGGER.debug("[{}] Stopping process", task.getParams().getId() );
LOGGER.debug("[{}] Stopping process", task.getParams().getId());
processContext.stop();
} else {
LOGGER.debug("[{}] No process context to stop", task.getParams().getId() );
LOGGER.debug("[{}] No process context to stop", task.getParams().getId());
task.markAsCompleted();
}
}

// Visible for testing
int getProcessContextCount() {
return processContextByAllocation.size();
}

class ProcessContext {

private final String id;
Expand All @@ -309,31 +330,26 @@ class ProcessContext {
this.id = Objects.requireNonNull(id);
}

public String getId() {
return id;
}

public boolean isProcessKilled() {
return processKilled;
synchronized String getFailureReason() {
return failureReason;
}

private synchronized void setFailureReason(String failureReason) {
synchronized void setFailureReason(String failureReason) {
// Only set the new reason if there isn't one already as we want to keep the first reason
if (failureReason != null) {
if (this.failureReason == null && failureReason != null) {
this.failureReason = failureReason;
}
}

private String getFailureReason() {
return failureReason;
}

public synchronized void stop() {
synchronized void stop() {
LOGGER.debug("[{}] Stopping process", id);
processKilled = true;
if (dataExtractor != null) {
dataExtractor.cancel();
}
if (resultProcessor != null) {
resultProcessor.cancel();
}
if (process != null) {
try {
process.kill();
Expand All @@ -346,8 +362,8 @@ public synchronized void stop() {
/**
* @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime
*/
private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config,
DataFrameAnalyticsTask task, @Nullable BytesReference state) {
if (processKilled) {
// The job was stopped before we started the process so no need to start it
return false;
Expand All @@ -365,8 +381,8 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr
process = createProcess(task, config, analyticsProcessConfig, state);
DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client,
dataExtractorFactory.newExtractor(true));
resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(),
trainedModelProvider, auditor, dataExtractor.getFieldNames());
resultProcessor = new AnalyticsResultProcessor(
config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.getFieldNames());
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,29 +31,26 @@
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

public class AnalyticsResultProcessor {

private static final Logger LOGGER = LogManager.getLogger(AnalyticsResultProcessor.class);

private final DataFrameAnalyticsConfig analytics;
private final DataFrameRowsJoiner dataFrameRowsJoiner;
private final Supplier<Boolean> isProcessKilled;
private final ProgressTracker progressTracker;
private final TrainedModelProvider trainedModelProvider;
private final DataFrameAnalyticsAuditor auditor;
private final List<String> fieldNames;
private final CountDownLatch completionLatch = new CountDownLatch(1);
private volatile String failure;
private volatile boolean isCancelled;

public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner,
Supplier<Boolean> isProcessKilled, ProgressTracker progressTracker,
TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor,
List<String> fieldNames) {
ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider,
DataFrameAnalyticsAuditor auditor, List<String> fieldNames) {
this.analytics = Objects.requireNonNull(analytics);
this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner);
this.isProcessKilled = Objects.requireNonNull(isProcessKilled);
this.progressTracker = Objects.requireNonNull(progressTracker);
this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider);
this.auditor = Objects.requireNonNull(auditor);
Expand All @@ -74,6 +71,10 @@ public void awaitForCompletion() {
}
}

public void cancel() {
isCancelled = true;
}

public void process(AnalyticsProcess<AnalyticsResult> process) {
long totalRows = process.getConfig().rows();
long processedRows = 0;
Expand All @@ -82,20 +83,23 @@ public void process(AnalyticsProcess<AnalyticsResult> process) {
try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) {
Iterator<AnalyticsResult> iterator = process.readAnalyticsResults();
while (iterator.hasNext()) {
if (isCancelled) {
break;
}
AnalyticsResult result = iterator.next();
processResult(result, resultsJoiner);
if (result.getRowResults() != null) {
processedRows++;
progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows));
}
}
if (isProcessKilled.get() == false) {
if (isCancelled == false) {
// This means we completed successfully so we need to set the progress to 100.
// This is because due to skipped rows, it is possible the processed rows will not reach the total rows.
progressTracker.writingResultsPercent.set(100);
}
} catch (Exception e) {
if (isProcessKilled.get()) {
if (isCancelled) {
// No need to log error as it's due to stopping
} else {
LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e);
Expand Down
Loading