Skip to content

Commit

Permalink
MinMDNS refactor: (#4102)
Browse files Browse the repository at this point in the history
- support legacy respones (include query in the response)
- Some refactoring for better readability
  • Loading branch information
andy31415 authored Dec 8, 2020
1 parent 6d8ff99 commit 0e2b4a4
Show file tree
Hide file tree
Showing 15 changed files with 305 additions and 118 deletions.
1 change: 0 additions & 1 deletion src/lib/mdns/minimal/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ static_library("minimal") {
"QueryReplyFilter.h",
"RecordData.cpp",
"RecordData.h",
"ReplyFilter.h",
"ResponseBuilder.h",
"ResponseSender.cpp",
"ResponseSender.h",
Expand Down
19 changes: 19 additions & 0 deletions src/lib/mdns/minimal/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ bool QueryData::Parse(const BytesRange & validData, const uint8_t ** start)
return true;
}

bool QueryData::Append(HeaderRef & hdr, chip::BufBound & out) const
{
if ((hdr.GetAdditionalCount() != 0) || (hdr.GetAnswerCount() != 0) || (hdr.GetAuthorityCount() != 0))
{
return false;
}

GetName().Put(out);
out.PutBE16(static_cast<uint16_t>(mType));
out.PutBE16(static_cast<uint16_t>(mClass) | (mAnswerViaUnicast ? kQClassUnicastAnswerFlag : 0));

if (out.Fit())
{
hdr.SetQueryCount(static_cast<uint16_t>(hdr.GetQueryCount() + 1));
}

return out.Fit();
}

bool ResourceData::Parse(const BytesRange & validData, const uint8_t ** start)
{
// Structure is:
Expand Down
3 changes: 3 additions & 0 deletions src/lib/mdns/minimal/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class QueryData
/// returns true on parse success, false on failure.
bool Parse(const BytesRange & validData, const uint8_t ** start);

/// Write out this query data back into an output buffer.
bool Append(HeaderRef & hdr, chip::BufBound & out) const;

private:
QType mType = QType::ANY;
QClass mClass = QClass::ANY;
Expand Down
3 changes: 2 additions & 1 deletion src/lib/mdns/minimal/QueryReplyFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

#pragma once

#include <mdns/minimal/responders/ReplyFilter.h>

#include "Parser.h"
#include "Query.h"
#include "ReplyFilter.h"

namespace mdns {
namespace Minimal {
Expand Down
35 changes: 28 additions & 7 deletions src/lib/mdns/minimal/ResponseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ class ResponseBuilder
{
public:
ResponseBuilder() : mHeader(nullptr) {}
ResponseBuilder(chip::System::PacketBuffer * packet) : mHeader(nullptr) { Reset(packet); }
ResponseBuilder(chip::System::PacketBufferHandle && packet) : mHeader(nullptr) { Reset(std::move(packet)); }

ResponseBuilder & Reset(chip::System::PacketBuffer * packet)
ResponseBuilder & Reset(chip::System::PacketBufferHandle && packet)
{
mPacket = packet;
mPacket = std::move(packet);
mHeader = HeaderRef(mPacket->Start());

if (mPacket->AvailableDataLength() >= HeaderRef::kSizeBytes)
Expand All @@ -52,12 +52,12 @@ class ResponseBuilder
return *this;
}

ResponseBuilder & Invalidate()
CHECK_RETURN_VALUE
chip::System::PacketBufferHandle && ReleasePacket()
{
mPacket = nullptr;
mHeader = HeaderRef(nullptr);
mBuildOk = false;
return *this;
return std::move(mPacket);
}

bool HasResponseRecords() const
Expand Down Expand Up @@ -90,10 +90,31 @@ class ResponseBuilder
return *this;
}

ResponseBuilder & AddQuery(const QueryData & query)
{
if (!mBuildOk)
{
return *this;
}

chip::BufBound out(mPacket->Start() + mPacket->DataLength(), mPacket->AvailableDataLength());

if (!query.Append(mHeader, out))
{
mBuildOk = false;
}
else
{
mPacket->SetDataLength(static_cast<uint16_t>(mPacket->DataLength() + out.Needed()));
}
return *this;
}

bool Ok() const { return mBuildOk; }
bool HasPacketBuffer() const { return !mPacket.IsNull(); }

private:
chip::System::PacketBuffer * mPacket = nullptr;
chip::System::PacketBufferHandle mPacket;
HeaderRef mHeader;
bool mBuildOk = false;
};
Expand Down
129 changes: 78 additions & 51 deletions src/lib/mdns/minimal/ResponseSender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "QueryReplyFilter.h"

#include <support/ReturnMacros.h>
#include <system/SystemClock.h>

#define RETURN_IF_ERROR(err) \
do \
Expand All @@ -45,113 +46,142 @@ constexpr uint16_t kMdnsStandardPort = 5353;
constexpr uint16_t kPacketSizeBytes = 512;

} // namespace
namespace Internal {

CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource)
bool ResponseSendingState::SendUnicast() const
{
mSendError = CHIP_NO_ERROR;
return mQuery->RequestedUnicastAnswer() || (mSource->SrcPort != kMdnsStandardPort);
}

bool ResponseSendingState::IncludeQuery() const
{
return (mSource->SrcPort != kMdnsStandardPort);
}

mCurrentSource = querySource;
mCurrentMessageId = messageId;
mSendUnicast = query.RequestedUnicastAnswer() || (querySource->SrcPort != kMdnsStandardPort);
// TODO: at this point we may want to ensure we protect against excessive multicast packet flooding.
// According to https://tools.ietf.org/html/rfc6762#section-6 we should multicast at most 1/sec
// TBD: do we filter out frequent multicasts or should we switch to unicast in those cases
} // namespace Internal

CHIP_ERROR ResponseSender::Respond(uint32_t messageId, const QueryData & query, const chip::Inet::IPPacketInfo * querySource)
{
mSendState.Reset(messageId, query, querySource);

// Responder has a stateful 'additional replies required' that is used within the response
// loop. 'no additionals required' is set at the start and additionals are marked as the query
// reply is built.
mResponder->ResetAdditionals();

mCurrentResourceType = ResourceType::kAnswer; // direct answer
QueryReplyFilter filter(query);
for (auto it = mResponder->begin(); it != mResponder->end(); it++)
// send all 'Answer' replies
{
Responder * responder = it->responder;
const uint64_t kTimeNowMs = chip::System::Platform::Layer::GetClock_MonotonicMS();

QueryReplyFilter queryReplyFilter(query);

QueryResponderRecordFilter responseFilter;

if (!filter.Accept(responder->GetQType(), responder->GetQClass(), responder->GetQName()))
responseFilter.SetReplyFilter(&queryReplyFilter);

if (!mSendState.SendUnicast())
{
continue;
// According to https://tools.ietf.org/html/rfc6762#section-6 we should multicast at most 1/sec
//
// TODO: the 'last sent' value does NOT track the interface we used to send, so this may cause
// broadcasts on one interface to throttle broadcasts on another interface.
constexpr uint64_t kOneSecondMs = 1000;
responseFilter.SetIncludeOnlyMulticastBeforeMS(kTimeNowMs - kOneSecondMs);
}

responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendError);
for (auto it = mResponder->begin(&responseFilter); it != mResponder->end(); it++)
{
it->responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendState.GetError());

mResponder->MarkAdditionalRepliesFor(it);

mResponder->MarkAdditionalRepliesFor(it);
if (!mSendState.SendUnicast())
{
it->lastMulticastTime = kTimeNowMs;
}
}
}

