Skip to content

Commit

Permalink
feat(router) match routes like "host:port"
Browse files Browse the repository at this point in the history
Allow routes to specify a port in the host parameter. If present, it
matches only requests to that specific port. If absent, requests to
any port matches. Requests without explicit port are handled as if
they included the appropriate default port (80 for HTTP, 443 for
HTTPS). Routes with port have higher priority than those that those
without; even for requests without explicit port.
  • Loading branch information
javierguerragiraldez committed Oct 9, 2019
1 parent 6412a31 commit 9e1b511
Show file tree
Hide file tree
Showing 11 changed files with 308 additions and 142 deletions.
10 changes: 1 addition & 9 deletions kong/db/schema/typedefs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,7 @@ local type = type

local function validate_host(host)
local res, err_or_port = utils.normalize_ip(host)
if type(err_or_port) == "string" and err_or_port ~= "invalid port number" then
return nil, "invalid value: " .. host
end

if err_or_port == "invalid port number" or type(res.port) == "number" then
return nil, "must not have a port"
end

return true
return res and true or false, err_or_port
end


Expand Down
51 changes: 40 additions & 11 deletions kong/router.lua
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ end)
local MATCH_SUBRULES = {
HAS_REGEX_URI = 0x01,
PLAIN_HOSTS_ONLY = 0x02,
HAS_HOST_PORT = 0x04,
}

local EMPTY_T = {}
Expand Down Expand Up @@ -216,6 +217,7 @@ local function marshall_route(r)

local has_host_wildcard
local has_host_plain
local host_port

for _, host in ipairs(hosts) do
if type(host) ~= "string" then
Expand All @@ -228,6 +230,12 @@ local function marshall_route(r)

local wildcard_host_regex = host:gsub("%.", "\\.")
:gsub("%*", ".+") .. "$"

_, host_port = utils.split_port(host)
if not host_port then
wildcard_host_regex = wildcard_host_regex:gsub('%$$', [[(:\d+)?$]])
end

