diff --git a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs index fae31cd3510e..ca2834411b6e 100644 --- a/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -453,6 +453,7 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Arra // Write the message header data to the buffer. int headerLength; + int? maskOffset = null; if (_isServer) { // The server doesn't send a mask, so the mask offset returned by WriteHeader @@ -463,15 +464,20 @@ private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, Arra { // We need to know where the mask starts so that we can use the mask to manipulate the payload data, // and we need to know the total length for sending it on the wire. - int maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); - headerLength = maskOffset + MaskLength; + maskOffset = WriteHeader(opcode, _sendBuffer, payloadBuffer, endOfMessage, useMask: true); + headerLength = maskOffset.GetValueOrDefault() + MaskLength; + } + + // Write the payload + if (payloadBuffer.Count > 0) + { + Buffer.BlockCopy(payloadBuffer.Array, payloadBuffer.Offset, _sendBuffer, headerLength, payloadBuffer.Count); - // If there is payload data, XOR it with the mask. We do the manipulation in the send buffer so as to avoid + // If we added a mask to the header, XOR the payload with the mask. We do the manipulation in the send buffer so as to avoid // changing the data in the caller-supplied payload buffer. - if (payloadBuffer.Count > 0) + if (maskOffset.HasValue) { - Buffer.BlockCopy(payloadBuffer.Array, payloadBuffer.Offset, _sendBuffer, headerLength, payloadBuffer.Count); - ApplyMask(_sendBuffer, headerLength, _sendBuffer, maskOffset, 0, payloadBuffer.Count); + ApplyMask(_sendBuffer, headerLength, _sendBuffer, maskOffset.Value, 0, payloadBuffer.Count); } } @@ -914,6 +920,9 @@ private bool TryParseMessageHeaderFromReceiveBuffer(out MessageHeader resultHead shouldFail = true; } header.Mask = CombineMaskBytes(_receiveBuffer, _receiveBufferOffset); + + // Consume the mask bytes + ConsumeFromBuffer(4); } // Do basic validation of the header