From 6086aac568d8dbafe4904118cca3dd74c9b19717 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 17 Sep 2024 16:29:41 +0100 Subject: [PATCH] Revert "[8.15] [ML] Downloaded and write model parts using multiple streams (#112869)" (#113017) This reverts commit f48a1c60a8e24aed95ceee5736c5620a9220f558. --- docs/changelog/111684.yaml | 5 - docs/changelog/112869.yaml | 5 - .../core/common/notifications/Level.java | 20 +- .../MachineLearningPackageLoader.java | 28 +- .../packageloader/action/ModelImporter.java | 313 ++++------------- .../action/ModelLoaderUtils.java | 148 +------- ...ortGetTrainedModelPackageConfigAction.java | 2 +- .../TransportLoadTrainedModelPackage.java | 95 +++--- .../MachineLearningPackageLoaderTests.java | 12 - .../action/ModelDownloadTaskTests.java | 20 +- .../action/ModelImporterTests.java | 316 ------------------ .../action/ModelLoaderUtilsTests.java | 40 +-- ...TransportLoadTrainedModelPackageTests.java | 76 ++--- 13 files changed, 178 insertions(+), 902 deletions(-) delete mode 100644 docs/changelog/111684.yaml delete mode 100644 docs/changelog/112869.yaml delete mode 100644 x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java diff --git a/docs/changelog/111684.yaml b/docs/changelog/111684.yaml deleted file mode 100644 index 32edb5723cb0a..0000000000000 --- a/docs/changelog/111684.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 111684 -summary: Write downloaded model parts async -area: Machine Learning -type: enhancement -issues: [] diff --git a/docs/changelog/112869.yaml b/docs/changelog/112869.yaml deleted file mode 100644 index cff88c5218bd9..0000000000000 --- a/docs/changelog/112869.yaml +++ /dev/null @@ -1,5 +0,0 @@ -pr: 112869 -summary: "[8.15] [ML] Downloaded and write model parts using multiple streams" -area: Machine Learning -type: enhancement -issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/Level.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/Level.java index f559370350972..2db973f8122c1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/Level.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/common/notifications/Level.java @@ -9,23 +9,9 @@ import java.util.Locale; public enum Level { - INFO { - public org.apache.logging.log4j.Level log4jLevel() { - return org.apache.logging.log4j.Level.INFO; - } - }, - WARNING { - public org.apache.logging.log4j.Level log4jLevel() { - return org.apache.logging.log4j.Level.WARN; - } - }, - ERROR { - public org.apache.logging.log4j.Level log4jLevel() { - return org.apache.logging.log4j.Level.ERROR; - } - }; - - public abstract org.apache.logging.log4j.Level log4jLevel(); + INFO, + WARNING, + ERROR; /** * Case-insensitive from string method. diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java index a63d911e9d40d..e927c46e6bd29 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoader.java @@ -15,17 +15,12 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ExecutorBuilder; -import org.elasticsearch.threadpool.FixedExecutorBuilder; import org.elasticsearch.xpack.core.ml.packageloader.action.GetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.ml.packageloader.action.ModelDownloadTask; -import org.elasticsearch.xpack.ml.packageloader.action.ModelImporter; import org.elasticsearch.xpack.ml.packageloader.action.TransportGetTrainedModelPackageConfigAction; import org.elasticsearch.xpack.ml.packageloader.action.TransportLoadTrainedModelPackage; @@ -49,6 +44,9 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Setting.Property.Dynamic ); + // re-using thread pool setup by the ml plugin + public static final String UTILITY_THREAD_POOL_NAME = "ml_utility"; + // This link will be invalid for serverless, but serverless will never be // air-gapped, so this message should never be needed. private static final String MODEL_REPOSITORY_DOCUMENTATION_LINK = format( @@ -56,8 +54,6 @@ public class MachineLearningPackageLoader extends Plugin implements ActionPlugin Build.current().version().replaceFirst("^(\\d+\\.\\d+).*", "$1") ); - public static final String MODEL_DOWNLOAD_THREADPOOL_NAME = "model_download"; - public MachineLearningPackageLoader() {} @Override @@ -85,24 +81,6 @@ public List getNamedWriteables() { ); } - @Override - public List> getExecutorBuilders(Settings settings) { - return List.of(modelDownloadExecutor(settings)); - } - - public static FixedExecutorBuilder modelDownloadExecutor(Settings settings) { - // Threadpool with a fixed number of threads for - // downloading the model definition files - return new FixedExecutorBuilder( - settings, - MODEL_DOWNLOAD_THREADPOOL_NAME, - ModelImporter.NUMBER_OF_STREAMS, - -1, // unbounded queue size - "xpack.ml.model_download_thread_pool", - EsExecutors.TaskTrackingConfig.DO_NOT_TRACK - ); - } - @Override public List getBootstrapChecks() { return List.of(new BootstrapCheck() { diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java index 86711804ed03c..33d5d5982d2b0 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporter.java @@ -10,248 +10,124 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.RefCountingListener; -import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesArray; -import org.elasticsearch.core.Nullable; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.TaskCancelledException; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; -import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; +import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.List; import java.util.Objects; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.core.Strings.format; /** - * For downloading and the vocabulary and model definition file and - * indexing those files in Elasticsearch. - * Holding the large model definition file in memory will consume - * too much memory, instead it is streamed in chunks and each chunk - * written to the index in a non-blocking request. - * The model files may be installed from a local file or download - * from a server. The server download uses {@link #NUMBER_OF_STREAMS} - * connections each using the Range header to split the stream by byte - * range. There is a complication in that the final part of the model - * definition must be uploaded last as writing this part causes an index - * refresh. - * When read from file a single thread is used to read the file - * stream, split into chunks and index those chunks. + * A helper class for abstracting out the use of the ModelLoaderUtils to make dependency injection testing easier. */ -public class ModelImporter { +class ModelImporter { private static final int DEFAULT_CHUNK_SIZE = 1024 * 1024; // 1MB - public static final int NUMBER_OF_STREAMS = 5; private static final Logger logger = LogManager.getLogger(ModelImporter.class); private final Client client; private final String modelId; private final ModelPackageConfig config; private final ModelDownloadTask task; - private final ExecutorService executorService; - private final AtomicInteger progressCounter = new AtomicInteger(); - private final URI uri; - ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task, ThreadPool threadPool) - throws URISyntaxException { + ModelImporter(Client client, String modelId, ModelPackageConfig packageConfig, ModelDownloadTask task) { this.client = client; this.modelId = Objects.requireNonNull(modelId); this.config = Objects.requireNonNull(packageConfig); this.task = Objects.requireNonNull(task); - this.executorService = threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME); - this.uri = ModelLoaderUtils.resolvePackageLocation( - config.getModelRepository(), - config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION - ); } - public void doImport(ActionListener listener) { - executorService.execute(() -> doImportInternal(listener)); - } + public void doImport() throws URISyntaxException, IOException, ElasticsearchStatusException { + long size = config.getSize(); - private void doImportInternal(ActionListener finalListener) { - assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) - : format( - "Model download must execute from [%s] but thread is [%s]", - MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, - Thread.currentThread().getName() - ); + // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the + // download is complete + if (Strings.isNullOrEmpty(config.getVocabularyFile()) == false) { + uploadVocabulary(); - ModelLoaderUtils.VocabularyParts vocabularyParts = null; - try { - if (config.getVocabularyFile() != null) { - vocabularyParts = ModelLoaderUtils.loadVocabulary( - ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) - ); - } + logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); + } - // simple round up - int totalParts = (int) ((config.getSize() + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); + URI uri = ModelLoaderUtils.resolvePackageLocation( + config.getModelRepository(), + config.getPackagedModelId() + ModelLoaderUtils.MODEL_FILE_EXTENSION + ); - if (ModelLoaderUtils.uriIsFile(uri) == false) { - var ranges = ModelLoaderUtils.split(config.getSize(), NUMBER_OF_STREAMS, DEFAULT_CHUNK_SIZE); - var downloaders = new ArrayList(ranges.size()); - for (var range : ranges) { - downloaders.add(new ModelLoaderUtils.HttpStreamChunker(uri, range, DEFAULT_CHUNK_SIZE)); - } - downloadModelDefinition(config.getSize(), totalParts, vocabularyParts, downloaders, finalListener); - } else { - InputStream modelInputStream = ModelLoaderUtils.getFileInputStream(uri); - ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker( - modelInputStream, - DEFAULT_CHUNK_SIZE - ); - readModelDefinitionFromFile(config.getSize(), totalParts, chunkIterator, vocabularyParts, finalListener); - } - } catch (Exception e) { - finalListener.onFailure(e); - return; - } - } + InputStream modelInputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); - void downloadModelDefinition( - long size, - int totalParts, - @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, - List downloaders, - ActionListener finalListener - ) { - try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { - var finalDownloader = downloaders.get(downloaders.size() - 1); - downloadFinalPart(size, totalParts, finalDownloader, finalListener.delegateFailureAndWrap((l, r) -> { - checkDownloadComplete(downloaders); - l.onResponse(AcknowledgedResponse.TRUE); - })); - }), finalListener::onFailure))) { - // Uploading other artefacts of the model first, that way the model is last and a simple search can be used to check if the - // download is complete - if (vocabularyParts != null) { - uploadVocabulary(vocabularyParts, countingListener); - } + ModelLoaderUtils.InputStreamChunker chunkIterator = new ModelLoaderUtils.InputStreamChunker(modelInputStream, DEFAULT_CHUNK_SIZE); - // Download all but the final split. - // The final split is a single chunk - for (int streamSplit = 0; streamSplit < downloaders.size() - 1; ++streamSplit) { - final var downloader = downloaders.get(streamSplit); - var rangeDownloadedListener = countingListener.acquire(); // acquire to keep the counting listener from closing - executorService.execute( - () -> downloadPartInRange(size, totalParts, downloader, executorService, countingListener, rangeDownloadedListener) - ); - } - } - } + // simple round up + int totalParts = (int) ((size + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE); - private void downloadPartInRange( - long size, - int totalParts, - ModelLoaderUtils.HttpStreamChunker downloadChunker, - ExecutorService executorService, - RefCountingListener countingListener, - ActionListener rangeFullyDownloadedListener - ) { - assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) - : format( - "Model download must execute from [%s] but thread is [%s]", - MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, - Thread.currentThread().getName() + for (int part = 0; part < totalParts - 1; ++part) { + task.setProgress(totalParts, part); + BytesArray definition = chunkIterator.next(); + + PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( + modelId, + definition, + part, + size, + totalParts, + true ); - if (countingListener.isFailing()) { - rangeFullyDownloadedListener.onResponse(null); // the error has already been reported elsewhere - return; + executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest); } - try { - throwIfTaskCancelled(); - var bytesAndIndex = downloadChunker.next(); - task.setProgress(totalParts, progressCounter.getAndIncrement()); - - indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes(), countingListener.acquire(ack -> {})); - } catch (Exception e) { - rangeFullyDownloadedListener.onFailure(e); - return; - } + // get the last part, this time verify the checksum and size + BytesArray definition = chunkIterator.next(); - if (downloadChunker.hasNext()) { - executorService.execute( - () -> downloadPartInRange( - size, - totalParts, - downloadChunker, - executorService, - countingListener, - rangeFullyDownloadedListener - ) + if (config.getSha256().equals(chunkIterator.getSha256()) == false) { + String message = format( + "Model sha256 checksums do not match, expected [%s] but got [%s]", + config.getSha256(), + chunkIterator.getSha256() ); - } else { - rangeFullyDownloadedListener.onResponse(null); + + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); } - } - private void downloadFinalPart( - long size, - int totalParts, - ModelLoaderUtils.HttpStreamChunker downloader, - ActionListener lastPartWrittenListener - ) { - assert ThreadPool.assertCurrentThreadPool(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME) - : format( - "Model download must execute from [%s] but thread is [%s]", - MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME, - Thread.currentThread().getName() + if (config.getSize() != chunkIterator.getTotalBytesRead()) { + String message = format( + "Model size does not match, expected [%d] but got [%d]", + config.getSize(), + chunkIterator.getTotalBytesRead() ); - try { - var bytesAndIndex = downloader.next(); - task.setProgress(totalParts, progressCounter.getAndIncrement()); - - indexPart(bytesAndIndex.partIndex(), totalParts, size, bytesAndIndex.bytes(), lastPartWrittenListener); - } catch (Exception e) { - lastPartWrittenListener.onFailure(e); + throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); } - } - - void readModelDefinitionFromFile( - long size, - int totalParts, - ModelLoaderUtils.InputStreamChunker chunkIterator, - @Nullable ModelLoaderUtils.VocabularyParts vocabularyParts, - ActionListener finalListener - ) { - try (var countingListener = new RefCountingListener(1, ActionListener.wrap(ignore -> executorService.execute(() -> { - finalListener.onResponse(AcknowledgedResponse.TRUE); - }), finalListener::onFailure))) { - try { - if (vocabularyParts != null) { - uploadVocabulary(vocabularyParts, countingListener); - } - for (int part = 0; part < totalParts; ++part) { - throwIfTaskCancelled(); - task.setProgress(totalParts, part); - BytesArray definition = chunkIterator.next(); - indexPart(part, totalParts, size, definition, countingListener.acquire(ack -> {})); - } - task.setProgress(totalParts, totalParts); + PutTrainedModelDefinitionPartAction.Request finalModelPartRequest = new PutTrainedModelDefinitionPartAction.Request( + modelId, + definition, + totalParts - 1, + size, + totalParts, + true + ); - checkDownloadComplete(chunkIterator, totalParts); - } catch (Exception e) { - countingListener.acquire().onFailure(e); - } - } + executeRequestIfNotCancelled(PutTrainedModelDefinitionPartAction.INSTANCE, finalModelPartRequest); + logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); } - private void uploadVocabulary(ModelLoaderUtils.VocabularyParts vocabularyParts, RefCountingListener countingListener) { + private void uploadVocabulary() throws URISyntaxException { + ModelLoaderUtils.VocabularyParts vocabularyParts = ModelLoaderUtils.loadVocabulary( + ModelLoaderUtils.resolvePackageLocation(config.getModelRepository(), config.getVocabularyFile()) + ); + PutTrainedModelVocabularyAction.Request request = new PutTrainedModelVocabularyAction.Request( modelId, vocabularyParts.vocab(), @@ -260,58 +136,17 @@ private void uploadVocabulary(ModelLoaderUtils.VocabularyParts vocabularyParts, true ); - client.execute(PutTrainedModelVocabularyAction.INSTANCE, request, countingListener.acquire(r -> { - logger.debug(() -> format("[%s] imported model vocabulary [%s]", modelId, config.getVocabularyFile())); - })); - } - - private void indexPart(int partIndex, int totalParts, long totalSize, BytesArray bytes, ActionListener listener) { - PutTrainedModelDefinitionPartAction.Request modelPartRequest = new PutTrainedModelDefinitionPartAction.Request( - modelId, - bytes, - partIndex, - totalSize, - totalParts, - true - ); - - client.execute(PutTrainedModelDefinitionPartAction.INSTANCE, modelPartRequest, listener); - } - - private void checkDownloadComplete(List downloaders) { - long totalBytesRead = downloaders.stream().mapToLong(ModelLoaderUtils.HttpStreamChunker::getTotalBytesRead).sum(); - int totalParts = downloaders.stream().mapToInt(ModelLoaderUtils.HttpStreamChunker::getCurrentPart).sum(); - checkSize(totalBytesRead); - logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); + executeRequestIfNotCancelled(PutTrainedModelVocabularyAction.INSTANCE, request); } - private void checkDownloadComplete(ModelLoaderUtils.InputStreamChunker fileInputStream, int totalParts) { - checkSha256(fileInputStream.getSha256()); - checkSize(fileInputStream.getTotalBytesRead()); - logger.debug(format("finished importing model [%s] using [%d] parts", modelId, totalParts)); - } - - private void checkSha256(String sha256) { - if (config.getSha256().equals(sha256) == false) { - String message = format("Model sha256 checksums do not match, expected [%s] but got [%s]", config.getSha256(), sha256); - - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); - } - } - - private void checkSize(long definitionSize) { - if (config.getSize() != definitionSize) { - String message = format("Model size does not match, expected [%d] but got [%d]", config.getSize(), definitionSize); - throw new ElasticsearchStatusException(message, RestStatus.INTERNAL_SERVER_ERROR); - } - } - - private void throwIfTaskCancelled() { + private void executeRequestIfNotCancelled( + ActionType action, + Request request + ) { if (task.isCancelled()) { - logger.info("Model [{}] download task cancelled", modelId); - throw new TaskCancelledException( - format("Model [%s] download task cancelled with reason [%s]", modelId, task.getReasonCancelled()) - ); + throw new TaskCancelledException(format("task cancelled with reason [%s]", task.getReasonCancelled())); } + + client.execute(action, request).actionGet(); } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java index 42bfbb249b623..2f3f9cbf3f32c 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtils.java @@ -17,7 +17,6 @@ import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.core.Nullable; import org.elasticsearch.core.SuppressForbidden; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.XContentParser; @@ -35,20 +34,16 @@ import java.security.AccessController; import java.security.MessageDigest; import java.security.PrivilegedAction; -import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import static java.net.HttpURLConnection.HTTP_MOVED_PERM; import static java.net.HttpURLConnection.HTTP_MOVED_TEMP; import static java.net.HttpURLConnection.HTTP_NOT_FOUND; import static java.net.HttpURLConnection.HTTP_OK; -import static java.net.HttpURLConnection.HTTP_PARTIAL; import static java.net.HttpURLConnection.HTTP_SEE_OTHER; /** @@ -66,73 +61,6 @@ final class ModelLoaderUtils { record VocabularyParts(List vocab, List merges, List scores) {} - // Range in bytes - record RequestRange(long rangeStart, long rangeEnd, int startPart, int numParts) { - public String bytesRange() { - return "bytes=" + rangeStart + "-" + rangeEnd; - } - } - - static class HttpStreamChunker { - - record BytesAndPartIndex(BytesArray bytes, int partIndex) {} - - private final InputStream inputStream; - private final int chunkSize; - private final AtomicLong totalBytesRead = new AtomicLong(); - private final AtomicInteger currentPart; - private final int lastPartNumber; - - HttpStreamChunker(URI uri, RequestRange range, int chunkSize) { - var inputStream = getHttpOrHttpsInputStream(uri, range); - this.inputStream = inputStream; - this.chunkSize = chunkSize; - this.lastPartNumber = range.startPart() + range.numParts(); - this.currentPart = new AtomicInteger(range.startPart()); - } - - // This ctor exists for testing purposes only. - HttpStreamChunker(InputStream inputStream, RequestRange range, int chunkSize) { - this.inputStream = inputStream; - this.chunkSize = chunkSize; - this.lastPartNumber = range.startPart() + range.numParts(); - this.currentPart = new AtomicInteger(range.startPart()); - } - - public boolean hasNext() { - return currentPart.get() < lastPartNumber; - } - - public BytesAndPartIndex next() throws IOException { - int bytesRead = 0; - byte[] buf = new byte[chunkSize]; - - while (bytesRead < chunkSize) { - int read = inputStream.read(buf, bytesRead, chunkSize - bytesRead); - // EOF?? - if (read == -1) { - break; - } - bytesRead += read; - } - - if (bytesRead > 0) { - totalBytesRead.addAndGet(bytesRead); - return new BytesAndPartIndex(new BytesArray(buf, 0, bytesRead), currentPart.getAndIncrement()); - } else { - return new BytesAndPartIndex(BytesArray.EMPTY, currentPart.get()); - } - } - - public long getTotalBytesRead() { - return totalBytesRead.get(); - } - - public int getCurrentPart() { - return currentPart.get(); - } - } - static class InputStreamChunker { private final InputStream inputStream; @@ -173,14 +101,14 @@ public int getTotalBytesRead() { } } - static InputStream getInputStreamFromModelRepository(URI uri) { + static InputStream getInputStreamFromModelRepository(URI uri) throws IOException { String scheme = uri.getScheme().toLowerCase(Locale.ROOT); // if you add a scheme here, also add it to the bootstrap check in {@link MachineLearningPackageLoader#validateModelRepository} switch (scheme) { case "http": case "https": - return getHttpOrHttpsInputStream(uri, null); + return getHttpOrHttpsInputStream(uri); case "file": return getFileInputStream(uri); default: @@ -188,11 +116,6 @@ static InputStream getInputStreamFromModelRepository(URI uri) { } } - static boolean uriIsFile(URI uri) { - String scheme = uri.getScheme().toLowerCase(Locale.ROOT); - return "file".equals(scheme); - } - static VocabularyParts loadVocabulary(URI uri) { if (uri.getPath().endsWith(".json")) { try (InputStream vocabInputStream = getInputStreamFromModelRepository(uri)) { @@ -251,7 +174,7 @@ private ModelLoaderUtils() {} @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need socket connection to download") - private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestRange range) { + private static InputStream getHttpOrHttpsInputStream(URI uri) throws IOException { assert uri.getUserInfo() == null : "URI's with credentials are not supported"; @@ -263,30 +186,18 @@ private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestR PrivilegedAction privilegedHttpReader = () -> { try { HttpURLConnection conn = (HttpURLConnection) uri.toURL().openConnection(); - if (range != null) { - conn.setRequestProperty("Range", range.bytesRange()); - } switch (conn.getResponseCode()) { case HTTP_OK: - case HTTP_PARTIAL: return conn.getInputStream(); - case HTTP_MOVED_PERM: case HTTP_MOVED_TEMP: case HTTP_SEE_OTHER: throw new IllegalStateException("redirects aren't supported yet"); case HTTP_NOT_FOUND: throw new ResourceNotFoundException("{} not found", uri); - case 416: // Range not satisfiable, for some reason not in the list of constants - throw new IllegalStateException("Invalid request range [" + range.bytesRange() + "]"); default: int responseCode = conn.getResponseCode(); - throw new ElasticsearchStatusException( - "error during downloading {}. Got response code {}", - RestStatus.fromCode(responseCode), - uri, - responseCode - ); + throw new ElasticsearchStatusException("error during downloading {}", RestStatus.fromCode(responseCode), uri); } } catch (IOException e) { throw new UncheckedIOException(e); @@ -298,7 +209,7 @@ private static InputStream getHttpOrHttpsInputStream(URI uri, @Nullable RequestR @SuppressWarnings("'java.lang.SecurityManager' is deprecated and marked for removal ") @SuppressForbidden(reason = "we need load model data from a file") - static InputStream getFileInputStream(URI uri) { + private static InputStream getFileInputStream(URI uri) { SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -321,53 +232,4 @@ static InputStream getFileInputStream(URI uri) { return AccessController.doPrivileged(privilegedFileReader); } - /** - * Split a stream of size {@code sizeInBytes} into {@code numberOfStreams} +1 - * ranges aligned on {@code chunkSizeBytes} boundaries. Each range contains a - * whole number of chunks. - * The first {@code numberOfStreams} ranges will be split evenly (in terms of - * number of chunks not the byte size), the final range split - * is for the single final chunk and will be no more than {@code chunkSizeBytes} - * in size. The separate range for the final chunk is because when streaming and - * uploading a large model definition, writing the last part has to handled - * as a special case. - * @param sizeInBytes The total size of the stream - * @param numberOfStreams Divide the bulk of the size into this many streams. - * @param chunkSizeBytes The size of each chunk - * @return List of {@code numberOfStreams} + 1 ranges. - */ - static List split(long sizeInBytes, int numberOfStreams, long chunkSizeBytes) { - int numberOfChunks = (int) ((sizeInBytes + chunkSizeBytes - 1) / chunkSizeBytes); - - var ranges = new ArrayList(); - - int baseChunksPerStream = numberOfChunks / numberOfStreams; - int remainder = numberOfChunks % numberOfStreams; - long startOffset = 0; - int startChunkIndex = 0; - - for (int i = 0; i < numberOfStreams - 1; i++) { - int numChunksInStream = (i < remainder) ? baseChunksPerStream + 1 : baseChunksPerStream; - long rangeEnd = startOffset + (numChunksInStream * chunkSizeBytes) - 1; // range index is 0 based - ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksInStream)); - startOffset = rangeEnd + 1; // range is inclusive start and end - startChunkIndex += numChunksInStream; - } - - // Want the final range request to be a single chunk - if (baseChunksPerStream > 1) { - int numChunksExcludingFinal = baseChunksPerStream - 1; - long rangeEnd = startOffset + (numChunksExcludingFinal * chunkSizeBytes) - 1; - ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, numChunksExcludingFinal)); - - startOffset = rangeEnd + 1; - startChunkIndex += numChunksExcludingFinal; - } - - // The final range is a single chunk the end of which should not exceed sizeInBytes - long rangeEnd = Math.min(sizeInBytes, startOffset + (baseChunksPerStream * chunkSizeBytes)) - 1; - ranges.add(new RequestRange(startOffset, rangeEnd, startChunkIndex, 1)); - - return ranges; - } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java index d51fc236f19fd..6cdeb93d1e07d 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java @@ -77,7 +77,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A String packagedModelId = request.getPackagedModelId(); logger.debug(() -> format("Fetch package manifest for [%s] from [%s]", packagedModelId, repository)); - threadPool.executor(MachineLearningPackageLoader.MODEL_DOWNLOAD_THREADPOOL_NAME).execute(() -> { + threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME).execute(() -> { try { URI uri = ModelLoaderUtils.resolvePackageLocation(repository, packagedModelId + ModelLoaderUtils.METADATA_FILE_EXTENSION); InputStream inputStream = ModelLoaderUtils.getInputStreamFromModelRepository(uri); diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index 689a411324eb3..b0544806d52bd 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -27,7 +27,6 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskAwareRequest; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.threadpool.ThreadPool; @@ -37,12 +36,14 @@ import org.elasticsearch.xpack.core.ml.action.NodeAcknowledgedResponse; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction.Request; +import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; import java.util.Map; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -96,13 +97,11 @@ protected void masterOperation(Task task, Request request, ClusterState state, A parentTaskAssigningClient, request.getModelId(), request.getModelPackageConfig(), - downloadTask, - threadPool + downloadTask ); - var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.noop(); - - importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask); + threadPool.executor(MachineLearningPackageLoader.UTILITY_THREAD_POOL_NAME) + .execute(() -> importModel(client, taskManager, request, modelImporter, listener, downloadTask)); } catch (Exception e) { taskManager.unregister(downloadTask); listener.onFailure(e); @@ -136,38 +135,43 @@ static void importModel( ActionListener listener, Task task ) { - final String modelId = request.getModelId(); - final long relativeStartNanos = System.nanoTime(); + String modelId = request.getModelId(); + final AtomicReference exceptionRef = new AtomicReference<>(); + + try { + final long relativeStartNanos = System.nanoTime(); - logAndWriteNotificationAtLevel(auditClient, modelId, "starting model import", Level.INFO); + logAndWriteNotificationAtInfo(auditClient, modelId, "starting model import"); + + modelImporter.doImport(); - var finishListener = ActionListener.wrap(success -> { final long totalRuntimeNanos = System.nanoTime() - relativeStartNanos; - logAndWriteNotificationAtLevel( + logAndWriteNotificationAtInfo( auditClient, modelId, - format("finished model import after [%d] seconds", TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos)), - Level.INFO + format("finished model import after [%d] seconds", TimeUnit.NANOSECONDS.toSeconds(totalRuntimeNanos)) ); - listener.onResponse(AcknowledgedResponse.TRUE); - }, exception -> listener.onFailure(processException(auditClient, modelId, exception))); + } catch (ElasticsearchException e) { + recordError(auditClient, modelId, exceptionRef, e); + } catch (MalformedURLException e) { + recordError(auditClient, modelId, "an invalid URL", exceptionRef, e, RestStatus.INTERNAL_SERVER_ERROR); + } catch (URISyntaxException e) { + recordError(auditClient, modelId, "an invalid URL syntax", exceptionRef, e, RestStatus.INTERNAL_SERVER_ERROR); + } catch (IOException e) { + recordError(auditClient, modelId, "an IOException", exceptionRef, e, RestStatus.SERVICE_UNAVAILABLE); + } catch (Exception e) { + recordError(auditClient, modelId, "an Exception", exceptionRef, e, RestStatus.INTERNAL_SERVER_ERROR); + } finally { + taskManager.unregister(task); - modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task))); - } + if (request.isWaitForCompletion()) { + if (exceptionRef.get() != null) { + listener.onFailure(exceptionRef.get()); + } else { + listener.onResponse(AcknowledgedResponse.TRUE); + } - static Exception processException(Client auditClient, String modelId, Exception e) { - if (e instanceof TaskCancelledException te) { - return recordError(auditClient, modelId, te, Level.WARNING); - } else if (e instanceof ElasticsearchException es) { - return recordError(auditClient, modelId, es, Level.ERROR); - } else if (e instanceof MalformedURLException) { - return recordError(auditClient, modelId, "an invalid URL", e, Level.ERROR, RestStatus.BAD_REQUEST); - } else if (e instanceof URISyntaxException) { - return recordError(auditClient, modelId, "an invalid URL syntax", e, Level.ERROR, RestStatus.BAD_REQUEST); - } else if (e instanceof IOException) { - return recordError(auditClient, modelId, "an IOException", e, Level.ERROR, RestStatus.SERVICE_UNAVAILABLE); - } else { - return recordError(auditClient, modelId, "an Exception", e, Level.ERROR, RestStatus.INTERNAL_SERVER_ERROR); + } } } @@ -195,21 +199,32 @@ public ModelDownloadTask createTask(long id, String type, String action, TaskId }, false); } - private static Exception recordError(Client client, String modelId, ElasticsearchException e, Level level) { - String message = format("Model importing failed due to [%s]", e.getDetailedMessage()); - logAndWriteNotificationAtLevel(client, modelId, message, level); - return e; + private static void recordError(Client client, String modelId, AtomicReference exceptionRef, ElasticsearchException e) { + logAndWriteNotificationAtError(client, modelId, e.getDetailedMessage()); + exceptionRef.set(e); } - private static Exception recordError(Client client, String modelId, String failureType, Exception e, Level level, RestStatus status) { + private static void recordError( + Client client, + String modelId, + String failureType, + AtomicReference exceptionRef, + Exception e, + RestStatus status + ) { String message = format("Model importing failed due to %s [%s]", failureType, e); - logAndWriteNotificationAtLevel(client, modelId, message, level); - return new ElasticsearchStatusException(message, status, e); + logAndWriteNotificationAtError(client, modelId, message); + exceptionRef.set(new ElasticsearchStatusException(message, status, e)); + } + + private static void logAndWriteNotificationAtError(Client client, String modelId, String message) { + writeNotification(client, modelId, message, Level.ERROR); + logger.error(format("[%s] %s", modelId, message)); } - private static void logAndWriteNotificationAtLevel(Client client, String modelId, String message, Level level) { - writeNotification(client, modelId, message, level); - logger.log(level.log4jLevel(), format("[%s] %s", modelId, message)); + private static void logAndWriteNotificationAtInfo(Client client, String modelId, String message) { + writeNotification(client, modelId, message, Level.INFO); + logger.info(format("[%s] %s", modelId, message)); } private static void writeNotification(Client client, String modelId, String message, Level level) { diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java index 2e487b6a9624c..967d1b4ba4b6a 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/MachineLearningPackageLoaderTests.java @@ -7,13 +7,9 @@ package org.elasticsearch.xpack.ml.packageloader; -import org.elasticsearch.common.settings.Setting; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.PathUtils; import org.elasticsearch.test.ESTestCase; -import java.util.List; - import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -84,12 +80,4 @@ public void testValidateModelRepository() { assertEquals("xpack.ml.model_repository does not support authentication", e.getMessage()); } - - public void testThreadPoolHasSingleThread() { - var fixedThreadPool = MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY); - List> settings = fixedThreadPool.getRegisteredSettings(); - var sizeSettting = settings.stream().filter(s -> s.getKey().startsWith("xpack.ml.model_download_thread_pool")).findFirst(); - assertTrue(sizeSettting.isPresent()); - assertEquals(5, sizeSettting.get().get(Settings.EMPTY)); - } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java index 3a682fb6a5094..0afd08c70cf45 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTaskTests.java @@ -20,7 +20,14 @@ public class ModelDownloadTaskTests extends ESTestCase { public void testStatus() { - var task = testTask(); + var task = new ModelDownloadTask( + 0L, + MODEL_IMPORT_TASK_TYPE, + MODEL_IMPORT_TASK_ACTION, + downloadModelTaskDescription("foo"), + TaskId.EMPTY_TASK_ID, + Map.of() + ); task.setProgress(100, 0); var taskInfo = task.taskInfo("node", true); @@ -32,15 +39,4 @@ public void testStatus() { status = Strings.toString(taskInfo.status()); assertThat(status, containsString("{\"total_parts\":100,\"downloaded_parts\":1}")); } - - public static ModelDownloadTask testTask() { - return new ModelDownloadTask( - 0L, - MODEL_IMPORT_TASK_TYPE, - MODEL_IMPORT_TASK_ACTION, - downloadModelTaskDescription("foo"), - TaskId.EMPTY_TASK_ID, - Map.of() - ); - } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java deleted file mode 100644 index 99efb331a350c..0000000000000 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelImporterTests.java +++ /dev/null @@ -1,316 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.ml.packageloader.action; - -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.LatchedActionListener; -import org.elasticsearch.action.support.ActionTestUtils; -import org.elasticsearch.action.support.master.AcknowledgedResponse; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.common.hash.MessageDigests; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction; -import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; -import org.elasticsearch.xpack.ml.packageloader.MachineLearningPackageLoader; -import org.junit.After; -import org.junit.Before; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.net.URISyntaxException; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; - -import static org.hamcrest.Matchers.containsString; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class ModelImporterTests extends ESTestCase { - - private TestThreadPool threadPool; - - @Before - public void createThreadPool() { - threadPool = createThreadPool(MachineLearningPackageLoader.modelDownloadExecutor(Settings.EMPTY)); - } - - @After - public void closeThreadPool() { - threadPool.close(); - } - - public void testDownloadModelDefinition() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = ModelDownloadTaskTests.testTask(); - var config = mockConfigWithRepoLinks(); - var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); - - int totalParts = 5; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); - - var digest = computeDigest(modelDef); - when(config.getSha256()).thenReturn(digest); - when(config.getSize()).thenReturn(size); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); - importer.downloadModelDefinition(size, totalParts, vocab, streamers, latchedListener); - - latch.await(); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - assertEquals(totalParts - 1, task.getStatus().downloadProgress().downloadedParts()); - assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); - } - - public void testReadModelDefinitionFromFile() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = ModelDownloadTaskTests.testTask(); - var config = mockConfigWithRepoLinks(); - var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); - - int totalParts = 3; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - - var digest = computeDigest(modelDef); - when(config.getSha256()).thenReturn(digest); - when(config.getSize()).thenReturn(size); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); - - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener(ActionTestUtils.assertNoFailureListener(ignore -> {}), latch); - importer.readModelDefinitionFromFile(size, totalParts, streamChunker, vocab, latchedListener); - - latch.await(); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - assertEquals(totalParts, task.getStatus().downloadProgress().downloadedParts()); - assertEquals(totalParts, task.getStatus().downloadProgress().totalParts()); - } - - public void testSizeMismatch() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 5; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); - - var digest = computeDigest(modelDef); - when(config.getSha256()).thenReturn(digest); - when(config.getSize()).thenReturn(size - 1); // expected size and read size are different - - var exceptionHolder = new AtomicReference(); - - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("Model size does not match")); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - public void testDigestMismatch() throws InterruptedException, URISyntaxException { - var client = mockClient(false); - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 5; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 2); - - when(config.getSha256()).thenReturn("0x"); // digest is different - when(config.getSize()).thenReturn(size); - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - // Message digest can only be calculated for the file reader - var streamChunker = new ModelLoaderUtils.InputStreamChunker(new ByteArrayInputStream(modelDef), chunkSize); - importer.readModelDefinitionFromFile(size, totalParts, streamChunker, null, latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("Model sha256 checksums do not match")); - verify(client, times(totalParts)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - public void testPutFailure() throws InterruptedException, URISyntaxException { - var client = mockClient(true); // client will fail put - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 4; - int chunkSize = 10; - long size = totalParts * chunkSize; - var modelDef = modelDefinition(totalParts, chunkSize); - var streamers = mockHttpStreamChunkers(modelDef, chunkSize, 1); - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(size, totalParts, null, streamers, latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("put model part failed")); - verify(client, times(1)).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - public void testReadFailure() throws IOException, InterruptedException, URISyntaxException { - var client = mockClient(true); - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - int totalParts = 4; - int chunkSize = 10; - long size = totalParts * chunkSize; - - var streamer = mock(ModelLoaderUtils.HttpStreamChunker.class); - when(streamer.hasNext()).thenReturn(true); - when(streamer.next()).thenThrow(new IOException("stream failed")); // fail the read - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(size, totalParts, null, List.of(streamer), latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("stream failed")); - } - - @SuppressWarnings("unchecked") - public void testUploadVocabFailure() throws InterruptedException, URISyntaxException { - var client = mock(Client.class); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[2]; - listener.onFailure(new ElasticsearchStatusException("put vocab failed", RestStatus.BAD_REQUEST)); - return null; - }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); - - var task = mock(ModelDownloadTask.class); - var config = mockConfigWithRepoLinks(); - - var vocab = new ModelLoaderUtils.VocabularyParts(List.of(), List.of(), List.of()); - - var exceptionHolder = new AtomicReference(); - var latch = new CountDownLatch(1); - var latchedListener = new LatchedActionListener( - ActionTestUtils.assertNoSuccessListener(exceptionHolder::set), - latch - ); - - var importer = new ModelImporter(client, "foo", config, task, threadPool); - importer.downloadModelDefinition(100, 5, vocab, List.of(), latchedListener); - - latch.await(); - assertThat(exceptionHolder.get().getMessage(), containsString("put vocab failed")); - verify(client, times(1)).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); - verify(client, never()).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - } - - private List mockHttpStreamChunkers(byte[] modelDef, int chunkSize, int numStreams) { - var ranges = ModelLoaderUtils.split(modelDef.length, numStreams, chunkSize); - - var result = new ArrayList(ranges.size()); - for (var range : ranges) { - int len = range.numParts() * chunkSize; - var modelDefStream = new ByteArrayInputStream(modelDef, (int) range.rangeStart(), len); - result.add(new ModelLoaderUtils.HttpStreamChunker(modelDefStream, range, chunkSize)); - } - - return result; - } - - private byte[] modelDefinition(int totalParts, int chunkSize) { - var bytes = new byte[totalParts * chunkSize]; - for (int i = 0; i < totalParts; i++) { - System.arraycopy(randomByteArrayOfLength(chunkSize), 0, bytes, i * chunkSize, chunkSize); - } - return bytes; - } - - private String computeDigest(byte[] modelDef) { - var digest = MessageDigests.sha256(); - digest.update(modelDef); - return MessageDigests.toHexString(digest.digest()); - } - - @SuppressWarnings("unchecked") - private Client mockClient(boolean failPutPart) { - var client = mock(Client.class); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[2]; - if (failPutPart) { - listener.onFailure(new IllegalStateException("put model part failed")); - } else { - listener.onResponse(AcknowledgedResponse.TRUE); - } - return null; - }).when(client).execute(eq(PutTrainedModelDefinitionPartAction.INSTANCE), any(), any()); - - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse(AcknowledgedResponse.TRUE); - return null; - }).when(client).execute(eq(PutTrainedModelVocabularyAction.INSTANCE), any(), any()); - - return client; - } - - private ModelPackageConfig mockConfigWithRepoLinks() { - var config = mock(ModelPackageConfig.class); - when(config.getModelRepository()).thenReturn("https://models.models"); - when(config.getPackagedModelId()).thenReturn("my-model"); - return config; - } -} diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java index f421a7b44e7f1..661cd12f99957 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java @@ -17,7 +17,6 @@ import java.nio.charset.StandardCharsets; import static org.hamcrest.Matchers.contains; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.core.Is.is; public class ModelLoaderUtilsTests extends ESTestCase { @@ -81,13 +80,14 @@ public void testSha256AndSize() throws IOException { assertEquals(64, expectedDigest.length()); int chunkSize = randomIntBetween(100, 10_000); - int totalParts = (bytes.length + chunkSize - 1) / chunkSize; ModelLoaderUtils.InputStreamChunker inputStreamChunker = new ModelLoaderUtils.InputStreamChunker( new ByteArrayInputStream(bytes), chunkSize ); + int totalParts = (bytes.length + chunkSize - 1) / chunkSize; + for (int part = 0; part < totalParts - 1; ++part) { assertEquals(chunkSize, inputStreamChunker.next().length()); } @@ -112,40 +112,4 @@ public void testParseVocabulary() throws IOException { assertThat(parsedVocab.merges(), contains("mergefoo", "mergebar", "mergebaz")); assertThat(parsedVocab.scores(), contains(1.0, 2.0, 3.0)); } - - public void testSplitIntoRanges() { - long totalSize = randomLongBetween(10_000, 50_000_000); - int numStreams = randomIntBetween(1, 10); - int chunkSize = 1024; - var ranges = ModelLoaderUtils.split(totalSize, numStreams, chunkSize); - assertThat(ranges, hasSize(numStreams + 1)); - - int expectedNumChunks = (int) ((totalSize + chunkSize - 1) / chunkSize); - assertThat(ranges.stream().mapToInt(ModelLoaderUtils.RequestRange::numParts).sum(), is(expectedNumChunks)); - - long startBytes = 0; - int startPartIndex = 0; - for (int i = 0; i < ranges.size() - 1; i++) { - assertThat(ranges.get(i).rangeStart(), is(startBytes)); - long end = startBytes + ((long) ranges.get(i).numParts() * chunkSize) - 1; - assertThat(ranges.get(i).rangeEnd(), is(end)); - long expectedNumBytesInRange = (long) chunkSize * ranges.get(i).numParts() - 1; - assertThat(ranges.get(i).rangeEnd() - ranges.get(i).rangeStart(), is(expectedNumBytesInRange)); - assertThat(ranges.get(i).startPart(), is(startPartIndex)); - - startBytes = end + 1; - startPartIndex += ranges.get(i).numParts(); - } - - var finalRange = ranges.get(ranges.size() - 1); - assertThat(finalRange.rangeStart(), is(startBytes)); - assertThat(finalRange.rangeEnd(), is(totalSize - 1)); - assertThat(finalRange.numParts(), is(1)); - } - - public void testRangeRequestBytesRange() { - long start = randomLongBetween(0, 2 << 10); - long end = randomLongBetween(start + 1, 2 << 11); - assertEquals("bytes=" + start + "-" + end, new ModelLoaderUtils.RequestRange(start, end, 0, 1).bytesRange()); - } } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java index bca9c0b6dc6fc..1e10ea48d03db 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java @@ -7,17 +7,14 @@ package org.elasticsearch.xpack.ml.packageloader.action; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.common.notifications.Level; import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; import org.elasticsearch.xpack.core.ml.packageloader.action.LoadTrainedModelPackageAction; @@ -33,7 +30,7 @@ import static org.hamcrest.core.Is.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -42,7 +39,7 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase { private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]"; public void testSendsFinishedUploadNotification() { - var uploader = createUploader(null); + var uploader = mock(ModelImporter.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); var client = mock(Client.class); @@ -63,48 +60,41 @@ public void testSendsFinishedUploadNotification() { assertThat(notificationArg.getValue().getMessage(), CoreMatchers.containsString("finished model import after")); } - public void testSendsErrorNotificationForInternalError() throws Exception { + public void testSendsErrorNotificationForInternalError() throws URISyntaxException, IOException { ElasticsearchStatusException exception = new ElasticsearchStatusException("exception", RestStatus.INTERNAL_SERVER_ERROR); - String message = format("Model importing failed due to [%s]", exception.toString()); - assertUploadCallsOnFailure(exception, message, Level.ERROR); + + assertUploadCallsOnFailure(exception, exception.toString()); } - public void testSendsErrorNotificationForMalformedURL() throws Exception { + public void testSendsErrorNotificationForMalformedURL() throws URISyntaxException, IOException { MalformedURLException exception = new MalformedURLException("exception"); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR); } - public void testSendsErrorNotificationForURISyntax() throws Exception { + public void testSendsErrorNotificationForURISyntax() throws URISyntaxException, IOException { URISyntaxException exception = mock(URISyntaxException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an invalid URL syntax", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.BAD_REQUEST, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR); } - public void testSendsErrorNotificationForIOException() throws Exception { + public void testSendsErrorNotificationForIOException() throws URISyntaxException, IOException { IOException exception = mock(IOException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an IOException", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE, Level.ERROR); + assertUploadCallsOnFailure(exception, message, RestStatus.SERVICE_UNAVAILABLE); } - public void testSendsErrorNotificationForException() throws Exception { + public void testSendsErrorNotificationForException() throws URISyntaxException, IOException { RuntimeException exception = mock(RuntimeException.class); String message = format(MODEL_IMPORT_FAILURE_MSG_FORMAT, "an Exception", exception.toString()); - assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR, Level.ERROR); - } - - public void testSendsWarningNotificationForTaskCancelledException() throws Exception { - TaskCancelledException exception = new TaskCancelledException("cancelled"); - String message = format("Model importing failed due to [%s]", exception.toString()); - - assertUploadCallsOnFailure(exception, message, Level.WARNING); + assertUploadCallsOnFailure(exception, message, RestStatus.INTERNAL_SERVER_ERROR); } - public void testCallsOnResponseWithAcknowledgedResponse() throws Exception { + public void testCallsOnResponseWithAcknowledgedResponse() throws URISyntaxException, IOException { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -133,21 +123,18 @@ public void testDoesNotCallListenerWhenNotWaitingForCompletion() { ); } - private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception { + private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status) throws URISyntaxException, IOException { var esStatusException = new ElasticsearchStatusException(message, status, exception); - assertNotificationAndOnFailure(exception, esStatusException, message, level); + + assertNotificationAndOnFailure(exception, esStatusException, message); } - private void assertUploadCallsOnFailure(ElasticsearchException exception, String message, Level level) throws Exception { - assertNotificationAndOnFailure(exception, exception, message, level); + private void assertUploadCallsOnFailure(ElasticsearchStatusException exception, String message) throws URISyntaxException, IOException { + assertNotificationAndOnFailure(exception, exception, message); } - private void assertNotificationAndOnFailure( - Exception thrownException, - ElasticsearchException onFailureException, - String message, - Level level - ) throws Exception { + private void assertNotificationAndOnFailure(Exception thrownException, ElasticsearchStatusException onFailureException, String message) + throws URISyntaxException, IOException { var client = mock(Client.class); var taskManager = mock(TaskManager.class); var task = mock(Task.class); @@ -163,11 +150,9 @@ private void assertNotificationAndOnFailure( var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class); // 2 notifications- the starting message and the failure verify(client, times(2)).execute(eq(AuditMlNotificationAction.INSTANCE), notificationArg.capture(), any()); - var notification = notificationArg.getValue(); - assertThat(notification.getMessage(), is(message)); // the last message is captured - assertThat(notification.getLevel(), is(level)); // the last message is captured + assertThat(notificationArg.getValue().getMessage(), is(message)); // the last message is captured - var receivedException = (ElasticsearchException) failureRef.get(); + var receivedException = (ElasticsearchStatusException) failureRef.get(); assertThat(receivedException.toString(), is(onFailureException.toString())); assertThat(receivedException.status(), is(onFailureException.status())); assertThat(receivedException.getCause(), is(onFailureException.getCause())); @@ -175,18 +160,11 @@ private void assertNotificationAndOnFailure( verify(taskManager).unregister(task); } - @SuppressWarnings("unchecked") - private ModelImporter createUploader(Exception exception) { + private ModelImporter createUploader(Exception exception) throws URISyntaxException, IOException { ModelImporter uploader = mock(ModelImporter.class); - doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments()[0]; - if (exception != null) { - listener.onFailure(exception); - } else { - listener.onResponse(AcknowledgedResponse.TRUE); - } - return null; - }).when(uploader).doImport(any(ActionListener.class)); + if (exception != null) { + doThrow(exception).when(uploader).doImport(); + } return uploader; }