insert(route_t.hosts, {
wildcard = true,
value = host,
Expand Down Expand Up @@ -255,6 +263,11 @@ local function marshall_route(r)
route_t.submatch_weight = bor(route_t.submatch_weight,
MATCH_SUBRULES.PLAIN_HOSTS_ONLY)
end

if host_port then
route_t.submatch_weight = bor(route_t.submatch_weight,
MATCH_SUBRULES.HAS_HOST_PORT)
end
end


Expand Down Expand Up @@ -662,6 +675,7 @@ do
local matchers = {
[MATCH_RULES.HOST] = function(route_t, ctx)
local host = route_t.hosts[ctx.hits.host or ctx.req_host]
or route_t.hosts[utils.split_port(ctx.hits.host or ctx.req_host)]
if host then
ctx.matches.host = host
return true
Expand All @@ -671,7 +685,7 @@ do
local host_t = route_t.hosts[i]

if host_t.wildcard then
local from, _, err = re_find(ctx.req_host, host_t.regex, "ajo")
local from, _, err = re_find(ctx.host_with_port, host_t.regex, "ajo")
if err then
log(ERR, "could not evaluate wildcard host regex: ", err)
return
Expand Down Expand Up @@ -1167,7 +1181,7 @@ function _M.new(routes)

local grab_req_headers = #plain_indexes.headers > 0

local function find_route(req_method, req_uri, req_host,
local function find_route(req_method, req_uri, req_host, req_scheme,
src_ip, src_port,
dst_ip, dst_port,
sni, req_headers)
Expand All @@ -1180,6 +1194,9 @@ function _M.new(routes)
if req_host and type(req_host) ~= "string" then
error("host must be a string", 2)
end
if req_scheme and type(req_scheme) ~= "string" then
error("scheme must be a string", 2)
end
if src_ip and type(src_ip) ~= "string" then
error("src_ip must be a string", 2)
end
Expand Down Expand Up @@ -1236,13 +1253,22 @@ function _M.new(routes)

req_method = upper(req_method)

if req_host ~= "" then
-- strip port number if given because matching ignores ports
local idx = find(req_host, ":", 2, true)
if idx then
ctx.req_host = sub(req_host, 1, idx - 1)
local host_with_port = ctx.req_host
local host_no_port, req_port = utils.split_port(host_with_port)
if not req_port then
req_port = 80
if req_scheme == 'https' then
req_port = 443
end
if host_with_port:byte() ~= host_no_port:byte() or
host_no_port:find(':')
then
host_with_port = ('[%s]:%d'):format(host_no_port, req_port)
else
host_with_port = ('%s:%d'):format(host_no_port, req_port)
end
end
ctx.host_with_port = host_with_port

local hits = ctx.hits
local req_category = 0x00
Expand All @@ -1255,12 +1281,14 @@ function _M.new(routes)

-- host match

if plain_indexes.hosts[ctx.req_host] then
if plain_indexes.hosts[host_with_port] or
plain_indexes.hosts[host_no_port]
then
req_category = bor(req_category, MATCH_RULES.HOST)

elseif ctx.req_host then
for i = 1, #wildcard_hosts do
local from, _, err = re_find(ctx.req_host, wildcard_hosts[i].regex, "ajo")
local from, _, err = re_find(host_with_port, wildcard_hosts[i].regex, "ajo")
if err then
log(ERR, "could not match wildcard host: ", err)
return
Expand Down Expand Up @@ -1485,6 +1513,7 @@ function _M.new(routes)
local req_method = get_method()
local req_uri = var.request_uri
local req_host = var.http_host or ""
local req_scheme = var.scheme
local sni = var.ssl_server_name

local headers
Expand All @@ -1507,7 +1536,7 @@ function _M.new(routes)
end
end

local match_t = find_route(req_method, req_uri, req_host,
local match_t = find_route(req_method, req_uri, req_host, req_scheme,
nil, nil, -- src_ip, src_port
nil, nil, -- dst_ip, dst_port
sni, headers)
Expand Down Expand Up @@ -1550,7 +1579,7 @@ function _M.new(routes)
local dst_port = tonumber(var.server_port, 10)
local sni = var.ssl_preread_server_name

return find_route(nil, nil, nil,
return find_route(nil, nil, nil, "tcp",
src_ip, src_port,
dst_ip, dst_port,
sni)
Expand Down
41 changes: 41 additions & 0 deletions kong/tools/utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,47 @@ _M.hostname_type = function(name)
return "name"
end

--- splits an optional ':port' section from a hostname
-- the port section must be decimal digits only.
-- brackets ('[]') are peeled off the hostname if present.
-- if there's more than one colon and no brackets, no split is possible.
-- on non-parseable input, returns name unchanged,
-- every string input produces at least one string output.
-- @param name (string) the string to split.
-- @return hostname (string)
-- @return port (as number) or nil if not found
function _M.split_port(name)
local ZERO, NINE, LEFTBRACKET, RIGHTBRACKET = ('09[]'):byte(1, -1)

if name:byte(1) == LEFTBRACKET then
if name:byte(-1) == RIGHTBRACKET then
return name:sub(2, -2)
end
local splitpos = name:find(']:', 2, true)
if splitpos then
local port = tonumber(name:sub(splitpos+2))
if port or splitpos == #name-1 then
return name:sub(2, splitpos-1), port
end
end
return name
end

local firstcolon = name:find(':', 1, true)
if not firstcolon then
return name
end

for i = firstcolon+1, #name do
local c = name:byte(i)
if c < ZERO or c > NINE then
return name
end
end
return name:sub(1, firstcolon-1), tonumber(name:sub(firstcolon+1))
end


--- parses, validates and normalizes an ipv4 address.
-- @param address the string containing the address (formats; ipv4, ipv4:port)
-- @return normalized address (string) + port (number or nil), or alternatively nil+error
Expand Down
14 changes: 2 additions & 12 deletions spec/01-unit/01-db/01-schema/05-services_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -337,28 +337,18 @@ describe("services", function()

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("invalid value: " .. invalid_hosts[i], err.host)
assert.equal("invalid hostname: " .. invalid_hosts[i], err.host)
end
end)

it("rejects values with a valid port", function()
local service = {
host = "example.com:80",
}

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("must not have a port", err.host)
end)

it("rejects values with an invalid port", function()
local service = {
host = "example.com:1000000",
}

local ok, err = Services:validate(service)
assert.falsy(ok)
assert.equal("must not have a port", err.host)
assert.equal("invalid port number", err.host)
end)

-- acceptance
Expand Down
32 changes: 12 additions & 20 deletions spec/01-unit/01-db/01-schema/06-routes_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ describe("routes schema", function()
end)
end)

describe("hosts attribute", function()
describe("hosts attribute #host", function()
-- refusals
it("must be a string", function()
local route = {
Expand All @@ -379,7 +379,7 @@ describe("routes schema", function()
assert.equal("length must be at least 1", err.hosts[1])
end)

it("rejects invalid hostnames", function()
it("rejects invalid hostnames #host", function()
local invalid_hosts = {
"/example",
".example",
Expand All @@ -401,33 +401,22 @@ describe("routes schema", function()

local ok, err = Routes:validate(route)
assert.falsy(ok)
assert.equal("invalid value: " .. invalid_hosts[i], err.hosts[1])
assert.equal("invalid hostname: " .. invalid_hosts[i], err.hosts[1])
end
end)

it("rejects values with a valid port", function()
local route = {
hosts = { "example.com:80" },
protocols = { "http" },
}

local ok, err = Routes:validate(route)
assert.falsy(ok)
assert.equal("must not have a port", err.hosts[1])
end)

it("rejects values with an invalid port", function()
it("rejects values with an invalid port #host", function()
local route = {
hosts = { "example.com:1000000" },
protocols = { "http" },
}

local ok, err = Routes:validate(route)
assert.falsy(ok)
assert.equal("must not have a port", err.hosts[1])
assert.equal("invalid port number", err.hosts[1])
end)

it("rejects invalid wildcard placement", function()
it("rejects invalid wildcard placement #host", function()
local invalid_hosts = {
"*example.com",
"www.example*",
Expand All @@ -447,7 +436,7 @@ describe("routes schema", function()
end
end)

it("rejects host with too many wildcards", function()
it("rejects #host with too many wildcards", function()
local invalid_hosts = {
"*.example.*",
"**.example.com",
Expand All @@ -468,7 +457,7 @@ describe("routes schema", function()
end)

-- acceptance
it("accepts valid hosts", function()
it("accepts valid hosts #host", function()
local valid_hosts = {
"hello.com",
"hello.fr",
Expand All @@ -482,6 +471,8 @@ describe("routes schema", function()
"hello.abcd",
"example_api.com",
"localhost",
"example.com:80",
"example.com:8080",
-- below:
-- punycode examples from RFC3492;
-- https://tools.ietf.org/html/rfc3492#page-14
Expand Down Expand Up @@ -511,10 +502,11 @@ describe("routes schema", function()
end
end)

it("accepts hosts with valid wildcard", function()
it("accepts hosts with valid wildcard #host", function()
local valid_hosts = {
"example.*",
"*.example.org",
"*.example.org:321",
}

for i = 1, #valid_hosts do
Expand Down
13 changes: 4 additions & 9 deletions spec/01-unit/01-db/01-schema/09-upstreams_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -219,33 +219,28 @@ describe("load upstreams", function()
describe("upstream attribute", function()
-- refusals
it("requires a valid hostname", function()
local ok, err = Upstreams:validate({
name = "host.test",
host_header = "ahostname:80" }
)
assert.falsy(ok)
assert.same({ host_header = "must not have a port" }, err)
local ok, err

ok, err = Upstreams:validate({
name = "host.test",
host_header = "http://ahostname.test" }
)
assert.falsy(ok)
assert.same({ host_header = "invalid value: http://ahostname.test" }, err)
assert.same({ host_header = "invalid hostname: http://ahostname.test" }, err)

ok, err = Upstreams:validate({
name = "host.test",
host_header = "ahostname-" }
)
assert.falsy(ok)
assert.same({ host_header = "invalid value: ahostname-" }, err)
assert.same({ host_header = "invalid hostname: ahostname-" }, err)

ok, err = Upstreams:validate({
name = "host.test",
host_header = "a hostname" }
)
assert.falsy(ok)
assert.same({ host_header = "invalid value: a hostname" }, err)
assert.same({ host_header = "invalid hostname: a hostname" }, err)
end)

-- acceptance
Expand Down
Loading

0 comments on commit 9e1b511

Please sign in to comment.