diff --git a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj index 3419171de4cc71..215cf6b4a91649 100644 --- a/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj +++ b/src/libraries/System.Net.WebSockets/src/System.Net.WebSockets.csproj @@ -7,7 +7,6 @@ - diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs index 36a0049d705fbf..a17fe4f8aeb9a2 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketDeflater.cs @@ -12,7 +12,7 @@ namespace System.Net.WebSockets.Compression /// internal sealed class WebSocketDeflater : IDisposable { - private readonly ZLibStreamPool _streamPool; + private readonly int _windowBits; private ZLibStreamHandle? _stream; private readonly bool _persisted; @@ -20,7 +20,7 @@ internal sealed class WebSocketDeflater : IDisposable internal WebSocketDeflater(int windowBits, bool persisted) { - _streamPool = ZLibStreamPool.GetOrCreate(windowBits); + _windowBits = -windowBits; // Negative for raw deflate _persisted = persisted; } @@ -28,7 +28,7 @@ public void Dispose() { if (_stream is not null) { - _streamPool.ReturnDeflater(_stream); + _stream.Dispose(); _stream = null; } } @@ -84,7 +84,7 @@ public ReadOnlySpan Deflate(ReadOnlySpan payload, bool endOfMessage) private void DeflatePrivate(ReadOnlySpan payload, Span output, bool endOfMessage, out int consumed, out int written, out bool needsMoreOutput) { - _stream ??= _streamPool.GetDeflater(); + _stream ??= CreateDeflater(); if (payload.Length == 0) { @@ -119,7 +119,7 @@ private void DeflatePrivate(ReadOnlySpan payload, Span output, bool if (endOfMessage && !_persisted) { - _streamPool.ReturnDeflater(_stream); + _stream.Dispose(); _stream = null; } } @@ -201,5 +201,33 @@ private static ErrorCode Deflate(ZLibStreamHandle stream, FlushCode flushCode) : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); throw new WebSocketException(message); } + + private ZLibStreamHandle CreateDeflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + try + { + errorCode = CreateZLibStreamForDeflate(out stream, + level: CompressionLevel.DefaultCompression, + windowBits: _windowBits, + memLevel: Deflate_DefaultMemLevel, + strategy: CompressionStrategy.DefaultStrategy); + } + catch (Exception cause) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); + } + + if (errorCode != ErrorCode.Ok) + { + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } + + return stream; + } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs index a13fadd10f208d..6ade12d539a440 100644 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs +++ b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/WebSocketInflater.cs @@ -15,7 +15,7 @@ internal sealed class WebSocketInflater : IDisposable internal const int FlushMarkerLength = 4; internal static ReadOnlySpan FlushMarker => new byte[] { 0x00, 0x00, 0xFF, 0xFF }; - private readonly ZLibStreamPool _streamPool; + private readonly int _windowBits; private ZLibStreamHandle? _stream; private readonly bool _persisted; @@ -48,7 +48,7 @@ internal sealed class WebSocketInflater : IDisposable internal WebSocketInflater(int windowBits, bool persisted) { - _streamPool = ZLibStreamPool.GetOrCreate(windowBits); + _windowBits = -windowBits; // Negative for raw deflate _persisted = persisted; } @@ -60,7 +60,7 @@ public void Dispose() { if (_stream is not null) { - _streamPool.ReturnInflater(_stream); + _stream.Dispose(); _stream = null; } ReleaseBuffer(); @@ -128,7 +128,7 @@ public void AddBytes(int totalBytesReceived, bool endOfMessage) /// public unsafe bool Inflate(Span output, out int written) { - _stream ??= _streamPool.GetInflater(); + _stream ??= CreateInflater(); if (_available > 0 && output.Length > 0) { @@ -192,7 +192,7 @@ private unsafe bool Finish(Span output, ref int written) { if (!_persisted) { - _streamPool.ReturnInflater(_stream); + _stream.Dispose(); _stream = null; } return true; @@ -254,5 +254,32 @@ private static unsafe int Inflate(ZLibStreamHandle stream, Span destinatio }; throw new WebSocketException(message); } + + private ZLibStreamHandle CreateInflater() + { + ZLibStreamHandle stream; + ErrorCode errorCode; + + try + { + errorCode = CreateZLibStreamForInflate(out stream, _windowBits); + } + catch (Exception exception) + { + throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); + } + + if (errorCode == ErrorCode.Ok) + { + return stream; + } + + stream.Dispose(); + + string message = errorCode == ErrorCode.MemError + ? SR.ZLibErrorNotEnoughMemory + : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); + throw new WebSocketException(message); + } } } diff --git a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs b/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs deleted file mode 100644 index 12c644bab7d7c5..00000000000000 --- a/src/libraries/System.Net.WebSockets/src/System/Net/WebSockets/Compression/ZLibStreamPool.cs +++ /dev/null @@ -1,291 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Diagnostics; -using System.Diagnostics.CodeAnalysis; -using System.Threading; -using static System.IO.Compression.ZLibNative; - -namespace System.Net.WebSockets.Compression -{ - internal sealed class ZLibStreamPool - { - private static readonly ZLibStreamPool?[] s_pools - = new ZLibStreamPool[WebSocketValidate.MaxDeflateWindowBits - WebSocketValidate.MinDeflateWindowBits + 1]; - - /// - /// The default amount of time after which a cached item will be removed. - /// - private const int DefaultTimeoutMilliseconds = 60_000; - - private readonly int _windowBits; - private readonly List _inflaters = new(); - private readonly List _deflaters = new(); - private readonly Timer _cleaningTimer; - - /// - /// The amount of time after which a cached item will be removed. - /// - private readonly int _timeoutMilliseconds; - - /// - /// The number of cached inflaters and deflaters. - /// - private int _activeCount; - - private ZLibStreamPool(int windowBits, int timeoutMilliseconds) - { - // Use negative window bits to for raw deflate data - _windowBits = -windowBits; - _timeoutMilliseconds = timeoutMilliseconds; - - bool restoreFlow = false; - try - { - if (!ExecutionContext.IsFlowSuppressed()) - { - ExecutionContext.SuppressFlow(); - restoreFlow = true; - } - - // There is no need to use weak references here, because these pools are kept - // for the entire lifetime of the application. Also we reset the timer on each tick, - // which prevents the object being rooted forever. - _cleaningTimer = new Timer(x => ((ZLibStreamPool)x!).RemoveStaleItems(), - state: this, Timeout.Infinite, Timeout.Infinite); - } - finally - { - if (restoreFlow) - { - ExecutionContext.RestoreFlow(); - } - } - } - - public static ZLibStreamPool GetOrCreate(int windowBits) - { - Debug.Assert(windowBits >= WebSocketValidate.MinDeflateWindowBits - && windowBits <= WebSocketValidate.MaxDeflateWindowBits); - - int index = windowBits - WebSocketValidate.MinDeflateWindowBits; - ref ZLibStreamPool? pool = ref s_pools[index]; - - return Volatile.Read(ref pool) ?? EnsureInitialized(windowBits, ref pool); - - static ZLibStreamPool EnsureInitialized(int windowBits, ref ZLibStreamPool? target) - { - ZLibStreamPool newPool = new(windowBits, DefaultTimeoutMilliseconds); - if (Interlocked.CompareExchange(ref target, newPool, null) is not null) - { - newPool._cleaningTimer.Dispose(); - } - - Debug.Assert(target != null); - return target; - } - } - - public ZLibStreamHandle GetInflater() - { - if (TryGet(_inflaters, out ZLibStreamHandle? stream)) - { - return stream; - } - - return CreateInflater(); - } - - public void ReturnInflater(ZLibStreamHandle stream) - { - if (stream.InflateReset() != ErrorCode.Ok) - { - stream.Dispose(); - return; - } - - Return(stream, _inflaters); - } - - public ZLibStreamHandle GetDeflater() - { - if (TryGet(_deflaters, out ZLibStreamHandle? stream)) - { - return stream; - } - - return CreateDeflater(); - } - - public void ReturnDeflater(ZLibStreamHandle stream) - { - if (stream.DeflateReset() != ErrorCode.Ok) - { - stream.Dispose(); - return; - } - - Return(stream, _deflaters); - } - - private void Return(ZLibStreamHandle stream, List cache) - { - lock (cache) - { - cache.Add(new CacheItem(stream)); - - if (Interlocked.Increment(ref _activeCount) == 1) - { - _cleaningTimer.Change(_timeoutMilliseconds, Timeout.Infinite); - } - } - } - - private bool TryGet(List cache, [NotNullWhen(true)] out ZLibStreamHandle? stream) - { - lock (cache) - { - int count = cache.Count; - - if (count > 0) - { - CacheItem item = cache[count - 1]; - cache.RemoveAt(count - 1); - Interlocked.Decrement(ref _activeCount); - - stream = item.Stream; - return true; - } - } - - stream = null; - return false; - } - - private void RemoveStaleItems() - { - RemoveStaleItems(_inflaters); - RemoveStaleItems(_deflaters); - - // There is a race condition here, were _activeCount could be decremented - // by a rent operation, but it's not big deal to schedule a timer tick that - // would eventually do nothing. - if (_activeCount > 0) - { - _cleaningTimer.Change(_timeoutMilliseconds, Timeout.Infinite); - } - } - - private void RemoveStaleItems(List cache) - { - long currentTimestamp = Environment.TickCount64; - List? removedStreams = null; - - lock (cache) - { - for (int index = 0; index < cache.Count; ++index) - { - CacheItem item = cache[index]; - - if (currentTimestamp - item.Timestamp > _timeoutMilliseconds) - { - removedStreams ??= new List(); - removedStreams.Add(item.Stream); - Interlocked.Decrement(ref _activeCount); - } - else - { - // The freshest streams are in the back of the collection. - // If we've reached a stream that is not timed out, all - // other after it will not be as well. - break; - } - } - - if (removedStreams is null) - { - return; - } - - cache.RemoveRange(0, removedStreams.Count); - } - - foreach (ZLibStreamHandle stream in removedStreams) - { - stream.Dispose(); - } - } - - private ZLibStreamHandle CreateDeflater() - { - ZLibStreamHandle stream; - ErrorCode errorCode; - try - { - errorCode = CreateZLibStreamForDeflate(out stream, - level: CompressionLevel.DefaultCompression, - windowBits: _windowBits, - memLevel: Deflate_DefaultMemLevel, - strategy: CompressionStrategy.DefaultStrategy); - } - catch (Exception cause) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, cause); - } - - if (errorCode != ErrorCode.Ok) - { - string message = errorCode == ErrorCode.MemError - ? SR.ZLibErrorNotEnoughMemory - : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); - throw new WebSocketException(message); - } - - return stream; - } - - private ZLibStreamHandle CreateInflater() - { - ZLibStreamHandle stream; - ErrorCode errorCode; - - try - { - errorCode = CreateZLibStreamForInflate(out stream, _windowBits); - } - catch (Exception exception) - { - throw new WebSocketException(SR.ZLibErrorDLLLoadError, exception); - } - - if (errorCode == ErrorCode.Ok) - { - return stream; - } - - stream.Dispose(); - - string message = errorCode == ErrorCode.MemError - ? SR.ZLibErrorNotEnoughMemory - : string.Format(SR.ZLibErrorUnexpected, (int)errorCode); - throw new WebSocketException(message); - } - - private readonly struct CacheItem - { - public CacheItem(ZLibStreamHandle stream) - { - Stream = stream; - Timestamp = Environment.TickCount64; - } - - public ZLibStreamHandle Stream { get; } - - /// - /// The time when this item was returned to cache. - /// - public long Timestamp { get; } - } - } -} diff --git a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj index 9ce68c7b42f998..4e0bc74ebdaeca 100644 --- a/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj +++ b/src/libraries/System.Net.WebSockets/tests/System.Net.WebSockets.Tests.csproj @@ -10,7 +10,6 @@ - - { - if (x % 2 == 0) - { - object inflater = pool.GetInflater(); - pool.ReturnInflater(inflater); - } - else - { - object deflater = pool.GetDeflater(); - pool.ReturnDeflater(deflater); - } - }); - - Assert.True(pool.ActiveCount >= 2); - Assert.True(pool.ActiveCount <= parallelOptions.MaxDegreeOfParallelism * 2); - await Task.Delay(200); - Assert.Equal(0, pool.ActiveCount); - } - - private sealed class Pool - { - private static Type? s_type; - private static ConstructorInfo? s_constructor; - private static FieldInfo? s_activeCount; - private static MethodInfo? s_rentInflater; - private static MethodInfo? s_returnInflater; - private static MethodInfo? s_rentDeflater; - private static MethodInfo? s_returnDeflater; - - private readonly object _instance; - - public Pool(int timeoutMilliseconds) - { - s_type ??= typeof(WebSocket).Assembly.GetType("System.Net.WebSockets.Compression.ZLibStreamPool", throwOnError: true); - s_constructor ??= s_type.GetConstructors(BindingFlags.Instance | BindingFlags.NonPublic)[0]; - - _instance = s_constructor.Invoke(new object[] { /*windowBits*/9, timeoutMilliseconds }); - } - - public int ActiveCount => (int)(s_activeCount ??= s_type.GetField("_activeCount", BindingFlags.Instance | BindingFlags.NonPublic)).GetValue(_instance); - - public object GetInflater() => GetMethod(ref s_rentInflater).Invoke(_instance, null); - - public void ReturnInflater(object inflater) => GetMethod(ref s_returnInflater).Invoke(_instance, new[] { inflater }); - - public object GetDeflater() => GetMethod(ref s_rentDeflater).Invoke(_instance, null); - - public void ReturnDeflater(object deflater) => GetMethod(ref s_returnDeflater).Invoke(_instance, new[] { deflater }); - - private static MethodInfo GetMethod(ref MethodInfo? method, [CallerMemberName] string? name = null) - { - return method ??= s_type.GetMethod(name) - ?? throw new InvalidProgramException($"Method {name} was not found in {s_type}."); - } - } - } -}