Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket inbound ping frames support #629

Merged
merged 28 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2cef977
Adjust padding to accommodate good enough headers and don't include …
akrambek Oct 25, 2023
d201582
Merge branch 'develop' into feature/consumer-group-cont
akrambek Oct 25, 2023
76bf9de
Merge branch 'feature/consumer-group-cont' into develop
akrambek Oct 26, 2023
29ae79c
Merge branch 'aklivity:develop' into develop
akrambek Oct 30, 2023
ec1b39e
Merge branch 'aklivity:develop' into develop
akrambek Oct 30, 2023
51a9f0e
Merge branch 'aklivity:develop' into develop
akrambek Oct 31, 2023
4394783
Merge branch 'aklivity:develop' into develop
akrambek Oct 31, 2023
e8696ce
Merge branch 'aklivity:develop' into develop
akrambek Nov 2, 2023
51c37b1
Merge branch 'aklivity:develop' into develop
akrambek Nov 2, 2023
5da5f04
Merge branch 'aklivity:develop' into develop
akrambek Nov 2, 2023
db1e17c
Merge branch 'aklivity:develop' into develop
akrambek Nov 4, 2023
40f73dc
Merge branch 'aklivity:develop' into develop
akrambek Nov 6, 2023
d1a0492
Merge branch 'aklivity:develop' into develop
akrambek Nov 23, 2023
45799ce
Merge branch 'aklivity:develop' into develop
akrambek Nov 29, 2023
1e55162
Merge branch 'aklivity:develop' into develop
akrambek Nov 30, 2023
0df6dfa
WIP
akrambek Dec 3, 2023
fedc41f
Merge branch 'aklivity:develop' into develop
akrambek Dec 4, 2023
18a8d74
Merge branch 'aklivity:develop' into develop
akrambek Dec 4, 2023
f160aad
Merge branch 'aklivity:develop' into develop
akrambek Dec 4, 2023
e0e7d5a
Merge branch 'aklivity:develop' into develop
akrambek Dec 6, 2023
175b58c
Merge branch 'develop' into bug/ws-ping
akrambek Dec 8, 2023
af50efd
WIP
akrambek Dec 8, 2023
ed59ab1
WIP server ping support
akrambek Dec 8, 2023
1092b5c
Ping client support
akrambek Dec 8, 2023
4ab11f7
Remove file
akrambek Dec 8, 2023
9f4a8a6
Merge branch 'aklivity:develop' into develop
akrambek Dec 8, 2023
8c5cac0
Merge branch 'develop' into bug/ws-ping
akrambek Dec 8, 2023
c23861f
Fix default session timeout value
akrambek Dec 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.FlushFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.HttpBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.ResetFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.SignalFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WindowFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsDataExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsEndExFW;
import io.aklivity.zilla.runtime.engine.EngineContext;
import io.aklivity.zilla.runtime.engine.binding.BindingHandler;
import io.aklivity.zilla.runtime.engine.binding.function.MessageConsumer;
import io.aklivity.zilla.runtime.engine.concurrent.Signaler;
import io.aklivity.zilla.runtime.engine.config.BindingConfig;

public final class WsClientFactory implements WsStreamFactory
Expand All @@ -77,13 +79,15 @@ public final class WsClientFactory implements WsStreamFactory
private static final String WEBSOCKET_UPGRADE = "websocket";
private static final String WEBSOCKET_VERSION_13 = "13";
private static final int MAXIMUM_HEADER_SIZE = 14;
private static final int PONG_SIGNAL_ID = 1;

private static final DirectBuffer CLOSE_PAYLOAD = new UnsafeBuffer(new byte[0]);

private final MessageDigest sha1 = initSHA1();

