Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: emit an error reason before closing websocket #7390

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

package io.confluent.ksql.rest.server.resources.streaming;

import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.rest.ApiJsonMapper;
import io.vertx.core.http.ServerWebSocket;
import java.nio.charset.StandardCharsets;
import org.slf4j.Logger;
Expand All @@ -37,7 +39,12 @@ static void closeSilently(
final int code,
final String message) {
try {
webSocket.close((short) code, truncate(message));
final ImmutableMap<String, String> finalMessage = ImmutableMap.of(
"error",
message != null ? message : ""
);
final String json = ApiJsonMapper.INSTANCE.get().writeValueAsString(finalMessage);
webSocket.writeFinalTextFrame(json).close((short) code, truncate(message));
} catch (final Exception e) {
LOG.info("Exception caught closing websocket", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ public void shouldExecutePushQueryThatReturnsStreamOverWebSocketWithV1ContentTyp
);

// Then:
assertThat(messages, hasSize(HEADER + LIMIT));
assertThat(messages, hasSize(HEADER + LIMIT + 1));
assertValidJsonMessages(messages);
assertThat(messages.get(0), is("["
+ "{\"name\":\"PAGEID\",\"schema\":{\"type\":\"STRING\",\"fields\":null,\"memberSchema\":null}},"
Expand All @@ -273,7 +273,7 @@ public void shouldExecutePushQueryThatReturnsTableOverWebSocketWithV1ContentType
);

// Then:
assertThat(messages, hasSize(HEADER + LIMIT));
assertThat(messages, hasSize(HEADER + LIMIT + 1));
assertValidJsonMessages(messages);
assertThat(messages.get(0), is("["
+ "{\"name\":\"VAL\",\"schema\":{\"type\":\"STRING\",\"fields\":null,\"memberSchema\":null}}"
Expand All @@ -296,7 +296,7 @@ public void shouldExecutePushQueryThatReturnsStreamOverWebSocketWithJsonContentT
);

// Then:
assertThat(messages, hasSize(HEADER + LIMIT));
assertThat(messages, hasSize(HEADER + LIMIT + 1));
assertValidJsonMessages(messages);
assertThat(messages.get(0), is("["
+ "{\"name\":\"PAGEID\",\"schema\":{\"type\":\"STRING\",\"fields\":null,\"memberSchema\":null}},"
Expand All @@ -321,7 +321,7 @@ public void shouldExecutePushQueryThatReturnsTableOverWebSocketWithJsonContentTy
);

// Then:
assertThat(messages, hasSize(HEADER + LIMIT));
assertThat(messages, hasSize(HEADER + LIMIT + 1));
assertValidJsonMessages(messages);
assertThat(messages.get(0), is("["
+ "{\"name\":\"VAL\",\"schema\":{\"type\":\"STRING\",\"fields\":null,\"memberSchema\":null}}"
Expand Down Expand Up @@ -510,7 +510,7 @@ public void shouldExecutePullQueryOverWebSocketWithV1ContentType() {
);

// Then:
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 1));
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 2));
assertValidJsonMessages(messages);
assertThat(messages.get(0),
is("["
Expand All @@ -531,7 +531,7 @@ public void shouldExecutePullQueryOverWebSocketWithJsonContentType() {
);

// Then:
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 1));
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 2));
assertValidJsonMessages(messages);
assertThat(messages.get(0),
is("["
Expand All @@ -552,7 +552,7 @@ public void shouldReturnCorrectSchemaForPullQueryWithOnlyKeyInSelect() {
);

// Then:
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 1));
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 2));
assertValidJsonMessages(messages);
assertThat(messages.get(0),
is("["
Expand All @@ -572,7 +572,7 @@ public void shouldReturnCorrectSchemaForPullQueryWithOnlyValueColumnInSelect() {
);

// Then:
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 1));
final List<String> messages = assertThatEventually(call, hasSize(HEADER + 2));
assertValidJsonMessages(messages);
assertThat(messages.get(0),
is("["
Expand Down Expand Up @@ -664,7 +664,7 @@ public void shouldPrintTopicOverWebSocket() {
MediaType.APPLICATION_JSON);

// Then:
assertThat(messages, hasSize(LIMIT));
assertThat(messages, hasSize(LIMIT + 1));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import io.vertx.core.http.ServerWebSocket;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
Expand All @@ -42,12 +44,18 @@ public class SessionUtilTest {
@Captor
private ArgumentCaptor<String> reasonCaptor;

@Before
public void setUp() {
when(websocket.writeFinalTextFrame(any(String.class))).thenReturn(websocket);
}

@Test
public void shouldCloseQuietly() throws Exception {
// Given:
doThrow(new RuntimeException("Boom")).when(websocket)
.close(any(Short.class), any(String.class));


// When:
SessionUtil.closeSilently(websocket, INVALID_MESSAGE_TYPE.code(), "reason");

Expand All @@ -65,6 +73,7 @@ public void shouldNotTruncateShortReasons() throws Exception {
SessionUtil.closeSilently(websocket, INVALID_MESSAGE_TYPE.code(), reason);

// Then:
verify(websocket).writeFinalTextFrame(any(String.class));
verify(websocket).close(codeCaptor.capture(), reasonCaptor.capture());
assertThat(reasonCaptor.getValue(), is(reason));
}
Expand All @@ -80,6 +89,7 @@ public void shouldTruncateMessageLongerThanCloseReasonAllows() throws Exception
SessionUtil.closeSilently(websocket, INVALID_MESSAGE_TYPE.code(), reason);

// Then:
verify(websocket).writeFinalTextFrame(any(String.class));
verify(websocket).close(codeCaptor.capture(), reasonCaptor.capture());
assertThat(reasonCaptor.getValue(), is(
"A long message that is longer than the maximum size that the CloseReason class "
Expand All @@ -99,6 +109,7 @@ public void shouldTruncateLongMessageWithMultiByteChars() throws Exception {
SessionUtil.closeSilently(websocket, INVALID_MESSAGE_TYPE.code(), reason);

// Then:
verify(websocket).writeFinalTextFrame(any(String.class));
verify(websocket).close(codeCaptor.capture(), reasonCaptor.capture());
assertThat(reasonCaptor.getValue(), is(
"A long message that is longer than the maximum size that the CloseReason class will "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ public void shouldSerializeToJson() {
final String jsonRequest = serialize(A_REQUEST);

// Then:
assertThat(jsonRequest, is(A_JSON_REQUEST_WITH_NULL_COMMAND_NUMBER));
assertThat(deserialize(jsonRequest), is(deserialize(A_JSON_REQUEST_WITH_NULL_COMMAND_NUMBER)));
}

@Test
Expand All @@ -168,7 +168,7 @@ public void shouldSerializeToJsonWithCommandNumber() {
final String jsonRequest = serialize(A_REQUEST_WITH_COMMAND_NUMBER);

// Then:
assertThat(jsonRequest, is(A_JSON_REQUEST_WITH_COMMAND_NUMBER));
assertThat(deserialize(jsonRequest), is(deserialize(A_JSON_REQUEST_WITH_COMMAND_NUMBER)));
}

@Test
Expand Down