Skip to content

Commit

Permalink
feat: responses should not flood the network #18
Browse files Browse the repository at this point in the history
  • Loading branch information
richardschneider committed Nov 4, 2018
1 parent dec3947 commit 1d6ae43
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 4 deletions.
64 changes: 60 additions & 4 deletions src/MulticastService.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
using Common.Logging;
using System;
using System.Collections.Generic;
using System.Collections.Concurrent;
using System.Linq;
using System.Net;
using System.Net.NetworkInformation;
using System.Net.Sockets;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Security.Cryptography;
using System.Text;

namespace Makaretu.Dns
{
Expand Down Expand Up @@ -47,6 +50,19 @@ public class MulticastService : IResolver, IDisposable
IPEndPoint mdnsEndpoint;
int maxPacketSize;

/// <summary>
/// Recently sent messages.
/// </summary>
/// <value>
/// The key is the MD5 hash of the <see cref="Message"/> and the
/// value is when the message was sent.
/// </value>
/// <remarks>
/// This is used to avoid floding of responses as per
/// <see href="https://github.com/richardschneider/net-mdns/issues/18"/>
/// </remarks>
ConcurrentDictionary<string, DateTime> sentMessages = new ConcurrentDictionary<string, DateTime>();

/// <summary>
/// Set the default TTLs.
/// </summary>
Expand Down Expand Up @@ -374,7 +390,7 @@ public void SendQuery(string name, DnsClass klass = DnsClass.IN, DnsType type =
/// </exception>
public void SendQuery(Message msg)
{
Send(msg);
Send(msg, checkDuplicate: false);
}

/// <summary>
Expand All @@ -383,6 +399,10 @@ public void SendQuery(Message msg)
/// <param name="answer">
/// The answer message.
/// </param>
/// <param name="checkDuplicate">
/// If <b>true</b>, then if the same <paramref name="answer"/> was
/// recently sent it will not be sent again.
/// </param>
/// <exception cref="InvalidOperationException">
/// When the service has not started.
/// </exception>
Expand All @@ -396,10 +416,14 @@ public void SendQuery(Message msg)
/// The <paramref name="answer"/> is <see cref="Message.Truncate">truncated</see>
/// if exceeds the maximum packet length.
/// </para>
/// <para>
/// <paramref name="checkDuplicate"/> should always be <b>true</b> except
/// when <see href="https://tools.ietf.org/html/rfc6762#section-8.1">answering a probe</see>.
/// </para>
/// </remarks>
/// <see cref="QueryReceived"/>
/// <seealso cref="Message.CreateResponse"/>
public void SendAnswer(Message answer)
public void SendAnswer(Message answer, bool checkDuplicate = true)
{
// All MDNS answers are authoritative and have a transaction
// ID of zero.
Expand All @@ -411,23 +435,55 @@ public void SendAnswer(Message answer)

answer.Truncate(maxPacketSize);

Send(answer);
Send(answer, checkDuplicate);
}

void Send(Message msg)
void Send(Message msg, bool checkDuplicate)
{
var packet = msg.ToByteArray();
if (packet.Length > maxPacketSize)
{
throw new ArgumentOutOfRangeException($"Exceeds max packet size of {maxPacketSize}.");
}

// Get the hash of the packet. MD5 is okay because
// the hash is not used for security.
string hash;
using (var md5 = MD5.Create())
{
var bytes = md5.ComputeHash(packet);
// TODO: there must be a more efficient way.
var s = new StringBuilder();
for (int i = 0; i < bytes.Length; i++)
{
s.Append(bytes[i].ToString("x2"));
}
hash = s.ToString();
}

// Prune the sent messages. Anything older than a second ago
// is removed.
var now = DateTime.Now;
var dead = now.AddSeconds(-1);
foreach (var notrecent in sentMessages.Where(x => x.Value < dead))
{
sentMessages.TryRemove(notrecent.Key, out DateTime _);
}

// If messsage was recently sent, then do not send again.
if (checkDuplicate && sentMessages.ContainsKey(hash))
{
return;
}

lock (senderLock)
{
if (sender == null)
throw new InvalidOperationException("MDNS is not started");
sender.SendAsync(packet, packet.Length, mdnsEndpoint).Wait();
}

sentMessages.AddOrUpdate(hash, DateTime.Now, (key, value) => value);
}

/// <summary>
Expand Down
83 changes: 83 additions & 0 deletions test/MulticastServiceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,89 @@ public void Resolve_NoAnswer()
});
}
}
[TestMethod]
public async Task DuplicateResponse()
{
var service = Guid.NewGuid().ToString() + ".local";

using (var mdns = new MulticastService())
{
var answerCount = 0;
mdns.NetworkInterfaceDiscovered += async (s, e) =>
{
mdns.SendQuery(service);
await Task.Delay(250);
mdns.SendQuery(service);
};
mdns.QueryReceived += (s, e) =>
{
var msg = e.Message;
if (msg.Questions.Any(q => q.Name == service))
{
var res = msg.CreateResponse();
res.Answers.Add(new ARecord
{
Name = service,
Address = IPAddress.Parse("127.1.1.1")
});
mdns.SendAnswer(res);
}
};
mdns.AnswerReceived += (s, e) =>
{
var msg = e.Message;
if (msg.Answers.Any(answer => answer.Name == service))
{
++answerCount;
};
};
mdns.Start();
await Task.Delay(1000);
Assert.AreEqual(1, answerCount);
}
}

[TestMethod]
public async Task NoDuplicateResponse()
{
var service = Guid.NewGuid().ToString() + ".local";

using (var mdns = new MulticastService())
{
var answerCount = 0;
mdns.NetworkInterfaceDiscovered += async (s, e) =>
{
mdns.SendQuery(service);
await Task.Delay(250);
mdns.SendQuery(service);
};
mdns.QueryReceived += (s, e) =>
{
var msg = e.Message;
if (msg.Questions.Any(q => q.Name == service))
{
var res = msg.CreateResponse();
res.Answers.Add(new ARecord
{
Name = service,
Address = IPAddress.Parse("127.1.1.1")
});
mdns.SendAnswer(res, checkDuplicate: false);
}
};
mdns.AnswerReceived += (s, e) =>
{
var msg = e.Message;
if (msg.Answers.Any(answer => answer.Name == service))
{
++answerCount;
};
};
mdns.Start();
await Task.Delay(1000);
Assert.AreEqual(2, answerCount);
}
}

}
}

0 comments on commit 1d6ae43

Please sign in to comment.