diff --git a/kong/router.lua b/kong/router.lua index 635956b5d3fc..71ed51db6177 100644 --- a/kong/router.lua +++ b/kong/router.lua @@ -5,7 +5,6 @@ local bit = require "bit" local hostname_type = utils.hostname_type -local split_port = utils.split_port local subsystem = ngx.config.subsystem local get_method = ngx.req.get_method local get_headers = ngx.req.get_headers @@ -55,6 +54,99 @@ do end +local split_port +do + local function safe_add_port(host, port) + if not port then + return host + end + + return host .. ':' .. port + end + + local ZERO, NINE, LEFTBRACKET, RIGHTBRACKET = ("09[]"):byte(1, -1) + + local function onlydigits(s, begin) + for i = begin or 1, #s do + local c = byte(s, i) + if c < ZERO or c > NINE then + return false + end + end + return true + end + + --- splits an optional ':port' section from a hostname + -- the port section must be decimal digits only. + -- brackets ('[]') are peeled off the hostname if present. + -- if there's more than one colon and no brackets, no split is possible. + -- on non-parseable input, returns name unchanged, + -- every string input produces at least one string output. + -- @tparam name string the string to split. + -- @tparam default_port number default port number + -- @treturn string hostname without port + -- @treturn string hostname with port + -- @treturn boolean true if input had a port number + local function l_split_port(name, default_port) + if byte(name, 1) == LEFTBRACKET then + if byte(name, -1) == RIGHTBRACKET then + return sub(name, 2, -2), safe_add_port(name, default_port), false + end + + local splitpos = find(name, "]:", 2, true) + if splitpos then + if splitpos == #name - 1 then + return sub(name, 2, splitpos - 1), name .. (default_port or ''), false + end + + if onlydigits(name, splitpos + 2) then + return sub(name, 2, splitpos - 1), name, true + end + end + + return name, safe_add_port(name, default_port), false + end + + local firstcolon = find(name, ":", 1, true) + if not firstcolon then + return name, safe_add_port(name, default_port), false + end + + if firstcolon == #name then + local host = sub(name, 1, firstcolon - 1) + return host, safe_add_port(host, default_port), false + end + + if not onlydigits(name, firstcolon + 1) then + if default_port then + return name, format("[%s]:%s", name, default_port), false + end + + return name, name, false + end + + return sub(name, 1, firstcolon - 1), name, true + end + + -- split_port is a pure function, so we can memoize it. + local memo_h = setmetatable({}, {__mode = 'k'}) + local memo_hp = setmetatable({}, {__mode = 'k'}) + local memo_p = setmetatable({}, {__mode = 'k'}) + + function split_port(name, default_port) + local k = name .. '#' .. (default_port or '') + local h, hp, p = memo_h[k], memo_hp[k], memo_p[k] + if not h then + h, hp, p = l_split_port(name, default_port) + memo_h[k], memo_hp[k], memo_p[k] = h, hp, p + end + + return h, hp, p + end +end + + + --[[ Hypothesis ---------- @@ -220,7 +312,7 @@ local function marshall_route(r) local has_host_wildcard local has_host_plain - local host_port + local has_port for _, host in ipairs(hosts) do if type(host) ~= "string" then @@ -234,8 +326,8 @@ local function marshall_route(r) local wildcard_host_regex = host:gsub("%.", "\\.") :gsub("%*", ".+") .. "$" - _, host_port = split_port(host) - if not host_port then + _, _, has_port = split_port(host) + if not has_port then wildcard_host_regex = wildcard_host_regex:gsub("%$$", [[(?::\d+)?$]]) end @@ -267,7 +359,7 @@ local function marshall_route(r) MATCH_SUBRULES.PLAIN_HOSTS_ONLY) end - if host_port then + if has_port then route_t.submatch_weight = bor(route_t.submatch_weight, MATCH_SUBRULES.HAS_HOST_PORT) end @@ -1013,6 +1105,7 @@ _M.has_capturing_groups = has_capturing_groups -- for unit-testing purposes only _M._set_ngx = _set_ngx +_M.split_port = split_port function _M.new(routes) @@ -1260,25 +1353,9 @@ function _M.new(routes) -- req_host might have port or maybe not, host_no_port definitely doesn't -- if there wasn't a port, req_port is assumed to be the default port -- according the protocol scheme - local host_with_port = req_host - local host_no_port, req_port = split_port(host_with_port) - if not req_port then - req_port = 80 - if req_scheme == 'https' then - req_port = 443 - end + local host_no_port, host_with_port = split_port( + req_host, req_scheme == 'https' and 443 or 80) - -- a literal IPv6 address needs the format [host]:port, - -- the fastest way to check is if the first byte was stripped by split_port() - -- otherwise, check for a colon - if byte(host_with_port) ~= byte(host_no_port) - or find(host_no_port, ":", 1, true) - then - host_with_port = format("[%s]:%d", host_no_port, req_port) - else - host_with_port = format("%s:%d", host_no_port, req_port) - end - end ctx.host_with_port = host_with_port local hits = ctx.hits diff --git a/kong/tools/utils.lua b/kong/tools/utils.lua index e814ea259820..0775a036ea31 100644 --- a/kong/tools/utils.lua +++ b/kong/tools/utils.lua @@ -24,12 +24,10 @@ local tostring = tostring local sort = table.sort local concat = table.concat local insert = table.insert -local byte = string.byte local lower = string.lower local fmt = string.format local find = string.find local gsub = string.gsub -local sub = string.sub local split = pl_stringx.split local re_find = ngx.re.find local re_match = ngx.re.match @@ -663,63 +661,6 @@ _M.hostname_type = function(name) return "name" end -do - local ZERO, NINE, LEFTBRACKET, RIGHTBRACKET = ("09[]"):byte(1, -1) - - --- splits an optional ':port' section from a hostname - -- the port section must be decimal digits only. - -- brackets ('[]') are peeled off the hostname if present. - -- if there's more than one colon and no brackets, no split is possible. - -- on non-parseable input, returns name unchanged, - -- every string input produces at least one string output. - -- @param name (string) the string to split. - -- @treturn string hostname - -- @treturn number port or nil if not found - local function l_split_port(name) - if byte(name, 1) == LEFTBRACKET then - if byte(name, -1) == RIGHTBRACKET then - return sub(name, 2, -2) - end - local splitpos = find(name, "]:", 2, true) - if splitpos then - local port = tonumber(sub(name, splitpos+2)) - if port or splitpos == #name - 1 then - return sub(name, 2, splitpos - 1), port - end - end - return name - end - - local firstcolon = find(name, ":", 1, true) - if not firstcolon then - return name - end - - for i = firstcolon + 1, #name do - local c = byte(name, i) - if c < ZERO or c > NINE then - return name - end - end - return sub(name, 1, firstcolon - 1), tonumber(name:sub(firstcolon+1)) - end - - -- split_port is a pure function, so we can memoize it. - local memo_h = setmetatable({}, {__mode = 'k'}) - local memo_p = setmetatable({}, {__mode = 'k'}) - - function _M.split_port(name) - local h, p = memo_h[name], memo_p[name] - if not h then - h, p = l_split_port(name) - memo_h[name], memo_p[name] = h, p - end - - return h, p - end -end - - --- parses, validates and normalizes an ipv4 address. -- @param address the string containing the address (formats; ipv4, ipv4:port) -- @return normalized address (string) + port (number or nil), or alternatively nil+error diff --git a/spec/01-unit/05-utils_spec.lua b/spec/01-unit/05-utils_spec.lua index e4db0c2614ad..105242cf4cb6 100644 --- a/spec/01-unit/05-utils_spec.lua +++ b/spec/01-unit/05-utils_spec.lua @@ -437,24 +437,6 @@ describe("Utils", function() end) describe("hostnames and ip addresses", function() - it("splits port number", function() - for _, case in ipairs({ - { '', { '', nil } }, - { 'localhost', { 'localhost', nil } }, - { 'localhost:', { 'localhost', nil } }, - { 'localhost:80', { 'localhost', 80 } }, - { 'localhost:23h', { 'localhost:23h', nil } }, - { 'localhost/24', { 'localhost/24', nil } }, - { '::1', { '::1', nil } }, - { '[::1]', { '::1', nil } }, - { '[::1]:', { '::1', nil } }, - { '[::1]:80', { '::1', 80 } }, - { '[::1]:80b', { '[::1]:80b', nil } }, - { '[::1]/96', { '[::1]/96', nil } }, - }) do - assert.same(case[2], { utils.split_port(case[1]) }) - end - end) describe("hostname_type", function() -- no check on "name" type as anything not ipv4 and not ipv6 will be labelled as 'name' anyway it("checks valid IPv4 address types", function() diff --git a/spec/01-unit/08-router_spec.lua b/spec/01-unit/08-router_spec.lua index d208e25c9f03..7a828947b0ce 100644 --- a/spec/01-unit/08-router_spec.lua +++ b/spec/01-unit/08-router_spec.lua @@ -207,6 +207,40 @@ local use_case = { } describe("Router", function() + describe("split_port()", function() + it("splits port number", function() + for _, case in ipairs({ + { {''}, { '', '', false } }, + { {'localhost'}, { 'localhost', 'localhost', false } }, + { {'localhost:'}, { 'localhost', 'localhost', false } }, + { {'localhost:80'}, { 'localhost', 'localhost:80', true } }, + { {'localhost:23h'}, { 'localhost:23h', 'localhost:23h', false } }, + { {'localhost/24'}, { 'localhost/24', 'localhost/24', false } }, + { {'::1'}, { '::1', '::1', false } }, + { {'[::1]'}, { '::1', '[::1]', false } }, + { {'[::1]:'}, { '::1', '[::1]:', false } }, + { {'[::1]:80'}, { '::1', '[::1]:80', true } }, + { {'[::1]:80b'}, { '[::1]:80b', '[::1]:80b', false } }, + { {'[::1]/96'}, { '[::1]/96', '[::1]/96', false } }, + + { {'', 88}, { '', ':88', false } }, + { {'localhost', 88}, { 'localhost', 'localhost:88', false } }, + { {'localhost:', 88}, { 'localhost', 'localhost:88', false } }, + { {'localhost:80', 88}, { 'localhost', 'localhost:80', true } }, + { {'localhost:23h', 88}, { 'localhost:23h', '[localhost:23h]:88', false } }, + { {'localhost/24', 88}, { 'localhost/24', 'localhost/24:88', false } }, + { {'::1', 88}, { '::1', '[::1]:88', false } }, + { {'[::1]', 88}, { '::1', '[::1]:88', false } }, + { {'[::1]:', 88}, { '::1', '[::1]:88', false } }, + { {'[::1]:80', 88}, { '::1', '[::1]:80', true } }, + { {'[::1]:80b', 88}, { '[::1]:80b', '[::1]:80b:88', false } }, + { {'[::1]/96', 88}, { '[::1]/96', '[::1]/96:88', false } }, + }) do + assert.same(case[2], { Router.split_port(unpack(case[1])) }) + end + end) + end) + describe("new()", function() describe("[errors]", function() it("enforces args types", function()