Skip to content

Commit

Permalink
WebSocket inbound ping frames support (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
akrambek authored Dec 8, 2023
1 parent 9bd1de6 commit e71c430
Show file tree
Hide file tree
Showing 20 changed files with 586 additions and 122 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.FlushFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.HttpBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.ResetFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.SignalFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WindowFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsDataExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsEndExFW;
import io.aklivity.zilla.runtime.engine.EngineContext;
import io.aklivity.zilla.runtime.engine.binding.BindingHandler;
import io.aklivity.zilla.runtime.engine.binding.function.MessageConsumer;
import io.aklivity.zilla.runtime.engine.concurrent.Signaler;
import io.aklivity.zilla.runtime.engine.config.BindingConfig;

public final class WsClientFactory implements WsStreamFactory
Expand All @@ -77,13 +79,15 @@ public final class WsClientFactory implements WsStreamFactory
private static final String WEBSOCKET_UPGRADE = "websocket";
private static final String WEBSOCKET_VERSION_13 = "13";
private static final int MAXIMUM_HEADER_SIZE = 14;
private static final int PONG_SIGNAL_ID = 1;

private static final DirectBuffer CLOSE_PAYLOAD = new UnsafeBuffer(new byte[0]);

private final MessageDigest sha1 = initSHA1();

private final BeginFW beginRO = new BeginFW();
private final DataFW dataRO = new DataFW();
private final SignalFW signalRO = new SignalFW();
private final EndFW endRO = new EndFW();
private final AbortFW abortRO = new AbortFW();
private final FlushFW flushRO = new FlushFW();
Expand All @@ -106,6 +110,8 @@ public final class WsClientFactory implements WsStreamFactory
private final ResetFW.Builder resetRW = new ResetFW.Builder();
private final ChallengeFW.Builder challengeRW = new ChallengeFW.Builder();

private final OctetsFW.Builder payloadRW = new OctetsFW.Builder();

private final OctetsFW payloadRO = new OctetsFW();

private final HttpBeginExFW httpBeginExRO = new HttpBeginExFW();
Expand All @@ -118,9 +124,11 @@ public final class WsClientFactory implements WsStreamFactory
private final WsHeaderFW.Builder wsHeaderRW = new WsHeaderFW.Builder();

private final MutableDirectBuffer writeBuffer;
private final MutableDirectBuffer extBuffer;
private final BindingHandler streamFactory;
private final LongUnaryOperator supplyInitialId;
private final LongUnaryOperator supplyReplyId;
private final Signaler signaler;

private final Long2ObjectHashMap<WsBindingConfig> bindings;
private final int wsTypeId;
Expand All @@ -131,10 +139,12 @@ public WsClientFactory(
EngineContext context)
{
this.writeBuffer = context.writeBuffer();
this.extBuffer = new UnsafeBuffer(new byte[context.writeBuffer().capacity()]);
this.streamFactory = context.streamFactory();
this.supplyInitialId = context::supplyInitialId;
this.supplyReplyId = context::supplyReplyId;
this.bindings = new Long2ObjectHashMap<>();
this.signaler = context.signaler();
this.wsTypeId = context.supplyTypeId(WsBinding.NAME);
this.httpTypeId = context.supplyTypeId("http");
}
Expand Down Expand Up @@ -713,6 +723,7 @@ private final class WsClient

private int statusLength;
private MutableDirectBuffer status;
private int pingReceived;

private WsClient(
long originId,
Expand Down Expand Up @@ -931,6 +942,10 @@ private void onNetMessage(
final DataFW data = dataRO.wrap(buffer, index, index + length);
onNetData(data);
break;
case SignalFW.TYPE_ID:
final SignalFW signal = signalRO.wrap(buffer, index, index + length);
onNetSignal(signal);
break;
case EndFW.TYPE_ID:
final EndFW end = endRO.wrap(buffer, index, index + length);
onNetEnd(end);
Expand Down Expand Up @@ -1069,6 +1084,23 @@ private void onNetData(
}
}

private void onNetSignal(
SignalFW signal)
{
final int signalId = signal.signalId();
final long traceId = signal.traceId();
final OctetsFW payload = signal.payload();

assert signalId == PONG_SIGNAL_ID;

if (--pingReceived == 0)
{
final int reserved = payload.sizeof() + MAXIMUM_HEADER_SIZE + replyPad;

doNetData(traceId, decodeAuthorization, initialBudgetId, reserved, payload, 0x8a);
}
}

private void onNetEnd(
EndFW end)
{
Expand Down Expand Up @@ -1251,6 +1283,9 @@ private int decodeHeader(
case 0x08:
this.decodeState = this::decodeClose;
break;
case 0x09:
this.decodeState = this::decodePing;
break;
case 0x0a:
this.decodeState = this::decodePong;
break;
Expand Down Expand Up @@ -1395,6 +1430,44 @@ private int decodeClose(
}
}

private int decodePing(
final DirectBuffer buffer,
final int offset,
final int length)
{
if (payloadLength > MAXIMUM_CONTROL_FRAME_PAYLOAD_SIZE)
{
doNetReset(decodeTraceId, decodeAuthorization);
doAppAbort(decodeTraceId, decodeAuthorization, STATUS_PROTOCOL_ERROR);
return length;
}
else
{
final int decodeBytes = Math.min(length, payloadLength - payloadProgress);

OctetsFW payload = payloadRO.wrap(buffer, offset, offset + decodeBytes);

OctetsFW.Builder payloadBuilder = payloadRW.wrap(extBuffer, 0, extBuffer.capacity());
payloadBuilder.set(payload);
xor(extBuffer, 0, payload.sizeof(), maskingKey);
OctetsFW unmaskedPayload = payloadBuilder.build();

pingReceived++;
signaler.signalNow(originId, routedId, initialId, decodeTraceId, PONG_SIGNAL_ID, 0,
unmaskedPayload.value(), 0, unmaskedPayload.sizeof());

payloadProgress += decodeBytes;
maskingKey = rotateMaskingKey(maskingKey, decodeBytes);

if (payloadProgress == payloadLength)
{
this.decodeState = this::decodeHeader;
}

return decodeBytes;
}
}

private int decodePong(
final DirectBuffer buffer,
final int offset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static io.aklivity.zilla.runtime.binding.ws.internal.types.codec.WsHeaderFW.STATUS_PROTOCOL_ERROR;
import static io.aklivity.zilla.runtime.binding.ws.internal.types.codec.WsHeaderFW.STATUS_UNEXPECTED_CONDITION;
import static io.aklivity.zilla.runtime.binding.ws.internal.util.WsMaskUtil.xor;
import static io.aklivity.zilla.runtime.engine.concurrent.Signaler.NO_CANCEL_ID;
import static java.nio.ByteOrder.BIG_ENDIAN;
import static java.nio.ByteOrder.nativeOrder;
import static java.nio.charset.StandardCharsets.US_ASCII;
Expand Down Expand Up @@ -56,13 +57,15 @@
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.FlushFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.HttpBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.ResetFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.SignalFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WindowFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsBeginExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsDataExFW;
import io.aklivity.zilla.runtime.binding.ws.internal.types.stream.WsEndExFW;
import io.aklivity.zilla.runtime.engine.EngineContext;
import io.aklivity.zilla.runtime.engine.binding.BindingHandler;
import io.aklivity.zilla.runtime.engine.binding.function.MessageConsumer;
import io.aklivity.zilla.runtime.engine.concurrent.Signaler;
import io.aklivity.zilla.runtime.engine.config.BindingConfig;

public final class WsServerFactory implements WsStreamFactory
Expand All @@ -74,12 +77,15 @@ public final class WsServerFactory implements WsStreamFactory
private static final String WEBSOCKET_VERSION_13 = "13";
private static final int MAXIMUM_HEADER_SIZE = 14;

private static final int PONG_SIGNAL_ID = 1;

private static final DirectBuffer CLOSE_PAYLOAD = new UnsafeBuffer(new byte[0]);

private final MessageDigest sha1 = initSHA1();

private final BeginFW beginRO = new BeginFW();
private final DataFW dataRO = new DataFW();
private final SignalFW signalRO = new SignalFW();
private final EndFW endRO = new EndFW();
private final AbortFW abortRO = new AbortFW();
private final FlushFW flushRO = new FlushFW();
Expand All @@ -98,6 +104,7 @@ public final class WsServerFactory implements WsStreamFactory
private final WsDataExFW.Builder wsDataExRW = new WsDataExFW.Builder();
private final WsEndExFW.Builder wsEndExRW = new WsEndExFW.Builder();

private final OctetsFW.Builder payloadRW = new OctetsFW.Builder();
private final WindowFW.Builder windowRW = new WindowFW.Builder();
private final ResetFW.Builder resetRW = new ResetFW.Builder();
private final ChallengeFW.Builder challengeRW = new ChallengeFW.Builder();
Expand All @@ -113,9 +120,11 @@ public final class WsServerFactory implements WsStreamFactory
private final WsHeaderFW.Builder wsHeaderRW = new WsHeaderFW.Builder();

private final MutableDirectBuffer writeBuffer;
private final MutableDirectBuffer extBuffer;
private final BindingHandler streamFactory;
private final LongUnaryOperator supplyInitialId;
private final LongUnaryOperator supplyReplyId;
private final Signaler signaler;

private final Long2ObjectHashMap<WsBindingConfig> bindings;
private final int wsTypeId;
Expand All @@ -126,10 +135,12 @@ public WsServerFactory(
EngineContext context)
{
this.writeBuffer = context.writeBuffer();
this.extBuffer = new UnsafeBuffer(new byte[context.writeBuffer().capacity()]);
this.streamFactory = context.streamFactory();
this.supplyInitialId = context::supplyInitialId;
this.supplyReplyId = context::supplyReplyId;
this.bindings = new Long2ObjectHashMap<>();
this.signaler = context.signaler();
this.wsTypeId = context.supplyTypeId(WsBinding.NAME);
this.httpTypeId = context.supplyTypeId("http");
}
Expand Down Expand Up @@ -284,6 +295,9 @@ private final class WsServer
private int replyMax;
private int replyPad;

private long pongId = NO_CANCEL_ID;
private int pingReceived;

private WsServer(
MessageConsumer receiver,
long originId,
Expand Down Expand Up @@ -506,6 +520,10 @@ private void onNetMessage(
final DataFW data = dataRO.wrap(buffer, index, index + length);
onNetData(data);
break;
case SignalFW.TYPE_ID:
final SignalFW signal = signalRO.wrap(buffer, index, index + length);
onNetSignal(signal);
break;
case EndFW.TYPE_ID:
final EndFW end = endRO.wrap(buffer, index, index + length);
onNetEnd(end);
Expand Down Expand Up @@ -600,6 +618,23 @@ private void onNetData(
}
}

private void onNetSignal(
SignalFW signal)
{
final int signalId = signal.signalId();
final long traceId = signal.traceId();
final OctetsFW payload = signal.payload();

assert signalId == PONG_SIGNAL_ID;

if (--pingReceived == 0)
{
final int reserved = payload.sizeof() + MAXIMUM_HEADER_SIZE + replyPad;

doNetData(traceId, decodeAuthorization, replyBudgetId, reserved, payload, 0x8a);
}
}

private void onNetEnd(
EndFW end)
{
Expand Down Expand Up @@ -747,6 +782,9 @@ private int decodeHeader(
case 0x08:
this.decodeState = this::decodeClose;
break;
case 0x09:
this.decodeState = this::decodePing;
break;
case 0x0a:
this.decodeState = this::decodePong;
break;
Expand Down Expand Up @@ -904,6 +942,44 @@ private int decodePong(
}
}

private int decodePing(
final DirectBuffer buffer,
final int offset,
final int length)
{
if (payloadLength > MAXIMUM_CONTROL_FRAME_PAYLOAD_SIZE)
{
doNetReset(decodeTraceId, decodeAuthorization);
stream.doAppAbort(decodeTraceId, decodeAuthorization, STATUS_PROTOCOL_ERROR);
return length;
}
else
{
final int decodeBytes = Math.min(length, payloadLength - payloadProgress);

OctetsFW payload = payloadRO.wrap(buffer, offset, offset + decodeBytes);

OctetsFW.Builder payloadBuilder = payloadRW.wrap(extBuffer, 0, extBuffer.capacity());
payloadBuilder.set(payload);
xor(extBuffer, 0, payload.sizeof(), maskingKey);
OctetsFW unmaskedPayload = payloadBuilder.build();

pingReceived++;
signaler.signalNow(originId, routedId, replyId, decodeTraceId, PONG_SIGNAL_ID, 0,
unmaskedPayload.value(), 0, unmaskedPayload.sizeof());

payloadProgress += decodeBytes;
maskingKey = rotateMaskingKey(maskingKey, decodeBytes);

if (payloadProgress == payloadLength)
{
this.decodeState = this::decodeHeader;
}

return decodeBytes;
}
}

private int decodeUnexpected(
final DirectBuffer directBuffer,
final int offset,
Expand Down
Loading

0 comments on commit e71c430

Please sign in to comment.