diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java index 96b4eb9d6493a..1b06402952603 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java @@ -9,7 +9,6 @@ import org.elasticsearch.common.CheckedFunction; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; @@ -17,6 +16,8 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.monitor.jvm.JvmInfo; import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream; import java.io.IOException; @@ -33,7 +34,10 @@ */ public final class InferenceToXContentCompressor { private static final int BUFFER_SIZE = 4096; - private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum + // Either 10% of the configured JVM heap, or 1 GB, which ever is smaller + private static final long MAX_INFLATED_BYTES = Math.min( + (long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()), + 1_000_000_000); // 1 gb maximum private InferenceToXContentCompressor() {} @@ -45,33 +49,34 @@ public static String deflate(T objectToCompress) th static T inflate(String compressedString, CheckedFunction parserFunction, NamedXContentRegistry xContentRegistry) throws IOException { - try(XContentParser parser = XContentHelper.createParser(xContentRegistry, + try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, MAX_INFLATED_BYTES), - XContentType.JSON)) { + inflate(compressedString, MAX_INFLATED_BYTES))) { return parserFunction.apply(parser); } } static Map inflateToMap(String compressedString) throws IOException { // Don't need the xcontent registry as we are not deflating named objects. - try(XContentParser parser = XContentHelper.createParser(NamedXContentRegistry.EMPTY, + try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, - inflate(compressedString, MAX_INFLATED_BYTES), - XContentType.JSON)) { + inflate(compressedString, MAX_INFLATED_BYTES))) { return parser.mapOrdered(); } } - static BytesReference inflate(String compressedString, long streamSize) throws IOException { + static InputStream inflate(String compressedString, long streamSize) throws IOException { byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); + // If the compressed length is already too large, it make sense that the inflated length would be as well + // In the extremely small string case, the compressed data could actually be longer than the compressed stream + if (compressedBytes.length > Math.max(100L, streamSize)) { + throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]"); + } InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE); - InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize); - return Streams.readFully(inflateStream); + return new SimpleBoundedInputStream(gzipStream, streamSize); } - //Public for testing (for now) - public static String deflate(BytesReference reference) throws IOException { + private static String deflate(BytesReference reference) throws IOException { BytesStreamOutput out = new BytesStreamOutput(); try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) { reference.writeTo(compressedOutput); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/SimpleBoundedInputStream.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/SimpleBoundedInputStream.java index 5eb875e0065d5..1d845d3b33245 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/SimpleBoundedInputStream.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/SimpleBoundedInputStream.java @@ -28,17 +28,16 @@ public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) { this.maxBytes = maxBytes; } - /** * A simple wrapper around the injected input stream that restricts the total number of bytes able to be read. - * @return The byte read. -1 on internal stream completion or when maxBytes is exceeded. - * @throws IOException on failure + * @return The byte read. + * @throws IOException on failure or when byte limit is exceeded */ @Override public int read() throws IOException { // We have reached the maximum, signal stream completion. if (numBytes >= maxBytes) { - return -1; + throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]"); } numBytes++; return in.read(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java index 1f12fbf393ca1..099c9ea4465d1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressorTests.java @@ -5,15 +5,17 @@ */ package org.elasticsearch.xpack.core.ml.inference; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.io.Streams; import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -33,20 +35,22 @@ public void testInflateAndDeflate() throws IOException { } public void testInflateTooLargeStream() throws IOException { - TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build(); + TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder() + .setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(), + OneHotEncodingTests.createRandom(), + TargetMeanEncodingTests.createRandom())) + .limit(100) + .collect(Collectors.toList())) + .build(); String firstDeflate = InferenceToXContentCompressor.deflate(definition); - BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L); - assertThat(inflatedBytes.length(), equalTo(10)); - try(XContentParser parser = XContentHelper.createParser(xContentRegistry(), - LoggingDeprecationHandler.INSTANCE, - inflatedBytes, - XContentType.JSON)) { - expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser, true)); - } + int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10; + IOException ex = expectThrows(IOException.class, + () -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max))); + assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]")); } public void testInflateGarbage() { - expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)); + expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L))); } @Override