diff --git a/src/Orleans.Core/Messaging/InvalidMessageFrameException.cs b/src/Orleans.Core/Messaging/InvalidMessageFrameException.cs new file mode 100644 index 0000000000..28eb9ce971 --- /dev/null +++ b/src/Orleans.Core/Messaging/InvalidMessageFrameException.cs @@ -0,0 +1,41 @@ +#nullable enable + +using System; +using System.Runtime.Serialization; + +namespace Orleans.Runtime.Messaging; + +/// +/// Indicates that a message frame is invalid, either when sending a message or receiving a message. +/// +[GenerateSerializer] +public sealed class InvalidMessageFrameException : OrleansException +{ + /// + /// Initializes a new instance of the class. + /// + public InvalidMessageFrameException() + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The message that describes the error. + public InvalidMessageFrameException(string message) : base(message) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The message that describes the error. + /// The exception that is the cause of the current exception. + public InvalidMessageFrameException(string message, Exception innerException) : base(message, innerException) + { + } + + protected InvalidMessageFrameException(SerializationInfo info, StreamingContext context) : base(info, context) + { + } +} \ No newline at end of file diff --git a/src/Orleans.Core/Messaging/MessageSerializer.cs b/src/Orleans.Core/Messaging/MessageSerializer.cs index 5de5ad8911..7131d508ae 100644 --- a/src/Orleans.Core/Messaging/MessageSerializer.cs +++ b/src/Orleans.Core/Messaging/MessageSerializer.cs @@ -189,6 +189,7 @@ private ResponseCodec GetRawCodec(Type fieldType) } var bodyLength = bufferWriter.CommittedBytes - headerLength; + // Before completing, check lengths ThrowIfLengthsInvalid(headerLength, bodyLength); @@ -217,8 +218,8 @@ private void ThrowIfLengthsInvalid(int headerLength, int bodyLength) if ((uint)bodyLength > (uint)_maxBodyLength) ThrowInvalidBodyLength(bodyLength); } - private void ThrowInvalidHeaderLength(int headerLength) => throw new OrleansException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})"); - private void ThrowInvalidBodyLength(int bodyLength) => throw new OrleansException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})"); + private void ThrowInvalidHeaderLength(int headerLength) => throw new InvalidMessageFrameException($"Invalid header size: {headerLength} (max configured value is {_maxHeaderLength}, see {nameof(MessagingOptions.MaxMessageHeaderSize)})"); + private void ThrowInvalidBodyLength(int bodyLength) => throw new InvalidMessageFrameException($"Invalid body size: {bodyLength} (max configured value is {_maxBodyLength}, see {nameof(MessagingOptions.MaxMessageBodySize)})"); private void Serialize(ref Writer writer, Message value, PackedHeaders headers) where TBufferWriter : IBufferWriter { diff --git a/src/Orleans.Core/Networking/Connection.cs b/src/Orleans.Core/Networking/Connection.cs index 02944fdacc..d306e641d3 100644 --- a/src/Orleans.Core/Networking/Connection.cs +++ b/src/Orleans.Core/Networking/Connection.cs @@ -288,7 +288,6 @@ private async Task ProcessIncoming() { var input = this._transport.Input; var requiredBytes = 0; - Message message = default; while (true) { var readResult = await input.ReadAsync(); @@ -298,6 +297,7 @@ private async Task ProcessIncoming() { do { + Message message = default; try { int headerLength, bodyLength; @@ -309,11 +309,14 @@ private async Task ProcessIncoming() var handler = MessageHandlerPool.Get(); handler.Set(message, this); ThreadPool.UnsafeQueueUserWorkItem(handler, preferLocal: true); - message = null; } } - catch (Exception exception) when (this.HandleReceiveMessageFailure(message, exception)) + catch (Exception exception) { + if (!HandleReceiveMessageFailure(message, exception)) + { + throw; + } } } while (requiredBytes == 0); } @@ -375,9 +378,12 @@ private async Task ProcessOutgoing() message = null; } } - catch (Exception exception) when (message != default) + catch (Exception exception) { - this.OnMessageSerializationFailure(message, exception); + if (!HandleSendMessageFailure(message, exception)) + { + throw; + } } var flushResult = await output.FlushAsync(); @@ -444,7 +450,7 @@ private static EndPoint NormalizeEndpoint(EndPoint endpoint) /// if the exception should not be caught and if it should be caught. private bool HandleReceiveMessageFailure(Message message, Exception exception) { - this.Log.LogWarning( + this.Log.LogError( exception, "Exception reading message {Message} from remote endpoint {Remote} to local endpoint {Local}", message, @@ -452,7 +458,7 @@ private bool HandleReceiveMessageFailure(Message message, Exception exception) this.LocalEndPoint); // If deserialization completely failed, rethrow the exception so that it can be handled at another level. - if (message is null) + if (message is null || exception is InvalidMessageFrameException) { // Returning false here informs the caller that the exception should not be caught. return false; @@ -486,16 +492,23 @@ private bool HandleReceiveMessageFailure(Message message, Exception exception) return true; } - private void OnMessageSerializationFailure(Message message, Exception exception) + private bool HandleSendMessageFailure(Message message, Exception exception) { - // we only get here if we failed to serialize the msg (or any other catastrophic failure). + // We get here if we failed to serialize the msg (or any other catastrophic failure). // Request msg fails to serialize on the sender, so we just enqueue a rejection msg. // Response msg fails to serialize on the responding silo, so we try to send an error response back. - this.Log.LogWarning( - (int)ErrorCode.Messaging_SerializationError, + this.Log.LogError( exception, - "Unexpected error serializing message {Message}", - message); + "Exception sending message {Message} to remote endpoint {Remote} from local endpoint {Local}", + message, + this.RemoteEndPoint, + this.LocalEndPoint); + + if (message is null || exception is InvalidMessageFrameException) + { + // Returning false here informs the caller that the exception should not be caught. + return false; + } MessagingInstruments.OnFailedSentMessage(message); @@ -527,6 +540,8 @@ private void OnMessageSerializationFailure(Message message, Exception exception) MessagingInstruments.OnDroppedSentMessage(message); } + + return true; } private sealed class MessageHandlerPoolPolicy : PooledObjectPolicy diff --git a/test/NonSilo.Tests/Serialization/MessageSerializerTests.cs b/test/NonSilo.Tests/Serialization/MessageSerializerTests.cs index 110925610a..af8caea53f 100644 --- a/test/NonSilo.Tests/Serialization/MessageSerializerTests.cs +++ b/test/NonSilo.Tests/Serialization/MessageSerializerTests.cs @@ -65,7 +65,7 @@ public void Message_SerializeHeaderTooBig() var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0)); var writer = pipe.Writer; - Assert.Throws(() => this.messageSerializer.Write(writer, message)); + Assert.Throws(() => this.messageSerializer.Write(writer, message)); } finally { @@ -85,7 +85,7 @@ public void Message_SerializeBodyTooBig() var pipe = new Pipe(new PipeOptions(pauseWriterThreshold: 0)); var writer = pipe.Writer; - Assert.Throws(() => this.messageSerializer.Write(writer, message)); + Assert.Throws(() => this.messageSerializer.Write(writer, message)); } [Fact, TestCategory("Functional"), TestCategory("Serialization")] @@ -119,7 +119,7 @@ private void DeserializeFakeMessage(int headerSize, int bodySize) pipe.Reader.TryRead(out var readResult); var reader = readResult.Buffer; - Assert.Throws(() => this.messageSerializer.TryRead(ref reader, out var message)); + Assert.Throws(() => this.messageSerializer.TryRead(ref reader, out var message)); } private Message RoundTripMessage(Message message)