Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNSName: correct len and offset types #13723

Merged
merged 4 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 63 additions & 37 deletions pdns/dnsname.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ DNSName::DNSName(const std::string_view sw)
}
}


DNSName::DNSName(const char* pos, int len, int offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, uint16_t minOffset)
DNSName::DNSName(const char* pos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, uint16_t minOffset)
{
if (offset >= len)
throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")");
Expand All @@ -115,16 +114,18 @@ DNSName::DNSName(const char* pos, int len, int offset, bool uncompress, uint16_t
}

// this should be the __only__ dns name parser in PowerDNS.
void DNSName::packetParser(const char* qpos, int len, int offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset)
void DNSName::packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset)
{
const unsigned char* pos=(const unsigned char*)qpos;
unsigned char labellen;
const unsigned char *opos = pos;

if (offset >= len)
if (offset >= len) {
throw std::range_error("Trying to read past the end of the buffer ("+std::to_string(offset)+ " >= "+std::to_string(len)+")");
if (offset < (int) minOffset)
}
if (offset < static_cast<size_t>(minOffset)) {
throw std::range_error("Trying to read before the beginning of the buffer ("+std::to_string(offset)+ " < "+std::to_string(minOffset)+")");
}

const unsigned char* end = pos + len;
pos += offset;
Expand All @@ -134,16 +135,19 @@ void DNSName::packetParser(const char* qpos, int len, int offset, bool uncompres
throw std::range_error("Found compressed label, instructed not to follow");

labellen &= (~0xc0);
int newpos = (labellen << 8) + *(const unsigned char*)pos;
size_t newpos = (labellen << 8) + *(const unsigned char*)pos;

if(newpos < offset) {
if(newpos < (int) minOffset)
if (newpos < offset) {
if (newpos < minOffset) {
throw std::range_error("Invalid label position during decompression ("+std::to_string(newpos)+ " < "+std::to_string(minOffset)+")");
if (++depth > 100)
}
if (++depth > 100) {
throw std::range_error("Abort label decompression after 100 redirects");
}
packetParser((const char*)opos, len, newpos, true, nullptr, nullptr, nullptr, depth, minOffset);
} else
} else {
throw std::range_error("Found a forward reference during label decompression");
}
pos++;
break;
} else if(labellen & 0xc0) {
Expand All @@ -152,15 +156,18 @@ void DNSName::packetParser(const char* qpos, int len, int offset, bool uncompres
if (pos + labellen < end) {
appendRawLabel((const char*)pos, labellen);
}
else
else {
throw std::range_error("Found an invalid label length in qname");
}
pos+=labellen;
}
if(d_storage.empty())
if (d_storage.empty()) {
d_storage.append(1, (char)0); // we just parsed the root
if(consumed)
}
if (consumed != nullptr) {
*consumed = pos - opos - offset;
if(qtype) {
}
if (qtype != nullptr) {
if (pos + 2 > end) {
throw std::range_error("Trying to read qtype past the end of the buffer ("+std::to_string((pos - opos) + 2)+ " > "+std::to_string(len)+")");
}
Expand Down Expand Up @@ -225,8 +232,9 @@ std::string DNSName::toLogString() const

std::string DNSName::toDNSString() const
{
if (empty())
if (empty()) {
throw std::out_of_range("Attempt to DNSString an unset dnsname");
}

return std::string(d_storage.c_str(), d_storage.length());
}
Expand All @@ -250,11 +258,13 @@ size_t DNSName::wirelength() const {
// Are WE part of parent
bool DNSName::isPartOf(const DNSName& parent) const
{
if(parent.empty() || empty())
if(parent.empty() || empty()) {
throw std::out_of_range("empty dnsnames aren't part of anything");
}

if(parent.d_storage.size() > d_storage.size())
if(parent.d_storage.size() > d_storage.size()) {
return false;
}

// this is slightly complicated since we can't start from the end, since we can't see where a label begins/ends then
for(auto us=d_storage.cbegin(); us<d_storage.cend(); us+=*us+1) {
Expand Down Expand Up @@ -290,8 +300,9 @@ void DNSName::makeUsRelative(const DNSName& zone)
d_storage.erase(d_storage.size()-zone.d_storage.size());
d_storage.append(1, (char)0); // put back the trailing 0
}
else
else {
clear();
}
}

DNSName DNSName::getCommonLabels(const DNSName& other) const
Expand Down Expand Up @@ -323,8 +334,9 @@ DNSName DNSName::labelReverse() const
{
DNSName ret;

if(isRoot())
if (isRoot()) {
return *this; // we don't create the root automatically below
}

if (!empty()) {
vector<string> l=getRawLabels();
Expand All @@ -343,14 +355,17 @@ void DNSName::appendRawLabel(const std::string& label)

void DNSName::appendRawLabel(const char* start, unsigned int length)
{
if(length==0)
if (length==0) {
throw std::range_error("no such thing as an empty label to append");
if(length > 63)
}
if (length > 63) {
throw std::range_error("label too long to append");
if(d_storage.size() + length > s_maxDNSNameLength - 1) // reserve one byte for the label length
}
if (d_storage.size() + length > s_maxDNSNameLength - 1) { // reserve one byte for the label length
throw std::range_error("name too long to append");
}

if(d_storage.empty()) {
if (d_storage.empty()) {
d_storage.append(1, (char)length);
}
else {
Expand All @@ -362,15 +377,19 @@ void DNSName::appendRawLabel(const char* start, unsigned int length)

void DNSName::prependRawLabel(const std::string& label)
{
if(label.empty())
if (label.empty()) {
throw std::range_error("no such thing as an empty label to prepend");
if(label.size() > 63)
}
if (label.size() > 63) {
throw std::range_error("label too long to prepend");
if(d_storage.size() + label.size() > s_maxDNSNameLength - 1) // reserve one byte for the label length
}
if (d_storage.size() + label.size() > s_maxDNSNameLength - 1) { // reserve one byte for the label length
throw std::range_error("name too long to prepend");
}

if(d_storage.empty())
if (d_storage.empty()) {
d_storage.append(1, (char)0);
}

string_t prep(1, (char)label.size());
prep.append(label.c_str(), label.size());
Expand Down Expand Up @@ -415,16 +434,18 @@ DNSName DNSName::getLastLabel() const

bool DNSName::chopOff()
{
if(d_storage.empty() || d_storage[0]==0)
if (d_storage.empty() || d_storage[0]==0) {
return false;
}
d_storage.erase(0, (unsigned int)d_storage[0]+1);
return true;
}

bool DNSName::isWildcard() const
{
if(d_storage.size() < 2)
if (d_storage.size() < 2) {
return false;
}
auto p = d_storage.begin();
return (*p == 0x01 && *++p == '*');
}
Expand Down Expand Up @@ -454,8 +475,9 @@ unsigned int DNSName::countLabels() const

void DNSName::trimToLabels(unsigned int to)
{
while(countLabels() > to && chopOff())
while(countLabels() > to && chopOff()) {
;
}
}


Expand All @@ -470,12 +492,15 @@ void DNSName::appendEscapedLabel(std::string& appendTo, const char* orig, size_t

while (pos < len) {
auto p = static_cast<uint8_t>(orig[pos]);
if(p=='.')
if (p=='.') {
appendTo+="\\.";
else if(p=='\\')
}
else if (p=='\\') {
appendTo+="\\\\";
else if(p > 0x20 && p < 0x7f)
}
else if (p > 0x20 && p < 0x7f) {
appendTo.append(1, (char)p);
}
else {
char buf[] = "000";
auto got = snprintf(buf, sizeof(buf), "%03" PRIu8, p);
Expand All @@ -498,11 +523,12 @@ bool DNSName::has8bitBytes() const
for (size_t idx = 0; idx < length; idx++) {
++pos;
char c = s.at(pos);
if(!((c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c =='-' || c == '_' || c=='*' || c=='.' || c=='/' || c=='@' || c==' ' || c=='\\' || c==':'))
if (!((c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c =='-' || c == '_' || c=='*' || c=='.' || c=='/' || c=='@' || c==' ' || c=='\\' || c==':')) {
return true;
}
}
++pos;
length = s.at(pos);
Expand Down
4 changes: 2 additions & 2 deletions pdns/dnsname.hh
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public:
DNSName(DNSName&& a) = default;

explicit DNSName(std::string_view sw); //!< Constructs from a human formatted, escaped presentation
DNSName(const char* p, int len, int offset, bool uncompress, uint16_t* qtype = nullptr, uint16_t* qclass = nullptr, unsigned int* consumed = nullptr, uint16_t minOffset = 0); //!< Construct from a DNS Packet, taking the first question if offset=12. If supplied, consumed is set to the number of bytes consumed from the packet, which will not be equal to the wire length of the resulting name in case of compression.
DNSName(const char* p, size_t len, size_t offset, bool uncompress, uint16_t* qtype = nullptr, uint16_t* qclass = nullptr, unsigned int* consumed = nullptr, uint16_t minOffset = 0); //!< Construct from a DNS Packet, taking the first question if offset=12. If supplied, consumed is set to the number of bytes consumed from the packet, which will not be equal to the wire length of the resulting name in case of compression.

bool isPartOf(const DNSName& rhs) const; //!< Are we part of the rhs name? Note that name.isPartOf(name).
inline bool operator==(const DNSName& rhs) const; //!< DNS-native comparison (case insensitive) - empty compares to empty
Expand Down Expand Up @@ -216,7 +216,7 @@ public:
private:
string_t d_storage;

void packetParser(const char* p, int len, int offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset);
void packetParser(const char* qpos, size_t len, size_t offset, bool uncompress, uint16_t* qtype, uint16_t* qclass, unsigned int* consumed, int depth, uint16_t minOffset);
static void appendEscapedLabel(std::string& appendTo, const char* orig, size_t len);
static std::string unescapeLabel(const std::string& orig);
static void throwSafeRangeError(const std::string& msg, const char* buf, size_t length);
Expand Down