Skip to content

Commit

Permalink
feat(router) match routes like "host:port"
Browse files Browse the repository at this point in the history
Allow routes to specify a port in the host parameter. If present, it
matches only requests to that specific port. If absent, requests to
any port matches. Requests without explicit port are handled as if
they included the appropriate default port (80 for HTTP, 443 for
HTTPS). Routes with port have higher priority than those that those
without; even for requests without explicit port.
  • Loading branch information
javierguerragiraldez committed Oct 25, 2019
1 parent 6412a31 commit 949d7e6
Show file tree
Hide file tree
Showing 12 changed files with 539 additions and 116 deletions.
10 changes: 1 addition & 9 deletions kong/db/schema/typedefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,7 @@ local type = type

local function validate_host(host)
local res, err_or_port = utils.normalize_ip(host)
if type(err_or_port) == "string" and err_or_port ~= "invalid port number" then
return nil, "invalid value: " .. host
end

if err_or_port == "invalid port number" or type(res.port) == "number" then
return nil, "must not have a port"
end

return true
return res and true or nil, err_or_port
end


Expand Down
64 changes: 52 additions & 12 deletions kong/router.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ local bit = require "bit"


local hostname_type = utils.hostname_type
local split_port = utils.split_port
local subsystem = ngx.config.subsystem
local get_method = ngx.req.get_method
local get_headers = ngx.req.get_headers
Expand All @@ -15,9 +16,11 @@ local var = ngx.var
local ngx_log = ngx.log
local insert = table.insert
local sort = table.sort
local byte = string.byte
local upper = string.upper
local lower = string.lower
local find = string.find
local format = string.format
local sub = string.sub
local tonumber = tonumber
local ipairs = ipairs
Expand Down Expand Up @@ -88,6 +91,7 @@ end)
local MATCH_SUBRULES = {
HAS_REGEX_URI = 0x01,
PLAIN_HOSTS_ONLY = 0x02,
HAS_HOST_PORT = 0x04,
}

local EMPTY_T = {}
Expand Down Expand Up @@ -216,6 +220,7 @@ local function marshall_route(r)

local has_host_wildcard
local has_host_plain
local host_port

for _, host in ipairs(hosts) do
if type(host) ~= "string" then
Expand All @@ -228,6 +233,12 @@ local function marshall_route(r)

local wildcard_host_regex = host:gsub("%.", "\\.")
:gsub("%*", ".+") .. "$"

_, host_port = utils.split_port(host)
if not host_port then
wildcard_host_regex = wildcard_host_regex:gsub("%$$", [[(?::\d+)?$]])
end