mCurrentResourceType = ResourceType::kAdditional; // Additional parts
filter.SetIgnoreNameMatch(true);
for (auto it = mResponder->additional_begin(); it != mResponder->additional_end(); it++)
// send all 'Additional' replies
{
Responder * responder = it->responder;
mSendState.SetResourceType(ResourceType::kAdditional);

QueryReplyFilter queryReplyFilter(query);
queryReplyFilter.SetIgnoreNameMatch(true);

if (!filter.Accept(responder->GetQType(), responder->GetQClass(), responder->GetQName()))
QueryResponderRecordFilter responseFilter;
responseFilter
.SetReplyFilter(&queryReplyFilter) //
.SetIncludeAdditionalRepliesOnly(true);

for (auto it = mResponder->begin(&responseFilter); it != mResponder->end(); it++)
{
continue;
it->responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendState.GetError());
}

it->responder->AddAllResponses(querySource, this);
ReturnErrorOnFailure(mSendError);
}

return FlushReply();
}

CHIP_ERROR ResponseSender::FlushReply()
{
ReturnErrorCodeIf(mCurrentPacket.IsNull(), CHIP_NO_ERROR); // nothing to flush
ReturnErrorCodeIf(!mResponseBuilder.HasPacketBuffer(), CHIP_NO_ERROR); // nothing to flush

if (mResponseBuilder.HasResponseRecords())
{

if (mSendUnicast)
if (mSendState.SendUnicast())
{
ChipLogProgress(Discovery, "Directly sending mDns reply to peer on port %d", mCurrentSource->SrcPort);
ReturnErrorOnFailure(mServer->DirectSend(mCurrentPacket.Release_ForNow(), mCurrentSource->SrcAddress,
mCurrentSource->SrcPort, mCurrentSource->Interface));
ChipLogProgress(Discovery, "Directly sending mDns reply to peer on port %d", mSendState.GetSourcePort());
ReturnErrorOnFailure(mServer->DirectSend(mResponseBuilder.ReleasePacket(), mSendState.GetSourceAddress(),
mSendState.GetSourcePort(), mSendState.GetSourceInterfaceId()));
}
else
{
ChipLogProgress(Discovery, "Broadcasting mDns reply");
ReturnErrorOnFailure(
mServer->BroadcastSend(mCurrentPacket.Release_ForNow(), kMdnsStandardPort, mCurrentSource->Interface));
mServer->BroadcastSend(mResponseBuilder.ReleasePacket(), kMdnsStandardPort, mSendState.GetSourceInterfaceId()));
}
mResponseBuilder.Invalidate();
mCurrentPacket.Adopt(nullptr);
}

