diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 455d4425ea92a..87b01af8ef889 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -239,7 +240,6 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception "Finished analysis"); } - @AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/49095") public void testStopAndRestart() throws Exception { initialize("regression_stop_and_restart"); @@ -270,8 +270,12 @@ public void testStopAndRestart() throws Exception { // Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED. assertBusy(() -> { DataFrameAnalyticsState state = getAnalyticsStats(jobId).getState(); - assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING), - equalTo(DataFrameAnalyticsState.STOPPED)))); + assertThat( + state, + is(anyOf( + equalTo(DataFrameAnalyticsState.REINDEXING), + equalTo(DataFrameAnalyticsState.ANALYZING), + equalTo(DataFrameAnalyticsState.STOPPED)))); }); stopAnalytics(jobId); waitUntilAnalyticsIsStopped(jobId); @@ -287,7 +291,7 @@ public void testStopAndRestart() throws Exception { } } - waitUntilAnalyticsIsStopped(jobId); + waitUntilAnalyticsIsStopped(jobId, TimeValue.timeValueMinutes(1)); SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get(); for (SearchHit hit : sourceData.getHits()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 2fe5004aabcd6..0a2c6440bf0cc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.admin.indices.refresh.RefreshAction; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.client.Client; import org.elasticsearch.common.Nullable; @@ -54,7 +53,8 @@ public class AnalyticsProcessManager { private static final Logger LOGGER = LogManager.getLogger(AnalyticsProcessManager.class); private final Client client; - private final ThreadPool threadPool; + private final ExecutorService executorServiceForJob; + private final ExecutorService executorServiceForProcess; private final AnalyticsProcessFactory processFactory; private final ConcurrentMap processContextByAllocation = new ConcurrentHashMap<>(); private final DataFrameAnalyticsAuditor auditor; @@ -65,8 +65,25 @@ public AnalyticsProcessManager(Client client, AnalyticsProcessFactory analyticsProcessFactory, DataFrameAnalyticsAuditor auditor, TrainedModelProvider trainedModelProvider) { + this( + client, + threadPool.generic(), + threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), + analyticsProcessFactory, + auditor, + trainedModelProvider); + } + + // Visible for testing + public AnalyticsProcessManager(Client client, + ExecutorService executorServiceForJob, + ExecutorService executorServiceForProcess, + AnalyticsProcessFactory analyticsProcessFactory, + DataFrameAnalyticsAuditor auditor, + TrainedModelProvider trainedModelProvider) { this.client = Objects.requireNonNull(client); - this.threadPool = Objects.requireNonNull(threadPool); + this.executorServiceForJob = Objects.requireNonNull(executorServiceForJob); + this.executorServiceForProcess = Objects.requireNonNull(executorServiceForProcess); this.processFactory = Objects.requireNonNull(analyticsProcessFactory); this.auditor = Objects.requireNonNull(auditor); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); @@ -74,31 +91,33 @@ public AnalyticsProcessManager(Client client, public void runJob(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, DataFrameDataExtractorFactory dataExtractorFactory, Consumer finishHandler) { - threadPool.generic().execute(() -> { - if (task.isStopping()) { - // The task was requested to stop before we created the process context - finishHandler.accept(null); - return; + executorServiceForJob.execute(() -> { + ProcessContext processContext = new ProcessContext(config.getId()); + synchronized (this) { + if (task.isStopping()) { + // The task was requested to stop before we created the process context + finishHandler.accept(null); + return; + } + if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { + finishHandler.accept( + ExceptionsHelper.serverError("[" + config.getId() + "] Could not create process as one already exists")); + return; + } } - // First we refresh the dest index to ensure data is searchable + // Refresh the dest index to ensure data is searchable refreshDest(config); - ProcessContext processContext = new ProcessContext(config.getId()); - if (processContextByAllocation.putIfAbsent(task.getAllocationId(), processContext) != null) { - finishHandler.accept(ExceptionsHelper.serverError("[" + processContext.id - + "] Could not create process as one already exists")); - return; - } - + // Fetch existing model state (if any) BytesReference state = getModelState(config); if (processContext.startProcess(dataExtractorFactory, config, task, state)) { - ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); - executorService.execute(() -> processResults(processContext)); - executorService.execute(() -> processData(task, config, processContext.dataExtractor, + executorServiceForProcess.execute(() -> processResults(processContext)); + executorServiceForProcess.execute(() -> processData(task, config, processContext.dataExtractor, processContext.process, processContext.resultProcessor, finishHandler, state)); } else { + processContextByAllocation.remove(task.getAllocationId()); finishHandler.accept(null); } }); @@ -111,8 +130,6 @@ private BytesReference getModelState(DataFrameAnalyticsConfig config) { } try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) { - SearchRequest searchRequest = new SearchRequest(AnomalyDetectorsIndex.jobStateIndexPattern()); - searchRequest.source().size(1).query(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId()))); SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern()) .setSize(1) .setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId()))) @@ -246,9 +263,8 @@ private void restoreState(DataFrameAnalyticsConfig config, @Nullable BytesRefere private AnalyticsProcess createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) { - ExecutorService executorService = threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME); - AnalyticsProcess process = processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, - executorService, onProcessCrash(task)); + AnalyticsProcess process = + processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task)); if (process.isProcessAlive() == false) { throw ExceptionsHelper.serverError("Failed to start data frame analytics process"); } @@ -285,17 +301,22 @@ private void closeProcess(DataFrameAnalyticsTask task) { } } - public void stop(DataFrameAnalyticsTask task) { + public synchronized void stop(DataFrameAnalyticsTask task) { ProcessContext processContext = processContextByAllocation.get(task.getAllocationId()); if (processContext != null) { - LOGGER.debug("[{}] Stopping process", task.getParams().getId() ); + LOGGER.debug("[{}] Stopping process", task.getParams().getId()); processContext.stop(); } else { - LOGGER.debug("[{}] No process context to stop", task.getParams().getId() ); + LOGGER.debug("[{}] No process context to stop", task.getParams().getId()); task.markAsCompleted(); } } + // Visible for testing + int getProcessContextCount() { + return processContextByAllocation.size(); + } + class ProcessContext { private final String id; @@ -309,31 +330,26 @@ class ProcessContext { this.id = Objects.requireNonNull(id); } - public String getId() { - return id; - } - - public boolean isProcessKilled() { - return processKilled; + synchronized String getFailureReason() { + return failureReason; } - private synchronized void setFailureReason(String failureReason) { + synchronized void setFailureReason(String failureReason) { // Only set the new reason if there isn't one already as we want to keep the first reason - if (failureReason != null) { + if (this.failureReason == null && failureReason != null) { this.failureReason = failureReason; } } - private String getFailureReason() { - return failureReason; - } - - public synchronized void stop() { + synchronized void stop() { LOGGER.debug("[{}] Stopping process", id); processKilled = true; if (dataExtractor != null) { dataExtractor.cancel(); } + if (resultProcessor != null) { + resultProcessor.cancel(); + } if (process != null) { try { process.kill(); @@ -346,8 +362,8 @@ public synchronized void stop() { /** * @return {@code true} if the process was started or {@code false} if it was not because it was stopped in the meantime */ - private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config, - DataFrameAnalyticsTask task, @Nullable BytesReference state) { + synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory, DataFrameAnalyticsConfig config, + DataFrameAnalyticsTask task, @Nullable BytesReference state) { if (processKilled) { // The job was stopped before we started the process so no need to start it return false; @@ -365,8 +381,8 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr process = createProcess(task, config, analyticsProcessConfig, state); DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); - resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(), - trainedModelProvider, auditor, dataExtractor.getFieldNames()); + resultProcessor = new AnalyticsResultProcessor( + config, dataFrameRowsJoiner, task.getProgressTracker(), trainedModelProvider, auditor, dataExtractor.getFieldNames()); return true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 3abc3b5e43cce..8eb101652d332 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -31,7 +31,6 @@ import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; public class AnalyticsResultProcessor { @@ -39,21 +38,19 @@ public class AnalyticsResultProcessor { private final DataFrameAnalyticsConfig analytics; private final DataFrameRowsJoiner dataFrameRowsJoiner; - private final Supplier isProcessKilled; private final ProgressTracker progressTracker; private final TrainedModelProvider trainedModelProvider; private final DataFrameAnalyticsAuditor auditor; private final List fieldNames; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; + private volatile boolean isCancelled; public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, - Supplier isProcessKilled, ProgressTracker progressTracker, - TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor, - List fieldNames) { + ProgressTracker progressTracker, TrainedModelProvider trainedModelProvider, + DataFrameAnalyticsAuditor auditor, List fieldNames) { this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); - this.isProcessKilled = Objects.requireNonNull(isProcessKilled); this.progressTracker = Objects.requireNonNull(progressTracker); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.auditor = Objects.requireNonNull(auditor); @@ -74,6 +71,10 @@ public void awaitForCompletion() { } } + public void cancel() { + isCancelled = true; + } + public void process(AnalyticsProcess process) { long totalRows = process.getConfig().rows(); long processedRows = 0; @@ -82,6 +83,9 @@ public void process(AnalyticsProcess process) { try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { + if (isCancelled) { + break; + } AnalyticsResult result = iterator.next(); processResult(result, resultsJoiner); if (result.getRowResults() != null) { @@ -89,13 +93,13 @@ public void process(AnalyticsProcess process) { progressTracker.writingResultsPercent.set(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); } } - if (isProcessKilled.get() == false) { + if (isCancelled == false) { // This means we completed successfully so we need to set the progress to 100. // This is because due to skipped rows, it is possible the processed rows will not reach the total rows. progressTracker.writingResultsPercent.set(100); } } catch (Exception e) { - if (isProcessKilled.get()) { + if (isCancelled) { // No need to log error as it's due to stopping } else { LOGGER.error(new ParameterizedMessage("[{}] Error parsing data frame analytics output", analytics.getId()), e); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java new file mode 100644 index 0000000000000..d86f0397319bb --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -0,0 +1,214 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfigTests; +import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; +import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; +import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyBoolean; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Test for the basic functionality of {@link AnalyticsProcessManager} and {@link AnalyticsProcessManager.ProcessContext}. + * This test does not spawn any threads. Instead: + * - job is run on a current thread (using {@code DirectExecutorService}) + * - {@code processData} and {@code processResults} methods are not run at all (using mock executor) + */ +public class AnalyticsProcessManagerTests extends ESTestCase { + + private static final long TASK_ALLOCATION_ID = 123; + private static final String CONFIG_ID = "config-id"; + private static final int NUM_ROWS = 100; + private static final int NUM_COLS = 4; + private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null); + + private Client client; + private DataFrameAnalyticsAuditor auditor; + private TrainedModelProvider trainedModelProvider; + private ExecutorService executorServiceForJob; + private ExecutorService executorServiceForProcess; + private AnalyticsProcess process; + private AnalyticsProcessFactory processFactory; + private DataFrameAnalyticsTask task; + private DataFrameAnalyticsConfig dataFrameAnalyticsConfig; + private DataFrameDataExtractorFactory dataExtractorFactory; + private DataFrameDataExtractor dataExtractor; + private Consumer finishHandler; + private ArgumentCaptor exceptionCaptor; + private AnalyticsProcessManager processManager; + + @SuppressWarnings("unchecked") + @Before + public void setUpMocks() { + ThreadPool threadPool = mock(ThreadPool.class); + when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY)); + client = mock(Client.class); + when(client.threadPool()).thenReturn(threadPool); + when(client.execute(any(), any())).thenReturn(mock(ActionFuture.class)); + executorServiceForJob = EsExecutors.newDirectExecutorService(); + executorServiceForProcess = mock(ExecutorService.class); + process = mock(AnalyticsProcess.class); + when(process.isProcessAlive()).thenReturn(true); + when(process.readAnalyticsResults()).thenReturn(List.of(PROCESS_RESULT).iterator()); + processFactory = mock(AnalyticsProcessFactory.class); + when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process); + auditor = mock(DataFrameAnalyticsAuditor.class); + trainedModelProvider = mock(TrainedModelProvider.class); + + task = mock(DataFrameAnalyticsTask.class); + when(task.getAllocationId()).thenReturn(TASK_ALLOCATION_ID); + when(task.getProgressTracker()).thenReturn(mock(DataFrameAnalyticsTask.ProgressTracker.class)); + dataFrameAnalyticsConfig = DataFrameAnalyticsConfigTests.createRandom(CONFIG_ID); + dataExtractor = mock(DataFrameDataExtractor.class); + when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS)); + dataExtractorFactory = mock(DataFrameDataExtractorFactory.class); + when(dataExtractorFactory.newExtractor(anyBoolean())).thenReturn(dataExtractor); + finishHandler = mock(Consumer.class); + + exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + processManager = new AnalyticsProcessManager( + client, executorServiceForJob, executorServiceForProcess, processFactory, auditor, trainedModelProvider); + } + + public void testRunJob_TaskIsStopping() { + when(task.isStopping()).thenReturn(true); + + processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler); + assertThat(processManager.getProcessContextCount(), equalTo(0)); + + verify(finishHandler).accept(null); + verifyNoMoreInteractions(finishHandler); + } + + public void testRunJob_ProcessContextAlreadyExists() { + processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler); + assertThat(processManager.getProcessContextCount(), equalTo(1)); + processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler); + assertThat(processManager.getProcessContextCount(), equalTo(1)); + + verify(finishHandler).accept(exceptionCaptor.capture()); + verifyNoMoreInteractions(finishHandler); + + Exception e = exceptionCaptor.getValue(); + assertThat(e.getMessage(), equalTo("[config-id] Could not create process as one already exists")); + } + + public void testRunJob_EmptyDataFrame() { + when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS)); + + processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler); + assertThat(processManager.getProcessContextCount(), equalTo(0)); // Make sure the process context did not leak + + InOrder inOrder = inOrder(dataExtractor, executorServiceForProcess, process, finishHandler); + inOrder.verify(dataExtractor).collectDataSummary(); + inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis()); + inOrder.verify(finishHandler).accept(null); + verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, finishHandler); + } + + public void testRunJob_Ok() { + processManager.runJob(task, dataFrameAnalyticsConfig, dataExtractorFactory, finishHandler); + assertThat(processManager.getProcessContextCount(), equalTo(1)); + + InOrder inOrder = inOrder(dataExtractor, executorServiceForProcess, process, finishHandler); + inOrder.verify(dataExtractor).collectDataSummary(); + inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis()); + inOrder.verify(process).isProcessAlive(); + inOrder.verify(dataExtractor).getFieldNames(); + inOrder.verify(executorServiceForProcess, times(2)).execute(any()); // 'processData' and 'processResults' threads + verifyNoMoreInteractions(dataExtractor, executorServiceForProcess, process, finishHandler); + } + + public void testProcessContext_GetSetFailureReason() { + AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID); + assertThat(processContext.getFailureReason(), is(nullValue())); + + processContext.setFailureReason("reason1"); + assertThat(processContext.getFailureReason(), equalTo("reason1")); + + processContext.setFailureReason(null); + assertThat(processContext.getFailureReason(), equalTo("reason1")); + + processContext.setFailureReason("reason2"); + assertThat(processContext.getFailureReason(), equalTo("reason1")); + + verifyNoMoreInteractions(dataExtractor, process, finishHandler); + } + + public void testProcessContext_StartProcess_ProcessAlreadyKilled() { + AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID); + processContext.stop(); + assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(false)); + + verifyNoMoreInteractions(dataExtractor, process, finishHandler); + } + + public void testProcessContext_StartProcess_EmptyDataFrame() { + when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS)); + + AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID); + assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(false)); + + InOrder inOrder = inOrder(dataExtractor, process, finishHandler); + inOrder.verify(dataExtractor).collectDataSummary(); + inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis()); + verifyNoMoreInteractions(dataExtractor, process, finishHandler); + } + + public void testProcessContext_StartAndStop() throws Exception { + AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID); + assertThat(processContext.startProcess(dataExtractorFactory, dataFrameAnalyticsConfig, task, null), is(true)); + processContext.stop(); + + InOrder inOrder = inOrder(dataExtractor, process, finishHandler); + // startProcess + inOrder.verify(dataExtractor).collectDataSummary(); + inOrder.verify(dataExtractor).getCategoricalFields(dataFrameAnalyticsConfig.getAnalysis()); + inOrder.verify(process).isProcessAlive(); + inOrder.verify(dataExtractor).getFieldNames(); + // stop + inOrder.verify(dataExtractor).cancel(); + inOrder.verify(process).kill(); + verifyNoMoreInteractions(dataExtractor, process, finishHandler); + } + + public void testProcessContext_Stop() { + AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(CONFIG_ID); + processContext.stop(); + + verifyNoMoreInteractions(dataExtractor, process, finishHandler); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index bdccdf8c6722f..79d0ee3fccbd5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -200,7 +200,7 @@ private AnalyticsResultProcessor createResultProcessor() { } private AnalyticsResultProcessor createResultProcessor(List fieldNames) { - return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider, - auditor, fieldNames); + return new AnalyticsResultProcessor( + analyticsConfig, dataFrameRowsJoiner, progressTracker, trainedModelProvider, auditor, fieldNames); } }