private final BeginFW beginRO = new BeginFW();
private final DataFW dataRO = new DataFW();
private final SignalFW signalRO = new SignalFW();
private final EndFW endRO = new EndFW();
private final AbortFW abortRO = new AbortFW();
private final FlushFW flushRO = new FlushFW();
Expand All @@ -106,6 +110,8 @@ public final class WsClientFactory implements WsStreamFactory
private final ResetFW.Builder resetRW = new ResetFW.Builder();
private final ChallengeFW.Builder challengeRW = new ChallengeFW.Builder();

private final OctetsFW.Builder payloadRW = new OctetsFW.Builder();

private final OctetsFW payloadRO = new OctetsFW();

private final HttpBeginExFW httpBeginExRO = new HttpBeginExFW();
Expand All @@ -118,9 +124,11 @@ public final class WsClientFactory implements WsStreamFactory
private final WsHeaderFW.Builder wsHeaderRW = new WsHeaderFW.Builder();

private final MutableDirectBuffer writeBuffer;
private final MutableDirectBuffer extBuffer;
private final BindingHandler streamFactory;
private final LongUnaryOperator supplyInitialId;
private final LongUnaryOperator supplyReplyId;
private final Signaler signaler;

private final Long2ObjectHashMap<WsBindingConfig> bindings;
private final int wsTypeId;
Expand All @@ -131,10 +139,12 @@ public WsClientFactory(
EngineContext context)
{
this.writeBuffer = context.writeBuffer();
this.extBuffer = new UnsafeBuffer(new byte[context.writeBuffer().capacity()]);
this.streamFactory = context.streamFactory();
this.supplyInitialId = context::supplyInitialId;
this.supplyReplyId = context::supplyReplyId;
this.bindings = new Long2ObjectHashMap<>();
this.signaler = context.signaler();
this.wsTypeId = context.supplyTypeId(WsBinding.NAME);
this.httpTypeId = context.supplyTypeId("http");
}
Expand Down Expand Up @@ -713,6 +723,7 @@ private final class WsClient

private int statusLength;
private MutableDirectBuffer status;
private int pingReceived;

