Skip to content

Commit

Permalink
[ML][Inference] stream inflate to parser + throw when byte limit is r…
Browse files Browse the repository at this point in the history
…eached (#51644)

Three fixes for when the `compressed_definition` is utilized on PUT

* Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously)
* Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out
* Throw when the maximum bytes are reach indicating that is why the request was rejected
  • Loading branch information
benwtrent authored Jan 30, 2020
1 parent 018c18e commit 8ea9aa2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.monitor.jvm.JvmInfo;
import org.elasticsearch.xpack.core.ml.inference.utils.SimpleBoundedInputStream;

import java.io.IOException;
Expand All @@ -33,7 +34,10 @@
*/
public final class InferenceToXContentCompressor {
private static final int BUFFER_SIZE = 4096;
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
private static final long MAX_INFLATED_BYTES = Math.min(
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
1_000_000_000); // 1 gb maximum

private InferenceToXContentCompressor() {}

Expand All @@ -45,33 +49,34 @@ public static <T extends ToXContentObject> String deflate(T objectToCompress) th
static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
try(XContentParser parser = XContentHelper.createParser(xContentRegistry,
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES),
XContentType.JSON)) {
inflate(compressedString, MAX_INFLATED_BYTES))) {
return parserFunction.apply(parser);
}
}

static Map<String, Object> inflateToMap(String compressedString) throws IOException {
// Don't need the xcontent registry as we are not deflating named objects.
try(XContentParser parser = XContentHelper.createParser(NamedXContentRegistry.EMPTY,
try(XContentParser parser = JsonXContent.jsonXContent.createParser(NamedXContentRegistry.EMPTY,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES),
XContentType.JSON)) {
inflate(compressedString, MAX_INFLATED_BYTES))) {
return parser.mapOrdered();
}
}

static BytesReference inflate(String compressedString, long streamSize) throws IOException {
static InputStream inflate(String compressedString, long streamSize) throws IOException {
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8));
// If the compressed length is already too large, it make sense that the inflated length would be as well
// In the extremely small string case, the compressed data could actually be longer than the compressed stream
if (compressedBytes.length > Math.max(100L, streamSize)) {
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]");
}
InputStream gzipStream = new GZIPInputStream(new BytesArray(compressedBytes).streamInput(), BUFFER_SIZE);
InputStream inflateStream = new SimpleBoundedInputStream(gzipStream, streamSize);
return Streams.readFully(inflateStream);
return new SimpleBoundedInputStream(gzipStream, streamSize);
}

//Public for testing (for now)
public static String deflate(BytesReference reference) throws IOException {
private static String deflate(BytesReference reference) throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
try (OutputStream compressedOutput = new GZIPOutputStream(out, BUFFER_SIZE)) {
reference.writeTo(compressedOutput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
this.maxBytes = maxBytes;
}


/**
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
* @return The byte read. -1 on internal stream completion or when maxBytes is exceeded.
* @throws IOException on failure
* @return The byte read.
* @throws IOException on failure or when byte limit is exceeded
*/
@Override
public int read() throws IOException {
// We have reached the maximum, signal stream completion.
if (numBytes >= maxBytes) {
return -1;
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
}
numBytes++;
return in.read();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
*/
package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.hamcrest.Matchers.equalTo;

Expand All @@ -33,20 +35,22 @@ public void testInflateAndDeflate() throws IOException {
}

public void testInflateTooLargeStream() throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder().build();
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(100)
.collect(Collectors.toList()))
.build();
String firstDeflate = InferenceToXContentCompressor.deflate(definition);
BytesReference inflatedBytes = InferenceToXContentCompressor.inflate(firstDeflate, 10L);
assertThat(inflatedBytes.length(), equalTo(10));
try(XContentParser parser = XContentHelper.createParser(xContentRegistry(),
LoggingDeprecationHandler.INSTANCE,
inflatedBytes,
XContentType.JSON)) {
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser, true));
}
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
IOException ex = expectThrows(IOException.class,
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
}

public void testInflateGarbage() {
expectThrows(IOException.class, () -> InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L));
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
}

@Override
Expand Down

0 comments on commit 8ea9aa2

Please sign in to comment.