From 2d041f1927d3f16e1f6f0294f957e43f6dccd9a4 Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Fri, 7 Jul 2017 12:39:09 -0700 Subject: [PATCH] feat(router) support for URIs as user-specified regexes When a value from an API's `uris` property does not respect the reserved character set from RFC 3986, the router now assumes this value is a user-specified regex. User regexes are stored in another index than what we consider "URI prefixes" (non-regex `uris` values), because while the later needs to be evaluated based on their length, the same does not apply for the former. The order in which regexes are specified (and ultimately, APIs are registered) will be the order in which they will be evaluated. Because users might specify their own capturing groups once we support parameters extraction, the URI stripping now uses named capturing groups. This requires PCRE 7.2+. The original implementation used numbered capturing groups, but this is not aligned with our long-term goal of allowing dynamic URI rewriting as part of our Plugins API or request-transformer plugin. Breaking changes: requires PCRE 7.2+ Implements: #677 (partially) --- kong/core/router.lua | 371 ++++++++++++++++++-------------- spec/01-unit/11-router_spec.lua | 146 +++++++++++++ 2 files changed, 352 insertions(+), 165 deletions(-) diff --git a/kong/core/router.lua b/kong/core/router.lua index b7fe8ede9360..dff5fd9f3cd1 100644 --- a/kong/core/router.lua +++ b/kong/core/router.lua @@ -5,7 +5,6 @@ local bit = require "bit" local re_match = ngx.re.match local re_find = ngx.re.find -local re_sub = ngx.re.sub local insert = table.insert local upper = string.upper local lower = string.lower @@ -20,6 +19,7 @@ local next = next local band = bit.band local bor = bit.bor local ERR = ngx.ERR +local clear_tab local log @@ -28,6 +28,16 @@ do log = function(lvl, ...) ngx_log(lvl, "[router] ", ...) end + + local ok + ok, clear_tab = pcall(require, "table.clear") + if not ok then + clear_tab = function(tab) + for k in pairs(tab) do + tab[k] = nil + end + end + end end @@ -63,17 +73,15 @@ local reduce local function marshall_api(api) - local api_t = { - api = api, - strip_uri = api.strip_uri, - preserve_host = api.preserve_host, - match_rules = 0x00, - hosts = {}, - wildcard_hosts = {}, - uris = {}, - uris_prefixes_regexes = {}, - methods = {}, - upstream = {}, + local api_t = { + api = api, + strip_uri = api.strip_uri, + preserve_host = api.preserve_host, + match_rules = 0x00, + hosts = {}, + uris = {}, + methods = {}, + upstream = {}, } @@ -105,9 +113,15 @@ local function marshall_api(api) -- wildcard host matching local wildcard_host_regex = host_value:gsub("%.", "\\.") :gsub("%*", ".+") .. "$" - insert(api_t.wildcard_hosts, { + insert(api_t.hosts, { + wildcard = true, + value = host_value, + regex = wildcard_host_regex + }) + + else + insert(api_t.hosts, { value = host_value, - regex = wildcard_host_regex }) end @@ -128,18 +142,38 @@ local function marshall_api(api) if #api.uris > 0 then api_t.match_rules = bor(api_t.match_rules, MATCH_RULES.URI) - for i, uri in ipairs(api.uris) do - local escaped_uri = [[\Q]] .. uri .. [[\E]] - local strip_regex = escaped_uri .. [[\/?(.*)]] + for _, uri in ipairs(api.uris) do + if re_match(uri, [[^[a-zA-Z0-9\.\-_~/%]*$]]) then + -- plain URI or URI prefix + local escaped_uri = [[\Q]] .. uri .. [[\E]] + local strip_regex = escaped_uri .. [[/?(?P.*)]] + + api_t.uris[uri] = { + prefix = true, + strip_regex = strip_regex, + } + + insert(api_t.uris, { + prefix = true, + value = uri, + regex = escaped_uri, + strip_regex = strip_regex, + }) - api_t.uris[uri] = { - strip_regex = strip_regex, - } + else + -- regex URI + local strip_regex = uri .. [[/?(?P.*)]] - api_t.uris_prefixes_regexes[i] = { - regex = escaped_uri, - strip_regex = strip_regex, - } + api_t.uris[uri] = { + strip_regex = strip_regex, + } + + insert(api_t.uris, { + value = uri, + regex = uri, + strip_regex = strip_regex, + }) + end end end end @@ -196,60 +230,62 @@ local function marshall_api(api) end -local function index_api_t(api_t, plain_indexes, uris_prefixes, wildcard_hosts) - for host in pairs(api_t.hosts) do - plain_indexes.hosts[host] = true - end +local function index_api_t(api_t, plain_indexes, uris_prefixes, uris_regexes, + wildcard_hosts) + for _, host_t in ipairs(api_t.hosts) do + if host_t.wildcard then + insert(wildcard_hosts, host_t) - for uri in pairs(api_t.uris) do - plain_indexes.uris[uri] = true + else + plain_indexes.hosts[host_t.value] = true + end end - for method in pairs(api_t.methods) do - plain_indexes.methods[method] = true - end + for _, uri_t in ipairs(api_t.uris) do + if uri_t.prefix then + plain_indexes.uris[uri_t.value] = true + insert(uris_prefixes, uri_t) - for _, wildcard_host in ipairs(api_t.wildcard_hosts) do - insert(wildcard_hosts, wildcard_host) + else + insert(uris_regexes, uri_t) + end end - api_t.wildcard_hosts = nil - - for _, uri_prefix_regex in ipairs(api_t.uris_prefixes_regexes) do - insert(uris_prefixes, uri_prefix_regex.regex) + for method in pairs(api_t.methods) do + plain_indexes.methods[method] = true end end -local function categorize_api_t(api_t, categories) - local category = categories[api_t.match_rules] +local function categorize_api_t(api_t, bit_category, categories) + local category = categories[bit_category] if not category then - category = { - apis_by_plain_hosts = {}, - apis_by_plain_uris = {}, - apis_by_methods = {}, - apis = {}, + category = { + apis_by_hosts = {}, + apis_by_uris = {}, + apis_by_methods = {}, + all = {}, } - categories[api_t.match_rules] = category + categories[bit_category] = category end - insert(category.apis, api_t) + insert(category.all, api_t) - for host in pairs(api_t.hosts) do - if not category.apis_by_plain_hosts[host] then - category.apis_by_plain_hosts[host] = {} + for _, host_t in ipairs(api_t.hosts) do + if not category.apis_by_hosts[host_t.value] then + category.apis_by_hosts[host_t.value] = {} end - insert(category.apis_by_plain_hosts[host], api_t) + insert(category.apis_by_hosts[host_t.value], api_t) end - for uri in pairs(api_t.uris) do - if not category.apis_by_plain_uris[uri] then - category.apis_by_plain_uris[uri] = {} + for _, uri_t in ipairs(api_t.uris) do + if not category.apis_by_uris[uri_t.value] then + category.apis_by_uris[uri_t.value] = {} end - insert(category.apis_by_plain_uris[uri], api_t) + insert(category.apis_by_uris[uri_t.value], api_t) end for method in pairs(api_t.methods) do @@ -264,31 +300,37 @@ end do local matchers = { - [MATCH_RULES.HOST] = function(api_t, _, _, host) + [MATCH_RULES.HOST] = function(api_t, ctx) + local host = ctx.wildcard_host or ctx.host + if api_t.hosts[host] then return true end end, - [MATCH_RULES.URI] = function(api_t, _, uri) + [MATCH_RULES.URI] = function(api_t, ctx) + local uri = ctx.uri_regex or ctx.uri_prefix or ctx.uri + if api_t.uris[uri] then if api_t.strip_uri then - api_t.strip_uri_regex = api_t.uris[uri].strip_regex + ctx.strip_uri_regex = api_t.uris[uri].strip_regex end return true end - for i = 1, #api_t.uris_prefixes_regexes do - local from, _, err = re_find(uri, api_t.uris_prefixes_regexes[i].regex, "ajo") + for i = 1, #api_t.uris do + local regex_t = api_t.uris[i] + + local from, _, err = re_find(ctx.uri, regex_t.regex, "ajo") if err then - log(ERR, "could not search for URI prefix: ", err) + log(ERR, "could not evaluate URI prefix/regex: ", err) return end if from then if api_t.strip_uri then - api_t.strip_uri_regex = api_t.uris_prefixes_regexes[i].strip_regex + ctx.strip_uri_regex = regex_t.strip_regex end return true @@ -296,16 +338,16 @@ do end end, - [MATCH_RULES.METHOD] = function(api_t, method) - return api_t.methods[method] + [MATCH_RULES.METHOD] = function(api_t, ctx) + return api_t.methods[ctx.method] end } - match_api = function(api_t, method, uri, host) + match_api = function(api_t, ctx) -- run cached matcher if type(matchers[api_t.match_rules]) == "function" then - return matchers[api_t.match_rules](api_t, method, uri, host) + return matchers[api_t.match_rules](api_t, ctx) end -- build and cache matcher @@ -318,9 +360,9 @@ do end end - matchers[api_t.match_rules] = function(api_t, method, uri, host) + matchers[api_t.match_rules] = function(api_t, ctx) for i = 1, #matchers_set do - if not matchers_set[i](api_t, method, uri, host) then + if not matchers_set[i](api_t, ctx) then return end end @@ -328,31 +370,35 @@ do return true end - return matchers[api_t.match_rules](api_t, method, uri, host) + return matchers[api_t.match_rules](api_t, ctx) end end do local reducers = { - [MATCH_RULES.HOST] = function(category, _, _, host) - return category.apis_by_plain_hosts[host] + [MATCH_RULES.HOST] = function(category, ctx) + local host = ctx.wildcard_host or ctx.host + + return category.apis_by_hosts[host] end, - [MATCH_RULES.URI] = function(category, _, uri) - return category.apis_by_plain_uris[uri] + [MATCH_RULES.URI] = function(category, ctx) + local uri = ctx.uri_regex or ctx.uri_prefix or ctx.uri + + return category.apis_by_uris[uri] end, - [MATCH_RULES.METHOD] = function(category, method) - return category.apis_by_methods[method] + [MATCH_RULES.METHOD] = function(category, ctx) + return category.apis_by_methods[ctx.method] end, } - reduce = function(category, bit_category, method, uri, host) + reduce = function(category, bit_category, ctx) -- run cached reducer if type(reducers[bit_category]) == "function" then - return reducers[bit_category](category, method, uri, host), category.apis + return reducers[bit_category](category, ctx), category.all end -- build and cache reducer @@ -365,22 +411,23 @@ do end end - reducers[bit_category] = function(category, method, uri, host) + reducers[bit_category] = function(category, ctx) local min_len = 0 local smallest_set for i = 1, #reducers_set do - local candidates = reducers_set[i](category, method, uri, host) - if candidates ~= nil and (not smallest_set or #candidates < min_len) then + local candidates = reducers_set[i](category, ctx) + if candidates ~= nil and (not smallest_set or #candidates < min_len) + then min_len = #candidates smallest_set = candidates end end - return smallest_set, category.apis + return smallest_set end - return reducers[bit_category](category, method, uri, host) + return reducers[bit_category](category, ctx), category.all end end @@ -397,6 +444,9 @@ function _M.new(apis) local self = {} + local ctx = {} + + -- hash table for fast lookup of plain hosts, uris -- and methods from incoming requests local plain_indexes = { @@ -408,7 +458,8 @@ function _M.new(apis) -- when hash lookup in plain_indexes fails, those are arrays -- of regexes for `uris` as prefixes and `hosts` as wildcards - local uris_prefixes = {} + local uris_prefixes = {} -- will be sorted by length + local uris_regexes = {} local wildcard_hosts = {} @@ -429,56 +480,16 @@ function _M.new(apis) return nil, err end - index_api_t(api_t, plain_indexes, uris_prefixes, wildcard_hosts) - categorize_api_t(api_t, categories) + categorize_api_t(api_t, api_t.match_rules, categories) + index_api_t(api_t, plain_indexes, uris_prefixes, uris_regexes, + wildcard_hosts) end - local function compare_uris_length(a, b, category_bit) - if not band(category_bit, MATCH_RULES.URI) then - return - end - - local max_uri_a = 0 - local max_uri_b = 0 - - for _, prefix in ipairs(a.uris_prefixes_regexes) do - if #prefix.regex > max_uri_a then - max_uri_a = #prefix.regex - end - end - - for _, prefix in ipairs(b.uris_prefixes_regexes) do - if #prefix.regex > max_uri_b then - max_uri_b = #prefix.regex - end - end - - return max_uri_a > max_uri_b - end - - table.sort(uris_prefixes, function(a, b) - return #a > #b + table.sort(uris_prefixes, function(uri_t_a, uri_t_b) + return #uri_t_a.value > #uri_t_b.value end) - for category_bit, category in pairs(categories) do - table.sort(category.apis, function(a, b) - return compare_uris_length(a, b, category_bit) - end) - - for _, apis_by_method in pairs(category.apis_by_methods) do - table.sort(apis_by_method, function(a, b) - return compare_uris_length(a, b, category_bit) - end) - end - - for _, apis_by_host in pairs(category.apis_by_plain_hosts) do - table.sort(apis_by_host, function(a, b) - return compare_uris_length(a, b, category_bit) - end) - end - end - local grab_host = #wildcard_hosts > 0 or next(plain_indexes.hosts) ~= nil @@ -519,20 +530,25 @@ function _M.new(apis) local cache_key = fmt("%s:%s:%s", method, uri, host) do - local api_t_from_cache = cache:get(cache_key) - if api_t_from_cache and match_api(api_t_from_cache, method, uri, host) - then - return api_t_from_cache + local cache = cache:get(cache_key) + if cache then + return cache.api_t, cache.ctx end end + clear_tab(ctx) + + ctx.uri = uri + ctx.host = host + ctx.method = method + + local req_category = 0x00 -- router, router, which of these APIs is the fairest? -- -- determine which category this request *might* be targeting - - local req_category = 0x00 + -- host match if plain_indexes.hosts[host] then req_category = bor(req_category, MATCH_RULES.HOST) @@ -546,39 +562,56 @@ function _M.new(apis) end if from then - host = wildcard_hosts[i].value - req_category = bor(req_category, MATCH_RULES.HOST) + ctx.wildcard_host = wildcard_hosts[i].value + req_category = bor(req_category, MATCH_RULES.HOST) break end end end + -- uri match + if plain_indexes.uris[uri] then req_category = bor(req_category, MATCH_RULES.URI) else for i = 1, #uris_prefixes do - local from, _, err = re_find(uri, uris_prefixes[i], "ajo") + local from, _, err = re_find(uri, uris_prefixes[i].regex, "ajo") + if err then + log(ERR, "could not evaluate URI prefix: ", err) + return + end + + if from then + ctx.uri_prefix = uris_prefixes[i].value + req_category = bor(req_category, MATCH_RULES.URI) + break + end + end + + for i = 1, #uris_regexes do + local from, _, err = re_find(uri, uris_regexes[i].regex, "ajo") if err then - log(ERR, "could not search for URI prefix: ", err) + log(ERR, "could not evaluate URI regex: ", err) return end if from then - -- strip \Q...\E tokens - uri = sub(uris_prefixes[i], 3, -3) - req_category = bor(req_category, MATCH_RULES.URI) + ctx.uri_regex = uris_regexes[i].value + req_category = bor(req_category, MATCH_RULES.URI) break end end end + -- method match + if plain_indexes.methods[method] then req_category = bor(req_category, MATCH_RULES.METHOD) end - --print("highest potential category: ", req_category) + --print("highest potential category: ", ctx.req_category) -- iterate from the highest matching to the lowest category to -- find our API @@ -593,34 +626,42 @@ function _M.new(apis) local category = categories[bit_category] if category then - local plain_candidates, apis_for_category = reduce(category, - bit_category, - method, uri, host) - if plain_candidates then - -- check for results from a set of reduced plain indexes - -- this is our best case scenario with hash lookups only - for i = 1, #plain_candidates do - if match_api(plain_candidates[i], method, uri, host) then - matched_api = plain_candidates[i] + local reduced_candidates, category_candidates = reduce(category, + bit_category, + ctx) + if reduced_candidates then + -- check against a reduced set of APIs that is a strong candidate + -- for this request, instead of iterating over all the APIs of + -- this category + for i = 1, #reduced_candidates do + if match_api(reduced_candidates[i], ctx) then + matched_api = reduced_candidates[i] break end end end if not matched_api then - -- must check for results from the full list of APIs from that - -- category before checking a lower category - for i = 1, #apis_for_category do - if match_api(apis_for_category[i], method, uri, host) then - matched_api = apis_for_category[i] + -- no result from the reduced set, must check for results from the + -- full list of APIs from that category before checking a lower + -- category + for i = 1, #category_candidates do + if match_api(category_candidates[i], ctx) then + matched_api = category_candidates[i] break end end end if matched_api then - cache:set(cache_key, matched_api) - return matched_api + cache:set(cache_key, { + api_t = matched_api, + ctx = { + strip_uri_regex = ctx.strip_uri_regex, + } + }) + + return matched_api, ctx end end @@ -659,21 +700,21 @@ function _M.new(apis) req_host = ngx.var.http_host end - - local api_t = find_api(method, uri, req_host) + local api_t, ctx = find_api(method, uri, req_host) if not api_t then return nil end local uri_root = request_uri == "/" - if not uri_root and api_t.strip_uri_regex then - local _, err - uri, _, err = re_sub(uri, api_t.strip_uri_regex, "/$1", "ajo") - if not uri then + if not uri_root and ctx.strip_uri_regex then + local m, err = re_match(uri, ctx.strip_uri_regex, "ajo") + if not m then log(ERR, "could not strip URI: ", err) return end + + uri = "/" .. m.stripped_uri end diff --git a/spec/01-unit/11-router_spec.lua b/spec/01-unit/11-router_spec.lua index 1a3107410231..d33dd8faf2c0 100644 --- a/spec/01-unit/11-router_spec.lua +++ b/spec/01-unit/11-router_spec.lua @@ -268,6 +268,65 @@ describe("Router", function() end) end) + describe("uri as a regex", function() + it("matches with [uri regex]", function() + local use_case = { + { + name = "api-1", + uris = { [[/users/\d+/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local api_t = router.select("GET", "/users/123/profile") + assert.truthy(api_t) + assert.same(use_case[1], api_t.api) + end) + + it("matches the right API when several ones have a [uri regex]", function() + local use_case = { + { + name = "api-1", + uris = { [[/api/persons/\d{3}]] }, + }, + { + name = "api-2", + uris = { [[/api/persons/\d{3}/following]] }, + }, + { + name = "api-3", + uris = { [[/api/persons/\d{3}/[a-z]+]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local api_t = router.select("GET", "/api/persons/456") + assert.truthy(api_t) + assert.same(use_case[1], api_t.api) + end) + + it("matches a [uri regex] even if a [prefix uri] got a match", function() + local use_case = { + { + name = "api-1", + uris = { [[/api/persons]] }, + }, + { + name = "api-2", + uris = { [[/api/persons/\d+/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local api_t = router.select("GET", "/api/persons/123/profile") + assert.truthy(api_t) + assert.same(use_case[2], api_t.api) + end) + end) + describe("wildcard domains", function() local use_case = { { @@ -365,6 +424,37 @@ describe("Router", function() end) end) + describe("[wildcard host] + [uri regex]", function() + it("matches", function() + local use_case = { + { + name = "api-1", + uris = { [[/users/\d+/profile]] }, + headers = { + ["host"] = { "*.example.com" }, + }, + }, + { + name = "api-2", + uris = { [[/users]] }, + headers = { + ["host"] = { "*.example.com" }, + }, + }, + } + + local router = assert(Router.new(use_case)) + + local match_t = router.select("GET", "/users/123/profile", "test.example.com") + assert.truthy(match_t) + assert.same(use_case[1], match_t.api) + + match_t = router.select("GET", "/users", "test.example.com") + assert.truthy(match_t) + assert.same(use_case[2], match_t.api) + end) + end) + describe("edge-cases", function() it("[host] and [uri] have higher priority than [method]", function() -- host @@ -434,6 +524,30 @@ describe("Router", function() assert.same(use_case[2], api_t.api) end) + it("half [uri regex] and [method] match does not supersede another API", function() + local use_case = { + { + name = "api-1", + methods = { "GET" }, + uris = { [[/users/\d+/profile]] }, + }, + { + name = "api-2", + methods = { "POST" }, + uris = { [[/users/\d*/profile]] }, + } + } + + local router = assert(Router.new(use_case)) + local api_t = router.select("GET", "/users/123/profile") + assert.truthy(api_t) + assert.same(use_case[1], api_t.api) + + api_t = router.select("POST", "/users/123/profile") + assert.truthy(api_t) + assert.same(use_case[2], api_t.api) + end) + it("[method] does not supersede non-plain [uri]", function() local use_case = { { @@ -1076,6 +1190,38 @@ describe("Router", function() assert.same(use_case_apis[1], api) assert.equal("/", uri) end) + + it("strips a [uri regex]", function() + local use_case = { + { + name = "api-1", + strip_uri = true, + uris = { [[/users/\d+/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local _ngx = mock_ngx("GET", "/users/123/profile/hello/world", {}) + local _, _, _, uri = router.exec(_ngx) + assert.equal("/hello/world", uri) + end) + + it("strips a [uri regex] with a capture group", function() + local use_case = { + { + name = "api-1", + strip_uri = true, + uris = { [[/users/(\d+)/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local _ngx = mock_ngx("GET", "/users/123/profile/hello/world", {}) + local _, _, _, uri = router.exec(_ngx) + assert.equal("/hello/world", uri) + end) end) describe("preserve Host header", function()