diff --git a/README.markdown b/README.markdown index 28b062a..ad809cc 100644 --- a/README.markdown +++ b/README.markdown @@ -125,7 +125,7 @@ It accepts a `opts` table argument. The following options are supported: * `nameservers` - a list of nameservers to be used. Each nameserver entry can be either a single hostname string or a table holding both the hostname string and the port number. The nameserver is picked up by a simple round-robin algorithm for each `query` method call. This option is required. + a list of nameservers to be used. Each nameserver entry can be either a single hostname string, DoH url with optional port or a table holding both the hostname string and the port number. The nameserver is picked up by a simple round-robin algorithm for each `query` method call. This option is required. * `retrans` the total number of times of retransmitting the DNS request when receiving a DNS response times out according to the `timeout` setting. Defaults to `5` times. When trying to retransmit the query, the next nameserver according to the round-robin algorithm will be picked up. @@ -138,7 +138,17 @@ It accepts a `opts` table argument. The following options are supported: * `no_random` a boolean flag controls whether to randomly pick the nameserver to query first, if `true` will always start with the first nameserver listed. Defaults to `false`. +* `doh` + a boolean flag controls whether to use DNS over Https (DoH) + +* `doh_method` + + type of DoH query possible values are `POST` or `GET` or boolean false, Defaults to nil. +* `doh_bootstrap` + + list of nameservers used to perform initial query for IP address of DoH servers. + [Back to TOC](#table-of-contents) query diff --git a/lib/resty/dns/resolver.lua b/lib/resty/dns/resolver.lua index a67b3c1..351ff5f 100644 --- a/lib/resty/dns/resolver.lua +++ b/lib/resty/dns/resolver.lua @@ -2,8 +2,29 @@ -- local socket = require "socket" -local bit = require "bit" + +local ok, b64 = pcall(require,"ngx.base64") +if not ok then + return false +end + +local ok, bit = pcall(require, "bit") +if not ok then + return false +end + +local ok, wire = pcall(require, "resty.dns.wireformat") +if not ok then + return false +end + +local ok, new_tab = pcall(require, "table.new") +if not ok then + new_tab = function (narr, nrec) return {} end +end + local udp = ngx.socket.udp +local tcp = ngx.socket.tcp local rand = math.random local char = string.char local byte = string.byte @@ -12,832 +33,933 @@ local gsub = string.gsub local sub = string.sub local rep = string.rep local format = string.format -local band = bit.band -local rshift = bit.rshift -local lshift = bit.lshift local insert = table.insert local concat = table.concat local re_sub = ngx.re.sub -local tcp = ngx.socket.tcp +local re_match = ngx.re.match +local re_find = ngx.re.find local log = ngx.log local DEBUG = ngx.DEBUG local unpack = unpack local setmetatable = setmetatable local type = type local ipairs = ipairs +local agent = "ngx_lua/" .. ngx.config.ngx_lua_version +local str_lower = string.lower +local tolower = string.lower +local ngx_get = ngx.HTTP_GET +local ngx_post = ngx.HTT_POST +local band = bit.band +local wire_build = wire.build_request +local wire_parse = wire.parse_response +local bit = require "bit" +local band = bit.band +local rshift = bit.rshift +local lshift = bit.lshift - -local ok, new_tab = pcall(require, "table.new") -if not ok then - new_tab = function (narr, nrec) return {} end -end - - -local DOT_CHAR = byte(".") -local ZERO_CHAR = byte("0") -local COLON_CHAR = byte(":") +local arpa_tmpl = new_tab(72, 0) local IP6_ARPA = "ip6.arpa" -local TYPE_A = 1 -local TYPE_NS = 2 -local TYPE_CNAME = 5 -local TYPE_SOA = 6 -local TYPE_PTR = 12 -local TYPE_MX = 15 -local TYPE_TXT = 16 -local TYPE_AAAA = 28 -local TYPE_SRV = 33 -local TYPE_SPF = 99 - -local CLASS_IN = 1 +for i = 1, #IP6_ARPA do + arpa_tmpl[64 + i] = byte(IP6_ARPA, i) +end -local SECTION_AN = 1 -local SECTION_NS = 2 -local SECTION_AR = 3 +for i = 2, 64, 2 do + arpa_tmpl[i] = DOT_CHAR +end +local COLON_CHAR = byte(":") local _M = { _VERSION = '0.22', - TYPE_A = TYPE_A, - TYPE_NS = TYPE_NS, - TYPE_CNAME = TYPE_CNAME, - TYPE_SOA = TYPE_SOA, - TYPE_PTR = TYPE_PTR, - TYPE_MX = TYPE_MX, - TYPE_TXT = TYPE_TXT, - TYPE_AAAA = TYPE_AAAA, - TYPE_SRV = TYPE_SRV, - TYPE_SPF = TYPE_SPF, - CLASS_IN = CLASS_IN, - SECTION_AN = SECTION_AN, - SECTION_NS = SECTION_NS, - SECTION_AR = SECTION_AR + TYPE_A = wire.TYPE.A, + TYPE_NS = wire.TYPE.NS, + TYPE_CNAME = wire.TYPE.CNAME, + TYPE_SOA = wire.TYPE.SOA, + TYPE_PTR = wire.TYPE.PTR, + TYPE_MX = wire.TYPE.MX, + TYPE_TXT = wire.TYPE.TXT, + TYPE_AAAA = wire.TYPE.AAAA, + TYPE_SRV = wire.TYPE.SRV, + TYPE_SPF = wire.TYPE.SPF, + CLASS_IN = wire.CLASS.IN, + SECTION_AN = wire.SECTION.AN, + SECTION_NS = wire.SECTION.NS, + SECTION_AR = wire.SECTION.AR, + MODE = { + UDP = 1, + TCP = 2, + UDP_TCP = 3, + DOT = 4, + DOH = 8 + } } +local MODE_UDP = _M.MODE.UDP +local MODE_TCP = _M.MODE.TCP +local MODE_DOT = _M.MODE.DOT +local MODE_DOH = _M.MODE.DOH +local MODE_UDP_TCP = _M.MODE.UDP_TCP -local resolver_errstrs = { - "format error", -- 1 - "server failure", -- 2 - "name error", -- 3 - "not implemented", -- 4 - "refused", -- 5 +local DOH_METHOD = { + GET = ngx_get, + POST = ngx_post } -local soa_int32_fields = { "serial", "refresh", "retry", "expire", "minimum" } - -local mt = { __index = _M } +local function _is_ip(str) + if type(str) ~= "string" then + return false + end + + local ret, err = re_match(str,"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)") + + if ret then + return true + end + + return false, err +end +local function _gen_id() -- (self) + --local id = self._id -- for regression testing + --if id then + -- return id + --end + return rand(0, 65535) -- two bytes +end -local arpa_tmpl = new_tab(72, 0) +----------------------[[ private implementation ]]---------------------------------------- -for i = 1, #IP6_ARPA do - arpa_tmpl[64 + i] = byte(IP6_ARPA, i) +local function _build_wire_request(self, qname, id, no_recurse, opts) + return wire_build(qname,id,no_recurse,opts) end -for i = 2, 64, 2 do - arpa_tmpl[i] = DOT_CHAR +local function _parse_wire_response(self, data, id, opts) + return wire_parse(data, id, opts) end +local function _build_post_wire_request(self, qname, id, no_recurse, opts) + local opts = { + method = ngx_post, + body = wire_build(qname, id, self.no_recurse, opts) + } + + return id, opts +end -function _M.new(class, opts) - if not opts then - return nil, "no options table specified" - end +local function _build_post_json_request(self, qname, id, no_recurse, opts) + local opts = {} + + return opts +end - local servers = opts.nameservers - if not servers or #servers == 0 then - return nil, "no nameservers specified" - end +local function _build_get_json_request(self, qname, id, no_recurse, opts) + local opts = { + method = ngx_get, + param = self.doh_encode and b64.encode_base64url(qname) or qname + } + + return opts +end - local timeout = opts.timeout or 2000 -- default 2 sec +local function _build_get_wire_request(self, qname, id, no_recurse, opts) + local opts = { + method = ngx_get, + param = self.doh_encode and b64.encode_base64url(qname) or qname + } + + return opts +end - local n = #servers +local function _parse_doh_wire_response(self, qname, id, no_recurse, opts) + + +end - local socks = {} +local function _parse_doh_json_response(self, qname, id, no_recurse, opts) + +end - for i = 1, n do - local server = servers[i] - local sock, err = udp() - if not sock then - return nil, "failed to create udp socket: " .. err - end - local host, port - if type(server) == 'table' then - host = server[1] - port = server[2] or 53 +--[[ sockets implementation ]]-- - else - host = server - port = 53 - servers[i] = {host, port} - end +local function _sock_write(self, data) + return self.fd:send(data) +end - local ok, err = sock:setpeername(host, port) - if not ok then - return nil, "failed to set peer name: " .. err - end +local function _sock_read(self) + return self.fd:receive() +end - sock:settimeout(timeout) +local function _sock_close(self) + return self.fd:close() +end - insert(socks, sock) - end +local function _sock_settimeout(self,timeout) + return self:settimeout(timeout) +end - local tcp_sock, err = tcp() - if not tcp_sock then - return nil, "failed to create tcp socket: " .. err +--[[ udp stream implementation ]]-- +local function _udp_open(self, host, port, opts, ip) + local fd = udp() + local addr = ip or self.ip or host or self.host + local pnum = port or self.port + local setts = opts or self.opts + local timeout = setts and setts.timeout or 2000 + + fd:settimeout(timeout) + + local ok, err = fd:setpeername(addr,pnum) + if not ok then + return false, err end - - tcp_sock:settimeout(timeout) - - return setmetatable( - { cur = opts.no_random and 1 or rand(1, n), - socks = socks, - tcp_sock = tcp_sock, - servers = servers, - retrans = opts.retrans or 5, - no_recurse = opts.no_recurse, - }, mt) + + self.fd = fd + --self.id = _gen_id() + + return true end -local function pick_sock(self, socks) - local cur = self.cur +local _udp_stream_mt = { + open = _udp_open, --_udp_cached_open(), + read = _sock_read, + write = _sock_write, + close = _sock_close, + settimeout = _sock_settimeout +} - if cur == #socks then - self.cur = 1 - else - self.cur = cur + 1 +--[[ tcp stream implementation ]]-- +local function _tcp_open(self, host, port, opts, ip) + local fd = tcp() + local addr = ip or self.ip or host or self.host + local pnum = port or self.port + local setts = opts or self.opts + local timeout = setts and setts.timeout or 2000 + + fd:settimeout(timeout) + + local ok, err = fd:connect(addr,pnum) + if not ok then + return false, err end + + self.fd = fd + --self.id = _gen_id() + + return true +end - return socks[cur] +local function _tcp_sock_write(self, data) + local query = concat(data,'') + local len = #query + local len_hi = char(rshift(len, 8)) + local len_lo = char(band(len, 0xff)) + + return self.fd:send({len_hi, len_lo, query}) end +local function _tcp_sock_read(self) + local buf, err = self.fd:receive(2) + local len_hi = byte(buf, 1) + local len_lo = byte(buf, 2) + local len = lshift(len_hi, 8) + len_lo + + return self.fd:receive(len) +end -local function _get_cur_server(self) - local cur = self.cur +local _tcp_stream_mt = { + open = _tcp_open, + read = _tcp_sock_read, + write = _tcp_sock_write, + close = _sock_close, + settimeout = _sock_settimeout +} - local servers = self.servers +-------------[[ encrypted ssl/tls tcp stream ]]------------ - if cur == 1 then - return servers[#servers] +local function _enc_tcp_open(self, host, port, opts, ip) + local addr = ip or host + local ok, err = _tcp_open(self,addr,port,opts,ip) + if not ok then + return false, err end - - return servers[cur - 1] + + if self.fd:getreusedtimes() == 0 then + local session, err = self.fd:sslhandshake(nil,host) + if not session then + return false, err + end + end + + return true end +local _enc_tcp_stream_mt = { + open = _enc_tcp_open, + read = _sock_read, + write = _sock_write, + close = _sock_close, + settimeout = _sock_settimeout +} -function _M.set_timeout(self, timeout) - local socks = self.socks - if not socks then - return nil, "not initialized" - end +-----------------[[ streams ]]--------------------------- - for i = 1, #socks do - local sock = socks[i] - sock:settimeout(timeout) - end +local function _new_stream_int(class,mt) + return setmetatable({ + host = class.host, + port = class.port, + ip = class.ip + },{ __index = mt}) +end - local tcp_sock = self.tcp_sock - if not tcp_sock then - return nil, "not initialized" - end - tcp_sock:settimeout(timeout) +local function _new_udp_stream(class) + return _new_stream_int(class,_udp_stream_mt) end -local function _encode_name(s) - return char(#s) .. s +local function _new_tcp_stream(class) + return _new_stream_int(class,_tcp_stream_mt) end -local function _decode_name(buf, pos) - local labels = {} - local nptrs = 0 - local p = pos - while nptrs < 128 do - local fst = byte(buf, p) - - if not fst then - return nil, 'truncated'; - end +local function _new_enc_stream(class) + return _new_stream_int(class,_enc_tcp_stream_mt) +end - -- print("fst at ", p, ": ", fst) - if fst == 0 then - if nptrs == 0 then - pos = pos + 1 - end - break - end +local _udp_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, --(buf, id, opts), + stream = _new_udp_stream +} - if band(fst, 0xc0) ~= 0 then - -- being a pointer - if nptrs == 0 then - pos = pos + 2 - end +local _tcp_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, --(buf, id, opts), + stream = _new_tcp_stream +} - nptrs = nptrs + 1 +local _udp_tcp_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, --(buf, id, opts), + stream = _new_udp_stream +} - local snd = byte(buf, p + 1) - if not snd then - return nil, 'truncated' - end +local _dot_pimpl = { + build = _build_wire_request, -- (qname, id, no_recurse, opts) + parse = _parse_wire_response, -- (buf, id, opts), + stream = _new_enc_stream +} - p = lshift(band(fst, 0x3f), 8) + snd + 1 +local _doh_wire_get_pimpl = { + build = _build_get_wire_request, + parse = _parse_wire_response, + stream = _new_enc_stream +} - -- print("resolving ptr ", p, ": ", byte(buf, p)) +local _doh_json_get_pimpl = { + build = _build_get_json_request, + parse = _parse_json_response, + stream = _new_enc_stream +} - else - -- being a label - local label = sub(buf, p + 1, p + fst) - insert(labels, label) +local _doh_wire_post_pimpl = { + build = _build_post_wire_request, + parse = _parse_doh_wire_response, + stream = _new_enc_stream +} - -- print("resolved label ", label) +local doh_json_post_pimpl = { + build = _build_post_json_request, + parse = _parse_doh_json_response, + stream = _new_enc_stream +} - p = p + fst + 1 +---------------------------[[ server parsers ]]------------------------------- - if nptrs == 0 then - pos = p - end - end +local function _udp_tcp_server_parser_int(server, opts, pimpl, mode) + local host, port + + if type(server) == 'table' then + host = server[1] + port = server[2] or 53 + else + host = server + port = 53 end - return concat(labels, "."), pos + return setmetatable({ + host = host, + port = port, + mode = mode + }, { __index = pimpl }) end -local function _build_request(qname, id, no_recurse, opts) - local qtype +local function _udp_server_parser(server, opts) + return _udp_tcp_server_parser_int(server,opts,_udp_pimpl, MODE_UDP) +end - if opts then - qtype = opts.qtype - end - if not qtype then - qtype = 1 -- A record - end +local function _tcp_server_parser(server, opts) + return _udp_tcp_server_parser_int(server,opts,_tcp_pimpl, MODE_TCP) +end - local ident_hi = char(rshift(id, 8)) - local ident_lo = char(band(id, 0xff)) - local flags - if no_recurse then - -- print("found no recurse") - flags = "\0\0" - else - flags = "\1\0" - end +local function _udp_tcp_server_parser(server, opts) + return _udp_tcp_server_parser_int(server,opts,_udp_pimpl, MODE_UDP_TCP) +end - local nqs = "\0\1" - local nan = "\0\0" - local nns = "\0\0" - local nar = "\0\0" - local typ = char(rshift(qtype, 8), band(qtype, 0xff)) - local class = "\0\1" -- the Internet class - if byte(qname, 1) == DOT_CHAR then - return nil, "bad name" +local function _dot_server_parser(server, opts) + local res, err = _tcp_servers_parser(server,opts) + if not res then + return nil, err end - - local name = gsub(qname, "([^.]+)%.?", _encode_name) .. '\0' - - return { - ident_hi, ident_lo, flags, nqs, nan, nns, nar, - name, typ, class - } + + return setmetatable({ + host = host, + port = port, + mode = MODE_DOT + }, { __index = _dot_pimpl }) end -local function parse_section(answers, section, buf, start_pos, size, - should_skip) - local pos = start_pos - - for _ = 1, size do - -- print(format("ans %d: qtype:%d qclass:%d", i, qtype, qclass)) - local ans = {} - - if not should_skip then - insert(answers, ans) - end - - ans.section = section - - local name - name, pos = _decode_name(buf, pos) - if not name then - return nil, pos - end - - ans.name = name - - -- print("name: ", name) - - local type_hi = byte(buf, pos) - local type_lo = byte(buf, pos + 1) - local typ = lshift(type_hi, 8) + type_lo - - ans.type = typ - - -- print("type: ", typ) - - local class_hi = byte(buf, pos + 2) - local class_lo = byte(buf, pos + 3) - local class = lshift(class_hi, 8) + class_lo - - ans.class = class - - -- print("class: ", class) - - local byte_1, byte_2, byte_3, byte_4 = byte(buf, pos + 4, pos + 7) - - local ttl = lshift(byte_1, 24) + lshift(byte_2, 16) - + lshift(byte_3, 8) + byte_4 - - -- print("ttl: ", ttl) - - ans.ttl = ttl - - local len_hi = byte(buf, pos + 8) - local len_lo = byte(buf, pos + 9) - local len = lshift(len_hi, 8) + len_lo - - -- print("record len: ", len) +local function _doh_server_parser(server, opts) + local method = (type(server) == 'table') and server.method or 'GET' + method = DOH_METHOD[method] + if not method then + return false, "invalid DoH mode specified" + end + + local res, err = _tcp_server_parser(server, opts) + if not res then + return false, err + end + + local url + local method + local ct + local ac + + if type(server) == 'table' then + url = server[1] or server.url + method = server[2] or server.method or ngx_get + ct = server[3] or server.ct or 'application/dns-message' + ac = server[4] or server.ac or 'application/dns-message' + else + url = server + method = ngx_get + ct = 'application/dns-message' + ac = 'application/dns-message' + end + + local captures, err = re_match(url,"^((https?)(://))?([A-Za-z0-9\\.-]+)(:[1-9][0-9]*)?(/.+)$") + if not captures then + return false, err + end + + local host = captures[4] + local ssl = (captures[1] == 'https://') and true or false + local port + + if captures[5] then + port = tonumber(sub(captures[5],2)) + elseif not ssl then + port = 80 + else + port = 443 + end + + if not port then + return false, "invalid port specified" + end + + local hoststr + if (ssl and port ~= 443) or (not ssl and port ~= 80) then + hoststr = host..":"..port + else + hoststr = host + end + + local query = { + 'Host: '..hoststr..'\r\n', + 'User-Agent: '..agent..'\r\n', + 'Connection: keep-alive'..'\r\n', + 'Accept: '..ac..'\r\n' + } - pos = pos + 10 + insert(query,(method == ngx_post) and 'Content-Type: '..ct..'\r\n' or "\r\n") - if typ == TYPE_A then + return setmetatable({ + host = host, + port = port, + url = captures[6], + ssl = ssl, + query = query, + mode = MODE_DOH + }, { __index = _doh_pimpl }) +end - if len ~= 4 then - return nil, "bad A record value length: " .. len - end +local _server_parser_tbl = { + { MODE_UDP, _udp_server_parser }, + { MODE_TCP, _tcp_server_parser }, + { MODE_UDP_TCP, _udp_tcp_server_parser }, + { MODE_DOT, _dot_server_parser }, + { MODE_DOH, _doh_server_parser } +} - local addr_bytes = { byte(buf, pos, pos + 3) } - local addr = concat(addr_bytes, ".") - -- print("ipv4 address: ", addr) +----------------------------[[ servers array ]]------------------------------------ - ans.address = addr +local function _servers_at(self, at) + return self.servers[at] +end - pos = pos + 4 - elseif typ == TYPE_CNAME then +local function _servers_size(self) + return #self.servers +end - local cname, p = _decode_name(buf, pos) - if not cname then - return nil, pos - end - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) +local function _servers_array_new(opts) + log(ngx.ERR,"NEW SERVER ARRAY") + + local servers = opts.nameservers + local pservers = {} + local n = #servers + local pn = #_server_parser_tbl + + for i = 1, n do + local server = servers[i] + local mode = (type(server) == 'table') and server.mode or MODE_UDP_TCP + local pserver_tbl, err + + for k = 1, pn do + local f_tbl = _server_parser_tbl[k] + + if f_tbl[1] == mode then + local pserver_tbl, err = f_tbl[2](server, opts) + + if not pserver_tbl then + return nil, "failed to create server at: "..k.."with error: "..err + end + + insert(pservers,pserver_tbl) + break end + end + + end + + local servers_mt = { + size = _servers_size, + at = _servers_at + } + + local servers_tbl = { + current = no_random and 1 or rand(1, #servers), + servers = pservers + } + + return setmetatable(servers_tbl, { __index = servers_mt }) +end - pos = p - - -- print("cname: ", cname) - - ans.cname = cname - - elseif typ == TYPE_AAAA then +-------------------------------------------------------------------------------- - if len ~= 16 then - return nil, "bad AAAA record value length: " .. len - end +local function _generic_query() - local addr_bytes = { byte(buf, pos, pos + 15) } - local flds = {} - for i = 1, 16, 2 do - local a = addr_bytes[i] - local b = addr_bytes[i + 1] - if a == 0 then - insert(flds, format("%x", b)) +end +local function answers__to_string(self) + local ret ='' + for k,v in pairs(self) do + local typ = type(v) + if typ ~= 'function' then + if typ == 'table' then + ret = ret..'[\r\n' + ret = ret..answers__to_string(v) + ret = ret..']\r\n' else - insert(flds, format("%x%02x", a, b)) + ret = ret..k..': '..v..'\r\n' end end + end + return ret + end - -- we do not compress the IPv6 addresses by default - -- due to performance considerations - - ans.address = concat(flds, ":") - - pos = pos + 16 - - elseif typ == TYPE_MX then - - -- print("len = ", len) - - if len < 3 then - return nil, "bad MX record value length: " .. len - end - - local pref_hi = byte(buf, pos) - local pref_lo = byte(buf, pos + 1) - - ans.preference = lshift(pref_hi, 8) + pref_lo - - local host, p = _decode_name(buf, pos + 2) - if not host then - return nil, pos - end - - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) - end - - ans.exchange = host - - pos = p +local answers_mt = { + __tostring = answers__to_string +} - elseif typ == TYPE_SRV then - if len < 7 then - return nil, "bad SRV record value length: " .. len - end +--[[ +Perform DNS TCP query over connected socket +]] +local function _tcp_query(self, server, qname, no_recurse, opts) + if server == nil or qname == nil then + return nil, 'invalid arguments', nil + end + + local srv, err = _tcp_server_parser(server, opts) + if srv == nil then + return nil, err, nil + end + + local id = _gen_id() + local query, err = srv:build(qname,id,no_recurse, opts) + if query == nil then + return nil, err, nil + end + + local stream, err = srv:stream() + if stream == nil then + return nil, err, nil + end - local prio_hi = byte(buf, pos) - local prio_lo = byte(buf, pos + 1) - ans.priority = lshift(prio_hi, 8) + prio_lo + local ok, err = stream:open() + if not ok then + return nil, err, nil + end + + local bytes, err = stream:write(query) + if not bytes then + return nil, "failed to send query to TCP server " + .. stream.host .. ":" .. stream.port .. ": " .. err, nil + end - local weight_hi = byte(buf, pos + 2) - local weight_lo = byte(buf, pos + 3) - ans.weight = lshift(weight_hi, 8) + weight_lo + local buf, err = stream:read() + if not buf then + return nil, "failed to receive the reply length field from TCP server " + .. stream.host, ":" .. stream.port.. ": " .. err, {} + end - local port_hi = byte(buf, pos + 4) - local port_lo = byte(buf, pos + 5) - ans.port = lshift(port_hi, 8) + port_lo + local answers, err = srv:parse(buf,id) + if not answers then + return nil, err + end + + return setmetatable(answers,answers_mt), nil, {} +end - local name, p = _decode_name(buf, pos + 6) - if not name then - return nil, pos - end +local function _udp_query(self, server, qname, no_recurse, opts) + if server == nil or qname == nil then + return nil, 'invalid arguments', nil + end + + local srv, err = _udp_server_parser(server, opts) + if srv == nil then + return nil, err, nil + end + + local id = _gen_id() + local query, err = srv:build(qname,id,no_recurse, opts) + if query == nil then + return nil, err, nil + end + + local stream, err = srv:stream() + if stream == nil then + return nil, err, nil + end - if p - pos ~= len then - return nil, format("bad srv record length: %d ~= %d", - p - pos, len) - end + local ok, err = stream:open() + if not ok then + return nil, err, nil + end + + local bytes, err = stream:write(query) + if not bytes then + return nil, "failed to send query to UDP server " + .. stream.host .. ":" .. stream.port .. ": " .. err, nil + end - ans.target = name + local buf, err = stream:read() + if not buf then + return nil, "failed to receive the reply UDP server " + .. stream.host, ":" .. stream.port.. ": " .. err, {} + end - pos = p + local answers, err = srv:parse(buf,id) + if not answers then + return nil, err + end - elseif typ == TYPE_NS then + return setmetatable(answers,answers_mt) +end - local name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) - end +local function _dot_query(self, server, qname, no_recurse, opts) + if server == nil or qname == nil then + return nil, 'invalid arguments', nil + end + + local srv, err = _dot_server_parser(server, opts) + if srv == nil then + return nil, err, nil + end + + local id = _gen_id() + local query, err = srv:build(qname,id,no_recurse, opts) + if query == nil then + return nil, err, nil + end + + local stream, err = srv:stream() + if stream == nil then + return nil, err, nil + end - pos = p + local ok, err = stream:open() + if not ok then + return nil, err, nil + end + + local bytes, err = stream:write(query) + if not bytes then + return nil, "failed to send query to DoT server " + .. stream.host .. ":" .. stream.port .. ": " .. err, nil + end - -- print("name: ", name) + local buf, err = stream:read() + if not buf then + return nil, "failed to receive the reply DoT server " + .. stream.host, ":" .. stream.port.. ": " .. err, {} + end - ans.nsdname = name + local answers, err = srv:parse(buf,id) + if not answers then + return nil, err + end - elseif typ == TYPE_TXT or typ == TYPE_SPF then + return setmetatable(answers,answers_mt) +end - local key = (typ == TYPE_TXT) and "txt" or "spf" - local slen = byte(buf, pos) - if slen + 1 > len then - -- truncate the over-run TXT record data - slen = len - end +--[[ local function _http_connect(sock,host) + local ok, err = sock:connect(host[1], host[2]) + if not ok then + return nil, "failed to connect to HTTP server " + .. host[1] .. ":" .. host[2] .. ": " .. err + end - -- print("slen: ", len) - - local val = sub(buf, pos + 1, pos + slen) - local last = pos + len - pos = pos + slen + 1 - - if pos < last then - -- more strings to be processed - -- this code path is usually cold, so we do not - -- merge the following loop on this code path - -- with the processing logic above. - - val = {val} - local idx = 2 - repeat - local slen = byte(buf, pos) - if pos + slen + 1 > last then - -- truncate the over-run TXT record data - slen = last - pos - 1 - end + if host[4] and sock:getreusedtimes() == 0 then + local session, err = sock:sslhandshake(nil,host[1]) + if not session then + return nil, err + end + end - val[idx] = sub(buf, pos + 1, pos + slen) - idx = idx + 1 - pos = pos + slen + 1 + return sock +end +]]-- - until pos >= last - end +local function _http_status_receive(sock) + local line, err, partial = sock:receive("*l") + if not line then + return nil, nil, nil, "failed to read http header status line: "..err + end - ans[key] = val + local ret, err = re_match(line,"(HTTP/[0-3](\\.[0-1])?) ([1-5][0-9]{2}) ([A-Za-z ]+)") - elseif typ == TYPE_PTR then + if not ret then + return nil, nil, nil, "failed to parse http status with error: "..err + end - local name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end + return ret[1], tonumber(ret[3]), ret[4] +end - if p - pos ~= len then - return nil, format("bad cname record length: %d ~= %d", - p - pos, len) - end - pos = p +local function _http_header_receive(sock) + local ret = {} - -- print("name: ", name) + repeat + local line, err = sock:receive("*l") + if not line then + return nil, err + end - ans.ptrdname = name + local m, err = re_match(line, "([^:\\s]+):\\s*(.*)", "jo") + if err then log(DEBUG, err) end - elseif typ == TYPE_SOA then - local name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end - ans.mname = name + if not m then + break + end - pos = p - name, p = _decode_name(buf, pos) - if not name then - return nil, pos - end - ans.rname = name + local key = string.lower(m[1]) + local val = m[2] - for _, field in ipairs(soa_int32_fields) do - local byte_1, byte_2, byte_3, byte_4 = byte(buf, p, p + 3) - ans[field] = lshift(byte_1, 24) + lshift(byte_2, 16) - + lshift(byte_3, 8) + byte_4 - p = p + 4 + if ret[key] then + if type(ret[key]) ~= "table" then + ret[key] = { ret[key] } end - - pos = p - + insert(ret[key], tostring(val)) else - -- for unknown types, just forward the raw value - - ans.rdata = sub(buf, pos, pos + len - 1) - pos = pos + len + ret[key] = tostring(val) end - end + until re_find(line, "^\\s*$", "jo") - return pos + return ret end -local function parse_response(buf, id, opts) - local n = #buf - if n < 12 then - return nil, 'truncated'; - end - - -- header layout: ident flags nqs nan nns nar - - local ident_hi = byte(buf, 1) - local ident_lo = byte(buf, 2) - local ans_id = lshift(ident_hi, 8) + ident_lo - - -- print("id: ", id, ", ans id: ", ans_id) - - if ans_id ~= id then - -- identifier mismatch and throw it away - log(DEBUG, "id mismatch in the DNS reply: ", ans_id, " ~= ", id) - return nil, "id mismatch" - end - - local flags_hi = byte(buf, 3) - local flags_lo = byte(buf, 4) - local flags = lshift(flags_hi, 8) + flags_lo - - -- print(format("flags: 0x%x", flags)) - - if band(flags, 0x8000) == 0 then - return nil, format("bad QR flag in the DNS response") - end - - if band(flags, 0x200) ~= 0 then - return nil, "truncated" - end - - local code = band(flags, 0xf) - - -- print(format("code: %d", code)) +local function _http_header_send(sock, host, method, length, param) + local hoststr - local nqs_hi = byte(buf, 5) - local nqs_lo = byte(buf, 6) - local nqs = lshift(nqs_hi, 8) + nqs_lo + -- HEADER - -- print("nqs: ", nqs) + local bytes, err = sock:send(query) - if nqs ~= 1 then - return nil, format("bad number of questions in DNS response: %d", nqs) + if not bytes then + return 0, err end - local nan_hi = byte(buf, 7) - local nan_lo = byte(buf, 8) - local nan = lshift(nan_hi, 8) + nan_lo - - -- print("nan: ", nan) - - local nns_hi = byte(buf, 9) - local nns_lo = byte(buf, 10) - local nns = lshift(nns_hi, 8) + nns_lo + return bytes +end - local nar_hi = byte(buf, 11) - local nar_lo = byte(buf, 12) - local nar = lshift(nar_hi, 8) + nar_lo - -- skip the question part +local function _http_body_receive(sock, header) + local len = header["content-length"] - local ans_qname, pos = _decode_name(buf, 13) - if not ans_qname then - return nil, pos + if header["content-type"] ~= "application/dns-message" then + return nil, "http query failed invalid Content-Type: "..header["content-type"] end - -- print("qname in reply: ", ans_qname) - - -- print("question: ", sub(buf, 13, pos)) + local data, err = sock:receiveany(tonumber(len)) - if pos + 3 + nan * 12 > n then - -- print(format("%d > %d", pos + 3 + nan * 12, n)) - return nil, 'truncated'; + if not data then + return nil, "http query failed to receive body "..err end - -- question section layout: qname qtype(2) qclass(2) - - --[[ - local type_hi = byte(buf, pos) - local type_lo = byte(buf, pos + 1) - local ans_type = lshift(type_hi, 8) + type_lo - ]] - - -- print("ans qtype: ", ans_type) - - local class_hi = byte(buf, pos + 2) - local class_lo = byte(buf, pos + 3) - local qclass = lshift(class_hi, 8) + class_lo + return data +end - -- print("ans qclass: ", qclass) - if qclass ~= 1 then - return nil, format("unknown query class %d in DNS response", qclass) +local function _http_query(sock,host,opts) + + local sock, err = _http_connect(sock, host) + if not sock then + return nil, err end - pos = pos + 4 - - local answers = {} - - if code ~= 0 then - answers.errcode = code - answers.errstr = resolver_errstrs[code] or "unknown" + local bytes, err = _http_header_send(sock, host, opts.method, opts.body and #opts.body or 0, opts.param) + if not bytes then + return nil, err end - local authority_section, additional_section - - if opts then - authority_section = opts.authority_section - additional_section = opts.additional_section - if opts.qtype == TYPE_SOA then - authority_section = true + if opts.body then + local bytes, err = sock:send(opts.body) + if not bytes or bytes < #opts.body then + return nil, "http POST query failed body not sent" end end - local err - - pos, err = parse_section(answers, SECTION_AN, buf, pos, nan) + local version, status, reason, err = _http_status_receive(sock) - if not pos then + if err then return nil, err end - if not authority_section and not additional_section then - return answers + if status ~= 200 then + return nil, "http query failed status code is: "..status.." reason: "..reason end - pos, err = parse_section(answers, SECTION_NS, buf, pos, nns, - not authority_section) - - if not pos then + local header, err = _http_header_receive(sock) + if not header then return nil, err end - if not additional_section then - return answers - end - - pos, err = parse_section(answers, SECTION_AR, buf, pos, nar) - - if not pos then + local data, err = _http_body_receive(sock, header) + if not data then return nil, err end - return answers -end + sock:setkeepalive() - -local function _gen_id(self) - local id = self._id -- for regression testing - if id then - return id - end - return rand(0, 65535) -- two bytes + return { + status = status, + version = version, + body = data + } end -local function _tcp_query(self, query, id, opts) - local sock = self.tcp_sock - if not sock then - return nil, "not initialized" - end - - log(DEBUG, "query the TCP server due to reply truncation") - - local server = _get_cur_server(self) - - local ok, err = sock:connect(server[1], server[2]) - if not ok then - return nil, "failed to connect to TCP server " - .. concat(server, ":") .. ": " .. err - end - - query = concat(query, "") - local len = #query - - local len_hi = char(rshift(len, 8)) - local len_lo = char(band(len, 0xff)) - - local bytes, err = sock:send({len_hi, len_lo, query}) - if not bytes then - return nil, "failed to send query to TCP server " - .. concat(server, ":") .. ": " .. err - end - - local buf, err = sock:receive(2) - if not buf then - return nil, "failed to receive the reply length field from TCP server " - .. concat(server, ":") .. ": " .. err - end - - len_hi = byte(buf, 1) - len_lo = byte(buf, 2) - len = lshift(len_hi, 8) + len_lo - - -- print("tcp message len: ", len) - - buf, err = sock:receive(len) - if not buf then - return nil, "failed to receive the reply message body from TCP server " - .. concat(server, ":") .. ": " .. err +local function _doh_query(qname, opts, tries, servers) + --local sock = self.tcp_sock + --if not sock then + -- return nil, "not initialized" + --end + + local servers = self.servers + if not servers:size() then + return nil, "no servers available" end - local answers, err = parse_response(buf, id, opts) - if not answers then - return nil, "failed to parse the reply from the TCP server " - .. concat(server, ":") .. ": " .. err + local retrans = self.retrans + if tries then + tries[1] = nil end + + local method = self.doh_method + local err - sock:close() - - return answers -end - - -function _M.tcp_query(self, qname, opts) - local socks = self.socks - if not socks then - return nil, "not initialized" - end + --if method == ngx_post then + -- id = _gen_id(self) + -- opts = { + -- method = ngx_post, + -- body = table.concat(_build_request(qname, id, self.no_recurse, opts)) + -- } + --else + -- opts = { + -- method = ngx_get, + -- param = self.doh_encode and b64.encode_base64url(qname) or qname + -- } + --end + + for i = 1, retrans do + local id, opts + local server = servers:pick() + + local res, err = _http_query(sock, server, opts) + if not res then + return nil, err, tries + end - pick_sock(self, socks) + if res and res.status == 200 and res.body then + local answers + if method == ngx_get then + local ident_hi = byte(res.body, 1) + local ident_lo = byte(res.body, 2) + id = lshift(ident_hi, 8) + ident_lo + end + answers, err = _parse_response(res.body, id, opts) + if answers then + return answers, nil, tries + end - local id = _gen_id(self) + if err and err ~= "id mismatch" then + break + else + log(DEBUG,"DoH query failed to parse response",err) + end + end - local query, err = _build_request(qname, id, self.no_recurse, opts) - if not query then - return nil, err + if tries then + tries[i] = err + tries[i + 1] = nil -- ensure termination for user supplied table + end end - return _tcp_query(self, query, id, opts) + return nil, err, tries end - -function _M.query(self, qname, opts, tries) - local socks = self.socks - if not socks then +local function _udp_tcp_query(qname, opts, tries, servers) + if not servers then return nil, "not initialized" end - local id = _gen_id(self) + --local id = _gen_id(self) local query, err = _build_request(qname, id, self.no_recurse, opts) if not query then @@ -855,12 +977,11 @@ function _M.query(self, qname, opts, tries) -- print("retrans: ", retrans) for i = 1, retrans do - local sock = pick_sock(self, socks) - - local ok - ok, err = sock:send(query) + local sock = servers:pick_sock() + + local ok, err = sock:send(query) if not ok then - local server = _get_cur_server(self) + local server = servers:current_server() err = "failed to send request to UDP server " .. concat(server, ":") .. ": " .. err @@ -870,7 +991,7 @@ function _M.query(self, qname, opts, tries) for _ = 1, 128 do buf, err = sock:receive(4096) if err then - local server = _get_cur_server(self) + local server = servers:current_server() err = "failed to receive reply from UDP server " .. concat(server, ":") .. ": " .. err break @@ -878,9 +999,9 @@ function _M.query(self, qname, opts, tries) if buf then local answers - answers, err = parse_response(buf, id, opts) + answers, err = _parse_response(buf, id, opts) if err == "truncated" then - answers, err = _tcp_query(self, query, id, opts) + answers, err = _tcp_query(sock, query, id, opts, servers) end if err and err ~= "id mismatch" then @@ -905,12 +1026,14 @@ function _M.query(self, qname, opts, tries) end -function _M.compress_ipv6_addr(addr) +---------------------------[[ private functions ]]------------------------ + +local function _compress_ipv6_addr(addr) local addr = re_sub(addr, "^(0:)+|(:0)+$|:(0:)+", "::", "jo") if addr == "::0" then addr = "::" end - + return addr end @@ -918,36 +1041,33 @@ end local function _expand_ipv6_addr(addr) if find(addr, "::", 1, true) then local ncol, addrlen = 8, #addr - + for i = 1, addrlen do if byte(addr, i) == COLON_CHAR then ncol = ncol - 1 end end - + if byte(addr, 1) == COLON_CHAR then addr = "0" .. addr end - + if byte(addr, -1) == COLON_CHAR then addr = addr .. "0" end - + addr = re_sub(addr, "::", ":" .. rep("0:", ncol), "jo") end - + return addr end -_M.expand_ipv6_addr = _expand_ipv6_addr - - -function _M.arpa_str(addr) +local function _arpa_str(addr) if find(addr, ":", 1, true) then addr = _expand_ipv6_addr(addr) local idx, hidx, addrlen = 1, 1, #addr - + for i = addrlen, 0, -1 do local s = byte(addr, i) if s == COLON_CHAR or not s then @@ -962,21 +1082,141 @@ function _M.arpa_str(addr) hidx = hidx + 1 end end - + addr = char(unpack(arpa_tmpl)) else addr = re_sub(addr, [[(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})]], "$4.$3.$2.$1.in-addr.arpa", "ajo") end - + return addr end -function _M.reverse_query(self, addr) - return self.query(self, self.arpa_str(addr), - {qtype = self.TYPE_PTR}) +------------------[[ public instance methods ]]--------------- + +local function _query(self, qname, opts, tries) + log(ngx.ERR,"QUERY") + + local servers = self.servers + if not servers then + return nil, "not initialized" + end + + local retrans = self.retrans + -- print("retrans: ", retrans) + if tries then + tries[1] = nil + end + + local id = _gen_id(self) + local query, err = _build_wire_request(qname, id, self.no_recurse, opts) + if not query then + return nil, err + end + + -- local cjson = require "cjson" + -- print("query: ", cjson.encode(concat(query, ""))) + + + + + + for i = 1, retrans do + local sock = servers:pick() + + --[[ Abstract send ]] + local ok, err = sock:send(query) + if not ok then + local server = servers:current_server() + err = "failed to send request to UDP server " + .. concat(server, ":") .. ": " .. err + --[[ End Of Send ]] + else + + local buf + for _ = 1, 128 do + --[[ Receive ]] + buf, err = sock:receive(4096) + if err then + local server = servers:current_server() + err = "failed to receive reply from UDP server " + .. concat(server, ":") .. ": " .. err + break + end + --[[ End Of Receive]] + + --[[ Parse ]] + if buf then + local answers + answers, err = _parse_response(buf, id, opts) + if err == "truncated" then + answers, err = _tcp_query(sock, query, id, opts, servers) + end + + if err and err ~= "id mismatch" then + break + end + + if answers then + return answers, nil, tries + end + end + --[[ End Of Parse ]] + -- only here in case of an "id mismatch" + end + --[[ END ]] + end + + if tries then + tries[i] = err + tries[i + 1] = nil -- ensure termination for user supplied table + end + end + + return nil, err, tries +end + + +local function _reverse_query(class, addr) + log(ngx.ERR,"REVERSE QUERY") + return query(class, arpa_str(addr), + {qtype = class.TYPE_PTR}) end +local resolver_mt = { + udp_query = _udp_query, + udp_tcp_query = _udp_tcp_query, + tcp_query = _tcp_query, + dot_query = _dot_query, + doh_query = _doh_query, + query = _query, + reverse_query = _reverse_query +} + +----------------------------------------------------------------------------- + +function _M.new(class, opts) + if not opts then + return nil, "no options table specified" + end + + local nameservers = opts.nameservers + if not nameservers or #nameservers == 0 then + return nil, "no nameservers specified" + end + + local servers, err = _servers_array_new(opts) + if not servers then + return nil, err + end + + return setmetatable({ + servers = servers, + retrans = opts.retrans or 5, + no_recurse = opts.no_recurse + }, { __index = resolver_mt }) +end + return _M diff --git a/lib/resty/dns/wireformat.lua b/lib/resty/dns/wireformat.lua new file mode 100644 index 0000000..efcda00 --- /dev/null +++ b/lib/resty/dns/wireformat.lua @@ -0,0 +1,596 @@ +local bit = require "bit" +local band = bit.band +local rshift = bit.rshift +local lshift = bit.lshift +local insert = table.insert +local concat = table.concat +local byte = string.byte +local char= string.char +local byte = string.byte +local sub = string.sub +local gsub = string.gsub + +local log = ngx.log +local DEBUG = ngx.DEBUG + +local DOT_CHAR = byte(".") +local ZERO_CHAR = byte("0") + +local TYPE_A = 1 +local TYPE_NS = 2 +local TYPE_CNAME = 5 +local TYPE_SOA = 6 +local TYPE_PTR = 12 +local TYPE_MX = 15 +local TYPE_TXT = 16 +local TYPE_AAAA = 28 +local TYPE_SRV = 33 +local TYPE_SPF = 99 + +local CLASS_IN = 1 + +local SECTION_AN = 1 +local SECTION_NS = 2 +local SECTION_AR = 3 + +local soa_int32_fields = { "serial", "refresh", "retry", "expire", "minimum" } + +local resolver_errstrs = { + "format error", -- 1 + "server failure", -- 2 + "name error", -- 3 + "not implemented", -- 4 + "refused", -- 5 +} + +local function _encode_name(s) + return char(#s) .. s +end + + +local function _decode_name(buf, pos) + local labels = {} + local nptrs = 0 + local p = pos + while nptrs < 128 do + local fst = byte(buf, p) + + if not fst then + return nil, 'truncated'; + end + + -- print("fst at ", p, ": ", fst) + + if fst == 0 then + if nptrs == 0 then + pos = pos + 1 + end + break + end + + if band(fst, 0xc0) ~= 0 then + -- being a pointer + if nptrs == 0 then + pos = pos + 2 + end + + nptrs = nptrs + 1 + + local snd = byte(buf, p + 1) + if not snd then + return nil, 'truncated' + end + + p = lshift(band(fst, 0x3f), 8) + snd + 1 + + -- print("resolving ptr ", p, ": ", byte(buf, p)) + + else + -- being a label + local label = sub(buf, p + 1, p + fst) + insert(labels, label) + + -- print("resolved label ", label) + + p = p + fst + 1 + + if nptrs == 0 then + pos = p + end + end + end + + return concat(labels, "."), pos +end + + +local function _parse_wire_section(answers, section, buf, start_pos, size, + should_skip) + local pos = start_pos + + for _ = 1, size do + -- print(format("ans %d: qtype:%d qclass:%d", i, qtype, qclass)) + local ans = {} + + if not should_skip then + insert(answers, ans) + end + + ans.section = section + + local name + name, pos = _decode_name(buf, pos) + if not name then + return nil, pos + end + + ans.name = name + + -- print("name: ", name) + + local type_hi = byte(buf, pos) + local type_lo = byte(buf, pos + 1) + local typ = lshift(type_hi, 8) + type_lo + + ans.type = typ + + -- print("type: ", typ) + + local class_hi = byte(buf, pos + 2) + local class_lo = byte(buf, pos + 3) + local class = lshift(class_hi, 8) + class_lo + + ans.class = class + + -- print("class: ", class) + + local byte_1, byte_2, byte_3, byte_4 = byte(buf, pos + 4, pos + 7) + + local ttl = lshift(byte_1, 24) + lshift(byte_2, 16) + + lshift(byte_3, 8) + byte_4 + + -- print("ttl: ", ttl) + + ans.ttl = ttl + + local len_hi = byte(buf, pos + 8) + local len_lo = byte(buf, pos + 9) + local len = lshift(len_hi, 8) + len_lo + + -- print("record len: ", len) + + pos = pos + 10 + + if typ == TYPE_A then + + if len ~= 4 then + return nil, "bad A record value length: " .. len + end + + local addr_bytes = { byte(buf, pos, pos + 3) } + local addr = concat(addr_bytes, ".") + -- print("ipv4 address: ", addr) + + ans.address = addr + + pos = pos + 4 + + elseif typ == TYPE_CNAME then + + local cname, p = _decode_name(buf, pos) + if not cname then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("cname: ", cname) + + ans.cname = cname + + elseif typ == TYPE_AAAA then + + if len ~= 16 then + return nil, "bad AAAA record value length: " .. len + end + + local addr_bytes = { byte(buf, pos, pos + 15) } + local flds = {} + for i = 1, 16, 2 do + local a = addr_bytes[i] + local b = addr_bytes[i + 1] + if a == 0 then + insert(flds, format("%x", b)) + + else + insert(flds, format("%x%02x", a, b)) + end + end + + -- we do not compress the IPv6 addresses by default + -- due to performance considerations + + ans.address = concat(flds, ":") + + pos = pos + 16 + + elseif typ == TYPE_MX then + + -- print("len = ", len) + + if len < 3 then + return nil, "bad MX record value length: " .. len + end + + local pref_hi = byte(buf, pos) + local pref_lo = byte(buf, pos + 1) + + ans.preference = lshift(pref_hi, 8) + pref_lo + + local host, p = _decode_name(buf, pos + 2) + if not host then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + ans.exchange = host + + pos = p + + elseif typ == TYPE_SRV then + if len < 7 then + return nil, "bad SRV record value length: " .. len + end + + local prio_hi = byte(buf, pos) + local prio_lo = byte(buf, pos + 1) + ans.priority = lshift(prio_hi, 8) + prio_lo + + local weight_hi = byte(buf, pos + 2) + local weight_lo = byte(buf, pos + 3) + ans.weight = lshift(weight_hi, 8) + weight_lo + + local port_hi = byte(buf, pos + 4) + local port_lo = byte(buf, pos + 5) + ans.port = lshift(port_hi, 8) + port_lo + + local name, p = _decode_name(buf, pos + 6) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad srv record length: %d ~= %d", + p - pos, len) + end + + ans.target = name + + pos = p + + elseif typ == TYPE_NS then + + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("name: ", name) + + ans.nsdname = name + + elseif typ == TYPE_TXT or typ == TYPE_SPF then + + local key = (typ == TYPE_TXT) and "txt" or "spf" + + local slen = byte(buf, pos) + if slen + 1 > len then + -- truncate the over-run TXT record data + slen = len + end + + -- print("slen: ", len) + + local val = sub(buf, pos + 1, pos + slen) + local last = pos + len + pos = pos + slen + 1 + + if pos < last then + -- more strings to be processed + -- this code path is usually cold, so we do not + -- merge the following loop on this code path + -- with the processing logic above. + + val = {val} + local idx = 2 + repeat + local slen = byte(buf, pos) + if pos + slen + 1 > last then + -- truncate the over-run TXT record data + slen = last - pos - 1 + end + + val[idx] = sub(buf, pos + 1, pos + slen) + idx = idx + 1 + pos = pos + slen + 1 + + until pos >= last + end + + ans[key] = val + + elseif typ == TYPE_PTR then + + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + + if p - pos ~= len then + return nil, format("bad cname record length: %d ~= %d", + p - pos, len) + end + + pos = p + + -- print("name: ", name) + + ans.ptrdname = name + + elseif typ == TYPE_SOA then + local name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + ans.mname = name + + pos = p + name, p = _decode_name(buf, pos) + if not name then + return nil, pos + end + ans.rname = name + + for _, field in ipairs(soa_int32_fields) do + local byte_1, byte_2, byte_3, byte_4 = byte(buf, p, p + 3) + ans[field] = lshift(byte_1, 24) + lshift(byte_2, 16) + + lshift(byte_3, 8) + byte_4 + p = p + 4 + end + + pos = p + + else + -- for unknown types, just forward the raw value + + ans.rdata = sub(buf, pos, pos + len - 1) + pos = pos + len + end + end + + return pos +end + + +local function _parse_wire_response(buf, id, opts) + local n = #buf + if n < 12 then + return nil, 'truncated' + end + + -- header layout: ident flags nqs nan nns nar + + local ident_hi = byte(buf, 1) + local ident_lo = byte(buf, 2) + local ans_id = lshift(ident_hi, 8) + ident_lo + + -- print("id: ", id, ", ans id: ", ans_id) + + if ans_id ~= id then + -- identifier mismatch and throw it away + log(DEBUG, "id mismatch in the DNS reply: ", ans_id, " ~= ", id) + return nil, "id mismatch" + end + + local flags_hi = byte(buf, 3) + local flags_lo = byte(buf, 4) + local flags = lshift(flags_hi, 8) + flags_lo + + -- print(format("flags: 0x%x", flags)) + + if band(flags, 0x8000) == 0 then + return nil, format("bad QR flag in the DNS response") + end + + if band(flags, 0x200) ~= 0 then + return nil, "truncated" + end + + local code = band(flags, 0xf) + + -- print(format("code: %d", code)) + + local nqs_hi = byte(buf, 5) + local nqs_lo = byte(buf, 6) + local nqs = lshift(nqs_hi, 8) + nqs_lo + + -- print("nqs: ", nqs) + + if nqs ~= 1 then + return nil, format("bad number of questions in DNS response: %d", nqs) + end + + local nan_hi = byte(buf, 7) + local nan_lo = byte(buf, 8) + local nan = lshift(nan_hi, 8) + nan_lo + + -- print("nan: ", nan) + + local nns_hi = byte(buf, 9) + local nns_lo = byte(buf, 10) + local nns = lshift(nns_hi, 8) + nns_lo + + local nar_hi = byte(buf, 11) + local nar_lo = byte(buf, 12) + local nar = lshift(nar_hi, 8) + nar_lo + + -- skip the question part + + local ans_qname, pos = _decode_name(buf, 13) + if not ans_qname then + return nil, pos + end + + -- print("qname in reply: ", ans_qname) + + -- print("question: ", sub(buf, 13, pos)) + + if pos + 3 + nan * 12 > n then + -- print(format("%d > %d", pos + 3 + nan * 12, n)) + return nil, 'truncated' + end + + -- question section layout: qname qtype(2) qclass(2) + + --[[ + local type_hi = byte(buf, pos) + local type_lo = byte(buf, pos + 1) + local ans_type = lshift(type_hi, 8) + type_lo + ]] + + -- print("ans qtype: ", ans_type) + + local class_hi = byte(buf, pos + 2) + local class_lo = byte(buf, pos + 3) + local qclass = lshift(class_hi, 8) + class_lo + + -- print("ans qclass: ", qclass) + + if qclass ~= 1 then + return nil, format("unknown query class %d in DNS response", qclass) + end + + pos = pos + 4 + + local answers = {} + + if code ~= 0 then + answers.errcode = code + answers.errstr = resolver_errstrs[code] or "unknown" + end + + local authority_section, additional_section + + if opts then + authority_section = opts.authority_section + additional_section = opts.additional_section + if opts.qtype == TYPE_SOA then + authority_section = true + end + end + + local err + + pos, err = _parse_wire_section(answers, SECTION_AN, buf, pos, nan) + + if not pos then + return nil, err + end + + if not authority_section and not additional_section then + return answers + end + + pos, err = _parse_wire_section(answers, SECTION_NS, buf, pos, nns, + not authority_section) + + if not pos then + return nil, err + end + + if not additional_section then + return answers + end + + pos, err = _parse_wire_section(answers, SECTION_AR, buf, pos, nar) + + if not pos then + return nil, err + end + + return answers +end + +local function _build_wire_request(qname, id, no_recurse, opts) + local qtype = opts and opts.qtype or 1 + local ident_hi = char(rshift(id, 8)) + local ident_lo = char(band(id, 0xff)) + + local flags + if no_recurse then + -- print("found no recurse") + flags = "\0\0" + else + flags = "\1\0" + end + + local nqs = "\0\1" + local nan = "\0\0" + local nns = "\0\0" + local nar = "\0\0" + local typ = char(rshift(qtype, 8), band(qtype, 0xff)) + local class = "\0\1" -- the Internet class + + if byte(qname, 1) == DOT_CHAR then + return nil, "bad name" + end + + local name = gsub(qname, "([^.]+)%.?", _encode_name) .. '\0' + + return { + ident_hi, ident_lo, flags, nqs, nan, nns, nar, + name, typ, class + } +end + + +return { + TYPE = { + A = TYPE_A, + NS = TYPE_NS, + CNAME = TYPE_CNAME, + SOA = TYPE_SOA, + PTR = TYPE_PTR, + MX = TYPE_MX, + TXT = TYPE_TXT, + AAAA = TYPE_AAAA, + SRV = TYPE_SRV, + SPF = TYPE_SPF + }, + CLASS = { + IN = CLASS_IN + }, + SECTION = { + AN = SECTION_AN, + NS = SECTION_NS, + AR = SECTION_AR + }, + build_request = _build_wire_request, + parse_response = _parse_wire_response +} diff --git a/t/mock.t b/t/mock.t index 43c186f..92a5d05 100644 --- a/t/mock.t +++ b/t/mock.t @@ -1994,3 +1994,97 @@ failed to query: 3: failed to receive reply from UDP server 127.0.0.1:20002: connection refused --- error_log Connection refused + + + +=== TEST 40: single answer DoH GET request, good A answer +--- http_config eval: $::HttpConfig +--- config + location /t { + content_by_lua ' + local resolver = require "resty.dns.resolver" + + local r, err = resolver:new{ + nameservers = { "https://cloudflare-dns.com/dns-query?name=" }, + doh = true, + doh_method = 'GET' + } + if not r then + ngx.say("failed to instantiate resolver: ", err) + return + end + + r._id = 125 + + local ans, err = r:query("www.google.com", { qtype = r.TYPE_A }) + if not ans then + ngx.say("failed to query: ", err) + return + end + + local ljson = require "ljson" + ngx.say("records: ", ljson.encode(ans)) + '; + } +--- doh_reply dns +{ + id => 125, + opcode => 0, + qname => 'www.google.com', + answer => [{ name => "www.google.com", ipv4 => "127.0.0.1", ttl => 123456 }], +} +--- request +GET /t +--- doh_query eval +"\x{00}}\x{01}\x{00}\x{00}\x{01}\x{00}\x{00}\x{00}\x{00}\x{00}\x{00}\x{03}www\x{06}google\x{03}com\x{00}\x{00}\x{01}\x{00}\x{01}" +--- response_body +records: [{"address":"127.0.0.1","class":1,"name":"www.google.com","section":1,"ttl":123456,"type":1}] +--- no_error_log +[error] + + + +=== TEST 41: single answer DoH POST reply, good A answer +--- http_config eval: $::HttpConfig +--- config + location /t { + content_by_lua ' + local resolver = require "resty.dns.resolver" + + local r, err = resolver:new{ + nameservers = { "https://cloudflare-dns.com/dns-query" }, + doh = true, + doh_method = 'POST' + } + if not r then + ngx.say("failed to instantiate resolver: ", err) + return + end + + r._id = 125 + + local ans, err = r:query("www.google.com", { qtype = r.TYPE_A }) + if not ans then + ngx.say("failed to query: ", err) + return + end + + local ljson = require "ljson" + ngx.say("records: ", ljson.encode(ans)) + '; + } +--- doh_reply dns +{ + id => 125, + opcode => 0, + qname => 'www.google.com', + answer => [{ name => "www.google.com", ipv4 => "127.0.0.1", ttl => 123456 }], +} +--- request +GET /t +--- doh_query eval +"\x{00}}\x{01}\x{00}\x{00}\x{01}\x{00}\x{00}\x{00}\x{00}\x{00}\x{00}\x{03}www\x{06}google\x{03}com\x{00}\x{00}\x{01}\x{00}\x{01}" +--- response_body +records: [{"address":"127.0.0.1","class":1,"name":"www.google.com","section":1,"ttl":123456,"type":1}] +--- no_error_log +[error]