insert(route_t.hosts, {
wildcard = true,
value = host,
Expand Down Expand Up @@ -255,6 +266,11 @@ local function marshall_route(r)
route_t.submatch_weight = bor(route_t.submatch_weight,
MATCH_SUBRULES.PLAIN_HOSTS_ONLY)
end

if host_port then
route_t.submatch_weight = bor(route_t.submatch_weight,
MATCH_SUBRULES.HAS_HOST_PORT)
end
end


Expand Down Expand Up @@ -661,7 +677,9 @@ end
do
local matchers = {
[MATCH_RULES.HOST] = function(route_t, ctx)
local host = route_t.hosts[ctx.hits.host or ctx.req_host]
local req_host = ctx.hits.host or ctx.req_host
local host = route_t.hosts[req_host]
or route_t.hosts[split_port(req_host)]
if host then
ctx.matches.host = host
return true
Expand All @@ -671,7 +689,7 @@ do
local host_t = route_t.hosts[i]

if host_t.wildcard then
local from, _, err = re_find(ctx.req_host, host_t.regex, "ajo")
local from, _, err = re_find(ctx.host_with_port, host_t.regex, "ajo")
if err then
log(ERR, "could not evaluate wildcard host regex: ", err)
return
Expand Down Expand Up @@ -1167,7 +1185,7 @@ function _M.new(routes)

local grab_req_headers = #plain_indexes.headers > 0

local function find_route(req_method, req_uri, req_host,
local function find_route(req_method, req_uri, req_host, req_scheme,
src_ip, src_port,
dst_ip, dst_port,
sni, req_headers)
Expand All @@ -1180,6 +1198,9 @@ function _M.new(routes)
if req_host and type(req_host) ~= "string" then
error("host must be a string", 2)
end
if req_scheme and type(req_scheme) ~= "string" then
error("scheme must be a string", 2)
end
if src_ip and type(src_ip) ~= "string" then
error("src_ip must be a string", 2)
end
Expand Down Expand Up @@ -1236,13 +1257,29 @@ function _M.new(routes)

req_method = upper(req_method)

if req_host ~= "" then
-- strip port number if given because matching ignores ports
local idx = find(req_host, ":", 2, true)
if idx then
ctx.req_host = sub(req_host, 1, idx - 1)
-- req_host might have port or maybe not, host_no_port definitely doesn't
-- if there wasn't a port, req_port is assumed to be the default port
-- according the protocol scheme
local host_with_port = req_host
local host_no_port, req_port = utils.split_port(host_with_port)
if not req_port then
req_port = 80
if req_scheme == 'https' then
req_port = 443
end

-- a literal IPv6 address needs the format [host]:port,
-- the fastest way to check is if the first byte was stripped by split_port()
-- otherwise, check for a colon
if byte(host_with_port) ~= byte(host_no_port)
or find(host_no_port, ":", 1, true)
then
host_with_port = format("[%s]:%d", host_no_port, req_port)
else
host_with_port = format("%s:%d", host_no_port, req_port)
end
end
ctx.host_with_port = host_with_port

local hits = ctx.hits
local req_category = 0x00
Expand All @@ -1255,12 +1292,14 @@ function _M.new(routes)

-- host match

if plain_indexes.hosts[ctx.req_host] then
if plain_indexes.hosts[host_with_port]
or plain_indexes.hosts[host_no_port]
then
req_category = bor(req_category, MATCH_RULES.HOST)

elseif ctx.req_host then
for i = 1, #wildcard_hosts do
local from, _, err = re_find(ctx.req_host, wildcard_hosts[i].regex, "ajo")
local from, _, err = re_find(host_with_port, wildcard_hosts[i].regex, "ajo")
if err then
log(ERR, "could not match wildcard host: ", err)
return
Expand Down Expand Up @@ -1485,6 +1524,7 @@ function _M.new(routes)
local req_method = get_method()
local req_uri = var.request_uri
local req_host = var.http_host or ""
local req_scheme = var.scheme
local sni = var.ssl_server_name

local headers
Expand All @@ -1507,7 +1547,7 @@ function _M.new(routes)
end
end

local match_t = find_route(req_method, req_uri, req_host,
local match_t = find_route(req_method, req_uri, req_host, req_scheme,
nil, nil, -- src_ip, src_port
nil, nil, -- dst_ip, dst_port
sni, headers)
Expand Down Expand Up @@ -1550,7 +1590,7 @@ function _M.new(routes)
local dst_port = tonumber(var.server_port, 10)
local sni = var.ssl_preread_server_name

return find_route(nil, nil, nil,
return find_route(nil, nil, nil, "tcp",
src_ip, src_port,
dst_ip, dst_port,
sni)
Expand Down
59 changes: 59 additions & 0 deletions kong/tools/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ local tostring = tostring
local sort = table.sort
local concat = table.concat
local insert = table.insert
local byte = string.byte
local lower = string.lower
local fmt = string.format
local find = string.find
local gsub = string.gsub
local sub = string.sub
local split = pl_stringx.split
local re_find = ngx.re.find
local re_match = ngx.re.match
Expand Down Expand Up @@ -661,6 +663,63 @@ _M.hostname_type = function(name)
return "name"
end

do
local ZERO, NINE, LEFTBRACKET, RIGHTBRACKET = ("09[]"):byte(1, -1)

--- splits an optional ':port' section from a hostname
-- the port section must be decimal digits only.
-- brackets ('[]') are peeled off the hostname if present.
-- if there's more than one colon and no brackets, no split is possible.
-- on non-parseable input, returns name unchanged,
-- every string input produces at least one string output.
-- @param name (string) the string to split.
-- @treturn string hostname
-- @treturn number port or nil if not found
local function l_split_port(name)
if byte(name, 1) == LEFTBRACKET then
if byte(name, -1) == RIGHTBRACKET then
return sub(name, 2, -2)
end
local splitpos = find(name, "]:", 2, true)
if splitpos then
local port = tonumber(sub(name, splitpos+2))
if port or splitpos == #name - 1 then
return sub(name, 2, splitpos - 1), port
end
end
return name
end

local firstcolon = find(name, ":", 1, true)
if not firstcolon then
return name
end

for i = firstcolon + 1, #name do
local c = byte(name, i)
if c < ZERO or c > NINE then
return name
end
end
return sub(name, 1, firstcolon - 1), tonumber(name:sub(firstcolon+1))
end

-- split_port is a pure function, so we can memoize it.
local memo_h = setmetatable({}, {__mode = 'k'})
local memo_p = setmetatable({}, {__mode = 'k'})

function _M.split_port(name)
local h, p = memo_h[name], memo_p[name]
if not h then
h, p = l_split_port(name)
memo_h[name], memo_p[name] = h, p
end

return h, p
end
end


--- parses, validates and normalizes an ipv4 address.
-- @param address the string containing the address (formats; ipv4, ipv4:port)
-- @return normalized address (string) + port (number or nil), or alternatively nil+error
Expand Down
14 changes: 2 additions & 12 deletions spec/01-unit/01-db/01-schema/05-services_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -337,28 +337,18 @@ describe("services", function()

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("invalid value: " .. invalid_hosts[i], err.host)
assert.equal("invalid hostname: " .. invalid_hosts[i], err.host)
end
end)

it("rejects values with a valid port", function()
local service = {
host = "example.com:80",
}

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("must not have a port", err.host)
end)

it("rejects values with an invalid port", function()
local service = {
host = "example.com:1000000",
}

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("must not have a port", err.host)
assert.equal("invalid port number", err.host)
end)

-- acceptance
Expand Down
18 changes: 5 additions & 13 deletions spec/01-unit/01-db/01-schema/06-routes_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -401,21 +401,10 @@ describe("routes schema", function()

local ok, err = Routes:validate(route)
assert.falsy(ok)
assert.equal("invalid value: " .. invalid_hosts[i], err.hosts[1])
assert.equal("invalid hostname: " .. invalid_hosts[i], err.hosts[1])
end
end)

it("rejects values with a valid port", function()
local route = {
hosts = { "example.com:80" },
protocols = { "http" },
}

local ok, err = Routes:validate(route)
assert.falsy(ok)
assert.equal("must not have a port", err.hosts[1])
end)

it("rejects values with an invalid port", function()
local route = {
hosts = { "example.com:1000000" },
Expand All @@ -424,7 +413,7 @@ describe("routes schema", function()

local ok, err = Routes:validate(route)
assert.falsy(ok)
assert.equal("must not have a port", err.hosts[1])
assert.equal("invalid port number", err.hosts[1])
end)

it("rejects invalid wildcard placement", function()
Expand Down Expand Up @@ -482,6 +471,8 @@ describe("routes schema", function()
"hello.abcd",
"example_api.com",
"localhost",
"example.com:80",
"example.com:8080",
-- below:
-- punycode examples from RFC3492;
-- https://tools.ietf.org/html/rfc3492#page-14
Expand Down Expand Up @@ -515,6 +506,7 @@ describe("routes schema", function()
local valid_hosts = {
"example.*",
"*.example.org",
"*.example.org:321",
}

for i = 1, #valid_hosts do
Expand Down
13 changes: 4 additions & 9 deletions spec/01-unit/01-db/01-schema/09-upstreams_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -219,33 +219,28 @@ describe("load upstreams", function()
describe("upstream attribute", function()
-- refusals
it("requires a valid hostname", function()
local ok, err = Upstreams:validate({
name = "host.test",
host_header = "ahostname:80" }
)
assert.falsy(ok)
assert.same({ host_header = "must not have a port" }, err)
local ok, err

ok, err = Upstreams:validate({
name = "host.test",
host_header = "http://ahostname.test" }
)
assert.falsy(ok)
assert.same({ host_header = "invalid value: http://ahostname.test" }, err)
assert.same({ host_header = "invalid hostname: http://ahostname.test" }, err)

ok, err = Upstreams:validate({
name = "host.test",
host_header = "ahostname-" }
)
assert.falsy(ok)
assert.same({ host_header = "invalid value: ahostname-" }, err)
assert.same({ host_header = "invalid hostname: ahostname-" }, err)

ok, err = Upstreams:validate({
name = "host.test",
host_header = "a hostname" }
)
assert.falsy(ok)
assert.same({ host_header = "invalid value: a hostname" }, err)
assert.same({ host_header = "invalid hostname: a hostname" }, err)
end)

-- acceptance
Expand Down
Loading

0 comments on commit 949d7e6

Please sign in to comment.