Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MinMDNS refactor: support legacy replies and time throttling #4102

Merged
merged 1 commit into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/lib/mdns/minimal/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ static_library("minimal") {
"QueryReplyFilter.h",
"RecordData.cpp",
"RecordData.h",
"ReplyFilter.h",
"ResponseBuilder.h",
"ResponseSender.cpp",
"ResponseSender.h",
Expand Down
19 changes: 19 additions & 0 deletions src/lib/mdns/minimal/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ bool QueryData::Parse(const BytesRange & validData, const uint8_t ** start)
return true;
}

bool QueryData::Append(HeaderRef & hdr, chip::BufBound & out) const
{
if ((hdr.GetAdditionalCount() != 0) || (hdr.GetAnswerCount() != 0) || (hdr.GetAuthorityCount() != 0))
{
return false;
}

GetName().Put(out);
out.PutBE16(static_cast<uint16_t>(mType));
out.PutBE16(static_cast<uint16_t>(mClass) | (mAnswerViaUnicast ? kQClassUnicastAnswerFlag : 0));

if (out.Fit())
{
hdr.SetQueryCount(static_cast<uint16_t>(hdr.GetQueryCount() + 1));
}

return out.Fit();
}

bool ResourceData::Parse(const BytesRange & validData, const uint8_t ** start)
{
// Structure is:
Expand Down
3 changes: 3 additions & 0 deletions src/lib/mdns/minimal/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class QueryData
/// returns true on parse success, false on failure.
bool Parse(const BytesRange & validData, const uint8_t ** start);

/// Write out this query data back into an output buffer.
bool Append(HeaderRef & hdr, chip::BufBound & out) const;

private:
QType mType = QType::ANY;
QClass mClass = QClass::ANY;
Expand Down
3 changes: 2 additions & 1 deletion src/lib/mdns/minimal/QueryReplyFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

#pragma once

#include <mdns/minimal/responders/ReplyFilter.h>

#include "Parser.h"
#include "Query.h"
#include "ReplyFilter.h"

namespace mdns {
namespace Minimal {
Expand Down
35 changes: 28 additions & 7 deletions src/lib/mdns/minimal/ResponseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class ResponseBuilder
{
public:
ResponseBuilder() : mHeader(nullptr) {}
ResponseBuilder(chip::System::PacketBuffer * packet) : mHeader(nullptr) { Reset(packet); }
ResponseBuilder(chip::System::PacketBufferHandle && packet) : mHeader(nullptr) { Reset(std::move(packet)); }

ResponseBuilder & Reset(chip::System::PacketBuffer * packet)
ResponseBuilder & Reset(chip::System::PacketBufferHandle && packet)
{
mPacket = packet;
mPacket = std::move(packet);
mHeader = HeaderRef(mPacket->Start());

if (mPacket->AvailableDataLength() >= HeaderRef::kSizeBytes)
Expand All @@ -52,12 +52,12 @@ class ResponseBuilder
return *this;
}

ResponseBuilder & Invalidate()
CHECK_RETURN_VALUE
chip::System::PacketBufferHandle && ReleasePacket()
{
mPacket = nullptr;
mHeader = HeaderRef(nullptr);
mBuildOk = false;
return *this;
return std::move(mPacket);
}

bool HasResponseRecords() const
Expand Down Expand Up @@ -90,10 +90,31 @@ class ResponseBuilder
return *this;
}

ResponseBuilder & AddQuery(const QueryData & query)
{
if (!mBuildOk)
{
return *this;
}

chip::BufBound out(mPacket->Start() + mPacket->DataLength(), mPacket->AvailableDataLength());

if (!query.Append(mHeader, out))
{
mBuildOk = false;
}
else
{
mPacket->SetDataLength(static_cast<uint16_t>(mPacket->DataLength() + out.Needed()));
}
return *this;
}

bool Ok() const { return mBuildOk; }
bool HasPacketBuffer() const { return !mPacket.IsNull(); }

private:
chip::System::PacketBuffer * mPacket = nullptr;
chip::System::PacketBufferHandle mPacket;
HeaderRef mHeader;
bool mBuildOk = false;
};
Expand Down
129 changes: 78 additions & 51 deletions src/lib/mdns/minimal/ResponseSender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "QueryReplyFilter.h"

#include <support/ReturnMacros.h>
#include <system/SystemClock.h>

#define RETURN_IF_ERROR(err) \
do \
Expand All @@ -45,113 +46,142 @@ constexpr uint16_t kMdnsStandardPort = 5353;
constexpr uint16_t kPacketSizeBytes = 512;

} // namespace
namespace Internal {

CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource)
bool ResponseSendingState::SendUnicast() const
{
mSendError = CHIP_NO_ERROR;
return mQuery->RequestedUnicastAnswer() || (mSource->SrcPort != kMdnsStandardPort);
}

bool ResponseSendingState::IncludeQuery() const
{
return (mSource->SrcPort != kMdnsStandardPort);
}

mCurrentSource = querySource;
mCurrentMessageId = messageId;
mSendUnicast = query.RequestedUnicastAnswer() || (querySource->SrcPort != kMdnsStandardPort);
// TODO: at this point we may want to ensure we protect against excessive multicast packet flooding.
// According to https://tools.ietf.org/html/rfc6762#section-6 we should multicast at most 1/sec
// TBD: do we filter out frequent multicasts or should we switch to unicast in those cases
} // namespace Internal

CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource)
{
mSendState.Reset(messageId, query, querySource);

// Responder has a stateful 'additional replies required' that is used within the response
// loop. 'no additionals required' is set at the start and additionals are marked as the query
// reply is built.
mResponder->ResetAdditionals();

mCurrentResourceType = ResourceType::kAnswer; // direct answer
QueryReplyFilter filter(query);
for (auto it = mResponder->begin(); it != mResponder->end(); it++)
// send all 'Answer' replies
{
Responder * responder = it->responder;
const uint64_t kTimeNowMs = chip::System::Platform::Layer::GetClock_MonotonicMS();

QueryReplyFilter queryReplyFilter(query);

QueryResponderRecordFilter responseFilter;

if (!filter.Accept(responder->GetQType(), responder->GetQClass(), responder->GetQName()))
responseFilter.SetReplyFilter(&queryReplyFilter);

if (!mSendState.SendUnicast())
{
continue;
// According to https://tools.ietf.org/html/rfc6762#section-6 we should multicast at most 1/sec
//
// TODO: the 'last sent' value does NOT track the interface we used to send, so this may cause
// broadcasts on one interface to throttle broadcasts on another interface.
constexpr uint64_t kOneSecondMs = 1000;
responseFilter.SetIncludeOnlyMulticastBeforeMS(kTimeNowMs - kOneSecondMs);
}

responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendError);
for (auto it = mResponder->begin(&responseFilter); it != mResponder->end(); it++)
{
it->responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendState.GetError());

mResponder->MarkAdditionalRepliesFor(it);

mResponder->MarkAdditionalRepliesFor(it);
if (!mSendState.SendUnicast())
{
it->lastMulticastTime = kTimeNowMs;
}
}
}

mCurrentResourceType = ResourceType::kAdditional; // Additional parts
filter.SetIgnoreNameMatch(true);
for (auto it = mResponder->additional_begin(); it != mResponder->additional_end(); it++)
// send all 'Additional' replies
{
Responder * responder = it->responder;
mSendState.SetResourceType(ResourceType::kAdditional);

QueryReplyFilter queryReplyFilter(query);
queryReplyFilter.SetIgnoreNameMatch(true);

if (!filter.Accept(responder->GetQType(), responder->GetQClass(), responder->GetQName()))
QueryResponderRecordFilter responseFilter;
responseFilter
.SetReplyFilter(&queryReplyFilter) //
.SetIncludeAdditionalRepliesOnly(true);

for (auto it = mResponder->begin(&responseFilter); it != mResponder->end(); it++)
{
continue;
it->responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendState.GetError());
}

it->responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendError);
}

return FlushReply();
}

