From 6271a26d102793c3a0548685082bcfb4ce51adba Mon Sep 17 00:00:00 2001 From: Extremelyd1 <10898310+Extremelyd1@users.noreply.github.com> Date: Thu, 9 Jan 2025 23:03:05 +0100 Subject: [PATCH] Fix server-side to a single socket --- HKMP/Networking/Server/DtlsServer.cs | 144 ++++++++++++------ HKMP/Networking/Server/DtlsServerClient.cs | 4 + .../Server/ServerDatagramTransport.cs | 90 ++++++----- 3 files changed, 151 insertions(+), 87 deletions(-) diff --git a/HKMP/Networking/Server/DtlsServer.cs b/HKMP/Networking/Server/DtlsServer.cs index 416d1dc..b71c328 100644 --- a/HKMP/Networking/Server/DtlsServer.cs +++ b/HKMP/Networking/Server/DtlsServer.cs @@ -14,20 +14,19 @@ internal class DtlsServer { private DtlsServerProtocol _serverProtocol; private ServerTlsServer _tlsServer; - private Socket _currentAcceptSocket; + private Socket _socket; private ServerDatagramTransport _currentDatagramTransport; - private CancellationTokenSource _acceptTaskTokenSource; - private CancellationTokenSource _clientUpdateTaskTokenSource; + private CancellationTokenSource _cancellationTokenSource; - private readonly List _acceptedDtlsClients; + private readonly Dictionary _dtlsClients; public event Action DataReceivedEvent; private int _port; public DtlsServer() { - _acceptedDtlsClients = new List(); + _dtlsClients = new Dictionary(); } public void Start(int port) { @@ -36,37 +35,81 @@ public void Start(int port) { _serverProtocol = new DtlsServerProtocol(); _tlsServer = new ServerTlsServer(new BcTlsCrypto()); - _acceptTaskTokenSource = new CancellationTokenSource(); - _clientUpdateTaskTokenSource = new CancellationTokenSource(); + _cancellationTokenSource = new CancellationTokenSource(); + + _socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); + _socket.Bind(new IPEndPoint(IPAddress.Any, _port)); - var cancellationToken = _acceptTaskTokenSource.Token; - new Thread(() => AcceptLoop(cancellationToken)).Start(); + new Thread(() => AcceptLoop(_cancellationTokenSource.Token)).Start(); + new Thread(() => SocketReceiveLoop(_cancellationTokenSource.Token)).Start(); } public void Stop() { - _acceptTaskTokenSource?.Cancel(); - _acceptTaskTokenSource?.Dispose(); - _acceptTaskTokenSource = null; + _cancellationTokenSource?.Cancel(); + _cancellationTokenSource?.Dispose(); + _cancellationTokenSource = null; + } + + public void DisconnectClient(IPEndPoint endPoint) { + if (!_dtlsClients.TryGetValue(endPoint, out var dtlsServerClient)) { + Logger.Warn("Could not find DtlsServerClient to disconnect"); + return; + } + + dtlsServerClient.ReceiveLoopTokenSource?.Cancel(); + dtlsServerClient.ReceiveLoopTokenSource?.Dispose(); - _clientUpdateTaskTokenSource?.Cancel(); - _clientUpdateTaskTokenSource?.Dispose(); - _clientUpdateTaskTokenSource = null; + dtlsServerClient.DatagramTransport?.Close(); + dtlsServerClient.DtlsTransport?.Close(); + } + + private void SocketReceiveLoop(CancellationToken cancellationToken) { + while (!cancellationToken.IsCancellationRequested) { + EndPoint endPoint = new IPEndPoint(IPAddress.Any, 0); + + Logger.Debug("Blocking on server socket receive"); + var numReceived = 0; + var buffer = new byte[1400]; + + try { + numReceived = _socket.ReceiveFrom( + buffer, + SocketFlags.None, + ref endPoint + ); + } catch (SocketException e) { + Logger.Error($"UDP server socket exception:\n{e}"); + } + + var ipEndPoint = (IPEndPoint) endPoint; + + ServerDatagramTransport serverDatagramTransport; + if (!_dtlsClients.TryGetValue(ipEndPoint, out var dtlsServerClient)) { + Logger.Debug("Received data on server socket from unknown IP"); + + serverDatagramTransport = _currentDatagramTransport; + // Set the IP endpoint of the datagram transport instance so it can send data to the correct IP + serverDatagramTransport.IPEndPoint = ipEndPoint; + } else { + serverDatagramTransport = dtlsServerClient.DatagramTransport; + } + + try { + serverDatagramTransport.ReceivedDataCollection.Add(new ServerDatagramTransport.ReceivedData { + Buffer = buffer, + Length = numReceived + }, cancellationToken); + } catch (OperationCanceledException) { + break; + } + } } private void AcceptLoop(CancellationToken cancellationToken) { while (!cancellationToken.IsCancellationRequested) { - Logger.Debug("Creating new socket for accepting connection"); + Logger.Debug("Creating new ServerDatagramTransport for handling new connection"); - _currentAcceptSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp); - // _currentAcceptSocket.DualMode = true; - _currentAcceptSocket.Bind(new IPEndPoint(IPAddress.Any, _port)); - // Logger.Debug("Setting socket to listening mode"); - // _currentAcceptSocket.Listen(1); - // Logger.Debug("Accepting connections for socket"); - // var socket = _currentAcceptSocket.Accept(); - // Logger.Debug($"Connection accepted, remote endpoint: {socket.RemoteEndPoint}"); - - _currentDatagramTransport = new ServerDatagramTransport(_currentAcceptSocket); + _currentDatagramTransport = new ServerDatagramTransport(_socket); DtlsTransport dtlsTransport; try { @@ -76,42 +119,53 @@ private void AcceptLoop(CancellationToken cancellationToken) { break; } - Logger.Debug($"Accepted DTLS connection on socket, endpoint: {_currentAcceptSocket.RemoteEndPoint}"); + var endPoint = _currentDatagramTransport.IPEndPoint; + Logger.Debug($"Accepted DTLS connection on socket, endpoint: {endPoint}"); + + if (_dtlsClients.ContainsKey(endPoint)) { + Logger.Error($"DtlsClient with endpoint ({endPoint}) already exists, cannot add"); + continue; + } + var dtlsServerClient = new DtlsServerClient { DtlsTransport = dtlsTransport, - EndPoint = _currentAcceptSocket.RemoteEndPoint as IPEndPoint + DatagramTransport = _currentDatagramTransport, + EndPoint = endPoint, + ReceiveLoopTokenSource = new CancellationTokenSource() }; - - _acceptedDtlsClients.Add(dtlsServerClient); - - var clientCancellationToken = _clientUpdateTaskTokenSource.Token; - new Thread(() => ReceiveClientLoop(dtlsServerClient, clientCancellationToken)).Start(); + + _dtlsClients.Add(endPoint, dtlsServerClient); + + Logger.Debug("Starting receive loop for client"); + new Thread(() => ClientReceiveLoop( + dtlsServerClient.ReceiveLoopTokenSource.Token, + dtlsServerClient) + ).Start(); } - _currentAcceptSocket?.Close(); - _currentAcceptSocket = null; - _currentDatagramTransport?.Close(); _currentDatagramTransport = null; _serverProtocol = null; - foreach (var dtlsServerClient in _acceptedDtlsClients) { + foreach (var dtlsServerClient in _dtlsClients.Values) { dtlsServerClient.DtlsTransport?.Close(); + dtlsServerClient.DatagramTransport?.Close(); } - _acceptedDtlsClients.Clear(); + _dtlsClients.Clear(); } - private void ReceiveClientLoop(DtlsServerClient dtlsServerClient, CancellationToken cancellationToken) { + private void ClientReceiveLoop(CancellationToken cancellationToken, DtlsServerClient dtlsServerClient) { + var dtlsTransport = dtlsServerClient.DtlsTransport; + while (!cancellationToken.IsCancellationRequested) { - // TODO: change to const defined somewhere central - var buffer = new byte[1400]; - var length = dtlsServerClient.DtlsTransport.Receive(buffer, 0, buffer.Length, 5); - if (length >= 0) { - DataReceivedEvent?.Invoke(dtlsServerClient, buffer, length); - } + var buffer = new byte[dtlsTransport.GetReceiveLimit()]; + + var numReceived = dtlsTransport.Receive(buffer, 0, dtlsTransport.GetReceiveLimit(), 5); + + DataReceivedEvent?.Invoke(dtlsServerClient, buffer, numReceived); } } } diff --git a/HKMP/Networking/Server/DtlsServerClient.cs b/HKMP/Networking/Server/DtlsServerClient.cs index b084429..09759bd 100644 --- a/HKMP/Networking/Server/DtlsServerClient.cs +++ b/HKMP/Networking/Server/DtlsServerClient.cs @@ -1,9 +1,13 @@ using System.Net; +using System.Threading; using Org.BouncyCastle.Tls; namespace Hkmp.Networking.Server; internal class DtlsServerClient { public DtlsTransport DtlsTransport { get; init; } + public ServerDatagramTransport DatagramTransport { get; init; } public IPEndPoint EndPoint { get; init; } + + public CancellationTokenSource ReceiveLoopTokenSource { get; init; } } diff --git a/HKMP/Networking/Server/ServerDatagramTransport.cs b/HKMP/Networking/Server/ServerDatagramTransport.cs index d17ec8f..319175a 100644 --- a/HKMP/Networking/Server/ServerDatagramTransport.cs +++ b/HKMP/Networking/Server/ServerDatagramTransport.cs @@ -1,3 +1,4 @@ +using System.Collections.Concurrent; using System.Net; using System.Net.Sockets; using Hkmp.Logging; @@ -8,10 +9,14 @@ namespace Hkmp.Networking.Server; internal class ServerDatagramTransport : DatagramTransport { private readonly Socket _socket; - private bool _hasReceived; + public IPEndPoint IPEndPoint { get; set; } + + public BlockingCollection ReceivedDataCollection { get; } public ServerDatagramTransport(Socket socket) { _socket = socket; + + ReceivedDataCollection = new BlockingCollection(); } public int GetReceiveLimit() { @@ -25,60 +30,61 @@ public int GetSendLimit() { } public int Receive(byte[] buf, int off, int len, int waitMillis) { - if (!_hasReceived) { - EndPoint endPoint = new IPEndPoint(IPAddress.Any, 0); - - try { - Logger.Debug("Blocking on first receive"); - - // _socket.ReceiveTimeout = waitMillis; - var numReceived = _socket.ReceiveFrom( - buf, - off, - len, - SocketFlags.None, - ref endPoint - ); - - _socket.Connect(endPoint); - _hasReceived = true; - - Logger.Debug($"First receive finished ({numReceived}), connecting socket to endpoint: {endPoint}"); - return numReceived; - } catch (SocketException e) { - Logger.Error($"UDP Socket exception:\n{e}"); + if (!ReceivedDataCollection.TryTake(out var data, waitMillis)) { + return -1; + } + + // If there is more data in the entry we received from the blocking collection than space in the buffer + // from the method, we need to add as much data into the buffer and put the rest back in the collection + if (len < data.Length) { + // Fill the buffer from the method with as much data from the entry as possible + for (var i = off; i < off + len; i++) { + buf[i] = data.Buffer[i - off]; } - } else { - try { - Logger.Debug("Non-first blocking receive called"); - // _socket.ReceiveTimeout = waitMillis; - var numReceived = _socket.Receive( - buf, - off, - len, - SocketFlags.None - ); - Logger.Debug($"End of non-first blocking receive: {numReceived}"); - return numReceived; - } catch (SocketException e) { - Logger.Error($"UDP Socket exception:\n{e}"); + + // Calculate the length of the leftover buffer and instantiate it + var leftoverLength = data.Length - len; + var leftoverBuffer = new byte[leftoverLength]; + + // Fill the leftover buffer with the leftover data from the entry + for (var i = 0; i < leftoverLength; i++) { + leftoverBuffer[i] = data.Buffer[len + i]; } + + // Add the leftover buffer and its length back to the collection + ReceivedDataCollection.Add(new ReceivedData { + Buffer = leftoverBuffer, + Length = leftoverLength + }); + + return len; + } + + // In this case, the space in the buffer from the method is large enough, so we fill it with all the data + // from the collection entry + for (var i = 0; i < data.Length; i++) { + buf[off + i] = data.Buffer[i]; } - return -1; + return data.Length; } public void Send(byte[] buf, int off, int len) { - if (!_hasReceived) { - Logger.Error("Cannot send because socket has not received yet"); + if (IPEndPoint == null) { + Logger.Error("Cannot send because transport has no endpoint"); return; } - Logger.Debug($"Server sending {len} bytes of data"); + Logger.Debug($"Server sending {len} bytes of data to: {IPEndPoint}"); - _socket.Send(buf, off, len, SocketFlags.None); + _socket.SendTo(buf, off, len, SocketFlags.None, IPEndPoint); } public void Close() { } + + public class ReceivedData { + public byte[] Buffer { get; set; } + public int Length { get; set; } + } }