diff --git a/kong/core/router.lua b/kong/core/router.lua index fb9f2962e05b..c6f4072addf2 100644 --- a/kong/core/router.lua +++ b/kong/core/router.lua @@ -11,6 +11,7 @@ local upper = string.upper local lower = string.lower local find = string.find local fmt = string.format +local sub = string.sub local tonumber = tonumber local ipairs = ipairs local pairs = pairs @@ -600,8 +601,25 @@ function _M.new(apis) end - if api_t.upstream.path then - new_uri = api_t.upstream.path .. new_uri + local upstream_path = api_t.upstream.path + if upstream_path then + if new_uri == "/" then + new_uri = upstream_path + + else + new_uri = upstream_path .. (sub(upstream_path, -1) == "/" and sub(new_uri, 2) or new_uri) + end + end + + + local req_uri_slash = sub(uri, -1) == "/" + local new_uri_slash = sub(new_uri, -1) == "/" + + if new_uri_slash and not req_uri_slash and new_uri ~= "/" then + new_uri = sub(new_uri, 1, -2) + + elseif not new_uri_slash and req_uri_slash and uri ~= "/" then + new_uri = new_uri .. "/" end diff --git a/kong/dao/schemas/apis.lua b/kong/dao/schemas/apis.lua index 6540a14cbf64..5e594f4bae45 100644 --- a/kong/dao/schemas/apis.lua +++ b/kong/dao/schemas/apis.lua @@ -14,10 +14,6 @@ local function validate_upstream_url(value) end end - if parsed_url.path and string.sub(value, #value) == "/" then - return false, "Cannot end with a slash" - end - return true end diff --git a/spec/01-unit/07-entities_schemas_spec.lua b/spec/01-unit/07-entities_schemas_spec.lua index 9077fe287747..859f94d5d242 100644 --- a/spec/01-unit/07-entities_schemas_spec.lua +++ b/spec/01-unit/07-entities_schemas_spec.lua @@ -78,15 +78,14 @@ describe("Entities Schemas", function() assert.equal("Supported protocols are HTTP and HTTPS", errors.upstream_url) end) - it("should return error with final slash in upstream_url", function() + it("should not return error with final slash in upstream_url", function() local valid, errors = validate_entity({ name = "mockbin", upstream_url = "http://mockbin.com/", hosts = { "mockbin.com" }, }, api_schema) - assert.is_false(valid) - assert.equal("Cannot end with a slash", errors.upstream_url) - + assert.is_nil(errors) + assert.is_true(valid) end) it("should validate with upper case protocol", function() diff --git a/spec/01-unit/11-router_spec.lua b/spec/01-unit/11-router_spec.lua index afd76b68373a..69e1f6934761 100644 --- a/spec/01-unit/11-router_spec.lua +++ b/spec/01-unit/11-router_spec.lua @@ -896,5 +896,58 @@ describe("Router", function() end) end) end) + + describe("trailing slash", function() + local checks = { + -- upstream url request path expected path strip uri + { "/", "/", "/", true }, + { "/", "/foo/bar", "/", true }, + { "/", "/foo/bar/", "/", true }, + { "/foo/bar", "/", "/foo/bar", true }, + { "/foo/bar/", "/", "/foo/bar/", true }, + { "/foo/bar", "/foo/bar", "/foo/bar", true }, + { "/foo/bar/", "/foo/bar", "/foo/bar", true }, + { "/foo/bar", "/foo/bar/", "/foo/bar/", true }, + { "/foo/bar/", "/foo/bar/", "/foo/bar/", true }, + { "/", "/", "/", false }, + { "/", "/foo/bar", "/foo/bar", false }, + { "/", "/foo/bar/", "/foo/bar/", false }, + { "/foo/bar", "/", "/foo/bar", false }, + { "/foo/bar/", "/", "/foo/bar/", false }, + { "/foo/bar", "/foo/bar", "/foo/bar/foo/bar", false }, + { "/foo/bar/", "/foo/bar", "/foo/bar/foo/bar", false }, + { "/foo/bar", "/foo/bar/", "/foo/bar/foo/bar/", false }, + { "/foo/bar/", "/foo/bar/", "/foo/bar/foo/bar/", false }, + } + + for i, args in ipairs(checks) do + + local config = args[4] == true and "(strip_uri = on) " or "(strip_uri = off)" + local space = string.sub(args[1], -1) == "/" and "" or " " + + it(config .. " is not appended to upstream uri " .. args[1] .. + space .. " when requesting " .. args[2], function() + + local use_case_apis = { + { + name = "api-1", + strip_uri = args[4], + upstream_url = "http://httpbin.org" .. args[1], + uris = { + args[2], + }, + } + } + + local router = assert(Router.new(use_case_apis) ) + + local _ngx = mock_ngx("GET", args[2], {}) + local api, upstream = router.exec(_ngx) + assert.same(use_case_apis[1], api) + assert.equal(args[1], upstream.path) + assert.equal(args[3], _ngx.var.uri) + end) + end + end) end) end) diff --git a/spec/02-integration/05-proxy/01-router_spec.lua b/spec/02-integration/05-proxy/01-router_spec.lua index 4a5ed55af0bd..211e9bfb4329 100644 --- a/spec/02-integration/05-proxy/01-router_spec.lua +++ b/spec/02-integration/05-proxy/01-router_spec.lua @@ -387,4 +387,98 @@ describe("Router", function() assert.equal("fixture-api", res.headers["kong-api-name"]) end) end) + + describe("trailing slash", function() + local checks = { + -- upstream url request path expected path strip uri + { "/", "/", "/", nil }, + { "/", "/foo/bar", "/", nil }, + { "/", "/foo/bar/", "/", nil }, + { "/foo/bar", "/", "/foo/bar", nil }, + { "/foo/bar/", "/", "/foo/bar/", nil }, + { "/foo/bar", "/foo/bar", "/foo/bar", nil }, + { "/foo/bar/", "/foo/bar", "/foo/bar", nil }, + { "/foo/bar", "/foo/bar/", "/foo/bar/", nil }, + { "/foo/bar/", "/foo/bar/", "/foo/bar/", nil }, + { "/", "/", "/", true }, + { "/", "/foo/bar", "/", true }, + { "/", "/foo/bar/", "/", true }, + { "/foo/bar", "/", "/foo/bar", true }, + { "/foo/bar/", "/", "/foo/bar/", true }, + { "/foo/bar", "/foo/bar", "/foo/bar", true }, + { "/foo/bar/", "/foo/bar", "/foo/bar", true }, + { "/foo/bar", "/foo/bar/", "/foo/bar/", true }, + { "/foo/bar/", "/foo/bar/", "/foo/bar/", true }, + { "/", "/", "/", false }, + { "/", "/foo/bar", "/foo/bar", false }, + { "/", "/foo/bar/", "/foo/bar/", false }, + { "/foo/bar", "/", "/foo/bar", false }, + { "/foo/bar/", "/", "/foo/bar/", false }, + { "/foo/bar", "/foo/bar", "/foo/bar/foo/bar", false }, + { "/foo/bar/", "/foo/bar", "/foo/bar/foo/bar", false }, + { "/foo/bar", "/foo/bar/", "/foo/bar/foo/bar/", false }, + { "/foo/bar/", "/foo/bar/", "/foo/bar/foo/bar/", false }, + } + + setup(function() + helpers.dao:truncate_tables() + + for i, args in ipairs(checks) do + assert(helpers.dao.apis:insert { + name = "localbin-" .. i, + strip_uri = args[4], + upstream_url = "http://localhost:9999" .. args[1], + hosts = { + "localbin-" .. i .. ".com", + }, + uris = { + args[2], + } + }) + end + + assert(helpers.start_kong { + nginx_conf = "spec/fixtures/custom_nginx.template", + }) + end) + + teardown(function() + helpers.stop_kong() + end) + + local function check(i, request_uri, expected_uri) + return function() + local res = assert(client:send { + method = "GET", + path = request_uri, + headers = { + ["Host"] = "localbin-" .. i .. ".com", + } + }) + + local json = assert.res_status(200, res) + local data = cjson.decode(json) + + assert.equal(expected_uri, data.vars.request_uri) + end + end + + for i, args in ipairs(checks) do + + local config = "(strip_uri = n/a)" + + if args[4] == true then + config = "(strip_uri = on) " + + elseif args[4] == false then + config = "(strip_uri = off)" + end + + local space = string.sub(args[1], -1) == "/" and "" or " " + + it(config .. " is not appended to upstream uri " .. args[1] .. + space .. " when requesting " .. args[2], + check(i, args[2], args[3])) + end + end) end) diff --git a/spec/fixtures/custom_nginx.template b/spec/fixtures/custom_nginx.template index f5967d0f5280..7072e7d8619a 100644 --- a/spec/fixtures/custom_nginx.template +++ b/spec/fixtures/custom_nginx.template @@ -193,6 +193,51 @@ http { return 200; } + location / { + content_by_lua_block { + local cjson = require "cjson" + local var = ngx.var + local req = ngx.req + + req.read_body() + + local json = cjson.encode { + vars = { + uri = var.uri, + host = var.host, + hostname = var.hostname, + https = var.https, + scheme = var.scheme, + is_args = var.is_args, + server_addr = var.server_addr, + server_port = var.server_port, + server_name = var.server_name, + server_protocol = var.server_protocol, + remote_addr = var.remote_addr, + remote_port = var.remote_port, + realip_remote_addr = var.realip_remote_addr, + realip_remote_port = var.realip_remote_port, + binary_remote_addr = var.binary_remote_addr, + request = var.request, + request_uri = var.request_uri, + request_time = var.request_time, + request_length = var.request_length, + request_method = var.request_method, + bytes_received = var.bytes_received, + }, + method = req.get_method(), + headers = req.get_headers(0), + uri_args = req.get_uri_args(0), + post_args = req.get_post_args(0), + http_version = req.http_version(), + } + + ngx.status = 200 + ngx.say(json) + ngx.exit(200) + } + } + location /headers-inspect { content_by_lua_block { local cjson = require "cjson"