diff --git a/cores/esp32/IPAddress.cpp b/cores/esp32/IPAddress.cpp index 0575363f254..002dccb3fcd 100644 --- a/cores/esp32/IPAddress.cpp +++ b/cores/esp32/IPAddress.cpp @@ -20,78 +20,244 @@ #include #include #include +#include -IPAddress::IPAddress() +IPAddress::IPAddress() : IPAddress(IPv4) {} + +IPAddress::IPAddress(IPType ip_type) { - _address.dword = 0; + _type = ip_type; + memset(_address.bytes, 0, sizeof(_address.bytes)); } IPAddress::IPAddress(uint8_t first_octet, uint8_t second_octet, uint8_t third_octet, uint8_t fourth_octet) { - _address.bytes[0] = first_octet; - _address.bytes[1] = second_octet; - _address.bytes[2] = third_octet; - _address.bytes[3] = fourth_octet; + _type = IPv4; + memset(_address.bytes, 0, sizeof(_address.bytes)); + _address.bytes[IPADDRESS_V4_BYTES_INDEX] = first_octet; + _address.bytes[IPADDRESS_V4_BYTES_INDEX + 1] = second_octet; + _address.bytes[IPADDRESS_V4_BYTES_INDEX + 2] = third_octet; + _address.bytes[IPADDRESS_V4_BYTES_INDEX + 3] = fourth_octet; +} + +IPAddress::IPAddress(uint8_t o1, uint8_t o2, uint8_t o3, uint8_t o4, uint8_t o5, uint8_t o6, uint8_t o7, uint8_t o8, uint8_t o9, uint8_t o10, uint8_t o11, uint8_t o12, uint8_t o13, uint8_t o14, uint8_t o15, uint8_t o16) { + _type = IPv6; + _address.bytes[0] = o1; + _address.bytes[1] = o2; + _address.bytes[2] = o3; + _address.bytes[3] = o4; + _address.bytes[4] = o5; + _address.bytes[5] = o6; + _address.bytes[6] = o7; + _address.bytes[7] = o8; + _address.bytes[8] = o9; + _address.bytes[9] = o10; + _address.bytes[10] = o11; + _address.bytes[11] = o12; + _address.bytes[12] = o13; + _address.bytes[13] = o14; + _address.bytes[14] = o15; + _address.bytes[15] = o16; } IPAddress::IPAddress(uint32_t address) { - _address.dword = address; + // IPv4 only + _type = IPv4; + memset(_address.bytes, 0, sizeof(_address.bytes)); + _address.dword[IPADDRESS_V4_DWORD_INDEX] = address; + + // NOTE on conversion/comparison and uint32_t: + // These conversions are host platform dependent. + // There is a defined integer representation of IPv4 addresses, + // based on network byte order (will be the value on big endian systems), + // e.g. http://2398766798 is the same as http://142.250.70.206, + // However on little endian systems the octets 0x83, 0xFA, 0x46, 0xCE, + // in that order, will form the integer (uint32_t) 3460758158 . +} + +IPAddress::IPAddress(const uint8_t *address) : IPAddress(IPv4, address) {} + +IPAddress::IPAddress(IPType ip_type, const uint8_t *address) +{ + _type = ip_type; + if (ip_type == IPv4) { + memset(_address.bytes, 0, sizeof(_address.bytes)); + memcpy(&_address.bytes[IPADDRESS_V4_BYTES_INDEX], address, sizeof(uint32_t)); + } else { + memcpy(_address.bytes, address, sizeof(_address.bytes)); + } } -IPAddress::IPAddress(const uint8_t *address) +IPAddress::IPAddress(const char *address) { - memcpy(_address.bytes, address, sizeof(_address.bytes)); + fromString(address); } IPAddress& IPAddress::operator=(const uint8_t *address) { - memcpy(_address.bytes, address, sizeof(_address.bytes)); + // IPv4 only conversion from byte pointer + _type = IPv4; + memset(_address.bytes, 0, sizeof(_address.bytes)); + memcpy(&_address.bytes[IPADDRESS_V4_BYTES_INDEX], address, sizeof(uint32_t)); + return *this; +} + +IPAddress& IPAddress::operator=(const char *address) +{ + fromString(address); return *this; } IPAddress& IPAddress::operator=(uint32_t address) { - _address.dword = address; + // IPv4 conversion + // See note on conversion/comparison and uint32_t + _type = IPv4; + memset(_address.bytes, 0, sizeof(_address.bytes)); + _address.dword[IPADDRESS_V4_DWORD_INDEX] = address; return *this; } +bool IPAddress::operator==(const IPAddress& addr) const +{ + return (addr._type == _type) + && (memcmp(addr._address.bytes, _address.bytes, sizeof(_address.bytes)) == 0); +} + bool IPAddress::operator==(const uint8_t* addr) const { - return memcmp(addr, _address.bytes, sizeof(_address.bytes)) == 0; + // IPv4 only comparison to byte pointer + // Can't support IPv6 as we know our type, but not the length of the pointer + return _type == IPv4 && memcmp(addr, &_address.bytes[IPADDRESS_V4_BYTES_INDEX], sizeof(uint32_t)) == 0; +} + +uint8_t IPAddress::operator[](int index) const { + if (_type == IPv4) { + return _address.bytes[IPADDRESS_V4_BYTES_INDEX + index]; + } + return _address.bytes[index]; +} + +uint8_t& IPAddress::operator[](int index) { + if (_type == IPv4) { + return _address.bytes[IPADDRESS_V4_BYTES_INDEX + index]; + } + return _address.bytes[index]; } size_t IPAddress::printTo(Print& p) const { size_t n = 0; - for(int i = 0; i < 3; i++) { - n += p.print(_address.bytes[i], DEC); + + if (_type == IPv6) { + // IPv6 IETF canonical format: compress left-most longest run of two or more zero fields, lower case + int8_t longest_start = -1; + int8_t longest_length = 1; + int8_t current_start = -1; + int8_t current_length = 0; + for (int8_t f = 0; f < 8; f++) { + if (_address.bytes[f * 2] == 0 && _address.bytes[f * 2 + 1] == 0) { + if (current_start == -1) { + current_start = f; + current_length = 1; + } else { + current_length++; + } + if (current_length > longest_length) { + longest_start = current_start; + longest_length = current_length; + } + } else { + current_start = -1; + } + } + for (int f = 0; f < 8; f++) { + if (f < longest_start || f >= longest_start + longest_length) { + uint8_t c1 = _address.bytes[f * 2] >> 4; + uint8_t c2 = _address.bytes[f * 2] & 0xf; + uint8_t c3 = _address.bytes[f * 2 + 1] >> 4; + uint8_t c4 = _address.bytes[f * 2 + 1] & 0xf; + if (c1 > 0) { + n += p.print((char)(c1 < 10 ? '0' + c1 : 'a' + c1 - 10)); + } + if (c1 > 0 || c2 > 0) { + n += p.print((char)(c2 < 10 ? '0' + c2 : 'a' + c2 - 10)); + } + if (c1 > 0 || c2 > 0 || c3 > 0) { + n += p.print((char)(c3 < 10 ? '0' + c3 : 'a' + c3 - 10)); + } + n += p.print((char)(c4 < 10 ? '0' + c4 : 'a' + c4 - 10)); + if (f < 7) { + n += p.print(':'); + } + } else if (f == longest_start) { + if (longest_start == 0) { + n += p.print(':'); + } + n += p.print(':'); + } + } + return n; + } + + // IPv4 + for (int i =0; i < 3; i++) + { + n += p.print(_address.bytes[IPADDRESS_V4_BYTES_INDEX + i], DEC); n += p.print('.'); } - n += p.print(_address.bytes[3], DEC); + n += p.print(_address.bytes[IPADDRESS_V4_BYTES_INDEX + 3], DEC); return n; } -String IPAddress::toString() const +String IPAddress::toString4() const { char szRet[16]; - sprintf(szRet,"%u.%u.%u.%u", _address.bytes[0], _address.bytes[1], _address.bytes[2], _address.bytes[3]); + snprintf(szRet, sizeof(szRet), "%u.%u.%u.%u", _address.bytes[IPADDRESS_V4_BYTES_INDEX], _address.bytes[IPADDRESS_V4_BYTES_INDEX + 1], _address.bytes[IPADDRESS_V4_BYTES_INDEX + 2], _address.bytes[IPADDRESS_V4_BYTES_INDEX + 3]); return String(szRet); } +String IPAddress::toString6() const +{ + StreamString s; + s.reserve(40); + printTo(s); + return s; +} + +String IPAddress::toString() const +{ + if (_type == IPv4) { + return toString4(); + } else { + return toString6(); + } +} + bool IPAddress::fromString(const char *address) +{ + if (!fromString4(address)) + { + return fromString6(address); + } + return true; +} + +bool IPAddress::fromString4(const char *address) { // TODO: add support for "a", "a.b", "a.b.c" formats - uint16_t acc = 0; // Accumulator + int16_t acc = -1; // Accumulator uint8_t dots = 0; + memset(_address.bytes, 0, sizeof(_address.bytes)); while (*address) { char c = *address++; if (c >= '0' && c <= '9') { - acc = acc * 10 + (c - '0'); + acc = (acc < 0) ? (c - '0') : acc * 10 + (c - '0'); if (acc > 255) { // Value out of [0..255] range return false; @@ -100,11 +266,15 @@ bool IPAddress::fromString(const char *address) else if (c == '.') { if (dots == 3) { - // Too much dots (there must be 3 dots) + // Too many dots (there must be 3 dots) return false; } - _address.bytes[dots++] = acc; - acc = 0; + if (acc < 0) { + /* No value between dots, e.g. '1..' */ + return false; + } + _address.bytes[IPADDRESS_V4_BYTES_INDEX + dots++] = acc; + acc = -1; } else { @@ -117,7 +287,80 @@ bool IPAddress::fromString(const char *address) // Too few dots (there must be 3 dots) return false; } - _address.bytes[3] = acc; + if (acc < 0) { + /* No value between dots, e.g. '1..' */ + return false; + } + _address.bytes[IPADDRESS_V4_BYTES_INDEX + 3] = acc; + _type = IPv4; + return true; +} + +bool IPAddress::fromString6(const char *address) { + uint32_t acc = 0; // Accumulator + int colons = 0, double_colons = -1; + + while (*address) + { + char c = tolower(*address++); + if (isalnum(c) && c <= 'f') { + if (c >= 'a') + c -= 'a' - '0' - 10; + acc = acc * 16 + (c - '0'); + if (acc > 0xffff) + // Value out of range + return false; + } + else if (c == ':') { + if (*address == ':') { + if (double_colons >= 0) { + // :: allowed once + return false; + } + if (*address != '\0' && *(address + 1) == ':') { + // ::: not allowed + return false; + } + // remember location + double_colons = colons + !!acc; + address++; + } else if (*address == '\0') { + // can't end with a single colon + return false; + } + if (colons == 7) + // too many separators + return false; + _address.bytes[colons * 2] = acc >> 8; + _address.bytes[colons * 2 + 1] = acc & 0xff; + colons++; + acc = 0; + } + else + // Invalid char + return false; + } + + if (double_colons == -1 && colons != 7) { + // Too few separators + return false; + } + if (double_colons > -1 && colons > 6) { + // Too many segments (double colon must be at least one zero field) + return false; + } + _address.bytes[colons * 2] = acc >> 8; + _address.bytes[colons * 2 + 1] = acc & 0xff; + colons++; + + if (double_colons != -1) { + for (int i = colons * 2 - double_colons * 2 - 1; i >= 0; i--) + _address.bytes[16 - colons * 2 + double_colons * 2 + i] = _address.bytes[double_colons * 2 + i]; + for (int i = double_colons * 2; i < 16 - colons * 2 + double_colons * 2; i++) + _address.bytes[i] = 0; + } + + _type = IPv6; return true; } diff --git a/cores/esp32/IPAddress.h b/cores/esp32/IPAddress.h index 3bedd4f8749..329ca92afe9 100644 --- a/cores/esp32/IPAddress.h +++ b/cores/esp32/IPAddress.h @@ -26,13 +26,23 @@ // A class to make it easier to handle and pass around IP addresses +#define IPADDRESS_V4_BYTES_INDEX 12 +#define IPADDRESS_V4_DWORD_INDEX 3 + +enum IPType +{ + IPv4, + IPv6 +}; + class IPAddress: public Printable { private: union { - uint8_t bytes[4]; // IPv4 address - uint32_t dword; + uint8_t bytes[16]; + uint32_t dword[4]; } _address; + IPType _type; // Access the raw byte array containing the address. Because this returns a pointer // to the internal structure rather than a copy of the address this function should only @@ -40,57 +50,65 @@ class IPAddress: public Printable // stored. uint8_t* raw_address() { - return _address.bytes; + return _type == IPv4 ? &_address.bytes[IPADDRESS_V4_BYTES_INDEX] : _address.bytes; } public: // Constructors IPAddress(); + IPAddress(IPType ip_type); IPAddress(uint8_t first_octet, uint8_t second_octet, uint8_t third_octet, uint8_t fourth_octet); + IPAddress(uint8_t o1, uint8_t o2, uint8_t o3, uint8_t o4, uint8_t o5, uint8_t o6, uint8_t o7, uint8_t o8, uint8_t o9, uint8_t o10, uint8_t o11, uint8_t o12, uint8_t o13, uint8_t o14, uint8_t o15, uint8_t o16); IPAddress(uint32_t address); IPAddress(const uint8_t *address); + IPAddress(IPType ip_type, const uint8_t *address); + // If IPv4 fails tries IPv6 see fromString function + IPAddress(const char *address); virtual ~IPAddress() {} bool fromString(const char *address); bool fromString(const String &address) { return fromString(address.c_str()); } - // Overloaded cast operator to allow IPAddress objects to be used where a pointer - // to a four-byte uint8_t array is expected + // Overloaded cast operator to allow IPAddress objects to be used where a + // uint32_t is expected operator uint32_t() const { - return _address.dword; - } - bool operator==(const IPAddress& addr) const - { - return _address.dword == addr._address.dword; + return _type == IPv4 ? _address.dword[IPADDRESS_V4_DWORD_INDEX] : 0; } + + bool operator==(const IPAddress& addr) const; bool operator==(const uint8_t* addr) const; // Overloaded index operator to allow getting and setting individual octets of the address - uint8_t operator[](int index) const - { - return _address.bytes[index]; - } - uint8_t& operator[](int index) - { - return _address.bytes[index]; - } + uint8_t operator[](int index) const; + uint8_t& operator[](int index); // Overloaded copy operators to allow initialisation of IPAddress objects from other types IPAddress& operator=(const uint8_t *address); IPAddress& operator=(uint32_t address); + // If IPv4 fails tries IPv6 see fromString function + IPAddress& operator=(const char *address); virtual size_t printTo(Print& p) const; String toString() const; + IPType type() const { return _type; } + friend class EthernetClass; friend class UDP; friend class Client; friend class Server; friend class DhcpClass; friend class DNSClient; + +protected: + bool fromString4(const char *address); + bool fromString6(const char *address); + String toString4() const; + String toString6() const; }; // changed to extern because const declaration creates copies in BSS of INADDR_NONE for each CPP unit that includes it extern IPAddress INADDR_NONE; +extern IPAddress IN6ADDR_ANY; #endif diff --git a/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino b/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino index 9221af1eaa2..7e292e1adfb 100644 --- a/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino +++ b/libraries/DNSServer/examples/CaptivePortal/CaptivePortal.ino @@ -1,52 +1,59 @@ +/* +This example enables catch-all Captive portal for ESP32 Access-Point +It will allow modern devices/OSes to detect that WiFi connection is +limited and offer a user to access a banner web-page. +There is no need to find and open device's IP address/URL, i.e. http://192.168.4.1/ +This works for Android, Ubuntu, FireFox, Windows, maybe others... +*/ + +#include #include #include +#include + -const byte DNS_PORT = 53; -IPAddress apIP(8,8,4,4); // The default android DNS DNSServer dnsServer; -WiFiServer server(80); +WebServer server(80); + +static const char responsePortal[] = R"===( +ESP32 CaptivePortal +

