From 42baf3faec0c378acfd9d3f75ea9e9eb79728034 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Thu, 9 Apr 2020 13:12:45 +0100 Subject: [PATCH] Fix non-deterministic behaviour in ModelLoadingServiceTests --- .../loadingservice/ModelLoadingService.java | 6 +- .../ModelLoadingServiceTests.java | 59 ++++++++++++++----- 2 files changed, 46 insertions(+), 19 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index c092f7873b067..34b4e4607917a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -320,7 +320,7 @@ public void clusterChanged(ClusterChangedEvent event) { // Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels allReferencedModelKeys.removeAll(referencedModels); referencedModels.addAll(allReferencedModelKeys); - + // Populate loadingListeners key so we know that we are currently loading the model for (String modelId : allReferencedModelKeys) { loadingListeners.put(modelId, new ArrayDeque<>()); @@ -353,9 +353,9 @@ private void auditIfNecessary(String modelId, MessageSupplier msg) { logger.trace(() -> new ParameterizedMessage("[{}] {}", modelId, msg.get().getFormattedMessage())); return; } - auditor.warning(modelId, msg.get().getFormattedMessage()); + auditor.info(modelId, msg.get().getFormattedMessage()); shouldNotAudit.add(modelId); - logger.warn("[{}] {}", modelId, msg.get().getFormattedMessage()); + logger.info("[{}] {}", modelId, msg.get().getFormattedMessage()); } private void loadModels(Set modelIds) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index 65f352282a1c4..d1733526287df 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -36,9 +36,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; @@ -46,6 +46,7 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.junit.After; import org.junit.Before; +import org.mockito.ArgumentMatcher; import org.mockito.Mockito; import java.io.IOException; @@ -144,13 +145,13 @@ public void testGetCachedModels() throws Exception { verify(trainedModelProvider, times(4)).getTrainedModel(eq(model3), eq(true), any()); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54986") public void testMaxCachedLimitReached() throws Exception { String model1 = "test-cached-limit-load-model-1"; String model2 = "test-cached-limit-load-model-2"; String model3 = "test-cached-limit-load-model-3"; + String[] modelIds = new String[]{model1, model2, model3}; withTrainedModel(model1, 10L); - withTrainedModel(model2, 5L); + withTrainedModel(model2, 6L); withTrainedModel(model3, 15L); ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, @@ -164,15 +165,15 @@ public void testMaxCachedLimitReached() throws Exception { modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); - // Should have been loaded from the cluster change event - // Verify that we have at least loaded all three so that evictions occur in the following loop + // Should have been loaded from the cluster change event but it is unknown in what order + // the loading occurred or which models are currently in the cache due to evictions. + // Verify that we have at least loaded all three assertBusy(() -> { verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(true), any()); verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); }); - String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { // Only reference models 1 and 2, so that cache is only invalidated once for model3 (after initial load) String model = modelIds[i%2]; @@ -181,28 +182,55 @@ public void testMaxCachedLimitReached() throws Exception { assertThat(future.get(), is(not(nullValue()))); } - verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model1), eq(true), any()); - verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model2), eq(true), any()); + verify(trainedModelProvider, times(2)).getTrainedModel(eq(model1), eq(true), any()); + verify(trainedModelProvider, times(2)).getTrainedModel(eq(model2), eq(true), any()); // Only loaded requested once on the initial load from the change event verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(true), any()); - // Load model 3, should invalidate 1 + // model 3 has been loaded and evicted exactly once + verify(trainedModelStatsService, times(1)).queueStats(argThat(new ArgumentMatcher<>() { + @Override + public boolean matches(final Object o) { + return ((InferenceStats)o).getModelId().equals(model3); + } + })); + + // Load model 3, should invalidate 1 and 2 for(int i = 0; i < 10; i++) { PlainActionFuture> future3 = new PlainActionFuture<>(); modelLoadingService.getModel(model3, future3); assertThat(future3.get(), is(not(nullValue()))); } - verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model3), eq(true), any()); - - // Load model 1, should invalidate 2 + verify(trainedModelProvider, times(2)).getTrainedModel(eq(model3), eq(true), any()); + + verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<>() { + @Override + public boolean matches(final Object o) { + return ((InferenceStats)o).getModelId().equals(model1); + } + })); + verify(trainedModelStatsService, atMost(2)).queueStats(argThat(new ArgumentMatcher<>() { + @Override + public boolean matches(final Object o) { + return ((InferenceStats)o).getModelId().equals(model2); + } + })); + + // Load model 1, should invalidate 3 for(int i = 0; i < 10; i++) { PlainActionFuture> future1 = new PlainActionFuture<>(); modelLoadingService.getModel(model1, future1); assertThat(future1.get(), is(not(nullValue()))); } verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any()); - - // Load model 2, should invalidate 3 + verify(trainedModelStatsService, times(2)).queueStats(argThat(new ArgumentMatcher<>() { + @Override + public boolean matches(final Object o) { + return ((InferenceStats)o).getModelId().equals(model3); + } + })); + + // Load model 2 for(int i = 0; i < 10; i++) { PlainActionFuture> future2 = new PlainActionFuture<>(); modelLoadingService.getModel(model2, future2); @@ -210,7 +238,6 @@ public void testMaxCachedLimitReached() throws Exception { } verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any()); - // Test invalidate cache for model3 // Now both model 1 and 2 should fit in cache without issues modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2)); @@ -223,7 +250,7 @@ public void testMaxCachedLimitReached() throws Exception { verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model1), eq(true), any()); verify(trainedModelProvider, atMost(3)).getTrainedModel(eq(model2), eq(true), any()); - verify(trainedModelProvider, Mockito.atLeast(4)).getTrainedModel(eq(model3), eq(true), any()); + verify(trainedModelProvider, Mockito.atLeast(5)).getTrainedModel(eq(model3), eq(true), any()); verify(trainedModelProvider, atMost(5)).getTrainedModel(eq(model3), eq(true), any()); }