Skip to content

Commit

Permalink
Fix server-side to a single socket
Browse files Browse the repository at this point in the history
  • Loading branch information
Extremelyd1 committed Jan 9, 2025
1 parent 0647a58 commit 6271a26
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 87 deletions.
144 changes: 99 additions & 45 deletions HKMP/Networking/Server/DtlsServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@ internal class DtlsServer {
private DtlsServerProtocol _serverProtocol;
private ServerTlsServer _tlsServer;

private Socket _currentAcceptSocket;
private Socket _socket;
private ServerDatagramTransport _currentDatagramTransport;

private CancellationTokenSource _acceptTaskTokenSource;
private CancellationTokenSource _clientUpdateTaskTokenSource;
private CancellationTokenSource _cancellationTokenSource;

private readonly List<DtlsServerClient> _acceptedDtlsClients;
private readonly Dictionary<IPEndPoint, DtlsServerClient> _dtlsClients;

public event Action<DtlsServerClient, byte[], int> DataReceivedEvent;

private int _port;

public DtlsServer() {
_acceptedDtlsClients = new List<DtlsServerClient>();
_dtlsClients = new Dictionary<IPEndPoint, DtlsServerClient>();
}

public void Start(int port) {
Expand All @@ -36,37 +35,81 @@ public void Start(int port) {
_serverProtocol = new DtlsServerProtocol();
_tlsServer = new ServerTlsServer(new BcTlsCrypto());

_acceptTaskTokenSource = new CancellationTokenSource();
_clientUpdateTaskTokenSource = new CancellationTokenSource();
_cancellationTokenSource = new CancellationTokenSource();

_socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
_socket.Bind(new IPEndPoint(IPAddress.Any, _port));

var cancellationToken = _acceptTaskTokenSource.Token;
new Thread(() => AcceptLoop(cancellationToken)).Start();
new Thread(() => AcceptLoop(_cancellationTokenSource.Token)).Start();
new Thread(() => SocketReceiveLoop(_cancellationTokenSource.Token)).Start();
}

public void Stop() {
_acceptTaskTokenSource?.Cancel();
_acceptTaskTokenSource?.Dispose();
_acceptTaskTokenSource = null;
_cancellationTokenSource?.Cancel();
_cancellationTokenSource?.Dispose();
_cancellationTokenSource = null;
}

public void DisconnectClient(IPEndPoint endPoint) {
if (!_dtlsClients.TryGetValue(endPoint, out var dtlsServerClient)) {
Logger.Warn("Could not find DtlsServerClient to disconnect");
return;
}

dtlsServerClient.ReceiveLoopTokenSource?.Cancel();
dtlsServerClient.ReceiveLoopTokenSource?.Dispose();

_clientUpdateTaskTokenSource?.Cancel();
_clientUpdateTaskTokenSource?.Dispose();
_clientUpdateTaskTokenSource = null;
dtlsServerClient.DatagramTransport?.Close();
dtlsServerClient.DtlsTransport?.Close();
}

private void SocketReceiveLoop(CancellationToken cancellationToken) {
while (!cancellationToken.IsCancellationRequested) {
EndPoint endPoint = new IPEndPoint(IPAddress.Any, 0);

Logger.Debug("Blocking on server socket receive");
var numReceived = 0;
var buffer = new byte[1400];

try {
numReceived = _socket.ReceiveFrom(
buffer,
SocketFlags.None,
ref endPoint
);
} catch (SocketException e) {
Logger.Error($"UDP server socket exception:\n{e}");
}

var ipEndPoint = (IPEndPoint) endPoint;

ServerDatagramTransport serverDatagramTransport;
if (!_dtlsClients.TryGetValue(ipEndPoint, out var dtlsServerClient)) {
Logger.Debug("Received data on server socket from unknown IP");

serverDatagramTransport = _currentDatagramTransport;
// Set the IP endpoint of the datagram transport instance so it can send data to the correct IP
serverDatagramTransport.IPEndPoint = ipEndPoint;
} else {
serverDatagramTransport = dtlsServerClient.DatagramTransport;
}

try {
serverDatagramTransport.ReceivedDataCollection.Add(new ServerDatagramTransport.ReceivedData {
Buffer = buffer,
Length = numReceived
}, cancellationToken);
} catch (OperationCanceledException) {
break;
}
}
}

private void AcceptLoop(CancellationToken cancellationToken) {
while (!cancellationToken.IsCancellationRequested) {
Logger.Debug("Creating new socket for accepting connection");
Logger.Debug("Creating new ServerDatagramTransport for handling new connection");

_currentAcceptSocket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
// _currentAcceptSocket.DualMode = true;
_currentAcceptSocket.Bind(new IPEndPoint(IPAddress.Any, _port));
// Logger.Debug("Setting socket to listening mode");
// _currentAcceptSocket.Listen(1);
// Logger.Debug("Accepting connections for socket");
// var socket = _currentAcceptSocket.Accept();
// Logger.Debug($"Connection accepted, remote endpoint: {socket.RemoteEndPoint}");

_currentDatagramTransport = new ServerDatagramTransport(_currentAcceptSocket);
_currentDatagramTransport = new ServerDatagramTransport(_socket);

DtlsTransport dtlsTransport;
try {
Expand All @@ -76,42 +119,53 @@ private void AcceptLoop(CancellationToken cancellationToken) {
break;
}

Logger.Debug($"Accepted DTLS connection on socket, endpoint: {_currentAcceptSocket.RemoteEndPoint}");
var endPoint = _currentDatagramTransport.IPEndPoint;

Logger.Debug($"Accepted DTLS connection on socket, endpoint: {endPoint}");

if (_dtlsClients.ContainsKey(endPoint)) {
Logger.Error($"DtlsClient with endpoint ({endPoint}) already exists, cannot add");
continue;
}

var dtlsServerClient = new DtlsServerClient {
DtlsTransport = dtlsTransport,
EndPoint = _currentAcceptSocket.RemoteEndPoint as IPEndPoint
DatagramTransport = _currentDatagramTransport,
EndPoint = endPoint,
ReceiveLoopTokenSource = new CancellationTokenSource()
};

_acceptedDtlsClients.Add(dtlsServerClient);

var clientCancellationToken = _clientUpdateTaskTokenSource.Token;
new Thread(() => ReceiveClientLoop(dtlsServerClient, clientCancellationToken)).Start();

_dtlsClients.Add(endPoint, dtlsServerClient);

Logger.Debug("Starting receive loop for client");
new Thread(() => ClientReceiveLoop(
dtlsServerClient.ReceiveLoopTokenSource.Token,
dtlsServerClient)
).Start();
}

_currentAcceptSocket?.Close();
_currentAcceptSocket = null;

_currentDatagramTransport?.Close();
_currentDatagramTransport = null;

_serverProtocol = null;

foreach (var dtlsServerClient in _acceptedDtlsClients) {
foreach (var dtlsServerClient in _dtlsClients.Values) {
dtlsServerClient.DtlsTransport?.Close();
dtlsServerClient.DatagramTransport?.Close();
}

_acceptedDtlsClients.Clear();
_dtlsClients.Clear();
}

private void ReceiveClientLoop(DtlsServerClient dtlsServerClient, CancellationToken cancellationToken) {
private void ClientReceiveLoop(CancellationToken cancellationToken, DtlsServerClient dtlsServerClient) {
var dtlsTransport = dtlsServerClient.DtlsTransport;

while (!cancellationToken.IsCancellationRequested) {
// TODO: change to const defined somewhere central
var buffer = new byte[1400];
var length = dtlsServerClient.DtlsTransport.Receive(buffer, 0, buffer.Length, 5);
if (length >= 0) {
DataReceivedEvent?.Invoke(dtlsServerClient, buffer, length);
}
var buffer = new byte[dtlsTransport.GetReceiveLimit()];

var numReceived = dtlsTransport.Receive(buffer, 0, dtlsTransport.GetReceiveLimit(), 5);

DataReceivedEvent?.Invoke(dtlsServerClient, buffer, numReceived);
}
}
}
4 changes: 4 additions & 0 deletions HKMP/Networking/Server/DtlsServerClient.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
using System.Net;
using System.Threading;
using Org.BouncyCastle.Tls;

namespace Hkmp.Networking.Server;

internal class DtlsServerClient {
public DtlsTransport DtlsTransport { get; init; }
public ServerDatagramTransport DatagramTransport { get; init; }
public IPEndPoint EndPoint { get; init; }

public CancellationTokenSource ReceiveLoopTokenSource { get; init; }
}
90 changes: 48 additions & 42 deletions HKMP/Networking/Server/ServerDatagramTransport.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Collections.Concurrent;
using System.Net;
using System.Net.Sockets;
using Hkmp.Logging;
Expand All @@ -8,10 +9,14 @@ namespace Hkmp.Networking.Server;
internal class ServerDatagramTransport : DatagramTransport {
private readonly Socket _socket;

private bool _hasReceived;
public IPEndPoint IPEndPoint { get; set; }

public BlockingCollection<ReceivedData> ReceivedDataCollection { get; }

public ServerDatagramTransport(Socket socket) {
_socket = socket;

ReceivedDataCollection = new BlockingCollection<ReceivedData>();
}

public int GetReceiveLimit() {
Expand All @@ -25,60 +30,61 @@ public int GetSendLimit() {
}

public int Receive(byte[] buf, int off, int len, int waitMillis) {
if (!_hasReceived) {
EndPoint endPoint = new IPEndPoint(IPAddress.Any, 0);

try {
Logger.Debug("Blocking on first receive");

// _socket.ReceiveTimeout = waitMillis;
var numReceived = _socket.ReceiveFrom(
buf,
off,
len,
SocketFlags.None,
ref endPoint
);

_socket.Connect(endPoint);
_hasReceived = true;

Logger.Debug($"First receive finished ({numReceived}), connecting socket to endpoint: {endPoint}");
return numReceived;
} catch (SocketException e) {
Logger.Error($"UDP Socket exception:\n{e}");
if (!ReceivedDataCollection.TryTake(out var data, waitMillis)) {
return -1;
}

// If there is more data in the entry we received from the blocking collection than space in the buffer
// from the method, we need to add as much data into the buffer and put the rest back in the collection
if (len < data.Length) {
// Fill the buffer from the method with as much data from the entry as possible
for (var i = off; i < off + len; i++) {
buf[i] = data.Buffer[i - off];
}
} else {
try {
Logger.Debug("Non-first blocking receive called");
// _socket.ReceiveTimeout = waitMillis;
var numReceived = _socket.Receive(
buf,
off,
len,
SocketFlags.None
);
Logger.Debug($"End of non-first blocking receive: {numReceived}");
return numReceived;
} catch (SocketException e) {
Logger.Error($"UDP Socket exception:\n{e}");

// Calculate the length of the leftover buffer and instantiate it
var leftoverLength = data.Length - len;
var leftoverBuffer = new byte[leftoverLength];

// Fill the leftover buffer with the leftover data from the entry
for (var i = 0; i < leftoverLength; i++) {
leftoverBuffer[i] = data.Buffer[len + i];
}

// Add the leftover buffer and its length back to the collection
ReceivedDataCollection.Add(new ReceivedData {
Buffer = leftoverBuffer,
Length = leftoverLength
});

return len;
}

// In this case, the space in the buffer from the method is large enough, so we fill it with all the data
// from the collection entry
for (var i = 0; i < data.Length; i++) {
buf[off + i] = data.Buffer[i];
}

return -1;
return data.Length;
}

public void Send(byte[] buf, int off, int len) {
if (!_hasReceived) {
Logger.Error("Cannot send because socket has not received yet");
if (IPEndPoint == null) {
Logger.Error("Cannot send because transport has no endpoint");
return;
}

Logger.Debug($"Server sending {len} bytes of data");
Logger.Debug($"Server sending {len} bytes of data to: {IPEndPoint}");

_socket.Send(buf, off, len, SocketFlags.None);
_socket.SendTo(buf, off, len, SocketFlags.None, IPEndPoint);
}

public void Close() {
}

public class ReceivedData {
public byte[] Buffer { get; set; }
public int Length { get; set; }
}
}

0 comments on commit 6271a26

Please sign in to comment.