-
Notifications
You must be signed in to change notification settings - Fork 10.1k
/
StreamTracker.cs
134 lines (113 loc) · 3.9 KB
/
StreamTracker.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Threading.Channels;
using Microsoft.AspNetCore.SignalR.Protocol;
namespace Microsoft.AspNetCore.SignalR;
internal sealed class StreamTracker
{
private static readonly MethodInfo _buildConverterMethod = typeof(StreamTracker).GetMethods(BindingFlags.NonPublic | BindingFlags.Static).Single(m => m.Name.Equals(nameof(BuildStream)));
private readonly object[] _streamConverterArgs;
private readonly ConcurrentDictionary<string, IStreamConverter> _lookup = new ConcurrentDictionary<string, IStreamConverter>();
public StreamTracker(int streamBufferCapacity)
{
_streamConverterArgs = new object[] { streamBufferCapacity };
}
/// <summary>
/// Creates a new stream and returns the ChannelReader for it as an object.
/// </summary>
public object AddStream(string streamId, Type itemType, Type targetType)
{
var newConverter = (IStreamConverter)_buildConverterMethod.MakeGenericMethod(itemType).Invoke(null, _streamConverterArgs)!;
_lookup[streamId] = newConverter;
return newConverter.GetReaderAsObject(targetType);
}
private bool TryGetConverter(string streamId, [NotNullWhen(true)] out IStreamConverter? converter)
{
if (_lookup.TryGetValue(streamId, out converter))
{
return true;
}
return false;
}
public bool TryProcessItem(StreamItemMessage message, [NotNullWhen(true)] out Task? task)
{
if (TryGetConverter(message.InvocationId!, out var converter))
{
task = converter.WriteToStream(message.Item);
return true;
}
task = default;
return false;
}
public Type GetStreamItemType(string streamId)
{
if (TryGetConverter(streamId, out var converter))
{
return converter.GetItemType();
}
throw new KeyNotFoundException($"No stream with id '{streamId}' could be found.");
}
public bool TryComplete(CompletionMessage message)
{
_lookup.TryRemove(message.InvocationId!, out var converter);
if (converter == null)
{
return false;
}
converter.TryComplete(message.HasResult || message.Error == null ? null : new HubException(message.Error));
return true;
}
public void CompleteAll(Exception ex)
{
foreach (var converter in _lookup)
{
converter.Value.TryComplete(ex);
}
}
private static IStreamConverter BuildStream<T>(int streamBufferCapacity)
{
return new ChannelConverter<T>(streamBufferCapacity);
}
private interface IStreamConverter
{
Type GetItemType();
object GetReaderAsObject(Type type);
Task WriteToStream(object? item);
void TryComplete(Exception? ex);
}
private sealed class ChannelConverter<T> : IStreamConverter
{
private readonly Channel<T?> _channel;
public ChannelConverter(int streamBufferCapacity)
{
_channel = Channel.CreateBounded<T?>(streamBufferCapacity);
}
public Type GetItemType()
{
return typeof(T);
}
public object GetReaderAsObject(Type type)
{
if (ReflectionHelper.IsIAsyncEnumerable(type))
{
return _channel.Reader.ReadAllAsync();
}
else
{
return _channel.Reader;
}
}
public Task WriteToStream(object? o)
{
return _channel.Writer.WriteAsync((T?)o).AsTask();
}
public void TryComplete(Exception? ex)
{
_channel.Writer.TryComplete(ex);
}
}
}