Skip to content

Commit

Permalink
[ML] Remove regex (elastic#113210)
Browse files Browse the repository at this point in the history
Regex is having trouble parsing some of the larger UTF8 characters, so
instead we are just going to use our non-regex parser.

Fix elastic#113179
Fix elastic#113148

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
prwhelan and elasticmachine committed Sep 23, 2024
1 parent 080ee4c commit c16f372
Showing 1 changed file with 17 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand All @@ -59,7 +62,6 @@
import java.util.concurrent.Flow;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Predicate;
Expand All @@ -80,9 +82,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
private static final String NO_STREAM_ROUTE = "/_inference_no_stream";
private static final Exception expectedException = new IllegalStateException("hello there");
private static final String expectedExceptionAsServerSentEvent = """
\uFEFF\
event: error
data: {\
{\
"error":{"root_cause":[{"type":"illegal_state_exception","reason":"hello there",\
"caused_by":{"type":"illegal_state_exception","reason":"hello there"}}],\
"type":"illegal_state_exception","reason":"hello there"},"status":500\
Expand Down Expand Up @@ -323,30 +323,16 @@ protected void releaseResources() {}
}

private static class RandomStringCollector {
private static final Pattern jsonPattern = Pattern.compile("^\uFEFFevent: message\ndata: \\{.*}$");
private static final Pattern endPattern = Pattern.compile("^\uFEFFevent: message\ndata: \\[DONE\\]$");
private final AtomicBoolean hasDoneChunk = new AtomicBoolean(false);
private final Deque<String> stringsVerified = new LinkedBlockingDeque<>();
private volatile String previousTokens = "";
private final ServerSentEventParser sseParser = new ServerSentEventParser();

private void collect(String str) throws IOException {
str = previousTokens + str;
String[] events = str.split("\n\n", -1);
for (var i = 0; i < events.length - 1; i++) {
var line = events[i];
if (jsonPattern.matcher(line).matches() || expectedExceptionAsServerSentEvent.equals(line)) {
stringsVerified.offer(line);
} else if (endPattern.matcher(line).matches()) {
hasDoneChunk.set(true);
} else {
throw new IOException("Line does not match expected JSON message or DONE message. Line: " + line);
}
}

previousTokens = events[events.length - 1];
if (endPattern.matcher(previousTokens.trim()).matches()) {
hasDoneChunk.set(true);
}
sseParser.parse(str.getBytes(StandardCharsets.UTF_8))
.stream()
.filter(event -> event.name() == ServerSentEventField.DATA)
.filter(ServerSentEvent::hasValue)
.map(ServerSentEvent::value)
.forEach(stringsVerified::offer);
}
}

Expand All @@ -363,8 +349,8 @@ public void testResponse() {

var response = callAsync(request);
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_OK));
assertThat(collector.stringsVerified.size(), equalTo(expectedTestCount));
assertThat(collector.hasDoneChunk.get(), equalTo(true));
assertThat(collector.stringsVerified.size(), equalTo(expectedTestCount + 1)); // normal payload count + done byte
assertThat(collector.stringsVerified.peekLast(), equalTo("[DONE]"));
}

private Response callAsync(Request request) {
Expand Down Expand Up @@ -409,10 +395,9 @@ public void testOnFailure() throws IOException {
} catch (ResponseException e) {
var response = e.getResponse();
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_INTERNAL_SERVER_ERROR));
assertThat(
EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8),
equalTo(expectedExceptionAsServerSentEvent + "\n\n")
);
assertThat(EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8), equalTo("""
\uFEFFevent: error
data:\s""" + expectedExceptionAsServerSentEvent + "\n\n"));
}
}

Expand All @@ -431,7 +416,7 @@ public void testErrorMidStream() {
var response = callAsync(request);
assertThat(response.getStatusLine().getStatusCode(), is(HttpStatus.SC_OK)); // error still starts with 200-OK
assertThat(collector.stringsVerified.size(), equalTo(expectedTestCount + 1)); // normal payload count + last error byte
assertThat("DONE chunk is not sent on error", collector.hasDoneChunk.get(), equalTo(false));
assertThat("DONE chunk is not sent on error", collector.stringsVerified.stream().anyMatch("[DONE]"::equals), equalTo(false));
assertThat(collector.stringsVerified.getLast(), equalTo(expectedExceptionAsServerSentEvent));
}

Expand Down

0 comments on commit c16f372

Please sign in to comment.