From 3b6078ea676330d0bf09e56f49e790b17a68dd5b Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 17 Sep 2024 12:37:29 -0400 Subject: [PATCH] Add missing Stream overrides It's important for perf to override the base implementations, in particular for Read/WriteByte. I just copied the existing implementations to the new overrides and fixed up the implementations with as minimal changes as possible. I also tried to retain the style of the surrounding code. And none of the existing implementations have tests, so I didn't add any for these, as much as it pained me :) --- .../ChunkedBufferStream.cs | 73 ++++++++++++++++ .../ConcatenatedReadStream.cs | 86 +++++++++++++++++++ src/DotUtils.StreamUtils/StreamExtensions.cs | 23 +++-- src/DotUtils.StreamUtils/SubStream.cs | 42 ++++++++- .../TransparentReadStream.cs | 63 +++++++++++++- 5 files changed, 271 insertions(+), 16 deletions(-) diff --git a/src/DotUtils.StreamUtils/ChunkedBufferStream.cs b/src/DotUtils.StreamUtils/ChunkedBufferStream.cs index 9b6ba1a..177a21e 100644 --- a/src/DotUtils.StreamUtils/ChunkedBufferStream.cs +++ b/src/DotUtils.StreamUtils/ChunkedBufferStream.cs @@ -1,5 +1,7 @@ using System; using System.IO; +using System.Threading; +using System.Threading.Tasks; namespace DotUtils.StreamUtils; @@ -58,6 +60,77 @@ public override void Write(byte[] buffer, int offset, int count) } while (count > 0); } + public override void WriteByte(byte value) + { + if (_position == _buffer.Length) + { + Flush(); + } + + _buffer[_position++] = value; + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + // Appends input to the buffer until it is full - then flushes it to the wrapped stream. + // Repeat above until all input is processed. + + int srcOffset = offset; + do + { + int currentCount = Math.Min(count, _buffer.Length - _position); + Buffer.BlockCopy(buffer, srcOffset, _buffer, _position, currentCount); + _position += currentCount; + count -= currentCount; + srcOffset += currentCount; + + if (_position == _buffer.Length) + { + await FlushAsync(cancellationToken).ConfigureAwait(false); + } + } while (count > 0); + } + +#if NET + public override void Write(ReadOnlySpan buffer) + { + // Appends input to the buffer until it is full - then flushes it to the wrapped stream. + // Repeat above until all input is processed. + + do + { + int currentCount = Math.Min(buffer.Length, _buffer.Length - _position); + buffer.CopyTo(_buffer.AsSpan(_position, currentCount)); + _position += currentCount; + buffer = buffer.Slice(currentCount); + + if (_position == _buffer.Length) + { + Flush(); + } + } while (!buffer.IsEmpty); + } + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + // Appends input to the buffer until it is full - then flushes it to the wrapped stream. + // Repeat above until all input is processed. + + do + { + int currentCount = Math.Min(buffer.Length, _buffer.Length - _position); + buffer.CopyTo(_buffer.AsMemory(_position, currentCount)); + _position += currentCount; + buffer = buffer.Slice(currentCount); + + if (_position == _buffer.Length) + { + await FlushAsync(cancellationToken).ConfigureAwait(false); + } + } while (!buffer.IsEmpty); + } +#endif + public override bool CanRead => false; public override bool CanSeek => false; public override bool CanWrite => _stream.CanWrite; diff --git a/src/DotUtils.StreamUtils/ConcatenatedReadStream.cs b/src/DotUtils.StreamUtils/ConcatenatedReadStream.cs index 7009321..af4cfee 100644 --- a/src/DotUtils.StreamUtils/ConcatenatedReadStream.cs +++ b/src/DotUtils.StreamUtils/ConcatenatedReadStream.cs @@ -2,6 +2,8 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using System.Threading; +using System.Threading.Tasks; namespace DotUtils.StreamUtils; @@ -70,6 +72,90 @@ public override int Read(byte[] buffer, int offset, int count) return totalBytesRead; } + public override int ReadByte() + { + while (_streams.Count > 0) + { + int value = _streams.Peek().ReadByte(); + if (value < 0) + { + _streams.Dequeue().Dispose(); + continue; + } + + _position++; + return value; + } + + return -1; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int totalBytesRead = 0; + + while (count > 0 && _streams.Count > 0) + { + int bytesRead = await _streams.Peek().ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + _streams.Dequeue().Dispose(); + continue; + } + + totalBytesRead += bytesRead; + offset += bytesRead; + count -= bytesRead; + } + + _position += totalBytesRead; + return totalBytesRead; + } + +#if NET + public override int Read(Span buffer) + { + int totalBytesRead = 0; + + while (!buffer.IsEmpty && _streams.Count > 0) + { + int bytesRead = _streams.Peek().Read(buffer); + if (bytesRead == 0) + { + _streams.Dequeue().Dispose(); + continue; + } + + totalBytesRead += bytesRead; + buffer = buffer.Slice(bytesRead); + } + + _position += totalBytesRead; + return totalBytesRead; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + int totalBytesRead = 0; + + while (!buffer.IsEmpty && _streams.Count > 0) + { + int bytesRead = await _streams.Peek().ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + if (bytesRead == 0) + { + _streams.Dequeue().Dispose(); + continue; + } + + totalBytesRead += bytesRead; + buffer = buffer.Slice(bytesRead); + } + + _position += totalBytesRead; + return totalBytesRead; + } +#endif + public override long Seek(long offset, SeekOrigin origin) { throw new NotSupportedException("ConcatenatedReadStream is forward-only read-only"); diff --git a/src/DotUtils.StreamUtils/StreamExtensions.cs b/src/DotUtils.StreamUtils/StreamExtensions.cs index 1cae8e5..39afc9e 100644 --- a/src/DotUtils.StreamUtils/StreamExtensions.cs +++ b/src/DotUtils.StreamUtils/StreamExtensions.cs @@ -98,29 +98,28 @@ public static int SkipBytes(this Stream stream, long bytesCount, bool throwOnEnd public static byte[] ReadToEnd(this Stream stream) { - if (stream.TryGetLength(out long length)) - { - BinaryReader reader = new(stream); - return reader.ReadBytes((int)length); - } - - using var ms = new MemoryStream(); + MemoryStream ms = stream.TryGetLength(out long length) && length <= int.MaxValue ? new((int)length) : new(); stream.CopyTo(ms); - return ms.ToArray(); + byte[] buffer = ms.GetBuffer(); + return buffer.Length == ms.Length ? buffer : ms.ToArray(); } public static bool TryGetLength(this Stream stream, out long length) { try { - length = stream.Length; - return true; + if (stream.CanSeek) + { + length = stream.Length; + return true; + } } catch (NotSupportedException) { - length = 0; - return false; } + + length = 0; + return false; } public static Stream ToReadableSeekableStream(this Stream stream) diff --git a/src/DotUtils.StreamUtils/SubStream.cs b/src/DotUtils.StreamUtils/SubStream.cs index 2f480f8..fec3163 100644 --- a/src/DotUtils.StreamUtils/SubStream.cs +++ b/src/DotUtils.StreamUtils/SubStream.cs @@ -1,5 +1,7 @@ using System; using System.IO; +using System.Threading; +using System.Threading.Tasks; namespace DotUtils.StreamUtils; @@ -36,7 +38,8 @@ public SubStream(Stream stream, long length) public override long Position { get => _position; set => throw new NotImplementedException(); } - public override void Flush() { } + public override void Flush() => _stream.Flush(); + public override Task FlushAsync(CancellationToken cancellationToken) => _stream.FlushAsync(cancellationToken); public override int Read(byte[] buffer, int offset, int count) { count = Math.Min((int)Math.Max(Length - _position, 0), count); @@ -44,6 +47,43 @@ public override int Read(byte[] buffer, int offset, int count) _position += read; return read; } + public override int ReadByte() + { + if (Length - _position > 0) + { + int value = _stream.ReadByte(); + if (value >= 0) + { + _position++; + return value; + } + } + + return -1; + } + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + count = Math.Min((int)Math.Max(Length - _position, 0), count); + int read = await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + _position += read; + return read; + } +#if NET + public override int Read(Span buffer) + { + buffer = buffer.Slice(0, Math.Min((int)Math.Max(Length - _position, 0), buffer.Length)); + int read = _stream.Read(buffer); + _position += read; + return read; + } + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + buffer = buffer.Slice(0, Math.Min((int)Math.Max(Length - _position, 0), buffer.Length)); + int read = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _position += read; + return read; + } +#endif public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException(); public override void SetLength(long value) => throw new NotImplementedException(); public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException(); diff --git a/src/DotUtils.StreamUtils/TransparentReadStream.cs b/src/DotUtils.StreamUtils/TransparentReadStream.cs index 374df9d..66c4033 100644 --- a/src/DotUtils.StreamUtils/TransparentReadStream.cs +++ b/src/DotUtils.StreamUtils/TransparentReadStream.cs @@ -1,8 +1,7 @@ using System; using System.IO; -#if NET -using System.Buffers; -#endif +using System.Threading; +using System.Threading.Tasks; namespace DotUtils.StreamUtils; @@ -74,6 +73,11 @@ public override void Flush() _stream.Flush(); } + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _stream.FlushAsync(cancellationToken); + } + public override int Read(byte[] buffer, int offset, int count) { if (_position + count > _maxAllowedPosition) @@ -86,6 +90,59 @@ public override int Read(byte[] buffer, int offset, int count) return cnt; } + public override int ReadByte() + { + if (_position + 1 <= _maxAllowedPosition) + { + int value = _stream.ReadByte(); + if (value >= 0) + { + _position++; + return value; + } + } + + return -1; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_position + count > _maxAllowedPosition) + { + count = (int)(_maxAllowedPosition - _position); + } + + int cnt = await _stream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + _position += cnt; + return cnt; + } + +#if NET + public override int Read(Span buffer) + { + if (_position + buffer.Length > _maxAllowedPosition) + { + buffer = buffer.Slice(0, (int)(_maxAllowedPosition - _position)); + } + + int cnt = _stream.Read(buffer); + _position += cnt; + return cnt; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (_position + buffer.Length > _maxAllowedPosition) + { + buffer = buffer.Slice(0, (int)(_maxAllowedPosition - _position)); + } + + int cnt = await _stream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _position += cnt; + return cnt; + } +#endif + public override long Seek(long offset, SeekOrigin origin) { if(origin != SeekOrigin.Current)