Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Close connection when invalid framing is detected #8692

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/Orleans.Core/Messaging/InvalidMessageFrameException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#nullable enable

using System;
using System.Runtime.Serialization;

namespace Orleans.Runtime.Messaging;

/// <summary>
/// Indicates that a message frame is invalid, either when sending a message or receiving a message.
/// </summary>
[GenerateSerializer]
public sealed class InvalidMessageFrameException : OrleansException
{
/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
public InvalidMessageFrameException()
{
}

/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
public InvalidMessageFrameException(string message) : base(message)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="InvalidMessageFrameException"/> class.
/// </summary>
/// <param name="message">The message that describes the error.</param>
/// <param name="innerException">The exception that is the cause of the current exception.</param>
public InvalidMessageFrameException(string message, Exception innerException) : base(message, innerException)
{
}

protected InvalidMessageFrameException(SerializationInfo info, StreamingContext context) : base(info, context)
{
}
}
5 changes: 3 additions & 2 deletions src/Orleans.Core/Messaging/MessageSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ private ResponseCodec GetRawCodec(Type fieldType)
}

var bodyLength = bufferWriter.CommittedBytes - headerLength;

// Before completing, check lengths
ThrowIfLengthsInvalid(headerLength, bodyLength);

Expand Down Expand Up @@ -217,8 +218,8 @@ private void ThrowIfLengthsInvalid(int headerLength, int bodyLength)
if ((uint)bodyLength > (uint)_maxBodyLength) ThrowInvalidBodyLength(bodyLength);
}

private void ThrowInvalidHeaderLength(int headerLength) => throw new OrleansException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})");
private void ThrowInvalidBodyLength(int bodyLength) => throw new OrleansException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})");
private void ThrowInvalidHeaderLength(int headerLength) => throw new InvalidMessageFrameException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})");
private void ThrowInvalidBodyLength(int bodyLength) => throw new InvalidMessageFrameException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})");

private void Serialize<TBufferWriter>(ref Writer<TBufferWriter> writer, Message value, PackedHeaders headers) where TBufferWriter : IBufferWriter<byte>
{
Expand Down
41 changes: 28 additions & 13 deletions src/Orleans.Core/Networking/Connection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ private async Task ProcessIncoming()
{
var input = this._transport.Input;
var requiredBytes = 0;
Message message = default;
while (true)
{
var readResult = await input.ReadAsync();
Expand All @@ -298,6 +297,7 @@ private async Task ProcessIncoming()
{
do
{
Message message = default;
try
{
int headerLength, bodyLength;
Expand All @@ -309,11 +309,14 @@ private async Task ProcessIncoming()
var handler = MessageHandlerPool.Get();
handler.Set(message, this);
ThreadPool.UnsafeQueueUserWorkItem(handler, preferLocal: true);
message = null;
}
}
catch (Exception exception) when (this.HandleReceiveMessageFailure(message, exception))
catch (Exception exception)
{
if (!HandleReceiveMessageFailure(message, exception))
{
throw;
}
}
} while (requiredBytes == 0);
}
Expand Down Expand Up @@ -375,9 +378,12 @@ private async Task ProcessOutgoing()
message = null;
}
}
catch (Exception exception) when (message != default)
catch (Exception exception)
{
this.OnMessageSerializationFailure(message, exception);
if (!HandleSendMessageFailure(message, exception))
{
throw;
}
}

var flushResult = await output.FlushAsync();
Expand Down Expand Up @@ -444,15 +450,15 @@ private static EndPoint NormalizeEndpoint(EndPoint endpoint)
/// <returns><see langword="true"/> if the exception should not be caught and <see langword="false"/> if it should be caught.</returns>
private bool HandleReceiveMessageFailure(Message message, Exception exception)
{
this.Log.LogWarning(
this.Log.LogError(
exception,
"Exception reading message {Message} from remote endpoint {Remote} to local endpoint {Local}",
message,
this.RemoteEndPoint,
this.LocalEndPoint);

// If deserialization completely failed, rethrow the exception so that it can be handled at another level.
if (message is null)
if (message is null || exception is InvalidMessageFrameException)
{
// Returning false here informs the caller that the exception should not be caught.
return false;
Expand Down Expand Up @@ -486,16 +492,23 @@ private bool HandleReceiveMessageFailure(Message message, Exception exception)
return true;
}

private void OnMessageSerializationFailure(Message message, Exception exception)
private bool HandleSendMessageFailure(Message message, Exception exception)
{
// we only get here if we failed to serialize the msg (or any other catastrophic failure).
// We get here if we failed to serialize the msg (or any other catastrophic failure).
// Request msg fails to serialize on the sender, so we just enqueue a rejection msg.
// Response msg fails to serialize on the responding silo, so we try to send an error response back.
this.Log.LogWarning(
(int)ErrorCode.Messaging_SerializationError,
this.Log.LogError(
exception,
"Unexpected error serializing message {Message}",
message);
"Exception sending message {Message} to remote endpoint {Remote} from local endpoint {Local}",
message,
this.RemoteEndPoint,
this.LocalEndPoint);

if (message is null || exception is InvalidMessageFrameException)
{
// Returning false here informs the caller that the exception should not be caught.
return false;
}

MessagingInstruments.OnFailedSentMessage(message);

Expand Down Expand Up @@ -527,6 +540,8 @@ private void OnMessageSerializationFailure(Message message, Exception exception)

MessagingInstruments.OnDroppedSentMessage(message);
}

return true;
}

private sealed class MessageHandlerPoolPolicy : PooledObjectPolicy<MessageHandler>
Expand Down
6 changes: 3 additions & 3 deletions test/NonSilo.Tests/Serialization/MessageSerializerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void Message_SerializeHeaderTooBig()

var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0));
var writer = pipe.Writer;
Assert.Throws<OrleansException>(() => this.messageSerializer.Write(writer, message));
Assert.Throws<InvalidMessageFrameException>(() => this.messageSerializer.Write(writer, message));
}
finally
{
Expand All @@ -85,7 +85,7 @@ public void Message_SerializeBodyTooBig()

var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0));
var writer = pipe.Writer;
Assert.Throws<OrleansException>(() => this.messageSerializer.Write(writer, message));
Assert.Throws<InvalidMessageFrameException>(() => this.messageSerializer.Write(writer, message));
}

[Fact, TestCategory("Functional"), TestCategory("Serialization")]
Expand Down Expand Up @@ -119,7 +119,7 @@ private void DeserializeFakeMessage(int headerSize, int bodySize)

pipe.Reader.TryRead(out var readResult);
var reader = readResult.Buffer;
Assert.Throws<OrleansException>(() => this.messageSerializer.TryRead(ref reader, out var message));
Assert.Throws<InvalidMessageFrameException>(() => this.messageSerializer.TryRead(ref reader, out var message));
}

private Message RoundTripMessage(Message message)
Expand Down
Loading