CHIP_ERROR ResponseSender::FlushReply()
{
ReturnErrorCodeIf(mCurrentPacket.IsNull(), CHIP_NO_ERROR); // nothing to flush
ReturnErrorCodeIf(!mResponseBuilder.HasPacketBuffer(), CHIP_NO_ERROR); // nothing to flush

if (mResponseBuilder.HasResponseRecords())
{

if (mSendUnicast)
if (mSendState.SendUnicast())
{
ChipLogProgress(Discovery, "Directly sending mDns reply to peer on port %d", mCurrentSource->SrcPort);
ReturnErrorOnFailure(mServer->DirectSend(mCurrentPacket.Release_ForNow(), mCurrentSource->SrcAddress,
mCurrentSource->SrcPort, mCurrentSource->Interface));
ChipLogProgress(Discovery, "Directly sending mDns reply to peer on port %d", mSendState.GetSourcePort());
ReturnErrorOnFailure(mServer->DirectSend(mResponseBuilder.ReleasePacket(), mSendState.GetSourceAddress(),
mSendState.GetSourcePort(), mSendState.GetSourceInterfaceId()));
}
else
{
ChipLogProgress(Discovery, "Broadcasting mDns reply");
ReturnErrorOnFailure(
mServer->BroadcastSend(mCurrentPacket.Release_ForNow(), kMdnsStandardPort, mCurrentSource->Interface));
mServer->BroadcastSend(mResponseBuilder.ReleasePacket(), kMdnsStandardPort, mSendState.GetSourceInterfaceId()));
}
mResponseBuilder.Invalidate();
mCurrentPacket.Adopt(nullptr);
}

return CHIP_NO_ERROR;
}

CHIP_ERROR ResponseSender::PrepareNewReplyPacket()
{
mCurrentPacket = chip::System::PacketBuffer::NewWithAvailableSize(kPacketSizeBytes);
ReturnErrorCodeIf(mCurrentPacket.IsNull(), CHIP_ERROR_NO_MEMORY);
chip::System::PacketBufferHandle buffer = chip::System::PacketBuffer::NewWithAvailableSize(kPacketSizeBytes);
ReturnErrorCodeIf(buffer.IsNull(), CHIP_ERROR_NO_MEMORY);

mResponseBuilder.Reset(mCurrentPacket.Get_ForNow());
mResponseBuilder.Reset(std::move(buffer));
mResponseBuilder.Header().SetMessageId(mSendState.GetMessageId());

mResponseBuilder.Header().SetMessageId(mCurrentMessageId);
if (mSendState.IncludeQuery())
{
mResponseBuilder.AddQuery(*mSendState.GetQuery());
}

return CHIP_NO_ERROR;
}

void ResponseSender::AddResponse(const ResourceRecord & record)
{
RETURN_IF_ERROR(mSendError);
RETURN_IF_ERROR(mSendState.GetError());

if (mCurrentPacket.IsNull())
if (!mResponseBuilder.HasPacketBuffer())
{
mSendError = PrepareNewReplyPacket();
RETURN_IF_ERROR(mSendError);
mSendState.SetError(PrepareNewReplyPacket());
RETURN_IF_ERROR(mSendState.GetError());
}

if (!mResponseBuilder.Ok())
{
mSendError = CHIP_ERROR_INCORRECT_STATE;
mSendState.SetError(CHIP_ERROR_INCORRECT_STATE);
return;
}

mResponseBuilder.AddRecord(mCurrentResourceType, record);
mResponseBuilder.AddRecord(mSendState.GetResourceType(), record);

// ResponseBuilder AddRecord will only fail if insufficient space is available (or at least this is
// the assumption here). It also guarantees that existing data and header are unchanged on
Expand All @@ -160,18 +190,15 @@ void ResponseSender::AddResponse(const ResourceRecord & record)
{
mResponseBuilder.Header().SetFlags(mResponseBuilder.Header().GetFlags().SetTruncated(true));

mSendError = FlushReply();
RETURN_IF_ERROR(mSendError);

mSendError = PrepareNewReplyPacket();
RETURN_IF_ERROR(mSendError);
RETURN_IF_ERROR(mSendState.SetError(FlushReply()));
RETURN_IF_ERROR(mSendState.SetError(PrepareNewReplyPacket()));

mResponseBuilder.AddRecord(mCurrentResourceType, record);
mResponseBuilder.AddRecord(mSendState.GetResourceType(), record);
if (!mResponseBuilder.Ok())
{
// Very much unexpected: single record addtion should fit (our records should not be that big).
ChipLogError(Discovery, "Failed to add single record to mDNS response.");
mSendError = CHIP_ERROR_INTERNAL;
mSendState.SetError(CHIP_ERROR_INTERNAL);
}
}
}
Expand Down
Loading