return CHIP_NO_ERROR;
}

CHIP_ERROR ResponseSender::PrepareNewReplyPacket()
{
mCurrentPacket = chip::System::PacketBuffer::NewWithAvailableSize(kPacketSizeBytes);
ReturnErrorCodeIf(mCurrentPacket.IsNull(), CHIP_ERROR_NO_MEMORY);
chip::System::PacketBufferHandle buffer = chip::System::PacketBuffer::NewWithAvailableSize(kPacketSizeBytes);
ReturnErrorCodeIf(buffer.IsNull(), CHIP_ERROR_NO_MEMORY);

mResponseBuilder.Reset(mCurrentPacket.Get_ForNow());
mResponseBuilder.Reset(std::move(buffer));
mResponseBuilder.Header().SetMessageId(mSendState.GetMessageId());

mResponseBuilder.Header().SetMessageId(mCurrentMessageId);
if (mSendState.IncludeQuery())
{
mResponseBuilder.AddQuery(*mSendState.GetQuery());
}

return CHIP_NO_ERROR;
}

void ResponseSender::AddResponse(const ResourceRecord & record)
{
RETURN_IF_ERROR(mSendError);
RETURN_IF_ERROR(mSendState.GetError());

if (mCurrentPacket.IsNull())
if (!mResponseBuilder.HasPacketBuffer())
{
mSendError = PrepareNewReplyPacket();
RETURN_IF_ERROR(mSendError);
mSendState.SetError(PrepareNewReplyPacket());
RETURN_IF_ERROR(mSendState.GetError());
}

if (!mResponseBuilder.Ok())
{
mSendError = CHIP_ERROR_INCORRECT_STATE;
mSendState.SetError(CHIP_ERROR_INCORRECT_STATE);
return;
}

mResponseBuilder.AddRecord(mCurrentResourceType, record);
mResponseBuilder.AddRecord(mSendState.GetResourceType(), record);

// ResponseBuilder AddRecord will only fail if insufficient space is available (or at least this is
// the assumption here). It also guarantees that existing data and header are unchanged on
Expand All @@ -160,18 +190,15 @@ void ResponseSender::AddResponse(const ResourceRecord & record)
{
mResponseBuilder.Header().SetFlags(mResponseBuilder.Header().GetFlags().SetTruncated(true));

mSendError = FlushReply();
RETURN_IF_ERROR(mSendError);

mSendError = PrepareNewReplyPacket();
RETURN_IF_ERROR(mSendError);
RETURN_IF_ERROR(mSendState.SetError(FlushReply()));
RETURN_IF_ERROR(mSendState.SetError(PrepareNewReplyPacket()));

mResponseBuilder.AddRecord(mCurrentResourceType, record);
mResponseBuilder.AddRecord(mSendState.GetResourceType(), record);
if (!mResponseBuilder.Ok())
{
// Very much unexpected: single record addtion should fit (our records should not be that big).
ChipLogError(Discovery, "Failed to add single record to mDNS response.");
mSendError = CHIP_ERROR_INTERNAL;
mSendState.SetError(CHIP_ERROR_INTERNAL);
}
}
}
Expand Down
Loading

0 comments on commit 0e2b4a4

Please sign in to comment.