Skip to content

Commit

Permalink
feat(router) support matching ports in Host header
Browse files Browse the repository at this point in the history
Allow routes to specify a port in the host parameter. If present, it
matches only requests to that specific port. If absent, requests to
any port will match. Requests without explicit port are handled as if
they included the appropriate default port (80 for HTTP, 443 for
HTTPs). Routes with a port have higher priority than those without;
even for requests without an explicit port in the Host header.

From #5102
  • Loading branch information
javierguerragiraldez authored and hutchic committed Nov 21, 2019
1 parent e359f7a commit d904bf3
Show file tree
Hide file tree
Showing 12 changed files with 640 additions and 114 deletions.
2 changes: 1 addition & 1 deletion kong/db/schema/entities/services.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ return {
{ retries = { type = "integer", default = 5, between = { 0, 32767 } }, },
-- { tags = { type = "array", array = { type = "string" } }, },
{ protocol = typedefs.protocol { required = true, default = "http" } },
{ host = typedefs.host { required = true } },
{ host = typedefs.host_with_optional_port { required = true } },
{ port = typedefs.port { required = true, default = 80 }, },
{ path = typedefs.path },
{ connect_timeout = nonzero_timeout { default = 60000 }, },
Expand Down
2 changes: 1 addition & 1 deletion kong/db/schema/entities/upstreams.lua
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ local r = {
fields = healthchecks_fields,
}, },
{ tags = typedefs.tags },
{ host_header = typedefs.host },
{ host_header = typedefs.host_with_optional_port },
},
entity_checks = {
-- hash_on_header must be present when hashing on header
Expand Down
14 changes: 13 additions & 1 deletion kong/db/schema/typedefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ local function validate_host(host)
end


local function validate_host_with_optional_port(host)
local res, err_or_port = utils.normalize_ip(host)
return (res and true or nil), err_or_port
end


local function validate_ip(ip)
local res, err = utils.normalize_ip(ip)
if not res then
Expand Down Expand Up @@ -215,6 +221,12 @@ typedefs.host = Schema.define {
}


typedefs.host_with_optional_port = Schema.define {
type = "string",
custom_validator = validate_host_with_optional_port,
}


typedefs.wildcard_host = Schema.define {
type = "string",
custom_validator = validate_wildcard_host,
Expand Down Expand Up @@ -400,7 +412,7 @@ typedefs.protocols_http = Schema.define {

local function validate_host_with_wildcards(host)
local no_wildcards = string.gsub(host, "%*", "abc")
return typedefs.host.custom_validator(no_wildcards)
return typedefs.host_with_optional_port.custom_validator(no_wildcards)
end

local function validate_path_with_regexes(path)
Expand Down
155 changes: 139 additions & 16 deletions kong/router.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ local var = ngx.var
local ngx_log = ngx.log
local insert = table.insert
local sort = table.sort
local byte = string.byte
local upper = string.upper
local lower = string.lower
local find = string.find
local format = string.format
local sub = string.sub
local tonumber = tonumber
local ipairs = ipairs
Expand Down Expand Up @@ -53,6 +55,103 @@ do
end


local split_port
do
local ZERO, NINE, LEFTBRACKET, RIGHTBRACKET = ("09[]"):byte(1, -1)


local function safe_add_port(host, port)
if not port then
return host
end

return host .. ":" .. port
end


local function onlydigits(s, begin)
for i = begin or 1, #s do
local c = byte(s, i)
if c < ZERO or c > NINE then
return false
end
end
return true
end


--- Splits an optional ':port' section from a hostname
-- the port section must be decimal digits only.
-- brackets ('[]') are peeled off the hostname if present.
-- if there's more than one colon and no brackets, no split is possible.
-- on non-parseable input, returns name unchanged,
-- every string input produces at least one string output.
-- @tparam string name the string to split.
-- @tparam number default_port default port number
-- @treturn string hostname without port
-- @treturn string hostname with port
-- @treturn boolean true if input had a port number
local function l_split_port(name, default_port)
if byte(name, 1) == LEFTBRACKET then
if byte(name, -1) == RIGHTBRACKET then
return sub(name, 2, -2), safe_add_port(name, default_port), false
end

local splitpos = find(name, "]:", 2, true)
if splitpos then
if splitpos == #name - 1 then
return sub(name, 2, splitpos - 1), name .. (default_port or ""), false
end

if onlydigits(name, splitpos + 2) then
return sub(name, 2, splitpos - 1), name, true
end
end

return name, safe_add_port(name, default_port), false
end

local firstcolon = find(name, ":", 1, true)
if not firstcolon then
return name, safe_add_port(name, default_port), false
end

if firstcolon == #name then
local host = sub(name, 1, firstcolon - 1)
return host, safe_add_port(host, default_port), false
end

if not onlydigits(name, firstcolon + 1) then
if default_port then
return name, format("[%s]:%s", name, default_port), false
end

return name, name, false
end

return sub(name, 1, firstcolon - 1), name, true
end


-- split_port is a pure function, so we can memoize it.
local memo_h = setmetatable({}, { __mode = "k" })
local memo_hp = setmetatable({}, { __mode = "k" })
local memo_p = setmetatable({}, { __mode = "k" })


split_port = function(name, default_port)
local k = name .. "#" .. (default_port or "")
local h, hp, p = memo_h[k], memo_hp[k], memo_p[k]
if not h then
h, hp, p = l_split_port(name, default_port)
memo_h[k], memo_hp[k], memo_p[k] = h, hp, p
end

return h, hp, p
end
end


--[[
Hypothesis
----------
Expand Down Expand Up @@ -87,8 +186,9 @@ sort(SORTED_MATCH_RULES, function(a, b)
end)

local MATCH_SUBRULES = {
HAS_REGEX_URI = 0x01,
PLAIN_HOSTS_ONLY = 0x02,
HAS_REGEX_URI = 0x01,
PLAIN_HOSTS_ONLY = 0x02,
HAS_WILDCARD_HOST_PORT = 0x04,
}

local EMPTY_T = {}
Expand Down Expand Up @@ -213,6 +313,7 @@ local function marshall_route(r)

local has_host_wildcard
local has_host_plain
local has_port

for _, host in ipairs(hosts) do
if type(host) ~= "string" then
Expand All @@ -225,6 +326,12 @@ local function marshall_route(r)

local wildcard_host_regex = host:gsub("%.", "\\.")
:gsub("%*", ".+") .. "$"

_, _, has_port = split_port(host)
if not has_port then
wildcard_host_regex = wildcard_host_regex:gsub("%$$", [[(?::\d+)?$]])
end

insert(route_t.hosts, {
wildcard = true,
value = host,
Expand Down Expand Up @@ -252,6 +359,11 @@ local function marshall_route(r)
route_t.submatch_weight = bor(route_t.submatch_weight,
MATCH_SUBRULES.PLAIN_HOSTS_ONLY)
end

if has_port then
route_t.submatch_weight = bor(route_t.submatch_weight,
MATCH_SUBRULES.HAS_WILDCARD_HOST_PORT)
end
end


Expand Down Expand Up @@ -658,7 +770,8 @@ end
do
local matchers = {
[MATCH_RULES.HOST] = function(route_t, ctx)
local host = route_t.hosts[ctx.hits.host or ctx.req_host]
local req_host = ctx.hits.host or ctx.req_host
local host = route_t.hosts[req_host] or route_t.hosts[ctx.host_no_port]
if host then
ctx.matches.host = host
return true
Expand All @@ -668,7 +781,7 @@ do
local host_t = route_t.hosts[i]

if host_t.wildcard then
local from, _, err = re_find(ctx.req_host, host_t.regex, "ajo")
local from, _, err = re_find(ctx.host_with_port, host_t.regex, "ajo")
if err then
log(ERR, "could not evaluate wildcard host regex: ", err)
return
Expand Down Expand Up @@ -992,6 +1105,7 @@ _M.has_capturing_groups = has_capturing_groups

-- for unit-testing purposes only
_M._set_ngx = _set_ngx
_M.split_port = split_port


function _M.new(routes)
Expand Down Expand Up @@ -1164,7 +1278,7 @@ function _M.new(routes)

local grab_req_headers = #plain_indexes.headers > 0

local function find_route(req_method, req_uri, req_host,
local function find_route(req_method, req_uri, req_host, req_scheme,
src_ip, src_port,
dst_ip, dst_port,
sni, req_headers)
Expand All @@ -1177,6 +1291,9 @@ function _M.new(routes)
if req_host and type(req_host) ~= "string" then
error("host must be a string", 2)
end
if req_scheme and type(req_scheme) ~= "string" then
error("scheme must be a string", 2)
end
if src_ip and type(src_ip) ~= "string" then
error("src_ip must be a string", 2)
end
Expand Down Expand Up @@ -1233,13 +1350,15 @@ function _M.new(routes)

req_method = upper(req_method)

if req_host ~= "" then
-- strip port number if given because matching ignores ports
local idx = find(req_host, ":", 2, true)
if idx then
ctx.req_host = sub(req_host, 1, idx - 1)
end
end
-- req_host might have port or maybe not, host_no_port definitely doesn't
-- if there wasn't a port, req_port is assumed to be the default port
-- according the protocol scheme
local host_no_port, host_with_port = split_port(req_host,
req_scheme == "https"
and 443 or 80)

ctx.host_with_port = host_with_port
ctx.host_no_port = host_no_port

local hits = ctx.hits
local req_category = 0x00
Expand All @@ -1252,12 +1371,15 @@ function _M.new(routes)

-- host match

if plain_indexes.hosts[ctx.req_host] then
if plain_indexes.hosts[host_with_port]
or plain_indexes.hosts[host_no_port]
then
req_category = bor(req_category, MATCH_RULES.HOST)

elseif ctx.req_host then
for i = 1, #wildcard_hosts do
local from, _, err = re_find(ctx.req_host, wildcard_hosts[i].regex, "ajo")
local from, _, err = re_find(host_with_port, wildcard_hosts[i].regex,
"ajo")
if err then
log(ERR, "could not match wildcard host: ", err)
return
Expand Down Expand Up @@ -1482,6 +1604,7 @@ function _M.new(routes)
local req_method = get_method()
local req_uri = var.request_uri
local req_host = var.http_host or ""
local req_scheme = var.scheme
local sni = var.ssl_server_name

local headers
Expand All @@ -1504,7 +1627,7 @@ function _M.new(routes)
end
end

local match_t = find_route(req_method, req_uri, req_host,
local match_t = find_route(req_method, req_uri, req_host, req_scheme,
nil, nil, -- src_ip, src_port
nil, nil, -- dst_ip, dst_port
sni, headers)
Expand Down Expand Up @@ -1547,7 +1670,7 @@ function _M.new(routes)
local dst_port = tonumber(var.server_port, 10)
local sni = var.ssl_preread_server_name

return find_route(nil, nil, nil,
return find_route(nil, nil, nil, "tcp",
src_ip, src_port,
dst_ip, dst_port,
sni)
Expand Down
14 changes: 2 additions & 12 deletions spec/01-unit/01-db/01-schema/05-services_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -337,28 +337,18 @@ describe("services", function()

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("invalid value: " .. invalid_hosts[i], err.host)
assert.equal("invalid hostname: " .. invalid_hosts[i], err.host)
end
end)

it("rejects values with a valid port", function()
local service = {
host = "example.com:80",
}

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("must not have a port", err.host)
end)

it("rejects values with an invalid port", function()
local service = {
host = "example.com:1000000",
}

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("must not have a port", err.host)
assert.equal("invalid port number", err.host)
end)

-- acceptance
Expand Down
Loading

0 comments on commit d904bf3

Please sign in to comment.