-
Notifications
You must be signed in to change notification settings - Fork 25k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML][Inference] stream inflate to parser + throw when byte limit is reached #51644
[ML][Inference] stream inflate to parser + throw when byte limit is reached #51644
Conversation
Pinging @elastic/ml-core (:ml) |
@@ -33,7 +34,10 @@ | |||
*/ | |||
public final class InferenceToXContentCompressor { | |||
private static final int BUFFER_SIZE = 4096; | |||
private static final long MAX_INFLATED_BYTES = 1_000_000_000; // 1 gb maximum | |||
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is 10% a good number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually we could have a dynamic limit that's integrated with the real memory circuit breaker (#31767). Maybe we could reserve a percentage of free memory and use that as the dynamic limit for a given request, then give back that reservation after finding out the actual size required. That's something to investigate for 7.7 or 7.8.
However, I think 10% is an OK first step for 7.6 to reduce the risk of someone accidentally triggering an OOM on a node with a small heap.
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); | ||
if (compressedBytes.length > streamSize) { | ||
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize +"]"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize +"]"); | |
throw new IOException("compressed stream is longer than maximum allowed bytes [" + streamSize + "]"); |
return parser.mapOrdered(); | ||
} | ||
} | ||
|
||
static BytesReference inflate(String compressedString, long streamSize) throws IOException { | ||
static InputStream inflate(String compressedString, long streamSize) throws IOException { | ||
// If the compressed length is already too large, it make sense that the inflated length would be as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move this line after line 70 (compressedBytes
) so that it is clearly visible that it refers to the if
check?
@@ -38,6 +39,9 @@ public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) { | |||
public int read() throws IOException { | |||
// We have reached the maximum, signal stream completion. | |||
if (numBytes >= maxBytes) { | |||
if (throwWhenExceeded) { | |||
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes +"]"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes +"]"); | |
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]"); |
@@ -38,6 +39,9 @@ public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) { | |||
public int read() throws IOException { | |||
// We have reached the maximum, signal stream completion. | |||
if (numBytes >= maxBytes) { | |||
if (throwWhenExceeded) { | |||
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes +"]"); | |||
} | |||
return -1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we sometimes throw and sometimes return -1
? Would it be possible to have only one exit point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
XContentType.JSON)) { | ||
expectThrows(IOException.class, () -> TrainedModelConfig.fromXContent(parser, true)); | ||
} | ||
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, 10L))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also verify that the IOException has the message containing "input stream exceeded maximum bytes"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Adjusting the code for my comment is not essential in the PR, but if you make any other change to the PR before merging then you might as well make my change too.
byte[] compressedBytes = Base64.getDecoder().decode(compressedString.getBytes(StandardCharsets.UTF_8)); | ||
// If the compressed length is already too large, it make sense that the inflated length would be as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment isn't true in general. If you compress a very small string then the compressed size is bigger than the original. For example echo a | gzip -9 | wc -c
returns 22.
The assumption is OK with the sort of streamSize
values this method is going to be called with given the current code, so it's not essential to change now, but you could make it something like if (compressedBytes.length > Math.max(100L, streamSize))
in case it ever needs to cope with an extreme edge case in the future.
Also, it would be good to adjust the comment to acknowledge the edge case.
…eached (elastic#51644) Three fixes for when the `compressed_definition` is utilized on PUT * Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously) * Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out * Throw when the maximum bytes are reach indicating that is why the request was rejected
…eached (elastic#51644) Three fixes for when the `compressed_definition` is utilized on PUT * Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously) * Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out * Throw when the maximum bytes are reach indicating that is why the request was rejected
…eached (#51644) (#51681) Three fixes for when the `compressed_definition` is utilized on PUT * Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously) * Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out * Throw when the maximum bytes are reach indicating that is why the request was rejected
…eached (#51644) (#51679) Three fixes for when the `compressed_definition` is utilized on PUT * Update the inflate byte limit to be the minimum of 10% the max heap, or 1GB (what it was previously) * Stream data directly to the JSON parser, so if it is invalid, we don't have to inflate the whole stream to find out * Throw when the maximum bytes are reach indicating that is why the request was rejected
Three fixes for when the
compressed_definition
is utilized on PUT