diff --git a/kong/core/router.lua b/kong/core/router.lua index e5ac8fe5b9f5..23035db38fc5 100644 --- a/kong/core/router.lua +++ b/kong/core/router.lua @@ -5,7 +5,6 @@ local bit = require "bit" local re_match = ngx.re.match local re_find = ngx.re.find -local re_sub = ngx.re.sub local insert = table.insert local upper = string.upper local lower = string.lower @@ -20,6 +19,7 @@ local next = next local band = bit.band local bor = bit.bor local ERR = ngx.ERR +local clear_tab local log @@ -28,6 +28,16 @@ do log = function(lvl, ...) ngx_log(lvl, "[router] ", ...) end + + local ok + ok, clear_tab = pcall(require, "table.clear") + if not ok then + clear_tab = function(tab) + for k in pairs(tab) do + tab[k] = nil + end + end + end end @@ -73,17 +83,15 @@ local reduce local function marshall_api(api) - local api_t = { - api = api, - strip_uri = api.strip_uri, - preserve_host = api.preserve_host, - match_rules = 0x00, - hosts = {}, - wildcard_hosts = {}, - uris = {}, - uris_prefixes_regexes = {}, - methods = {}, - upstream = {}, + local api_t = { + api = api, + strip_uri = api.strip_uri, + preserve_host = api.preserve_host, + match_rules = 0x00, + hosts = {}, + uris = {}, + methods = {}, + upstream = {}, } @@ -115,9 +123,15 @@ local function marshall_api(api) -- wildcard host matching local wildcard_host_regex = host_value:gsub("%.", "\\.") :gsub("%*", ".+") .. "$" - insert(api_t.wildcard_hosts, { + insert(api_t.hosts, { + wildcard = true, + value = host_value, + regex = wildcard_host_regex + }) + + else + insert(api_t.hosts, { value = host_value, - regex = wildcard_host_regex }) end @@ -138,18 +152,38 @@ local function marshall_api(api) if #api.uris > 0 then api_t.match_rules = bor(api_t.match_rules, MATCH_RULES.URI) - for i, uri in ipairs(api.uris) do - local escaped_uri = [[\Q]] .. uri .. [[\E]] - local strip_regex = escaped_uri .. [[\/?(.*)]] + for _, uri in ipairs(api.uris) do + if re_match(uri, [[^[a-zA-Z0-9\.\-_~/%]*$]]) then + -- plain URI or URI prefix + local escaped_uri = [[\Q]] .. uri .. [[\E]] + local strip_regex = escaped_uri .. [[/?(?P.*)]] + + api_t.uris[uri] = { + prefix = true, + strip_regex = strip_regex, + } + + insert(api_t.uris, { + prefix = true, + value = uri, + regex = escaped_uri, + strip_regex = strip_regex, + }) - api_t.uris[uri] = { - strip_regex = strip_regex, - } + else + -- regex URI + local strip_regex = uri .. [[/?(?P.*)]] - api_t.uris_prefixes_regexes[i] = { - regex = escaped_uri, - strip_regex = strip_regex, - } + api_t.uris[uri] = { + strip_regex = strip_regex, + } + + insert(api_t.uris, { + value = uri, + regex = uri, + strip_regex = strip_regex, + }) + end end end end @@ -206,60 +240,62 @@ local function marshall_api(api) end -local function index_api_t(api_t, plain_indexes, uris_prefixes, wildcard_hosts) - for host in pairs(api_t.hosts) do - plain_indexes.hosts[host] = true - end +local function index_api_t(api_t, plain_indexes, uris_prefixes, uris_regexes, + wildcard_hosts) + for _, host_t in ipairs(api_t.hosts) do + if host_t.wildcard then + insert(wildcard_hosts, host_t) - for uri in pairs(api_t.uris) do - plain_indexes.uris[uri] = true + else + plain_indexes.hosts[host_t.value] = true + end end - for method in pairs(api_t.methods) do - plain_indexes.methods[method] = true - end + for _, uri_t in ipairs(api_t.uris) do + if uri_t.prefix then + plain_indexes.uris[uri_t.value] = true + insert(uris_prefixes, uri_t) - for _, wildcard_host in ipairs(api_t.wildcard_hosts) do - insert(wildcard_hosts, wildcard_host) + else + insert(uris_regexes, uri_t) + end end - api_t.wildcard_hosts = nil - - for _, uri_prefix_regex in ipairs(api_t.uris_prefixes_regexes) do - insert(uris_prefixes, uri_prefix_regex.regex) + for method in pairs(api_t.methods) do + plain_indexes.methods[method] = true end end -local function categorize_api_t(api_t, categories) - local category = categories[api_t.match_rules] +local function categorize_api_t(api_t, bit_category, categories) + local category = categories[bit_category] if not category then - category = { - apis_by_plain_hosts = {}, - apis_by_plain_uris = {}, - apis_by_methods = {}, - apis = {}, + category = { + apis_by_hosts = {}, + apis_by_uris = {}, + apis_by_methods = {}, + all = {}, } - categories[api_t.match_rules] = category + categories[bit_category] = category end - insert(category.apis, api_t) + insert(category.all, api_t) - for host in pairs(api_t.hosts) do - if not category.apis_by_plain_hosts[host] then - category.apis_by_plain_hosts[host] = {} + for _, host_t in ipairs(api_t.hosts) do + if not category.apis_by_hosts[host_t.value] then + category.apis_by_hosts[host_t.value] = {} end - insert(category.apis_by_plain_hosts[host], api_t) + insert(category.apis_by_hosts[host_t.value], api_t) end - for uri in pairs(api_t.uris) do - if not category.apis_by_plain_uris[uri] then - category.apis_by_plain_uris[uri] = {} + for _, uri_t in ipairs(api_t.uris) do + if not category.apis_by_uris[uri_t.value] then + category.apis_by_uris[uri_t.value] = {} end - insert(category.apis_by_plain_uris[uri], api_t) + insert(category.apis_by_uris[uri_t.value], api_t) end for method in pairs(api_t.methods) do @@ -274,31 +310,37 @@ end do local matchers = { - [MATCH_RULES.HOST] = function(api_t, _, _, host) + [MATCH_RULES.HOST] = function(api_t, ctx) + local host = ctx.wildcard_host or ctx.host + if api_t.hosts[host] then return true end end, - [MATCH_RULES.URI] = function(api_t, _, uri) + [MATCH_RULES.URI] = function(api_t, ctx) + local uri = ctx.uri_regex or ctx.uri_prefix or ctx.uri + if api_t.uris[uri] then if api_t.strip_uri then - api_t.strip_uri_regex = api_t.uris[uri].strip_regex + ctx.strip_uri_regex = api_t.uris[uri].strip_regex end return true end - for i = 1, #api_t.uris_prefixes_regexes do - local from, _, err = re_find(uri, api_t.uris_prefixes_regexes[i].regex, "ajo") + for i = 1, #api_t.uris do + local regex_t = api_t.uris[i] + + local from, _, err = re_find(ctx.uri, regex_t.regex, "ajo") if err then - log(ERR, "could not search for URI prefix: ", err) + log(ERR, "could not evaluate URI prefix/regex: ", err) return end if from then if api_t.strip_uri then - api_t.strip_uri_regex = api_t.uris_prefixes_regexes[i].strip_regex + ctx.strip_uri_regex = regex_t.strip_regex end return true @@ -306,16 +348,16 @@ do end end, - [MATCH_RULES.METHOD] = function(api_t, method) - return api_t.methods[method] + [MATCH_RULES.METHOD] = function(api_t, ctx) + return api_t.methods[ctx.method] end } - match_api = function(api_t, method, uri, host) + match_api = function(api_t, ctx) -- run cached matcher if type(matchers[api_t.match_rules]) == "function" then - return matchers[api_t.match_rules](api_t, method, uri, host) + return matchers[api_t.match_rules](api_t, ctx) end -- build and cache matcher @@ -328,9 +370,9 @@ do end end - matchers[api_t.match_rules] = function(api_t, method, uri, host) + matchers[api_t.match_rules] = function(api_t, ctx) for i = 1, #matchers_set do - if not matchers_set[i](api_t, method, uri, host) then + if not matchers_set[i](api_t, ctx) then return end end @@ -338,31 +380,35 @@ do return true end - return matchers[api_t.match_rules](api_t, method, uri, host) + return matchers[api_t.match_rules](api_t, ctx) end end do local reducers = { - [MATCH_RULES.HOST] = function(category, _, _, host) - return category.apis_by_plain_hosts[host] + [MATCH_RULES.HOST] = function(category, ctx) + local host = ctx.wildcard_host or ctx.host + + return category.apis_by_hosts[host] end, - [MATCH_RULES.URI] = function(category, _, uri) - return category.apis_by_plain_uris[uri] + [MATCH_RULES.URI] = function(category, ctx) + local uri = ctx.uri_regex or ctx.uri_prefix or ctx.uri + + return category.apis_by_uris[uri] end, - [MATCH_RULES.METHOD] = function(category, method) - return category.apis_by_methods[method] + [MATCH_RULES.METHOD] = function(category, ctx) + return category.apis_by_methods[ctx.method] end, } - reduce = function(category, bit_category, method, uri, host) + reduce = function(category, bit_category, ctx) -- run cached reducer if type(reducers[bit_category]) == "function" then - return reducers[bit_category](category, method, uri, host), category.apis + return reducers[bit_category](category, ctx), category.all end -- build and cache reducer @@ -375,22 +421,23 @@ do end end - reducers[bit_category] = function(category, method, uri, host) + reducers[bit_category] = function(category, ctx) local min_len = 0 local smallest_set for i = 1, #reducers_set do - local candidates = reducers_set[i](category, method, uri, host) - if candidates ~= nil and (not smallest_set or #candidates < min_len) then + local candidates = reducers_set[i](category, ctx) + if candidates ~= nil and (not smallest_set or #candidates < min_len) + then min_len = #candidates smallest_set = candidates end end - return smallest_set, category.apis + return smallest_set end - return reducers[bit_category](category, method, uri, host) + return reducers[bit_category](category, ctx), category.all end end @@ -407,6 +454,9 @@ function _M.new(apis) local self = {} + local ctx = {} + + -- hash table for fast lookup of plain hosts, uris -- and methods from incoming requests local plain_indexes = { @@ -418,7 +468,8 @@ function _M.new(apis) -- when hash lookup in plain_indexes fails, those are arrays -- of regexes for `uris` as prefixes and `hosts` as wildcards - local uris_prefixes = {} + local uris_prefixes = {} -- will be sorted by length + local uris_regexes = {} local wildcard_hosts = {} @@ -439,56 +490,16 @@ function _M.new(apis) return nil, err end - index_api_t(api_t, plain_indexes, uris_prefixes, wildcard_hosts) - categorize_api_t(api_t, categories) + categorize_api_t(api_t, api_t.match_rules, categories) + index_api_t(api_t, plain_indexes, uris_prefixes, uris_regexes, + wildcard_hosts) end - local function compare_uris_length(a, b, category_bit) - if not band(category_bit, MATCH_RULES.URI) then - return - end - - local max_uri_a = 0 - local max_uri_b = 0 - - for _, prefix in ipairs(a.uris_prefixes_regexes) do - if #prefix.regex > max_uri_a then - max_uri_a = #prefix.regex - end - end - - for _, prefix in ipairs(b.uris_prefixes_regexes) do - if #prefix.regex > max_uri_b then - max_uri_b = #prefix.regex - end - end - - return max_uri_a > max_uri_b - end - - table.sort(uris_prefixes, function(a, b) - return #a > #b + table.sort(uris_prefixes, function(uri_t_a, uri_t_b) + return #uri_t_a.value > #uri_t_b.value end) - for category_bit, category in pairs(categories) do - table.sort(category.apis, function(a, b) - return compare_uris_length(a, b, category_bit) - end) - - for _, apis_by_method in pairs(category.apis_by_methods) do - table.sort(apis_by_method, function(a, b) - return compare_uris_length(a, b, category_bit) - end) - end - - for _, apis_by_host in pairs(category.apis_by_plain_hosts) do - table.sort(apis_by_host, function(a, b) - return compare_uris_length(a, b, category_bit) - end) - end - end - local grab_host = #wildcard_hosts > 0 or next(plain_indexes.hosts) ~= nil @@ -529,20 +540,25 @@ function _M.new(apis) local cache_key = fmt("%s:%s:%s", method, uri, host) do - local api_t_from_cache = cache:get(cache_key) - if api_t_from_cache and match_api(api_t_from_cache, method, uri, host) - then - return api_t_from_cache + local cache = cache:get(cache_key) + if cache then + return cache.api_t, cache.ctx end end + clear_tab(ctx) + + ctx.uri = uri + ctx.host = host + ctx.method = method + + local req_category = 0x00 -- router, router, which of these APIs is the fairest? -- -- determine which category this request *might* be targeting - - local req_category = 0x00 + -- host match if plain_indexes.hosts[host] then req_category = bor(req_category, MATCH_RULES.HOST) @@ -556,39 +572,56 @@ function _M.new(apis) end if from then - host = wildcard_hosts[i].value - req_category = bor(req_category, MATCH_RULES.HOST) + ctx.wildcard_host = wildcard_hosts[i].value + req_category = bor(req_category, MATCH_RULES.HOST) break end end end + -- uri match + if plain_indexes.uris[uri] then req_category = bor(req_category, MATCH_RULES.URI) else for i = 1, #uris_prefixes do - local from, _, err = re_find(uri, uris_prefixes[i], "ajo") + local from, _, err = re_find(uri, uris_prefixes[i].regex, "ajo") + if err then + log(ERR, "could not evaluate URI prefix: ", err) + return + end + + if from then + ctx.uri_prefix = uris_prefixes[i].value + req_category = bor(req_category, MATCH_RULES.URI) + break + end + end + + for i = 1, #uris_regexes do + local from, _, err = re_find(uri, uris_regexes[i].regex, "ajo") if err then - log(ERR, "could not search for URI prefix: ", err) + log(ERR, "could not evaluate URI regex: ", err) return end if from then - -- strip \Q...\E tokens - uri = sub(uris_prefixes[i], 3, -3) - req_category = bor(req_category, MATCH_RULES.URI) + ctx.uri_regex = uris_regexes[i].value + req_category = bor(req_category, MATCH_RULES.URI) break end end end + -- method match + if plain_indexes.methods[method] then req_category = bor(req_category, MATCH_RULES.METHOD) end - --print("highest potential category: ", req_category) + --print("highest potential category: ", ctx.req_category) -- iterate from the highest matching to the lowest category to -- find our API @@ -603,34 +636,42 @@ function _M.new(apis) local category = categories[bit_category] if category then - local plain_candidates, apis_for_category = reduce(category, - bit_category, - method, uri, host) - if plain_candidates then - -- check for results from a set of reduced plain indexes - -- this is our best case scenario with hash lookups only - for i = 1, #plain_candidates do - if match_api(plain_candidates[i], method, uri, host) then - matched_api = plain_candidates[i] + local reduced_candidates, category_candidates = reduce(category, + bit_category, + ctx) + if reduced_candidates then + -- check against a reduced set of APIs that is a strong candidate + -- for this request, instead of iterating over all the APIs of + -- this category + for i = 1, #reduced_candidates do + if match_api(reduced_candidates[i], ctx) then + matched_api = reduced_candidates[i] break end end end if not matched_api then - -- must check for results from the full list of APIs from that - -- category before checking a lower category - for i = 1, #apis_for_category do - if match_api(apis_for_category[i], method, uri, host) then - matched_api = apis_for_category[i] + -- no result from the reduced set, must check for results from the + -- full list of APIs from that category before checking a lower + -- category + for i = 1, #category_candidates do + if match_api(category_candidates[i], ctx) then + matched_api = category_candidates[i] break end end end if matched_api then - cache:set(cache_key, matched_api) - return matched_api + cache:set(cache_key, { + api_t = matched_api, + ctx = { + strip_uri_regex = ctx.strip_uri_regex, + } + }) + + return matched_api, ctx end end @@ -669,21 +710,21 @@ function _M.new(apis) req_host = ngx.var.http_host end - - local api_t = find_api(method, uri, req_host) + local api_t, ctx = find_api(method, uri, req_host) if not api_t then return nil end local uri_root = request_uri == "/" - if not uri_root and api_t.strip_uri_regex then - local _, err - uri, _, err = re_sub(uri, api_t.strip_uri_regex, "/$1", "ajo") - if not uri then + if not uri_root and ctx.strip_uri_regex then + local m, err = re_match(uri, ctx.strip_uri_regex, "ajo") + if not m then log(ERR, "could not strip URI: ", err) return end + + uri = "/" .. m.stripped_uri end diff --git a/spec/01-unit/010-router_spec.lua b/spec/01-unit/010-router_spec.lua index 1a3107410231..d33dd8faf2c0 100644 --- a/spec/01-unit/010-router_spec.lua +++ b/spec/01-unit/010-router_spec.lua @@ -268,6 +268,65 @@ describe("Router", function() end) end) + describe("uri as a regex", function() + it("matches with [uri regex]", function() + local use_case = { + { + name = "api-1", + uris = { [[/users/\d+/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local api_t = router.select("GET", "/users/123/profile") + assert.truthy(api_t) + assert.same(use_case[1], api_t.api) + end) + + it("matches the right API when several ones have a [uri regex]", function() + local use_case = { + { + name = "api-1", + uris = { [[/api/persons/\d{3}]] }, + }, + { + name = "api-2", + uris = { [[/api/persons/\d{3}/following]] }, + }, + { + name = "api-3", + uris = { [[/api/persons/\d{3}/[a-z]+]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local api_t = router.select("GET", "/api/persons/456") + assert.truthy(api_t) + assert.same(use_case[1], api_t.api) + end) + + it("matches a [uri regex] even if a [prefix uri] got a match", function() + local use_case = { + { + name = "api-1", + uris = { [[/api/persons]] }, + }, + { + name = "api-2", + uris = { [[/api/persons/\d+/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local api_t = router.select("GET", "/api/persons/123/profile") + assert.truthy(api_t) + assert.same(use_case[2], api_t.api) + end) + end) + describe("wildcard domains", function() local use_case = { { @@ -365,6 +424,37 @@ describe("Router", function() end) end) + describe("[wildcard host] + [uri regex]", function() + it("matches", function() + local use_case = { + { + name = "api-1", + uris = { [[/users/\d+/profile]] }, + headers = { + ["host"] = { "*.example.com" }, + }, + }, + { + name = "api-2", + uris = { [[/users]] }, + headers = { + ["host"] = { "*.example.com" }, + }, + }, + } + + local router = assert(Router.new(use_case)) + + local match_t = router.select("GET", "/users/123/profile", "test.example.com") + assert.truthy(match_t) + assert.same(use_case[1], match_t.api) + + match_t = router.select("GET", "/users", "test.example.com") + assert.truthy(match_t) + assert.same(use_case[2], match_t.api) + end) + end) + describe("edge-cases", function() it("[host] and [uri] have higher priority than [method]", function() -- host @@ -434,6 +524,30 @@ describe("Router", function() assert.same(use_case[2], api_t.api) end) + it("half [uri regex] and [method] match does not supersede another API", function() + local use_case = { + { + name = "api-1", + methods = { "GET" }, + uris = { [[/users/\d+/profile]] }, + }, + { + name = "api-2", + methods = { "POST" }, + uris = { [[/users/\d*/profile]] }, + } + } + + local router = assert(Router.new(use_case)) + local api_t = router.select("GET", "/users/123/profile") + assert.truthy(api_t) + assert.same(use_case[1], api_t.api) + + api_t = router.select("POST", "/users/123/profile") + assert.truthy(api_t) + assert.same(use_case[2], api_t.api) + end) + it("[method] does not supersede non-plain [uri]", function() local use_case = { { @@ -1076,6 +1190,38 @@ describe("Router", function() assert.same(use_case_apis[1], api) assert.equal("/", uri) end) + + it("strips a [uri regex]", function() + local use_case = { + { + name = "api-1", + strip_uri = true, + uris = { [[/users/\d+/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local _ngx = mock_ngx("GET", "/users/123/profile/hello/world", {}) + local _, _, _, uri = router.exec(_ngx) + assert.equal("/hello/world", uri) + end) + + it("strips a [uri regex] with a capture group", function() + local use_case = { + { + name = "api-1", + strip_uri = true, + uris = { [[/users/(\d+)/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local _ngx = mock_ngx("GET", "/users/123/profile/hello/world", {}) + local _, _, _, uri = router.exec(_ngx) + assert.equal("/hello/world", uri) + end) end) describe("preserve Host header", function()