private WsClient(
long originId,
Expand Down Expand Up @@ -931,6 +942,10 @@ private void onNetMessage(
final DataFW data = dataRO.wrap(buffer, index, index + length);
onNetData(data);
break;
case SignalFW.TYPE_ID:
final SignalFW signal = signalRO.wrap(buffer, index, index + length);
onNetSignal(signal);
break;
case EndFW.TYPE_ID:
final EndFW end = endRO.wrap(buffer, index, index + length);
onNetEnd(end);
Expand Down Expand Up @@ -1069,6 +1084,23 @@ private void onNetData(
}
}

private void onNetSignal(
SignalFW signal)
{
final int signalId = signal.signalId();
final long traceId = signal.traceId();
final OctetsFW payload = signal.payload();

assert signalId == PONG_SIGNAL_ID;

if (--pingReceived == 0)
{
final int reserved = payload.sizeof() + MAXIMUM_HEADER_SIZE + replyPad;

doNetData(traceId, decodeAuthorization, initialBudgetId, reserved, payload, 0x8a);
}
}

private void onNetEnd(
EndFW end)
{
Expand Down Expand Up @@ -1251,6 +1283,9 @@ private int decodeHeader(
case 0x08:
this.decodeState = this::decodeClose;
break;
case 0x09:
this.decodeState = this::decodePing;
break;
case 0x0a:
this.decodeState = this::decodePong;
break;
Expand Down Expand Up @@ -1395,6 +1430,44 @@ private int decodeClose(
}
}

private int decodePing(
final DirectBuffer buffer,
final int offset,
final int length)
{
if (payloadLength > MAXIMUM_CONTROL_FRAME_PAYLOAD_SIZE)
{
doNetReset(decodeTraceId, decodeAuthorization);
doAppAbort(decodeTraceId, decodeAuthorization, STATUS_PROTOCOL_ERROR);
return length;
}
else
{
final int decodeBytes = Math.min(length, payloadLength - payloadProgress);

OctetsFW payload = payloadRO.wrap(buffer, offset, offset + decodeBytes);

OctetsFW.Builder payloadBuilder = payloadRW.wrap(extBuffer, 0, extBuffer.capacity());
payloadBuilder.set(payload);
xor(extBuffer, 0, payload.sizeof(), maskingKey);
OctetsFW unmaskedPayload = payloadBuilder.build();

pingReceived++;
signaler.signalNow(originId, routedId, initialId, decodeTraceId, PONG_SIGNAL_ID, 0,
unmaskedPayload.value(), 0, unmaskedPayload.sizeof());

payloadProgress += decodeBytes;
maskingKey = rotateMaskingKey(maskingKey, decodeBytes);

if (payloadProgress == payloadLength)
{
this.decodeState = this::decodeHeader;
}

return decodeBytes;
}
}

private int decodePong(
final DirectBuffer buffer,
final int offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static io.aklivity.zilla.runtime.binding.ws.internal.types.codec.WsHeaderFW.STATUS_PROTOCOL_ERROR;
import static io.aklivity.zilla.runtime.binding.ws.internal.types.codec.WsHeaderFW.STATUS_UNEXPECTED_CONDITION;
import static io.aklivity.zilla.runtime.binding.ws.internal.util.WsMaskUtil.xor;
import static io.aklivity.zilla.runtime.engine.concurrent.Signaler.NO_CANCEL_ID;
import static java.nio.ByteOrder.BIG_ENDIAN;
import static java.nio.ByteOrder.nativeOrder;
import static java.nio.charset.StandardCharsets.US_ASCII;
Expand Down Expand Up @@ -56,13 +57,15 @@
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.FlushFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.HttpBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.ResetFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.SignalFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WindowFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsDataExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsEndExFW;
import io.aklivity.zilla.runtime.engine.EngineContext;
import io.aklivity.zilla.runtime.engine.binding.BindingHandler;
import io.aklivity.zilla.runtime.engine.binding.function.MessageConsumer;
import io.aklivity.zilla.runtime.engine.concurrent.Signaler;
import io.aklivity.zilla.runtime.engine.config.BindingConfig;

public final class WsServerFactory implements WsStreamFactory
Expand All @@ -74,12 +77,15 @@ public final class WsServerFactory implements WsStreamFactory
private static final String WEBSOCKET_VERSION_13 = "13";
private static final int MAXIMUM_HEADER_SIZE = 14;

private static final int PONG_SIGNAL_ID = 1;

private static final DirectBuffer CLOSE_PAYLOAD = new UnsafeBuffer(new byte[0]);

private final MessageDigest sha1 = initSHA1();

private final BeginFW beginRO = new BeginFW();
private final DataFW dataRO = new DataFW();
private final SignalFW signalRO = new SignalFW();
private final EndFW endRO = new EndFW();
private final AbortFW abortRO = new AbortFW();
private final FlushFW flushRO = new FlushFW();
Expand All @@ -98,6 +104,7 @@ public final class WsServerFactory implements WsStreamFactory
private final WsDataExFW.Builder wsDataExRW = new WsDataExFW.Builder();
private final WsEndExFW.Builder wsEndExRW = new WsEndExFW.Builder();

private final OctetsFW.Builder payloadRW = new OctetsFW.Builder();
private final WindowFW.Builder windowRW = new WindowFW.Builder();
private final ResetFW.Builder resetRW = new ResetFW.Builder();
private final ChallengeFW.Builder challengeRW = new ChallengeFW.Builder();
Expand All @@ -113,9 +120,11 @@ public final class WsServerFactory implements WsStreamFactory
private final WsHeaderFW.Builder wsHeaderRW = new WsHeaderFW.Builder();

private final MutableDirectBuffer writeBuffer;
private final MutableDirectBuffer extBuffer;
private final BindingHandler streamFactory;
private final LongUnaryOperator supplyInitialId;
private final LongUnaryOperator supplyReplyId;
private final Signaler signaler;

private final Long2ObjectHashMap<WsBindingConfig> bindings;
private final int wsTypeId;
Expand All @@ -126,10 +135,12 @@ public WsServerFactory(
EngineContext context)
{
this.writeBuffer = context.writeBuffer();
this.extBuffer = new UnsafeBuffer(new byte[context.writeBuffer().capacity()]);
this.streamFactory = context.streamFactory();
this.supplyInitialId = context::supplyInitialId;
this.supplyReplyId = context::supplyReplyId;
this.bindings = new Long2ObjectHashMap<>();
this.signaler = context.signaler();
this.wsTypeId = context.supplyTypeId(WsBinding.NAME);
this.httpTypeId = context.supplyTypeId("http");
}
Expand Down Expand Up @@ -284,6 +295,9 @@ private final class WsServer
private int replyMax;
private int replyPad;

private long pongId = NO_CANCEL_ID;
private int pingReceived;

private WsServer(
MessageConsumer receiver,
long originId,
Expand Down Expand Up @@ -506,6 +520,10 @@ private void onNetMessage(
final DataFW data = dataRO.wrap(buffer, index, index + length);
onNetData(data);
break;
case SignalFW.TYPE_ID:
final SignalFW signal = signalRO.wrap(buffer, index, index + length);
onNetSignal(signal);
break;
case EndFW.TYPE_ID:
final EndFW end = endRO.wrap(buffer, index, index + length);
onNetEnd(end);
Expand Down Expand Up @@ -600,6 +618,23 @@ private void onNetData(
}
}

private void onNetSignal(
SignalFW signal)
{
final int signalId = signal.signalId();
final long traceId = signal.traceId();
final OctetsFW payload = signal.payload();

assert signalId == PONG_SIGNAL_ID;

if (--pingReceived == 0)
{
final int reserved = payload.sizeof() + MAXIMUM_HEADER_SIZE + replyPad;

doNetData(traceId, decodeAuthorization, replyBudgetId, reserved, payload, 0x8a);
}
}

private void onNetEnd(
EndFW end)
{
Expand Down Expand Up @@ -747,6 +782,9 @@ private int decodeHeader(
case 0x08:
this.decodeState = this::decodeClose;
break;
case 0x09:
this.decodeState = this::decodePing;
break;
case 0x0a:
this.decodeState = this::decodePong;
break;
Expand Down Expand Up @@ -904,6 +942,44 @@ private int decodePong(
}
}

private int decodePing(
final DirectBuffer buffer,
final int offset,
final int length)
{
if (payloadLength > MAXIMUM_CONTROL_FRAME_PAYLOAD_SIZE)
{
doNetReset(decodeTraceId, decodeAuthorization);
stream.doAppAbort(decodeTraceId, decodeAuthorization, STATUS_PROTOCOL_ERROR);
return length;
}
else
{
final int decodeBytes = Math.min(length, payloadLength - payloadProgress);

OctetsFW payload = payloadRO.wrap(buffer, offset, offset + decodeBytes);

OctetsFW.Builder payloadBuilder = payloadRW.wrap(extBuffer, 0, extBuffer.capacity());
payloadBuilder.set(payload);
xor(extBuffer, 0, payload.sizeof(), maskingKey);
OctetsFW unmaskedPayload = payloadBuilder.build();

pingReceived++;
signaler.signalNow(originId, routedId, replyId, decodeTraceId, PONG_SIGNAL_ID, 0,
unmaskedPayload.value(), 0, unmaskedPayload.sizeof());

payloadProgress += decodeBytes;
maskingKey = rotateMaskingKey(maskingKey, decodeBytes);

if (payloadProgress == payloadLength)
{
this.decodeState = this::decodeHeader;
}

return decodeBytes;
}
}

private int decodeUnexpected(
final DirectBuffer directBuffer,
final int offset,
Expand Down
Loading