diff --git a/kong/core/router.lua b/kong/core/router.lua index 692e0545a9d9..02a9616f7c0d 100644 --- a/kong/core/router.lua +++ b/kong/core/router.lua @@ -72,6 +72,15 @@ local match_api local reduce +local function has_capturing_groups(subj) + local s = find(subj, "[^\\]%(.-[^\\]%)") + s = s or find(subj, "^%(.-[^\\]%)") + s = s or find(subj, "%(%)") + + return s ~= nil +end + + local function marshall_api(api) local api_t = { api = api, @@ -162,16 +171,19 @@ local function marshall_api(api) else -- regex URI - local strip_regex = uri .. [[/?(?P.*)]] + local strip_regex = uri .. [[/?(?P.*)]] + local has_captures = has_capturing_groups(uri) api_t.uris[uri] = { + has_captures = has_captures, strip_regex = strip_regex, } insert(api_t.uris, { - value = uri, - regex = uri, - strip_regex = strip_regex, + value = uri, + regex = uri, + has_captures = has_captures, + strip_regex = strip_regex, }) end end @@ -309,31 +321,77 @@ do end, [MATCH_RULES.URI] = function(api_t, ctx) - local uri = ctx.uri_prefix or ctx.uri_regex or ctx.uri + do + local uri = ctx.uri_prefix or ctx.uri_regex or ctx.uri + local regex_t = api_t.uris[uri] + + if regex_t then + if api_t.strip_uri or regex_t.has_captures then + local m, err = re_match(ctx.uri, regex_t.strip_regex, "ajo") + if err then + log(ERR, "could not evaluate URI prefix/regex: ", err) + return + end - if api_t.uris[uri] then - if api_t.strip_uri then - ctx.strip_uri_regex = api_t.uris[uri].strip_regex - end + if m then + if m.stripped_uri then + ctx.stripped_uri = "/" .. m.stripped_uri + -- remove the stripped_uri group + m[#m] = nil + m.stripped_uri = nil + end - return true + if regex_t.has_captures then + m[0] = nil + ctx.uri_captures = m + end + + return true + end + end + + -- plain or prefix match from the index without strip_uri + return true + end end for i = 1, #api_t.uris do local regex_t = api_t.uris[i] - local from, _, err = re_find(ctx.uri, regex_t.regex, "ajo") - if err then - log(ERR, "could not evaluate URI prefix/regex: ", err) - return - end + if api_t.strip_uri or regex_t.has_captures then + local m, err = re_match(ctx.uri, regex_t.strip_regex, "ajo") + if err then + log(ERR, "could not evaluate URI prefix/regex: ", err) + return + end - if from then - if api_t.strip_uri then - ctx.strip_uri_regex = regex_t.strip_regex + if m then + if m.stripped_uri then + ctx.stripped_uri = "/" .. m.stripped_uri + -- remove the stripped_uri group + m[#m] = nil + m.stripped_uri = nil + end + + if regex_t.has_captures then + m[0] = nil + ctx.uri_captures = m + end + + return true end - return true + else + -- prefix match without strip_uri + local from, _, err = re_find(ctx.uri, regex_t.regex, "ajo") + if err then + log(ERR, "could not evaluate URI prefix/regex: ", err) + return + end + + if from then + return true + end end end end, @@ -435,6 +493,9 @@ end local _M = {} +_M.has_capturing_groups = has_capturing_groups + + function _M.new(apis) if type(apis) ~= "table" then return error("expected arg #1 apis to be a table") @@ -659,7 +720,8 @@ function _M.new(apis) cache:set(cache_key, { api_t = matched_api, ctx = { - strip_uri_regex = ctx.strip_uri_regex, + stripped_uri = ctx.stripped_uri, + uri_captures = ctx.uri_captures, } }) @@ -709,14 +771,8 @@ function _M.new(apis) local uri_root = request_uri == "/" - if not uri_root and ctx.strip_uri_regex then - local m, err = re_match(uri, ctx.strip_uri_regex, "ajo") - if not m then - log(ERR, "could not strip URI: ", err) - return - end - - uri = "/" .. m.stripped_uri + if not uri_root and api_t.strip_uri and ctx.stripped_uri then + uri = ctx.stripped_uri end @@ -757,7 +813,7 @@ function _M.new(apis) ngx.header["Kong-Api-Name"] = api_t.api.name end - return api_t.api, api_t.upstream, host_header, uri + return api_t.api, api_t.upstream, host_header, uri, ctx.uri_captures end diff --git a/spec/01-unit/11-router_spec.lua b/spec/01-unit/11-router_spec.lua index ba07fabb9ec4..8285bc311e8d 100644 --- a/spec/01-unit/11-router_spec.lua +++ b/spec/01-unit/11-router_spec.lua @@ -1126,6 +1126,85 @@ describe("Router", function() local _, _, _, uri = router.exec(_ngx) assert.equal("/hello/world", uri) end) + + it("returns groups captures from a [uri regex]", function() + local use_case = { + { + name = "api-1", + uris = { [[/users/(?P\d+)/profile/?(?P[a-z]*)]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local _ngx = mock_ngx("GET", "/users/1984/profile", {}) + local _, _, _, _, uri_captures = router.exec(_ngx) + assert.equal("1984", uri_captures[1]) + assert.equal("1984", uri_captures.user_id) + assert.equal("", uri_captures[2]) + assert.equal("", uri_captures.scope) + -- no full match + assert.is_nil(uri_captures[0]) + -- no stripped_uri capture + assert.is_nil(uri_captures.stripped_uri) + assert.equal(2, #uri_captures) + + -- again, this time from the LRU cache + local _, _, _, _, uri_captures = router.exec(_ngx) + assert.equal("1984", uri_captures[1]) + assert.equal("1984", uri_captures.user_id) + assert.equal("", uri_captures[2]) + assert.equal("", uri_captures.scope) + -- no full match + assert.is_nil(uri_captures[0]) + -- no stripped_uri capture + assert.is_nil(uri_captures.stripped_uri) + assert.equal(2, #uri_captures) + + local _ngx = mock_ngx("GET", "/users/1984/profile/email", {}) + local _, _, _, _, uri_captures = router.exec(_ngx) + assert.equal("1984", uri_captures[1]) + assert.equal("1984", uri_captures.user_id) + assert.equal("email", uri_captures[2]) + assert.equal("email", uri_captures.scope) + -- no full match + assert.is_nil(uri_captures[0]) + -- no stripped_uri capture + assert.is_nil(uri_captures.stripped_uri) + assert.equal(2, #uri_captures) + end) + + it("returns no group capture from a [uri prefix] match", function() + local use_case = { + { + name = "api-1", + uris = { "/hello" }, + strip_uri = true, + }, + } + + local router = assert(Router.new(use_case)) + + local _ngx = mock_ngx("GET", "/hello/world", {}) + local _, _, _, uri, uri_captures = router.exec(_ngx) + assert.equal("/world", uri) + assert.is_nil(uri_captures) + end) + + it("returns no group capture from a [uri regex] match without groups", function() + local use_case = { + { + name = "api-1", + uris = { [[/users/\d+/profile]] }, + }, + } + + local router = assert(Router.new(use_case)) + + local _ngx = mock_ngx("GET", "/users/1984/profile", {}) + local _, _, _, _, uri_captures = router.exec(_ngx) + assert.is_nil(uri_captures) + end) end) describe("preserve Host header", function() @@ -1289,4 +1368,43 @@ describe("Router", function() end end) end) + + describe("has_capturing_groups()", function() + -- load the `assert.fail` assertion + require "spec.helpers" + + it("detects if a string has capturing groups", function() + local uris = { + ["/users/(foo)"] = true, + ["/users/()"] = true, + ["/users/()/foo"] = true, + ["/users/(hello(foo)world)"] = true, + ["/users/(hello(foo)world"] = true, + ["/users/(foo)/thing/(bar)"] = true, + ["/users/\\(foo\\)/thing/(bar)"] = true, + -- 0-indexed capture groups + ["()/world"] = true, + ["(/hello)/world"] = true, + + ["/users/\\(foo\\)"] = false, + ["/users/\\(\\)"] = false, + -- unbalanced capture groups + ["(/hello\\)/world"] = false, + ["/users/(foo"] = false, + ["/users/\\(foo)"] = false, + ["/users/(foo\\)"] = false, + } + + for uri, expected_to_match in pairs(uris) do + local has_captures = Router.has_capturing_groups(uri) + if expected_to_match and not has_captures then + assert.fail(uri, "has capturing groups that were not detected") + + elseif not expected_to_match and has_captures then + assert.fail(uri, "has no capturing groups but false-positives " .. + "were detected") + end + end + end) + end) end)