diff --git a/kong/db/schema/typedefs.lua b/kong/db/schema/typedefs.lua index 24efd46e1d2f..cc46cdd81af9 100644 --- a/kong/db/schema/typedefs.lua +++ b/kong/db/schema/typedefs.lua @@ -20,15 +20,7 @@ local type = type local function validate_host(host) local res, err_or_port = utils.normalize_ip(host) - if type(err_or_port) == "string" and err_or_port ~= "invalid port number" then - return nil, "invalid value: " .. host - end - - if err_or_port == "invalid port number" or type(res.port) == "number" then - return nil, "must not have a port" - end - - return true + return res and true or nil, err_or_port end diff --git a/kong/router.lua b/kong/router.lua index e722d0cd6297..0d95a7ddf01b 100644 --- a/kong/router.lua +++ b/kong/router.lua @@ -5,6 +5,7 @@ local bit = require "bit" local hostname_type = utils.hostname_type +local split_port = utils.split_port local subsystem = ngx.config.subsystem local get_method = ngx.req.get_method local get_headers = ngx.req.get_headers @@ -15,9 +16,11 @@ local var = ngx.var local ngx_log = ngx.log local insert = table.insert local sort = table.sort +local byte = string.byte local upper = string.upper local lower = string.lower local find = string.find +local format = string.format local sub = string.sub local tonumber = tonumber local ipairs = ipairs @@ -88,6 +91,7 @@ end) local MATCH_SUBRULES = { HAS_REGEX_URI = 0x01, PLAIN_HOSTS_ONLY = 0x02, + HAS_HOST_PORT = 0x04, } local EMPTY_T = {} @@ -216,6 +220,7 @@ local function marshall_route(r) local has_host_wildcard local has_host_plain + local host_port for _, host in ipairs(hosts) do if type(host) ~= "string" then @@ -228,6 +233,12 @@ local function marshall_route(r) local wildcard_host_regex = host:gsub("%.", "\\.") :gsub("%*", ".+") .. "$" + + _, host_port = utils.split_port(host) + if not host_port then + wildcard_host_regex = wildcard_host_regex:gsub("%$$", [[(?::\d+)?$]]) + end + insert(route_t.hosts, { wildcard = true, value = host, @@ -255,6 +266,11 @@ local function marshall_route(r) route_t.submatch_weight = bor(route_t.submatch_weight, MATCH_SUBRULES.PLAIN_HOSTS_ONLY) end + + if host_port then + route_t.submatch_weight = bor(route_t.submatch_weight, + MATCH_SUBRULES.HAS_HOST_PORT) + end end @@ -661,7 +677,9 @@ end do local matchers = { [MATCH_RULES.HOST] = function(route_t, ctx) - local host = route_t.hosts[ctx.hits.host or ctx.req_host] + local req_host = ctx.hits.host or ctx.req_host + local host = route_t.hosts[req_host] + or route_t.hosts[split_port(req_host)] if host then ctx.matches.host = host return true @@ -671,7 +689,7 @@ do local host_t = route_t.hosts[i] if host_t.wildcard then - local from, _, err = re_find(ctx.req_host, host_t.regex, "ajo") + local from, _, err = re_find(ctx.host_with_port, host_t.regex, "ajo") if err then log(ERR, "could not evaluate wildcard host regex: ", err) return @@ -1167,7 +1185,7 @@ function _M.new(routes) local grab_req_headers = #plain_indexes.headers > 0 - local function find_route(req_method, req_uri, req_host, + local function find_route(req_method, req_uri, req_host, req_scheme, src_ip, src_port, dst_ip, dst_port, sni, req_headers) @@ -1180,6 +1198,9 @@ function _M.new(routes) if req_host and type(req_host) ~= "string" then error("host must be a string", 2) end + if req_scheme and type(req_scheme) ~= "string" then + error("scheme must be a string", 2) + end if src_ip and type(src_ip) ~= "string" then error("src_ip must be a string", 2) end @@ -1236,13 +1257,29 @@ function _M.new(routes) req_method = upper(req_method) - if req_host ~= "" then - -- strip port number if given because matching ignores ports - local idx = find(req_host, ":", 2, true) - if idx then - ctx.req_host = sub(req_host, 1, idx - 1) + -- req_host might have port or maybe not, host_no_port definitely doesn't + -- if there wasn't a port, req_port is assumed to be the default port + -- according the protocol scheme + local host_with_port = req_host + local host_no_port, req_port = utils.split_port(host_with_port) + if not req_port then + req_port = 80 + if req_scheme == 'https' then + req_port = 443 + end + + -- a literal IPv6 address needs the format [host]:port, + -- the fastest way to check is if the first byte was stripped by split_port() + -- otherwise, check for a colon + if byte(host_with_port) ~= byte(host_no_port) + or find(host_no_port, ":", 1, true) + then + host_with_port = format("[%s]:%d", host_no_port, req_port) + else + host_with_port = format("%s:%d", host_no_port, req_port) end end + ctx.host_with_port = host_with_port local hits = ctx.hits local req_category = 0x00 @@ -1255,12 +1292,14 @@ function _M.new(routes) -- host match - if plain_indexes.hosts[ctx.req_host] then + if plain_indexes.hosts[host_with_port] + or plain_indexes.hosts[host_no_port] + then req_category = bor(req_category, MATCH_RULES.HOST) elseif ctx.req_host then for i = 1, #wildcard_hosts do - local from, _, err = re_find(ctx.req_host, wildcard_hosts[i].regex, "ajo") + local from, _, err = re_find(host_with_port, wildcard_hosts[i].regex, "ajo") if err then log(ERR, "could not match wildcard host: ", err) return @@ -1485,6 +1524,7 @@ function _M.new(routes) local req_method = get_method() local req_uri = var.request_uri local req_host = var.http_host or "" + local req_scheme = var.scheme local sni = var.ssl_server_name local headers @@ -1507,7 +1547,7 @@ function _M.new(routes) end end - local match_t = find_route(req_method, req_uri, req_host, + local match_t = find_route(req_method, req_uri, req_host, req_scheme, nil, nil, -- src_ip, src_port nil, nil, -- dst_ip, dst_port sni, headers) @@ -1550,7 +1590,7 @@ function _M.new(routes) local dst_port = tonumber(var.server_port, 10) local sni = var.ssl_preread_server_name - return find_route(nil, nil, nil, + return find_route(nil, nil, nil, "tcp", src_ip, src_port, dst_ip, dst_port, sni) diff --git a/kong/tools/utils.lua b/kong/tools/utils.lua index 0775a036ea31..e814ea259820 100644 --- a/kong/tools/utils.lua +++ b/kong/tools/utils.lua @@ -24,10 +24,12 @@ local tostring = tostring local sort = table.sort local concat = table.concat local insert = table.insert +local byte = string.byte local lower = string.lower local fmt = string.format local find = string.find local gsub = string.gsub +local sub = string.sub local split = pl_stringx.split local re_find = ngx.re.find local re_match = ngx.re.match @@ -661,6 +663,63 @@ _M.hostname_type = function(name) return "name" end +do + local ZERO, NINE, LEFTBRACKET, RIGHTBRACKET = ("09[]"):byte(1, -1) + + --- splits an optional ':port' section from a hostname + -- the port section must be decimal digits only. + -- brackets ('[]') are peeled off the hostname if present. + -- if there's more than one colon and no brackets, no split is possible. + -- on non-parseable input, returns name unchanged, + -- every string input produces at least one string output. + -- @param name (string) the string to split. + -- @treturn string hostname + -- @treturn number port or nil if not found + local function l_split_port(name) + if byte(name, 1) == LEFTBRACKET then + if byte(name, -1) == RIGHTBRACKET then + return sub(name, 2, -2) + end + local splitpos = find(name, "]:", 2, true) + if splitpos then + local port = tonumber(sub(name, splitpos+2)) + if port or splitpos == #name - 1 then + return sub(name, 2, splitpos - 1), port + end + end + return name + end + + local firstcolon = find(name, ":", 1, true) + if not firstcolon then + return name + end + + for i = firstcolon + 1, #name do + local c = byte(name, i) + if c < ZERO or c > NINE then + return name + end + end + return sub(name, 1, firstcolon - 1), tonumber(name:sub(firstcolon+1)) + end + + -- split_port is a pure function, so we can memoize it. + local memo_h = setmetatable({}, {__mode = 'k'}) + local memo_p = setmetatable({}, {__mode = 'k'}) + + function _M.split_port(name) + local h, p = memo_h[name], memo_p[name] + if not h then + h, p = l_split_port(name) + memo_h[name], memo_p[name] = h, p + end + + return h, p + end +end + + --- parses, validates and normalizes an ipv4 address. -- @param address the string containing the address (formats; ipv4, ipv4:port) -- @return normalized address (string) + port (number or nil), or alternatively nil+error diff --git a/spec/01-unit/01-db/01-schema/05-services_spec.lua b/spec/01-unit/01-db/01-schema/05-services_spec.lua index 5418762745fe..f211195c1a24 100644 --- a/spec/01-unit/01-db/01-schema/05-services_spec.lua +++ b/spec/01-unit/01-db/01-schema/05-services_spec.lua @@ -337,20 +337,10 @@ describe("services", function() local ok, err = Services:validate(service) assert.falsy(ok) - assert.equal("invalid value: " .. invalid_hosts[i], err.host) + assert.equal("invalid hostname: " .. invalid_hosts[i], err.host) end end) - it("rejects values with a valid port", function() - local service = { - host = "example.com:80", - } - - local ok, err = Services:validate(service) - assert.falsy(ok) - assert.equal("must not have a port", err.host) - end) - it("rejects values with an invalid port", function() local service = { host = "example.com:1000000", @@ -358,7 +348,7 @@ describe("services", function() local ok, err = Services:validate(service) assert.falsy(ok) - assert.equal("must not have a port", err.host) + assert.equal("invalid port number", err.host) end) -- acceptance diff --git a/spec/01-unit/01-db/01-schema/06-routes_spec.lua b/spec/01-unit/01-db/01-schema/06-routes_spec.lua index 966499b3a0b0..5a693fc67420 100644 --- a/spec/01-unit/01-db/01-schema/06-routes_spec.lua +++ b/spec/01-unit/01-db/01-schema/06-routes_spec.lua @@ -401,21 +401,10 @@ describe("routes schema", function() local ok, err = Routes:validate(route) assert.falsy(ok) - assert.equal("invalid value: " .. invalid_hosts[i], err.hosts[1]) + assert.equal("invalid hostname: " .. invalid_hosts[i], err.hosts[1]) end end) - it("rejects values with a valid port", function() - local route = { - hosts = { "example.com:80" }, - protocols = { "http" }, - } - - local ok, err = Routes:validate(route) - assert.falsy(ok) - assert.equal("must not have a port", err.hosts[1]) - end) - it("rejects values with an invalid port", function() local route = { hosts = { "example.com:1000000" }, @@ -424,7 +413,7 @@ describe("routes schema", function() local ok, err = Routes:validate(route) assert.falsy(ok) - assert.equal("must not have a port", err.hosts[1]) + assert.equal("invalid port number", err.hosts[1]) end) it("rejects invalid wildcard placement", function() @@ -482,6 +471,8 @@ describe("routes schema", function() "hello.abcd", "example_api.com", "localhost", + "example.com:80", + "example.com:8080", -- below: -- punycode examples from RFC3492; -- https://tools.ietf.org/html/rfc3492#page-14 @@ -515,6 +506,7 @@ describe("routes schema", function() local valid_hosts = { "example.*", "*.example.org", + "*.example.org:321", } for i = 1, #valid_hosts do diff --git a/spec/01-unit/01-db/01-schema/09-upstreams_spec.lua b/spec/01-unit/01-db/01-schema/09-upstreams_spec.lua index 3d04030d1d08..b7e31a3941b3 100644 --- a/spec/01-unit/01-db/01-schema/09-upstreams_spec.lua +++ b/spec/01-unit/01-db/01-schema/09-upstreams_spec.lua @@ -219,33 +219,28 @@ describe("load upstreams", function() describe("upstream attribute", function() -- refusals it("requires a valid hostname", function() - local ok, err = Upstreams:validate({ - name = "host.test", - host_header = "ahostname:80" } - ) - assert.falsy(ok) - assert.same({ host_header = "must not have a port" }, err) + local ok, err ok, err = Upstreams:validate({ name = "host.test", host_header = "http://ahostname.test" } ) assert.falsy(ok) - assert.same({ host_header = "invalid value: http://ahostname.test" }, err) + assert.same({ host_header = "invalid hostname: http://ahostname.test" }, err) ok, err = Upstreams:validate({ name = "host.test", host_header = "ahostname-" } ) assert.falsy(ok) - assert.same({ host_header = "invalid value: ahostname-" }, err) + assert.same({ host_header = "invalid hostname: ahostname-" }, err) ok, err = Upstreams:validate({ name = "host.test", host_header = "a hostname" } ) assert.falsy(ok) - assert.same({ host_header = "invalid value: a hostname" }, err) + assert.same({ host_header = "invalid hostname: a hostname" }, err) end) -- acceptance diff --git a/spec/01-unit/05-utils_spec.lua b/spec/01-unit/05-utils_spec.lua index 105242cf4cb6..e4db0c2614ad 100644 --- a/spec/01-unit/05-utils_spec.lua +++ b/spec/01-unit/05-utils_spec.lua @@ -437,6 +437,24 @@ describe("Utils", function() end) describe("hostnames and ip addresses", function() + it("splits port number", function() + for _, case in ipairs({ + { '', { '', nil } }, + { 'localhost', { 'localhost', nil } }, + { 'localhost:', { 'localhost', nil } }, + { 'localhost:80', { 'localhost', 80 } }, + { 'localhost:23h', { 'localhost:23h', nil } }, + { 'localhost/24', { 'localhost/24', nil } }, + { '::1', { '::1', nil } }, + { '[::1]', { '::1', nil } }, + { '[::1]:', { '::1', nil } }, + { '[::1]:80', { '::1', 80 } }, + { '[::1]:80b', { '[::1]:80b', nil } }, + { '[::1]/96', { '[::1]/96', nil } }, + }) do + assert.same(case[2], { utils.split_port(case[1]) }) + end + end) describe("hostname_type", function() -- no check on "name" type as anything not ipv4 and not ipv6 will be labelled as 'name' anyway it("checks valid IPv4 address types", function() diff --git a/spec/01-unit/08-router_spec.lua b/spec/01-unit/08-router_spec.lua index a042c7c78b2f..d208e25c9f03 100644 --- a/spec/01-unit/08-router_spec.lua +++ b/spec/01-unit/08-router_spec.lua @@ -185,6 +185,25 @@ local use_case = { }, }, }, + -- 13. host + port + { + service = service, + route = { + hosts = { + "domain-1.org:321", + "domain-2.org" + }, + }, + }, + -- 14. no "any-port" route + { + service = service, + route = { + hosts = { + "domain-3.org:321", + }, + }, + }, } describe("Router", function() @@ -227,8 +246,18 @@ describe("Router", function() assert.same(nil, match_t.matches.uri_captures) end) - it("[host] ignores port", function() + it("[host] ignores default port", function() -- host + local match_t = router.select("GET", "/", "domain-1.org:80") + assert.truthy(match_t) + assert.equal(use_case[1].route, match_t.route) + assert.same(use_case[1].route.hosts[1], match_t.matches.host) + assert.same(nil, match_t.matches.method) + assert.same(nil, match_t.matches.uri) + assert.same(nil, match_t.matches.uri_captures) + end) + + it("[host] weird port matches no-port route", function() local match_t = router.select("GET", "/", "domain-1.org:123") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) @@ -238,6 +267,34 @@ describe("Router", function() assert.same(nil, match_t.matches.uri_captures) end) + it("[host] matches specific port", function() + -- host + local match_t = router.select("GET", "/", "domain-1.org:321") + assert.truthy(match_t) + assert.equal(use_case[13].route, match_t.route) + assert.same(use_case[13].route.hosts[1], match_t.matches.host) + assert.same(nil, match_t.matches.method) + assert.same(nil, match_t.matches.uri) + assert.same(nil, match_t.matches.uri_captures) + end) + + it("[host] matches specific port on port-only route", function() + -- host + local match_t = router.select("GET", "/", "domain-3.org:321") + assert.truthy(match_t) + assert.equal(use_case[14].route, match_t.route) + assert.same(use_case[14].route.hosts[1], match_t.matches.host) + assert.same(nil, match_t.matches.method) + assert.same(nil, match_t.matches.uri) + assert.same(nil, match_t.matches.uri_captures) + end) + + it("[host] fails just because of port on port-only route", function() + -- host + local match_t = router.select("GET", "/", "domain-3.org:123") + assert.falsy(match_t) + end) + it("[uri]", function() -- uri local match_t = router.select("GET", "/my-route", "domain.org") @@ -319,7 +376,7 @@ describe("Router", function() it("single [headers] value", function() -- headers (single) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = "my-location-1" }) assert.truthy(match_t) @@ -329,7 +386,7 @@ describe("Router", function() assert.same(nil, match_t.matches.uri_captures) assert.same({ location = "my-location-1" }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = "my-location-2" }) assert.truthy(match_t) @@ -339,7 +396,7 @@ describe("Router", function() assert.same(nil, match_t.matches.uri_captures) assert.same({ location = "my-location-2" }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = { "my-location-3", "my-location-2" } }) assert.truthy(match_t) @@ -349,12 +406,12 @@ describe("Router", function() assert.same(nil, match_t.matches.uri_captures) assert.same({ location = "my-location-2" }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = "my-location-3" }) assert.is_nil(match_t) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = { "my-location-3", "foo" } }) assert.is_nil(match_t) @@ -362,7 +419,7 @@ describe("Router", function() it("multiple [headers] values", function() -- headers (multiple) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = "my-location-1", version = "v1", }) @@ -374,7 +431,7 @@ describe("Router", function() assert.same({ location = "my-location-1", version = "v1", }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = "my-location-1", version = "v2", }) @@ -386,7 +443,7 @@ describe("Router", function() assert.same({ location = "my-location-1", version = "v2", }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = { "my-location-3", "my-location-1" }, version = "v2", }) @@ -398,7 +455,7 @@ describe("Router", function() assert.same({ location = "my-location-1", version = "v2", }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, { + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = { "my-location-3", "my-location-2" }, version = "v2", }) @@ -413,7 +470,7 @@ describe("Router", function() it("[headers + uri]", function() -- headers + uri - local match_t = router.select("GET", "/headers-uri", nil, nil, nil, nil, + local match_t = router.select("GET", "/headers-uri", nil, "http", nil, nil, nil, nil, nil, { location = "my-location-2" }) assert.truthy(match_t) assert.same(use_case[11].route, match_t.route) @@ -426,7 +483,7 @@ describe("Router", function() it("[host + headers + uri + method]", function() -- host + headers + uri + method local match_t = router.select("PUT", "/headers-host-uri-method", - "domain-with-headers-1.org", + "domain-with-headers-1.org", "http", nil, nil, nil, nil, nil, { location = "my-location-2", }) @@ -449,6 +506,81 @@ describe("Router", function() assert.same(use_case[8].route.paths[1], match_t.matches.uri) end) + describe("[IPv6 #literal host]", function() + local use_case = { + -- 1: no port, with and without brackets, unique IPs + { + service = service, + route = { + hosts = { "::11", "[::12]" }, + }, + }, + + -- 2: no port, with and without brackets, same hosts as 4 + { + service = service, + route = { + hosts = { "::21", "[::22]" }, + }, + }, + + -- 3: unique IPs, with port + { + service = service, + route = { + hosts = { "[::31]:321", "[::32]:321" }, + }, + }, + + -- 4: same hosts as 2, with port, needs brackets + { + service = service, + route = { + hosts = { "[::21]:321", "[::22]:321" }, + }, + }, + } + local router = assert(Router.new(use_case)) + + describe("no-port route is any-port", function() + describe("no-port request", function() + it("plain match", function() + local match_t = assert(router.select("GET", "/", "::11")) + assert.equal(use_case[1].route, match_t.route) + end) + it("with brackets", function() + local match_t = assert(router.select("GET", "/", "[::11]")) + assert.equal(use_case[1].route, match_t.route) + end) + end) + + it("explicit port still matches", function() + local match_t = assert(router.select("GET", "/", "[::11]:654")) + assert.equal(use_case[1].route, match_t.route) + end) + end) + + describe("port-specific route", function() + it("matches by port", function() + local match_t = assert(router.select("GET", "/", "[::21]:321")) + assert.equal(use_case[4].route, match_t.route) + + local match_t = assert(router.select("GET", "/", "[::31]:321")) + assert.equal(use_case[3].route, match_t.route) + end) + + it("matches other ports to any-port fallback", function() + local match_t = assert(router.select("GET", "/", "[::21]:654")) + assert.equal(use_case[2].route, match_t.route) + end) + + it("fails if there's no any-port route", function() + local match_t = router.select("GET", "/", "[::31]:654") + assert.falsy(match_t) + end) + end) + end) + describe("[uri prefix]", function() it("matches when given [uri] is in request URI prefix", function() -- uri prefix @@ -777,6 +909,105 @@ describe("Router", function() assert.equal(use_case[2].route, match_t.route) end) + it("matches any port in request", function() + local match_t = router.select("GET", "/", "route.org:123") + assert.truthy(match_t) + assert.equal(use_case[2].route, match_t.route) + + local match_t = router.select("GET", "/", "foo.route.com:123", "domain.org") + assert.truthy(match_t) + assert.equal(use_case[1].route, match_t.route) + end) + + it("matches port-specific routes", function() + table.insert(use_case, { + service = service, + route = { + hosts = { "*.route.net:123" }, + }, + }) + table.insert(use_case, { + service = service, + route = { + hosts = { "route.*:123" }, -- same as [2] but port-specific + }, + }) + router = assert(Router.new(use_case)) + + finally(function() + table.remove(use_case) + table.remove(use_case) + router = assert(Router.new(use_case)) + end) + + -- match the right port + local match_t = router.select("GET", "/", "foo.route.net:123") + assert.truthy(match_t) + assert.equal(use_case[3].route, match_t.route) + + -- fail different port + assert.is_nil(router.select("GET", "/", "foo.route.net:456")) + + -- port-specific is higher priority + local match_t = router.select("GET", "/", "route.org:123") + assert.truthy(match_t) + assert.equal(use_case[4].route, match_t.route) + end) + + it("prefers port-specific even for http default port", function() + table.insert(use_case, { + service = service, + route = { + hosts = { "route.*:80" }, -- same as [2] but port-specific + }, + }) + router = assert(Router.new(use_case)) + + finally(function() + table.remove(use_case) + router = assert(Router.new(use_case)) + end) + + -- non-port matches any + local match_t = assert(router.select("GET", "/", "route.org:123")) + assert.equal(use_case[2].route, match_t.route) + + -- port 80 goes to port-specific route + local match_t = assert(router.select("GET", "/", "route.org:80")) + assert.equal(use_case[3].route, match_t.route) + + -- even if it's implicit port 80 + local match_t = assert(router.select("GET", "/", "route.org")) + assert.equal(use_case[3].route, match_t.route) + end) + + it("prefers port-specific even for https default port", function() + table.insert(use_case, { + service = service, + route = { + hosts = { "route.*:443" }, -- same as [2] but port-specific + }, + }) + router = assert(Router.new(use_case)) + + finally(function() + table.remove(use_case) + router = assert(Router.new(use_case)) + end) + + -- non-port matches any + local match_t = assert(router.select("GET", "/", "route.org:123")) + assert.equal(use_case[2].route, match_t.route) + + -- port 80 goes to port-specific route + local match_t = assert(router.select("GET", "/", "route.org:443")) + assert.equal(use_case[3].route, match_t.route) + + -- even if it's implicit port 80 + local match_t = assert(router.select("GET", "/", "route.org", "https")) + assert.equal(use_case[3].route, match_t.route) + end) + it("does not take precedence over a plain host", function() table.insert(use_case, 1, { service = service, @@ -982,7 +1213,7 @@ describe("Router", function() local router = assert(Router.new(use_case)) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { version = "v1", user_agent = "foo", @@ -1014,7 +1245,7 @@ describe("Router", function() local router = assert(Router.new(use_case)) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, setmetatable({ user_agent = "foo", }, headers_mt)) @@ -1022,7 +1253,7 @@ describe("Router", function() assert.equal(use_case[1].route, match_t.route) assert.same({ user_agent = "foo" }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, setmetatable({ ["USER_AGENT"] = "baz", }, headers_mt)) @@ -1053,7 +1284,7 @@ describe("Router", function() local router = assert(Router.new(use_case)) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { user_agent = "FOO", }) @@ -1061,7 +1292,7 @@ describe("Router", function() assert.equal(use_case[1].route, match_t.route) assert.same({ user_agent = "foo" }, match_t.matches.headers) - local match_t = router.select("GET", "/", nil, nil, nil, nil, nil, nil, + local match_t = router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { user_agent = "baz", }) @@ -1073,7 +1304,97 @@ describe("Router", function() describe("edge-cases", function() it("[host] and [uri] have higher priority than [method]", function() - -- host + local use_case = { + -- 1. host + { + service = service, + route = { + hosts = { + "domain-1.org", + "domain-2.org" + }, + }, + }, + -- 2. method + { + service = service, + route = { + methods = { + "TRACE" + }, + } + }, + -- 3. uri + { + service = service, + route = { + paths = { + "/my-route" + }, + } + }, + -- 4. host + uri + { + service = service, + route = { + paths = { + "/route-4" + }, + hosts = { + "domain-1.org", + "domain-2.org" + }, + }, + }, + -- 5. host + method + { + service = service, + route = { + hosts = { + "domain-1.org", + "domain-2.org" + }, + methods = { + "POST", + "PUT", + "PATCH" + }, + }, + }, + -- 6. uri + method + { + service = service, + route = { + methods = { + "POST", + "PUT", + "PATCH", + }, + paths = { + "/route-6" + }, + } + }, + -- 7. host + uri + method + { + service = service, + route = { + hosts = { + "domain-with-uri-1.org", + "domain-with-uri-2.org" + }, + methods = { + "POST", + "PUT", + "PATCH", + }, + paths = { + "/my-route-uri" + }, + }, + }, + } + local router = assert(Router.new(use_case)) local match_t = router.select("TRACE", "/", "domain-2.org") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) @@ -1350,7 +1671,7 @@ describe("Router", function() local router = assert(Router.new(use_case)) - local match_t = router.select("GET", "/my-route/hello", "domain.org", + local match_t = router.select("GET", "/my-route/hello", "domain.org", "http", nil, nil, nil, nil, nil, { version = "v1", location = "us-east", @@ -1361,7 +1682,7 @@ describe("Router", function() assert.same({ version = "v1", location = "us-east" }, match_t.matches.headers) - local match_t = router.select("GET", "/my-route/hello/world", + local match_t = router.select("GET", "/my-route/hello/world", "http", "domain.org", nil, nil, nil, nil, nil, { version = "v1", location = "us-east", @@ -1403,25 +1724,25 @@ describe("Router", function() end) it("invalid [headers]", function() - assert.is_nil(router.select("GET", "/", nil, nil, nil, nil, nil, nil, + assert.is_nil(router.select("GET", "/", nil, "http", nil, nil, nil, nil, nil, { location = "invalid-location" })) end) it("invalid headers in [headers + uri]", function() assert.is_nil(router.select("GET", "/headers-uri", - nil, nil, nil, nil, nil, nil, + nil, "http", nil, nil, nil, nil, nil, { location = "invalid-location" })) end) it("invalid headers in [headers + uri + method]", function() assert.is_nil(router.select("PUT", "/headers-uri-method", - nil, nil, nil, nil, nil, nil, + nil, "http", nil, nil, nil, nil, nil, { location = "invalid-location" })) end) it("invalid headers in [headers + host + uri + method]", function() assert.is_nil(router.select("PUT", "/headers-host-uri-method", - nil, nil, nil, nil, nil, nil, + nil, "http", nil, nil, nil, nil, nil, { location = "invalid-location", host = "domain-with-headers-1.org" })) end) @@ -1534,7 +1855,7 @@ describe("Router", function() it("takes < 1ms", function() local match_t = router.select("GET", "/", - nil, nil, nil, nil, nil, nil, + nil, "http", nil, nil, nil, nil, nil, { location = target_location }) assert.truthy(match_t) assert.same(benchmark_use_cases[#benchmark_use_cases].route, @@ -1569,7 +1890,7 @@ describe("Router", function() it("takes < 1ms", function() local match_t = router.select("GET", "/", - nil, nil, nil, nil, nil, nil, + nil, "http", nil, nil, nil, nil, nil, { [target_key] = target_val }) assert.truthy(match_t) assert.same(benchmark_use_cases[#benchmark_use_cases].route, @@ -1666,7 +1987,7 @@ describe("Router", function() end) it("takes < 1ms", function() - local match_t = router.select("POST", target_uri, target_domain, + local match_t = router.select("POST", target_uri, target_domain, "http", nil, nil, nil, nil, nil, { location = target_location, }) @@ -1693,26 +2014,30 @@ describe("Router", function() assert.error_matches(function() router.select("GET", "/", "", 1) + end, "scheme must be a string", nil, true) + + assert.error_matches(function() + router.select("GET", "/", "", "http", 1) end, "src_ip must be a string", nil, true) assert.error_matches(function() - router.select("GET", "/", "", nil, "") + router.select("GET", "/", "", "http", nil, "") end, "src_port must be a number", nil, true) assert.error_matches(function() - router.select("GET", "/", "", nil, nil, 1) + router.select("GET", "/", "", "http", nil, nil, 1) end, "dst_ip must be a string", nil, true) assert.error_matches(function() - router.select("GET", "/", "", nil, nil, nil, "") + router.select("GET", "/", "", "http", nil, nil, nil, "") end, "dst_port must be a number", nil, true) assert.error_matches(function() - router.select("GET", "/", "", nil, nil, nil, nil, 1) + router.select("GET", "/", "", "http", nil, nil, nil, nil, 1) end, "sni must be a string", nil, true) assert.error_matches(function() - router.select("GET", "/", "", nil, nil, nil, nil, nil, 1) + router.select("GET", "/", "", "http", nil, nil, nil, nil, nil, 1) end, "headers must be a table", nil, true) end) end) @@ -2641,44 +2966,44 @@ describe("Router", function() local router = assert(Router.new(use_case)) it("[src_ip]", function() - local match_t = router.select(nil, nil, nil, "127.0.0.1") + local match_t = router.select(nil, nil, nil, "tcp", "127.0.0.1") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) - match_t = router.select(nil, nil, nil, "127.0.0.1") + match_t = router.select(nil, nil, nil, "tcp", "127.0.0.1") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) end) it("[src_port]", function() - local match_t = router.select(nil, nil, nil, "127.0.0.3", 65001) + local match_t = router.select(nil, nil, nil, "tcp", "127.0.0.3", 65001) assert.truthy(match_t) assert.equal(use_case[2].route, match_t.route) end) it("[src_ip] range match", function() - local match_t = router.select(nil, nil, nil, "127.168.0.1") + local match_t = router.select(nil, nil, nil, "tcp", "127.168.0.1") assert.truthy(match_t) assert.equal(use_case[3].route, match_t.route) end) it("[src_ip] + [src_port]", function() - local match_t = router.select(nil, nil, nil, "127.0.0.1", 65001) + local match_t = router.select(nil, nil, nil, "tcp", "127.0.0.1", 65001) assert.truthy(match_t) assert.equal(use_case[4].route, match_t.route) end) it("[src_ip] range match + [src_port]", function() - local match_t = router.select(nil, nil, nil, "127.168.10.1", 65301) + local match_t = router.select(nil, nil, nil, "tcp", "127.168.10.1", 65301) assert.truthy(match_t) assert.equal(use_case[5].route, match_t.route) end) it("[src_ip] no match", function() - local match_t = router.select(nil, nil, nil, "10.0.0.1") + local match_t = router.select(nil, nil, nil, "tcp", "10.0.0.1") assert.falsy(match_t) - match_t = router.select(nil, nil, nil, "10.0.0.2", 65301) + match_t = router.select(nil, nil, nil, "tcp", "10.0.0.2", 65301) assert.falsy(match_t) end) end) @@ -2737,51 +3062,51 @@ describe("Router", function() local router = assert(Router.new(use_case)) it("[dst_ip]", function() - local match_t = router.select(nil, nil, nil, nil, nil, + local match_t = router.select(nil, nil, nil, "tcp", nil, nil, "127.0.0.1") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) - match_t = router.select(nil, nil, nil, nil, nil, + match_t = router.select(nil, nil, nil, "tcp", nil, nil, "127.0.0.1") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) end) it("[dst_port]", function() - local match_t = router.select(nil, nil, nil, nil, nil, + local match_t = router.select(nil, nil, nil, "tcp", nil, nil, "127.0.0.3", 65001) assert.truthy(match_t) assert.equal(use_case[2].route, match_t.route) end) it("[dst_ip] range match", function() - local match_t = router.select(nil, nil, nil, nil, nil, + local match_t = router.select(nil, nil, nil, "tcp", nil, nil, "127.168.0.1") assert.truthy(match_t) assert.equal(use_case[3].route, match_t.route) end) it("[dst_ip] + [dst_port]", function() - local match_t = router.select(nil, nil, nil, nil, nil, + local match_t = router.select(nil, nil, nil, "tcp", nil, nil, "127.0.0.1", 65001) assert.truthy(match_t) assert.equal(use_case[4].route, match_t.route) end) it("[dst_ip] range match + [dst_port]", function() - local match_t = router.select(nil, nil, nil, nil, nil, + local match_t = router.select(nil, nil, nil, "tcp", nil, nil, "127.168.10.1", 65301) assert.truthy(match_t) assert.equal(use_case[5].route, match_t.route) end) it("[dst_ip] no match", function() - local match_t = router.select(nil, nil, nil, nil, nil, + local match_t = router.select(nil, nil, nil, "tcp", nil, nil, "10.0.0.1") assert.falsy(match_t) - match_t = router.select(nil, nil, nil, nil, nil, + match_t = router.select(nil, nil, nil, "tcp", nil, nil, "10.0.0.2", 65301) assert.falsy(match_t) end) @@ -2801,7 +3126,7 @@ describe("Router", function() local router = assert(Router.new(use_case)) it("[sni]", function() - local match_t = router.select(nil, nil, nil, nil, nil, nil, nil, + local match_t = router.select(nil, nil, nil, "tcp", nil, nil, nil, nil, "www.example.org") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) @@ -2837,12 +3162,12 @@ describe("Router", function() local router = assert(Router.new(use_case)) - local match_t = router.select(nil, nil, nil, "127.0.0.1", nil, + local match_t = router.select(nil, nil, nil, "tcp", "127.0.0.1", nil, nil, nil, "www.example.org") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) - match_t = router.select(nil, nil, nil, nil, nil, + match_t = router.select(nil, nil, nil, "tcp", nil, nil, "172.168.0.1", nil, "www.example.org") assert.truthy(match_t) assert.equal(use_case[1].route, match_t.route) @@ -2871,7 +3196,7 @@ describe("Router", function() local router = assert(Router.new(use_case)) - local match_t = router.select(nil, nil, nil, "127.0.0.1", nil, + local match_t = router.select(nil, nil, nil, "tcp", "127.0.0.1", nil, "172.168.0.1", nil, "www.example.org") assert.truthy(match_t) assert.equal(use_case[2].route, match_t.route) diff --git a/spec/02-integration/02-cmd/02-start_stop_spec.lua b/spec/02-integration/02-cmd/02-start_stop_spec.lua index e2215afeac27..f6e6b4657019 100644 --- a/spec/02-integration/02-cmd/02-start_stop_spec.lua +++ b/spec/02-integration/02-cmd/02-start_stop_spec.lua @@ -438,7 +438,7 @@ describe("kong start/stop #" .. strategy, function() in 'routes': - in entry 1 of 'routes': in 'hosts': - - in entry 2 of 'hosts': invalid value: \\99 + - in entry 2 of 'hosts': invalid hostname: \\99 ]], err, nil, true) end) end diff --git a/spec/02-integration/05-proxy/02-router_spec.lua b/spec/02-integration/05-proxy/02-router_spec.lua index 480ca50d86a2..10b7df1a2eec 100644 --- a/spec/02-integration/05-proxy/02-router_spec.lua +++ b/spec/02-integration/05-proxy/02-router_spec.lua @@ -352,12 +352,20 @@ for _, strategy in helpers.each_strategy() do routes = insert_routes(bp, { { protocols = { "grpc", "grpcs" }, - hosts = { "grpc1" }, + hosts = { + "grpc1", + "grpc1:" .. helpers.get_proxy_port(false, true), + "grpc1:"..helpers.get_proxy_port(true, true), + }, service = service, }, { protocols = { "grpc", "grpcs" }, - hosts = { "grpc2" }, + hosts = { + "grpc2", + "grpc2:" .. helpers.get_proxy_port(false, true), + "grpc2:"..helpers.get_proxy_port(true, true), + }, service = service, }, { @@ -865,7 +873,7 @@ for _, strategy in helpers.each_strategy() do routes = insert_routes(bp, { { preserve_host = true, - hosts = { "preserved.com" }, + hosts = { "preserved.com", "preserved.com:123" }, service = { path = "/request" }, diff --git a/spec/02-integration/05-proxy/10-balancer/01-ring-balancer_spec.lua b/spec/02-integration/05-proxy/10-balancer/01-ring-balancer_spec.lua index 2011a28a980d..2144eb385a6a 100644 --- a/spec/02-integration/05-proxy/10-balancer/01-ring-balancer_spec.lua +++ b/spec/02-integration/05-proxy/10-balancer/01-ring-balancer_spec.lua @@ -1874,7 +1874,11 @@ for _, strategy in helpers.each_strategy() do add_target(bp, upstream_id, localhost, port1) add_target(bp, upstream_id, localhost, port2) - local _, service_id, route_id = add_api(bp, upstream_name, 500, 500, nil, nil, "tcp") + local _, service_id, route_id = add_api(bp, upstream_name, { + read_timeout = 500, + write_timeout = 500, + route_protocol = "tcp", + }) end_testcase_setup(strategy, bp) finally(function() diff --git a/spec/02-integration/05-proxy/19-grpc_proxy_spec.lua b/spec/02-integration/05-proxy/19-grpc_proxy_spec.lua index cbc751c7e0e7..09a2c35b4921 100644 --- a/spec/02-integration/05-proxy/19-grpc_proxy_spec.lua +++ b/spec/02-integration/05-proxy/19-grpc_proxy_spec.lua @@ -57,7 +57,7 @@ for _, strategy in helpers.each_strategy() do end) it("proxies grpc", function() - local ok, resp = proxy_client_grpc({ + local ok, resp = assert(proxy_client_grpc({ service = "hello.HelloService.SayHello", body = { greeting = "world!" @@ -65,13 +65,13 @@ for _, strategy in helpers.each_strategy() do opts = { ["-authority"] = "grpc", } - }) + })) assert.truthy(ok) assert.truthy(resp) end) it("proxies grpcs", function() - local ok, resp = proxy_client_grpcs({ + local ok, resp = assert(proxy_client_grpcs({ service = "hello.HelloService.SayHello", body = { greeting = "world!" @@ -79,7 +79,7 @@ for _, strategy in helpers.each_strategy() do opts = { ["-authority"] = "grpcs", } - }) + })) assert.truthy(ok) assert.truthy(resp) end)