Hello World!

This is a captive portal example page. All unknown http requests will +be redirected here.

+)==="; + +// index page handler +void handleRoot() { + server.send(200, "text/plain", "Hello from esp32!"); +} -String responseHTML = "" - "CaptivePortal" - "

Hello World!

This is a captive portal example. All requests will " - "be redirected here.

"; +// this will redirect unknown http req's to our captive portal page +// based on this redirect various systems could detect that WiFi AP has a captive portal page +void handleNotFound() { + server.sendHeader("Location", "/portal"); + server.send(302, "text/plain", "redirect to captive portal"); +} -void setup() { +void setup() { + Serial.begin(115200); WiFi.mode(WIFI_AP); WiFi.softAP("ESP32-DNSServer"); - WiFi.softAPConfig(apIP, apIP, IPAddress(255, 255, 255, 0)); - // if DNSServer is started with "*" for domain name, it will reply with - // provided IP to all DNS request - dnsServer.start(DNS_PORT, "*", apIP); + // by default DNSServer is started serving any "*" domain name. It will reply + // AccessPoint's IP to all DNS request (this is requred for Captive Portal detection) + dnsServer.start(); + + // serve a simple root page + server.on("/", handleRoot); + + // serve portal page + server.on("/portal",[](){server.send(200, "text/html", responsePortal);}); + // all unknown pages are redirected to captive portal + server.onNotFound(handleNotFound); server.begin(); } void loop() { - dnsServer.processNextRequest(); - WiFiClient client = server.available(); // listen for incoming clients - - if (client) { - String currentLine = ""; - while (client.connected()) { - if (client.available()) { - char c = client.read(); - if (c == '\n') { - if (currentLine.length() == 0) { - client.println("HTTP/1.1 200 OK"); - client.println("Content-type:text/html"); - client.println(); - client.print(responseHTML); - break; - } else { - currentLine = ""; - } - } else if (c != '\r') { - currentLine += c; - } - } - } - client.stop(); - } + server.handleClient(); + delay(5); // give CPU some idle time } diff --git a/libraries/DNSServer/src/DNSServer.cpp b/libraries/DNSServer/src/DNSServer.cpp index aaa3d33344b..a8114733460 100644 --- a/libraries/DNSServer/src/DNSServer.cpp +++ b/libraries/DNSServer/src/DNSServer.cpp @@ -1,6 +1,8 @@ #include "DNSServer.h" #include #include +#include + // #define DEBUG_ESP_DNS #ifdef DEBUG_ESP_PORT @@ -9,45 +11,37 @@ #define DEBUG_OUTPUT Serial #endif -DNSServer::DNSServer() -{ - _ttl = htonl(DNS_DEFAULT_TTL); - _errorReplyCode = DNSReplyCode::NonExistentDomain; - _dnsHeader = (DNSHeader*) malloc( sizeof(DNSHeader) ) ; - _dnsQuestion = (DNSQuestion*) malloc( sizeof(DNSQuestion) ) ; - _buffer = NULL; - _currentPacketSize = 0; - _port = 0; -} +#define DNS_MIN_REQ_LEN 17 // minimal size for DNS request asking ROOT = DNS_HEADER_SIZE + 1 null byte for Name + 4 bytes type/class -DNSServer::~DNSServer() -{ - if (_dnsHeader) { - free(_dnsHeader); - _dnsHeader = NULL; - } - if (_dnsQuestion) { - free(_dnsQuestion); - _dnsQuestion = NULL; - } - if (_buffer) { - free(_buffer); - _buffer = NULL; - } +DNSServer::DNSServer() : _port(DNS_DEFAULT_PORT), _ttl(htonl(DNS_DEFAULT_TTL)), _errorReplyCode(DNSReplyCode::NonExistentDomain){} + +DNSServer::DNSServer(const String &domainName) : _port(DNS_DEFAULT_PORT), _ttl(htonl(DNS_DEFAULT_TTL)), _errorReplyCode(DNSReplyCode::NonExistentDomain), _domainName(domainName){}; + + +bool DNSServer::start(){ + if (_resolvedIP.operator uint32_t() == 0){ // no address is set, try to obtain AP interface's IP + if (WiFi.getMode() & WIFI_AP){ + _resolvedIP = WiFi.softAPIP(); + } else return false; // won't run if WiFi is not in AP mode + } + + _udp.close(); + _udp.onPacket([this](AsyncUDPPacket& pkt){ this->_handleUDP(pkt); }); + return _udp.listen(_port); } -bool DNSServer::start(const uint16_t &port, const String &domainName, - const IPAddress &resolvedIP) -{ +bool DNSServer::start(uint16_t port, const String &domainName, const IPAddress &resolvedIP){ _port = port; - _buffer = NULL; - _domainName = domainName; - _resolvedIP[0] = resolvedIP[0]; - _resolvedIP[1] = resolvedIP[1]; - _resolvedIP[2] = resolvedIP[2]; - _resolvedIP[3] = resolvedIP[3]; - downcaseAndRemoveWwwPrefix(_domainName); - return _udp.begin(_port) == 1; + if (domainName != "*"){ + _domainName = domainName; + downcaseAndRemoveWwwPrefix(_domainName); + } else + _domainName.clear(); + + _resolvedIP = resolvedIP; + _udp.close(); + _udp.onPacket([this](AsyncUDPPacket& pkt){ this->_handleUDP(pkt); }); + return _udp.listen(_port); } void DNSServer::setErrorReplyCode(const DNSReplyCode &replyCode) @@ -62,9 +56,7 @@ void DNSServer::setTTL(const uint32_t &ttl) void DNSServer::stop() { - _udp.stop(); - free(_buffer); - _buffer = NULL; + _udp.close(); } void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName) @@ -73,151 +65,125 @@ void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName) domainName.replace("www.", ""); } -void DNSServer::processNextRequest() +void DNSServer::_handleUDP(AsyncUDPPacket& pkt) { - _currentPacketSize = _udp.parsePacket(); - if (_currentPacketSize) - { - // Allocate buffer for the DNS query - if (_buffer != NULL) - free(_buffer); - _buffer = (unsigned char*)malloc(_currentPacketSize * sizeof(char)); - if (_buffer == NULL) - return; + if (pkt.length() < DNS_MIN_REQ_LEN) return; // truncated packet or not a DNS req + + // get DNS header (beginning of message) + DNSHeader dnsHeader; + DNSQuestion dnsQuestion; + memcpy( &dnsHeader, pkt.data(), DNS_HEADER_SIZE ); + if (dnsHeader.QR != DNS_QR_QUERY) return; // ignore non-query mesages - // Put the packet received in the buffer and get DNS header (beginning of message) - // and the question - _udp.read(_buffer, _currentPacketSize); - memcpy( _dnsHeader, _buffer, DNS_HEADER_SIZE ) ; - if ( requestIncludesOnlyOneQuestion() ) + if ( requestIncludesOnlyOneQuestion(dnsHeader) ) { +/* // The QName has a variable length, maximum 255 bytes and is comprised of multiple labels. // Each label contains a byte to describe its length and the label itself. The list of // labels terminates with a zero-valued byte. In "github.com", we have two labels "github" & "com" - // Iterate through the labels and copy them as they come into a single buffer (for simplicity's sake) - _dnsQuestion->QNameLength = 0 ; - while ( _buffer[ DNS_HEADER_SIZE + _dnsQuestion->QNameLength ] != 0 ) - { - memcpy( (void*) &_dnsQuestion->QName[_dnsQuestion->QNameLength], (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength], _buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength] + 1 ) ; - _dnsQuestion->QNameLength += _buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength] + 1 ; - } - _dnsQuestion->QName[_dnsQuestion->QNameLength] = 0 ; - _dnsQuestion->QNameLength++ ; - - // Copy the QType and QClass - memcpy( &_dnsQuestion->QType, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength], sizeof(_dnsQuestion->QType) ) ; - memcpy( &_dnsQuestion->QClass, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength + sizeof(_dnsQuestion->QType)], sizeof(_dnsQuestion->QClass) ) ; +*/ + const char * enoflbls = strchr(reinterpret_cast(pkt.data()) + DNS_HEADER_SIZE, 0); // find end_of_label marker + ++enoflbls; // advance after null terminator + dnsQuestion.QName = pkt.data() + DNS_HEADER_SIZE; // we can reference labels from the request + dnsQuestion.QNameLength = enoflbls - (char*)pkt.data() - DNS_HEADER_SIZE; + /* + check if we aint going out of pkt bounds + proper dns req should have label terminator at least 4 bytes before end of packet + */ + if (dnsQuestion.QNameLength > pkt.length() - DNS_HEADER_SIZE - sizeof(dnsQuestion.QType) - sizeof(dnsQuestion.QClass)) return; // malformed packet + + // Copy the QType and QClass + memcpy( &dnsQuestion.QType, enoflbls, sizeof(dnsQuestion.QType) ); + memcpy( &dnsQuestion.QClass, enoflbls + sizeof(dnsQuestion.QType), sizeof(dnsQuestion.QClass) ); } - - if (_dnsHeader->QR == DNS_QR_QUERY && - _dnsHeader->OPCode == DNS_OPCODE_QUERY && - requestIncludesOnlyOneQuestion() && - (_domainName == "*" || getDomainNameWithoutWwwPrefix() == _domainName) + // will reply with IP only to "*" or if doman matches without www. subdomain + if (dnsHeader.OPCode == DNS_OPCODE_QUERY && + requestIncludesOnlyOneQuestion(dnsHeader) && + (_domainName.isEmpty() || + getDomainNameWithoutWwwPrefix(static_cast(dnsQuestion.QName), dnsQuestion.QNameLength) == _domainName) ) { - replyWithIP(); - } - else if (_dnsHeader->QR == DNS_QR_QUERY) - { - replyWithCustomCode(); + replyWithIP(pkt, dnsHeader, dnsQuestion); + return; } - free(_buffer); - _buffer = NULL; - } + // otherwise reply with custom code + replyWithCustomCode(pkt, dnsHeader); } -bool DNSServer::requestIncludesOnlyOneQuestion() +bool DNSServer::requestIncludesOnlyOneQuestion(DNSHeader& dnsHeader) { - return ntohs(_dnsHeader->QDCount) == 1 && - _dnsHeader->ANCount == 0 && - _dnsHeader->NSCount == 0 && - _dnsHeader->ARCount == 0; + return ntohs(dnsHeader.QDCount) == 1 && + dnsHeader.ANCount == 0 && + dnsHeader.NSCount == 0 && + dnsHeader.ARCount == 0; } -String DNSServer::getDomainNameWithoutWwwPrefix() +String DNSServer::getDomainNameWithoutWwwPrefix(const unsigned char* start, size_t len) { - // Error checking : if the buffer containing the DNS request is a null pointer, return an empty domain - String parsedDomainName = ""; - if (_buffer == NULL) - return parsedDomainName; - - // Set the start of the domain just after the header (12 bytes). If equal to null character, return an empty domain - unsigned char *start = _buffer + DNS_OFFSET_DOMAIN_NAME; - if (*start == 0) - { - return parsedDomainName; - } + String parsedDomainName(start, --len); // exclude trailing null byte from labels length, String constructor will add it anyway int pos = 0; - while(true) + while(posQR = DNS_QR_RESPONSE; - _dnsHeader->ANCount = _dnsHeader->QDCount; - _udp.write( (unsigned char*) _dnsHeader, DNS_HEADER_SIZE ) ; + dnsHeader.QR = DNS_QR_RESPONSE; + dnsHeader.ANCount = dnsHeader.QDCount; + rpl.write( (unsigned char*) &dnsHeader, DNS_HEADER_SIZE ) ; // Write the question - _udp.write(_dnsQuestion->QName, _dnsQuestion->QNameLength) ; - _udp.write( (unsigned char*) &_dnsQuestion->QType, 2 ) ; - _udp.write( (unsigned char*) &_dnsQuestion->QClass, 2 ) ; + rpl.write(dnsQuestion.QName, dnsQuestion.QNameLength) ; + rpl.write( (uint8_t*) &dnsQuestion.QType, 2 ) ; + rpl.write( (uint8_t*) &dnsQuestion.QClass, 2 ) ; // Write the answer // Use DNS name compression : instead of repeating the name in this RNAME occurence, // set the two MSB of the byte corresponding normally to the length to 1. The following // 14 bits must be used to specify the offset of the domain name in the message // (<255 here so the first byte has the 6 LSB at 0) - _udp.write((uint8_t) 0xC0); - _udp.write((uint8_t) DNS_OFFSET_DOMAIN_NAME); + rpl.write((uint8_t) 0xC0); + rpl.write((uint8_t) DNS_OFFSET_DOMAIN_NAME); // DNS type A : host address, DNS class IN for INternet, returning an IPv4 address uint16_t answerType = htons(DNS_TYPE_A), answerClass = htons(DNS_CLASS_IN), answerIPv4 = htons(DNS_RDLENGTH_IPV4) ; - _udp.write((unsigned char*) &answerType, 2 ); - _udp.write((unsigned char*) &answerClass, 2 ); - _udp.write((unsigned char*) &_ttl, 4); // DNS Time To Live - _udp.write((unsigned char*) &answerIPv4, 2 ); - _udp.write(_resolvedIP, sizeof(_resolvedIP)); // The IP address to return - _udp.endPacket(); + rpl.write((unsigned char*) &answerType, 2 ); + rpl.write((unsigned char*) &answerClass, 2 ); + rpl.write((unsigned char*) &_ttl, 4); // DNS Time To Live + rpl.write((unsigned char*) &answerIPv4, 2 ); + uint32_t ip = _resolvedIP; + rpl.write(reinterpret_cast(&ip), sizeof(uint32_t)); // The IPv4 address to return + + _udp.sendTo(rpl, req.remoteIP(), req.remotePort()); #ifdef DEBUG_ESP_DNS DEBUG_OUTPUT.printf("DNS responds: %s for %s\n", - IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix().c_str() ); + _resolvedIP.toString().c_str(), getDomainNameWithoutWwwPrefix(static_cast(dnsQuestion.QName), dnsQuestion.QNameLength).c_str() ); #endif } -void DNSServer::replyWithCustomCode() +void DNSServer::replyWithCustomCode(AsyncUDPPacket& req, DNSHeader& dnsHeader) { - _dnsHeader->QR = DNS_QR_RESPONSE; - _dnsHeader->RCode = (unsigned char)_errorReplyCode; - _dnsHeader->QDCount = 0; + dnsHeader.QR = DNS_QR_RESPONSE; + dnsHeader.RCode = static_cast(_errorReplyCode); + dnsHeader.QDCount = 0; - _udp.beginPacket(_udp.remoteIP(), _udp.remotePort()); - _udp.write((unsigned char*)_dnsHeader, sizeof(DNSHeader)); - _udp.endPacket(); + AsyncUDPMessage rpl(sizeof(DNSHeader)); + rpl.write(reinterpret_cast(&dnsHeader), sizeof(DNSHeader)); + _udp.sendTo(rpl, req.remoteIP(), req.remotePort()); } diff --git a/libraries/DNSServer/src/DNSServer.h b/libraries/DNSServer/src/DNSServer.h index 1250f5ce960..860507ac9ec 100644 --- a/libraries/DNSServer/src/DNSServer.h +++ b/libraries/DNSServer/src/DNSServer.h @@ -1,15 +1,15 @@ -#ifndef DNSServer_h -#define DNSServer_h -#include +#pragma once +#include #define DNS_QR_QUERY 0 #define DNS_QR_RESPONSE 1 #define DNS_OPCODE_QUERY 0 #define DNS_DEFAULT_TTL 60 // Default Time To Live : time interval in seconds that the resource record should be cached before being discarded -#define DNS_OFFSET_DOMAIN_NAME 12 // Offset in bytes to reach the domain name in the DNS message #define DNS_HEADER_SIZE 12 +#define DNS_OFFSET_DOMAIN_NAME DNS_HEADER_SIZE // Offset in bytes to reach the domain name labels in the DNS message +#define DNS_DEFAULT_PORT 53 -enum class DNSReplyCode +enum class DNSReplyCode:uint16_t { NoError = 0, FormError = 1, @@ -59,14 +59,14 @@ struct DNSHeader uint16_t Flags; }; uint16_t QDCount; // number of question entries - uint16_t ANCount; // number of answer entries + uint16_t ANCount; // number of ANswer entries uint16_t NSCount; // number of authority entries - uint16_t ARCount; // number of resource entries + uint16_t ARCount; // number of Additional Resource entries }; struct DNSQuestion { - uint8_t QName[256] ; //need 1 Byte for zero termination! + const uint8_t* QName; uint16_t QNameLength ; uint16_t QType ; uint16_t QClass ; @@ -75,36 +75,116 @@ struct DNSQuestion class DNSServer { public: + /** + * @brief Construct a new DNSServer object + * by default server is configured to run in "Captive-portal" mode + * it must be started with start() call to establish a listening socket + * + */ DNSServer(); - ~DNSServer(); - void processNextRequest(); + + /** + * @brief Construct a new DNSServer object + * builds DNS server with default parameters + * @param domainName - domain name to serve + */ + DNSServer(const String &domainName); + ~DNSServer(){}; // default d-tor + + // Copy semantics not implemented (won't run on same UDP port anyway) + DNSServer(const DNSServer&) = delete; + DNSServer& operator=(const DNSServer&) = delete; + + + /** + * @brief stub, left for compatibility with an old version + * does nothing actually + * + */ + void processNextRequest(){}; + + /** + * @brief Set the Error Reply Code for all req's not matching predifined domain + * + * @param replyCode + */ void setErrorReplyCode(const DNSReplyCode &replyCode); + + /** + * @brief set TTL for successfull replies + * + * @param ttl in seconds + */ void setTTL(const uint32_t &ttl); - // Returns true if successful, false if there are no sockets available - bool start(const uint16_t &port, + /** + * @brief (re)Starts a server with current configuration or with default parameters + * if it's the first call. + * Defaults are: + * port: 53 + * domainName: any + * ip: WiFi AP's IP address + * + * @return true on success + * @return false if IP or socket error + */ + bool start(); + + /** + * @brief (re)Starts a server with provided configuration + * + * @return true on success + * @return false if IP or socket error + */ + bool start(uint16_t port, const String &domainName, const IPAddress &resolvedIP); - // stops the DNS server + + /** + * @brief stops the server and close UDP socket + * + */ void stop(); + /** + * @brief returns true if DNS server runs in captive-portal mode + * i.e. all requests are served with AP's ip address + * + * @return true if catch-all mode active + * @return false otherwise + */ + inline bool isCaptive() const { return _domainName.isEmpty(); }; + + /** + * @brief returns 'true' if server is up and UDP socket is listening for UDP req's + * + * @return true if server is up + * @return false otherwise + */ + inline bool isUp() { return _udp.connected(); }; + private: - WiFiUDP _udp; + AsyncUDP _udp; uint16_t _port; - String _domainName; - unsigned char _resolvedIP[4]; - int _currentPacketSize; - unsigned char* _buffer; - DNSHeader* _dnsHeader; uint32_t _ttl; DNSReplyCode _errorReplyCode; - DNSQuestion* _dnsQuestion ; + String _domainName; + IPAddress _resolvedIP; void downcaseAndRemoveWwwPrefix(String &domainName); - String getDomainNameWithoutWwwPrefix(); - bool requestIncludesOnlyOneQuestion(); - void replyWithIP(); - void replyWithCustomCode(); + + /** + * @brief Get the Domain Name Without Www Prefix object + * scan labels in DNS packet and build a string of a domain name + * truncate any www. label if found + * @param start a pointer to the start of labels records in DNS packet + * @param len labels length + * @return String + */ + String getDomainNameWithoutWwwPrefix(const unsigned char* start, size_t len); + inline bool requestIncludesOnlyOneQuestion(DNSHeader& dnsHeader); + void replyWithIP(AsyncUDPPacket& req, DNSHeader& dnsHeader, DNSQuestion& dnsQuestion); + inline void replyWithCustomCode(AsyncUDPPacket& req, DNSHeader& dnsHeader); + void _handleUDP(AsyncUDPPacket& pkt); }; -#endif diff --git a/libraries/WebServer/src/WebServer.cpp b/libraries/WebServer/src/WebServer.cpp index 0d99f3680a1..ea9b4d5692e 100644 --- a/libraries/WebServer/src/WebServer.cpp +++ b/libraries/WebServer/src/WebServer.cpp @@ -135,7 +135,7 @@ bool WebServer::authenticate(const char * username, const char * password){ authReq = authReq.substring(6); authReq.trim(); char toencodeLen = strlen(username)+strlen(password)+1; - char *toencode = (char *)malloc[toencodeLen + 1]; + char *toencode = (char *)malloc(toencodeLen + 1); if(toencode == NULL){ authReq = ""; return false;