Skip to content

Commit

Permalink
Close connection when invalid framing is detected (#8692)
Browse files Browse the repository at this point in the history
* Ensure message is always set to null before deframing/deserializing to ensure that corruption triggers connection shutdown.

* Close connection when invalid framing is detected

* Fix tests

* Avoid complex exception filters

* Update src/Orleans.Core/Messaging/InvalidMessageFrameException.cs

Co-authored-by: David Pine <[email protected]>

---------

Co-authored-by: David Pine <[email protected]>
  • Loading branch information
ReubenBond and IEvangelist authored Oct 31, 2023
1 parent eef2802 commit d3e7d94
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 18 deletions.
41 changes: 41 additions & 0 deletions src/Orleans.Core/Messaging/InvalidMessageFrameException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#nullable enable

using System;
using System.Runtime.Serialization;

namespace Orleans.Runtime.Messaging;

/// <summary>
/// Indicates that a message frame is invalid, either when sending a message or receiving a message.
/// </summary>
[GenerateSerializer]
public sealed class InvalidMessageFrameException : OrleansException
{
/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
public InvalidMessageFrameException()
{
}

/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
public InvalidMessageFrameException(string message) : base(message)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
/// <param name="innerException">The exception that is the cause of the current exception.</param>
public InvalidMessageFrameException(string message, Exception innerException) : base(message, innerException)
{
}

protected InvalidMessageFrameException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
5 changes: 3 additions & 2 deletions src/Orleans.Core/Messaging/MessageSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ private ResponseCodec GetRawCodec(Type fieldType)
}

var bodyLength = bufferWriter.CommittedBytes - headerLength;

// Before completing, check lengths
ThrowIfLengthsInvalid(headerLength, bodyLength);

Expand Down Expand Up @@ -217,8 +218,8 @@ private void ThrowIfLengthsInvalid(int headerLength, int bodyLength)
if ((uint)bodyLength > (uint)_maxBodyLength) ThrowInvalidBodyLength(bodyLength);
}

private void ThrowInvalidHeaderLength(int headerLength) => throw new OrleansException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})");
private void ThrowInvalidBodyLength(int bodyLength) => throw new OrleansException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})");
private void ThrowInvalidHeaderLength(int headerLength) => throw new InvalidMessageFrameException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})");
private void ThrowInvalidBodyLength(int bodyLength) => throw new InvalidMessageFrameException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})");

private void Serialize<TBufferWriter>(ref Writer<TBufferWriter> writer, Message value, PackedHeaders headers) where TBufferWriter : IBufferWriter<byte>
{
Expand Down
41 changes: 28 additions & 13 deletions src/Orleans.Core/Networking/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ private async Task ProcessIncoming()
{
var input = this._transport.Input;
var requiredBytes = 0;
Message message = default;
while (true)
{
var readResult = await input.ReadAsync();
Expand All @@ -298,6 +297,7 @@ private async Task ProcessIncoming()
{
do
{
Message message = default;
try
{
int headerLength, bodyLength;
Expand All @@ -309,11 +309,14 @@ private async Task ProcessIncoming()
var handler = MessageHandlerPool.Get();
handler.Set(message, this);
ThreadPool.UnsafeQueueUserWorkItem(handler, preferLocal: true);
message = null;
}
}
catch (Exception exception) when (this.HandleReceiveMessageFailure(message, exception))
catch (Exception exception)
{
if (!HandleReceiveMessageFailure(message, exception))
{
throw;
}
}
} while (requiredBytes == 0);
}
Expand Down Expand Up @@ -375,9 +378,12 @@ private async Task ProcessOutgoing()
message = null;
}
}
catch (Exception exception) when (message != default)
catch (Exception exception)
{
this.OnMessageSerializationFailure(message, exception);
if (!HandleSendMessageFailure(message, exception))
{
throw;
}
}

var flushResult = await output.FlushAsync();
Expand Down Expand Up @@ -444,15 +450,15 @@ private static EndPoint NormalizeEndpoint(EndPoint endpoint)
/// <returns><see langword="true"/> if the exception should not be caught and <see langword="false"/> if it should be caught.</returns>
private bool HandleReceiveMessageFailure(Message message, Exception exception)
{
this.Log.LogWarning(
this.Log.LogError(
exception,
"Exception reading message {Message} from remote endpoint {Remote} to local endpoint {Local}",
message,
this.RemoteEndPoint,
this.LocalEndPoint);

// If deserialization completely failed, rethrow the exception so that it can be handled at another level.
if (message is null)
if (message is null || exception is InvalidMessageFrameException)
{
// Returning false here informs the caller that the exception should not be caught.
return false;
Expand Down Expand Up @@ -486,16 +492,23 @@ private bool HandleReceiveMessageFailure(Message message, Exception exception)
return true;
}

private void OnMessageSerializationFailure(Message message, Exception exception)
private bool HandleSendMessageFailure(Message message, Exception exception)
{
// we only get here if we failed to serialize the msg (or any other catastrophic failure).
// We get here if we failed to serialize the msg (or any other catastrophic failure).
// Request msg fails to serialize on the sender, so we just enqueue a rejection msg.
// Response msg fails to serialize on the responding silo, so we try to send an error response back.
this.Log.LogWarning(
(int)ErrorCode.Messaging_SerializationError,
this.Log.LogError(
exception,
"Unexpected error serializing message {Message}",
message);
"Exception sending message {Message} to remote endpoint {Remote} from local endpoint {Local}",
message,
this.RemoteEndPoint,
this.LocalEndPoint);

if (message is null || exception is InvalidMessageFrameException)
{
// Returning false here informs the caller that the exception should not be caught.
return false;
}

MessagingInstruments.OnFailedSentMessage(message);

Expand Down Expand Up @@ -527,6 +540,8 @@ private void OnMessageSerializationFailure(Message message, Exception exception)

MessagingInstruments.OnDroppedSentMessage(message);
}

return true;
}

private sealed class MessageHandlerPoolPolicy : PooledObjectPolicy<MessageHandler>
Expand Down
6 changes: 3 additions & 3 deletions test/NonSilo.Tests/Serialization/MessageSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void Message_SerializeHeaderTooBig()

var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0));
var writer = pipe.Writer;
Assert.Throws<OrleansException>(() => this.messageSerializer.Write(writer, message));
Assert.Throws<InvalidMessageFrameException>(() => this.messageSerializer.Write(writer, message));
}
finally
{
Expand All @@ -85,7 +85,7 @@ public void Message_SerializeBodyTooBig()

var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0));
var writer = pipe.Writer;
Assert.Throws<OrleansException>(() => this.messageSerializer.Write(writer, message));
Assert.Throws<InvalidMessageFrameException>(() => this.messageSerializer.Write(writer, message));
}

[Fact, TestCategory("Functional"), TestCategory("Serialization")]
Expand Down Expand Up @@ -119,7 +119,7 @@ private void DeserializeFakeMessage(int headerSize, int bodySize)

pipe.Reader.TryRead(out var readResult);
var reader = readResult.Buffer;
Assert.Throws<OrleansException>(() => this.messageSerializer.TryRead(ref reader, out var message));
Assert.Throws<InvalidMessageFrameException>(() => this.messageSerializer.TryRead(ref reader, out var message));
}

private Message RoundTripMessage(Message message)
Expand Down

0 comments on commit d3e7d94

Please sign in to comment.