diff --git a/kong/core/handler.lua b/kong/core/handler.lua index ed5a3ac7a8e5..c83af12d5659 100644 --- a/kong/core/handler.lua +++ b/kong/core/handler.lua @@ -18,6 +18,7 @@ -- -- @see https://github.com/openresty/lua-nginx-module#ngxctx +local url = require "socket.url" local utils = require "kong.tools.utils" local reports = require "kong.core.reports" local stringy = require "stringy" @@ -25,8 +26,10 @@ local resolver = require "kong.core.resolver" local constants = require "kong.constants" local certificate = require "kong.core.certificate" -local table_insert = table.insert +local type = type +local ipairs = ipairs local math_floor = math.floor +local table_insert = table.insert local MULT = 10^3 local function round(num) @@ -43,7 +46,7 @@ return { access = { before = function() ngx.ctx.KONG_ACCESS_START = ngx.now() - ngx.ctx.api, ngx.ctx.upstream_url = resolver.execute() + ngx.ctx.api, ngx.ctx.upstream_url, ngx.var.upstream_host = resolver.execute(ngx.var.request_uri, ngx.req.get_headers()) end, -- Only executed if the `resolver` module found an API and allows nginx to proxy it. after = function() @@ -53,12 +56,14 @@ return { ngx.ctx.KONG_PROXIED = true -- Append any querystring parameters modified during plugins execution - local upstream_url = unpack(stringy.split(ngx.ctx.upstream_url, "?")) - if utils.table_size(ngx.req.get_uri_args()) > 0 then - upstream_url = upstream_url.."?"..ngx.encode_args(ngx.req.get_uri_args()) + local upstream_url = ngx.ctx.upstream_url + local uri_args = ngx.req.get_uri_args() + if utils.table_size(uri_args) > 0 then + upstream_url = upstream_url.."?"..utils.encode_args(uri_args) end - -- Set the `$upstream_url` variable for the `proxy_pass` nginx's directive. + -- Set the `$upstream_url` and `$upstream_host` variables for the `proxy_pass` nginx + -- directive in kong.yml. ngx.var.upstream_url = upstream_url end }, diff --git a/kong/core/resolver.lua b/kong/core/resolver.lua index 46d020505cd8..f76bb4c7e526 100644 --- a/kong/core/resolver.lua +++ b/kong/core/resolver.lua @@ -3,6 +3,7 @@ local cache = require "kong.tools.database_cache" local stringy = require "stringy" local constants = require "kong.constants" local responses = require "kong.tools.responses" + local table_insert = table.insert local string_match = string.match local string_find = string.find @@ -46,7 +47,7 @@ local function get_upstream_url(api) return result end -local function get_host_from_url(val) +local function get_host_from_upstream_url(val) local parsed_url = url.parse(val) local port @@ -99,7 +100,7 @@ function _M.load_apis_in_memory() end function _M.find_api_by_request_host(req_headers, apis_dics) - local all_hosts = {} + local hosts_list = {} for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do local hosts = req_headers[header_name] if hosts then @@ -109,9 +110,9 @@ function _M.find_api_by_request_host(req_headers, apis_dics) -- for all values of this header, try to find an API using the apis_by_dns dictionnary for _, host in ipairs(hosts) do host = unpack(stringy.split(host, ":")) - table_insert(all_hosts, host) + table_insert(hosts_list, host) if apis_dics.by_dns[host] then - return apis_dics.by_dns[host] + return apis_dics.by_dns[host], host else -- If the API was not found in the dictionary, maybe it is a wildcard request_host. -- In that case, we need to loop over all of them. @@ -125,7 +126,7 @@ function _M.find_api_by_request_host(req_headers, apis_dics) end end - return nil, all_hosts + return nil, nil, hosts_list end -- To do so, we have to compare entire URI segments (delimited by "/"). @@ -180,13 +181,14 @@ end -- We keep APIs in the database cache for a longer time than usual. -- @see https://github.com/Mashape/kong/issues/15 for an improvement on this. -- --- @param `uri` The URI for this request. --- @return `err` Any error encountered during the retrieval. --- @return `api` The retrieved API, if any. --- @return `hosts` The list of headers values found in Host and X-Host-Override. +-- @param `uri` The URI for this request. +-- @return `err` Any error encountered during the retrieval. +-- @return `api` The retrieved API, if any. +-- @return `matched_host` The host that was matched for this API, if matched. +-- @return `hosts` The list of headers values found in Host and X-Host-Override. -- @return `strip_request_path_pattern` If the API was retrieved by request_path, contain the pattern to strip it from the URI. -local function find_api(uri) - local api, all_hosts, strip_request_path_pattern +local function find_api(uri, headers) + local api, matched_host, hosts_list, strip_request_path_pattern -- Retrieve all APIs local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", _M.load_apis_in_memory, 60) -- 60 seconds cache, longer than usual @@ -195,37 +197,37 @@ local function find_api(uri) end -- Find by Host header - api, all_hosts = _M.find_api_by_request_host(ngx.req.get_headers(), apis_dics) - + api, matched_host, hosts_list = _M.find_api_by_request_host(headers, apis_dics) -- If it was found by Host, return if api then - return nil, api, all_hosts + return nil, api, matched_host, hosts_list end -- Otherwise, we look for it by request_path. We have to loop over all APIs and compare the requested URI. api, strip_request_path_pattern = _M.find_api_by_request_path(uri, apis_dics.request_path_arr) - return nil, api, all_hosts, strip_request_path_pattern + return nil, api, nil, hosts_list, strip_request_path_pattern end local function url_has_path(url) - local _, count_slashes = string.gsub(url, "/", "") + local _, count_slashes = string_gsub(url, "/", "") return count_slashes > 2 end -function _M.execute() - local uri = stringy.split(ngx.var.request_uri, "?")[1] - local err, api, hosts, strip_request_path_pattern = find_api(uri) +function _M.execute(request_uri, request_headers) + local uri = unpack(stringy.split(request_uri, "?")) + local err, api, matched_host, hosts_list, strip_request_path_pattern = find_api(uri, request_headers) if err then return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) elseif not api then return responses.send_HTTP_NOT_FOUND { message = "API not found with these values", - request_host = hosts, + request_host = hosts_list, request_path = uri } end + local upstream_host local upstream_url = get_upstream_url(api) -- If API was retrieved by request_path and the request_path needs to be stripped @@ -235,13 +237,15 @@ function _M.execute() upstream_url = upstream_url..uri - -- Set the if api.preserve_host then - ngx.var.upstream_host = ngx.req.get_headers()["host"] - else - ngx.var.upstream_host = get_host_from_url(upstream_url) + upstream_host = matched_host end - return api, upstream_url + + if upstream_host == nil then + upstream_host = get_host_from_upstream_url(upstream_url) + end + + return api, upstream_url, upstream_host end return _M diff --git a/kong/tools/http_client.lua b/kong/tools/http_client.lua index 14a25957ae1d..5b8e9ccfa7bc 100644 --- a/kong/tools/http_client.lua +++ b/kong/tools/http_client.lua @@ -55,7 +55,7 @@ local function with_body(method) else headers["content-type"] = "application/x-www-form-urlencoded" if type(body) == "table" then - body = ngx.encode_args(body) + body = ngx.encode_args(body, true) end end @@ -75,7 +75,7 @@ local function without_body(method) if not headers then headers = {} end if querystring then - url = string.format("%s?%s", url, ngx.encode_args(querystring)) + url = string.format("%s?%s", url, ngx.encode_args(querystring, true)) end return http_call { diff --git a/kong/tools/ngx_stub.lua b/kong/tools/ngx_stub.lua index 08deee0b54c3..e8ff9b5c872c 100644 --- a/kong/tools/ngx_stub.lua +++ b/kong/tools/ngx_stub.lua @@ -5,6 +5,102 @@ -- Monkeypatches the global `ngx` table. local reg = require "rex_pcre" +local utils = require "kong.tools.utils" + +-- DICT Proxy +-- https://github.com/bsm/fakengx/blob/master/fakengx.lua + +local SharedDict = {} + +local function set(data, key, value) + data[key] = { + value = value, + info = {expired = false} + } +end + +function SharedDict:new() + return setmetatable({data = {}}, {__index = self}) +end + +function SharedDict:get(key) + return self.data[key] and self.data[key].value, nil +end + +function SharedDict:set(key, value) + set(self.data, key, value) + return true, nil, false +end + +SharedDict.safe_set = SharedDict.set + +function SharedDict:add(key, value) + if self.data[key] ~= nil then + return false, "exists", false + end + + set(self.data, key, value) + return true, nil, false +end + +function SharedDict:replace(key, value) + if self.data[key] == nil then + return false, "not found", false + end + + set(self.data, key, value) + return true, nil, false +end + +function SharedDict:delete(key) + self.data[key] = nil +end + +function SharedDict:incr(key, value) + if not self.data[key] then + return nil, "not found" + elseif type(self.data[key].value) ~= "number" then + return nil, "not a number" + end + + self.data[key].value = self.data[key].value + value + return self.data[key].value, nil +end + +function SharedDict:flush_all() + for _, item in pairs(self.data) do + item.info.expired = true + end +end + +function SharedDict:flush_expired(n) + local data = self.data + local flushed = 0 + + for key, item in pairs(self.data) do + if item.info.expired then + data[key] = nil + flushed = flushed + 1 + if n and flushed == n then + break + end + end + end + + self.data = data + + return flushed +end + +local shared = {} +local shared_mt = { + __index = function(self, key) + if shared[key] == nil then + shared[key] = SharedDict:new() + end + return shared[key] + end +} _G.ngx = { req = {}, @@ -19,6 +115,7 @@ _G.ngx = { timer = { at = function() end }, + shared = setmetatable({}, shared_mt), re = { match = reg.match, gsub = function(str, pattern, sub) @@ -29,37 +126,5 @@ _G.ngx = { encode_base64 = function(str) return string.format("base64_%s", str) end, - -- Builds a querystring from a table, separated by `&` - -- @param `tab` The key/value parameters - -- @param `key` The parent key if the value is multi-dimensional (optional) - -- @return `querystring` A string representing the built querystring - encode_args = function(tab, key) - local query = {} - local keys = {} - - for k in pairs(tab) do - keys[#keys+1] = k - end - - table.sort(keys) - - for _, name in ipairs(keys) do - local value = tab[name] - if key then - name = string.format("%s[%s]", tostring(key), tostring(name)) - end - if type(value) == "table" then - query[#query+1] = ngx.encode_args(value, name) - else - value = tostring(value) - if value ~= "" then - query[#query+1] = string.format("%s=%s", name, value) - else - query[#query+1] = name - end - end - end - - return table.concat(query, "&") - end + encode_args = utils.encode_args } diff --git a/kong/tools/utils.lua b/kong/tools/utils.lua index 8423d3cd5e8b..3df7b96a6b75 100644 --- a/kong/tools/utils.lua +++ b/kong/tools/utils.lua @@ -1,8 +1,23 @@ --- --- Module containing some general utility functions +-- Module containing some general utility functions used in many places in Kong. +-- +-- NOTE: Before implementing a function here, consider if it will be used in many places +-- across Kong. If not, a local function in the appropriate module is prefered. +-- +local url = require "socket.url" local uuid = require "lua_uuid" +local type = type +local pairs = pairs +local ipairs = ipairs +local tostring = tostring +local table_sort = table.sort +local table_concat = table.concat +local table_insert = table.insert +local string_find = string.find +local string_format = string.format + local _M = {} --- Generates a random unique string @@ -11,6 +26,63 @@ function _M.random_string() return uuid():gsub("-", "") end +--- URL escape and format key and value +-- An obligatory url.unescape pass must be done to prevent double-encoding +-- already encoded values (which contain a '%' character that `url.escape` escapes) +local function encode_args_value(key, value, raw) + if not raw then + key = url.unescape(key) + key = url.escape(key) + end + if value ~= nil then + if not raw then + value = url.unescape(value) + value = url.escape(value) + end + return string_format("%s=%s", key, value) + else + return key + end +end + +--- Encode a Lua table to a querystring +-- Tries to mimic ngx_lua's `ngx.encode_args`, but also percent-encode querystring values. +-- Supports multi-value query args, boolean values. +-- @TODO drop and use `ngx.encode_args` once it implements percent-encoding. +-- @see https://github.com/Mashape/kong/issues/749 +-- @param[type=table] args A key/value table containing the query args to encode +-- @treturn string A valid querystring (without the prefixing '?') +function _M.encode_args(args, raw) + local query = {} + local keys = {} + + for k in pairs(args) do + keys[#keys+1] = k + end + + table_sort(keys) + + for _, key in ipairs(keys) do + local value = args[key] + if type(value) == "table" then + for _, sub_value in ipairs(value) do + query[#query+1] = encode_args_value(key, sub_value, raw) + end + elseif value == true then + query[#query+1] = encode_args_value(key, nil, raw) + elseif value ~= false and value ~= nil then + value = tostring(value) + if value ~= "" then + query[#query+1] = encode_args_value(key, value, raw) + elseif raw then + query[#query+1] = key + end + end + end + + return table_concat(query, "&") +end + --- Calculates a table size. -- All entries both in array and hash part. -- @param t The table to use @@ -99,7 +171,7 @@ function _M.add_error(errors, k, v) errors[k] = setmetatable({errors[k]}, err_list_mt) end - table.insert(errors[k], v) + table_insert(errors[k], v) else errors[k] = v end @@ -118,7 +190,7 @@ function _M.load_module_if_exists(module_name) if status then return true, res -- Here we match any character because if a module has a dash '-' in its name, we would need to escape it. - elseif type(res) == "string" and string.find(res, "module '"..module_name.."' not found", nil, true) then + elseif type(res) == "string" and string_find(res, "module '"..module_name.."' not found", nil, true) then return false else error(res) diff --git a/spec/integration/admin_api/apis_routes_spec.lua b/spec/integration/admin_api/apis_routes_spec.lua index 572aed7d214f..762cfe67daba 100644 --- a/spec/integration/admin_api/apis_routes_spec.lua +++ b/spec/integration/admin_api/apis_routes_spec.lua @@ -18,77 +18,69 @@ describe("Admin API", function() local BASE_URL = spec_helper.API_URL.."/apis/" describe("POST", function() - it("[SUCCESS] should create an API", function() send_content_types(BASE_URL, "POST", { - name="api-POST-tests", - request_host="api.mockbin.com", - upstream_url="http://mockbin.com" - }, 201, nil, {drop_db=true}) + name = "api-POST-tests", + request_host = "api.mockbin.com", + upstream_url = "http://mockbin.com" + }, 201, nil, {drop_db = true}) end) - it("[FAILURE] should notify of malformed body", function() local response, status = http_client.post(BASE_URL, '{"hello":"world"', {["content-type"] = "application/json"}) assert.are.equal(400, status) assert.are.equal('{"message":"Cannot parse JSON body"}\n', response) end) - it("[FAILURE] should return proper errors", function() send_content_types(BASE_URL, "POST", {}, 400, '{"upstream_url":"upstream_url is required","request_path":"At least a \'request_host\' or a \'request_path\' must be specified","request_host":"At least a \'request_host\' or a \'request_path\' must be specified"}') - send_content_types(BASE_URL, "POST", {request_host="api.mockbin.com"}, + send_content_types(BASE_URL, "POST", {request_host = "api.mockbin.com"}, 400, '{"upstream_url":"upstream_url is required"}') send_content_types(BASE_URL, "POST", { - request_host="api.mockbin.com", - upstream_url="http://mockbin.com" + request_host = "api.mockbin.com", + upstream_url = "http://mockbin.com" }, 409, '{"request_host":"request_host already exists with value \'api.mockbin.com\'"}') end) - end) describe("PUT", function() - setup(function() spec_helper.drop_db() end) it("[SUCCESS] should create and update", function() local api = send_content_types(BASE_URL, "PUT", { - name="api-PUT-tests", - request_host="api.mockbin.com", - upstream_url="http://mockbin.com" - }, 201, nil, {drop_db=true}) + name = "api-PUT-tests", + request_host = "api.mockbin.com", + upstream_url = "http://mockbin.com" + }, 201, nil, {drop_db = true}) api = send_content_types(BASE_URL, "PUT", { - id=api.id, - name="api-PUT-tests-updated", - request_host="updated-api.mockbin.com", - upstream_url="http://mockbin.com" + id = api.id, + name = "api-PUT-tests-updated", + request_host = "updated-api.mockbin.com", + upstream_url = "http://mockbin.com" }, 200) assert.equal("api-PUT-tests-updated", api.name) end) - it("[FAILURE] should return proper errors", function() send_content_types(BASE_URL, "PUT", {}, 400, '{"upstream_url":"upstream_url is required","request_path":"At least a \'request_host\' or a \'request_path\' must be specified","request_host":"At least a \'request_host\' or a \'request_path\' must be specified"}') - send_content_types(BASE_URL, "PUT", {request_host="api.mockbin.com"}, + send_content_types(BASE_URL, "PUT", {request_host = "api.mockbin.com"}, 400, '{"upstream_url":"upstream_url is required"}') send_content_types(BASE_URL, "PUT", { - request_host="updated-api.mockbin.com", - upstream_url="http://mockbin.com" + request_host = "updated-api.mockbin.com", + upstream_url = "http://mockbin.com" }, 409, '{"request_host":"request_host already exists with value \'updated-api.mockbin.com\'"}') end) - end) describe("GET", function() - setup(function() spec_helper.drop_db() spec_helper.seed_db(10) @@ -102,9 +94,8 @@ describe("Admin API", function() assert.equal(10, table.getn(body.data)) assert.equal(10, body.total) end) - it("should retrieve a paginated set", function() - local response, status = http_client.get(BASE_URL, {size=3}) + local response, status = http_client.get(BASE_URL, {size = 3}) assert.equal(200, status) local body_page_1 = json.decode(response) assert.truthy(body_page_1.data) @@ -112,7 +103,7 @@ describe("Admin API", function() assert.truthy(body_page_1.next) assert.equal(10, body_page_1.total) - response, status = http_client.get(BASE_URL, {size=3,offset=body_page_1.next}) + response, status = http_client.get(BASE_URL, {size = 3, offset = body_page_1.next}) assert.equal(200, status) local body_page_2 = json.decode(response) assert.truthy(body_page_2.data) @@ -121,7 +112,7 @@ describe("Admin API", function() assert.not_same(body_page_1, body_page_2) assert.equal(10, body_page_2.total) - response, status = http_client.get(BASE_URL, {size=4,offset=body_page_2.next}) + response, status = http_client.get(BASE_URL, {size = 4, offset = body_page_2.next}) assert.equal(200, status) local body_page_3 = json.decode(response) assert.truthy(body_page_3.data) @@ -130,7 +121,6 @@ describe("Admin API", function() assert.falsy(body_page_3.next) assert.not_same(body_page_2, body_page_3) end) - end) end) @@ -141,71 +131,64 @@ describe("Admin API", function() setup(function() spec_helper.drop_db() local fixtures = spec_helper.insert_fixtures { - api = {{ request_host="mockbin.com", upstream_url="http://mockbin.com" }} + api = { + {request_host = "mockbin.com", upstream_url = "http://mockbin.com"} + } } api = fixtures.api[1] end) describe("GET", function() - it("should retrieve by id", function() local response, status = http_client.get(BASE_URL..api.id) assert.equal(200, status) local body = json.decode(response) assert.same(api, body) end) - it("should retrieve by name", function() local response, status = http_client.get(BASE_URL..api.name) assert.equal(200, status) local body = json.decode(response) assert.same(api, body) end) - end) describe("PATCH", function() - it("[SUCCESS] should update an API", function() - local response, status = http_client.patch(BASE_URL..api.id, {name="patch-updated"}) + local response, status = http_client.patch(BASE_URL..api.id, {name = "patch-updated"}) assert.equal(200, status) local body = json.decode(response) assert.same("patch-updated", body.name) api = body - response, status = http_client.patch(BASE_URL..api.name, {name="patch-updated-json"}, {["content-type"]="application/json"}) + response, status = http_client.patch(BASE_URL..api.name, {name = "patch-updated-json"}, {["content-type"] = "application/json"}) assert.equal(200, status) body = json.decode(response) assert.same("patch-updated-json", body.name) api = body end) - it("[FAILURE] should return proper errors", function() - local _, status = http_client.patch(BASE_URL.."hello", {name="patch-updated"}) + local _, status = http_client.patch(BASE_URL.."hello", {name = "patch-updated"}) assert.equal(404, status) - local response, status = http_client.patch(BASE_URL..api.id, {upstream_url=""}) + local response, status = http_client.patch(BASE_URL..api.id, {upstream_url = ""}) assert.equal(400, status) assert.equal('{"upstream_url":"upstream_url is not a url"}\n', response) end) - end) describe("DELETE", function() - it("[FAILURE] should return proper errors", function() local _, status = http_client.delete(BASE_URL.."hello") assert.equal(404, status) end) - it("[SUCCESS] should delete an API", function() local response, status = http_client.delete(BASE_URL..api.id) assert.equal(204, status) assert.falsy(response) end) - end) describe("/apis/:api/plugins/", function() @@ -214,19 +197,19 @@ describe("Admin API", function() setup(function() spec_helper.drop_db() local fixtures = spec_helper.insert_fixtures { - api = {{ request_host="mockbin.com", upstream_url="http://mockbin.com" }} + api = { + {request_host = "mockbin.com", upstream_url = "http://mockbin.com"} + } } api = fixtures.api[1] BASE_URL = BASE_URL..api.id.."/plugins/" end) describe("POST", function() - it("[FAILURE] should return proper errors", function() send_content_types(BASE_URL, "POST", {}, 400, '{"name":"name is required"}') end) - it("[SUCCESS] should create a plugin configuration", function() local response, status = http_client.post(BASE_URL, { name = "key-auth", @@ -240,15 +223,14 @@ describe("Admin API", function() response, status = http_client.post(BASE_URL, { name = "key-auth", - config = {key_names={"apikey"}} - }, {["content-type"]="application/json"}) + config = {key_names = {"apikey"}} + }, {["content-type"] = "application/json"}) assert.equal(201, status) body = json.decode(response) _, err = dao_plugins:delete({id = body.id, name = body.name}) assert.falsy(err) end) - end) describe("PUT", function() @@ -258,7 +240,6 @@ describe("Admin API", function() send_content_types(BASE_URL, "PUT", {}, 400, '{"name":"name is required"}') end) - it("[SUCCESS] should create and update", function() local response, status = http_client.put(BASE_URL, { name = "key-auth", @@ -288,7 +269,6 @@ describe("Admin API", function() body = json.decode(response) assert.equal("updated_apikey", body.config.key_names[1]) end) - it("should override a plugin's `config` if partial", function() local response, status = http_client.put(BASE_URL, { id = plugin_id, @@ -313,7 +293,6 @@ describe("Admin API", function() end) describe("GET", function() - it("should retrieve all", function() local response, status = http_client.get(BASE_URL) assert.equal(200, status) @@ -321,7 +300,6 @@ describe("Admin API", function() assert.truthy(body.data) assert.equal(1, table.getn(body.data)) end) - end) describe("/apis/:api/plugins/:plugin", function() @@ -331,8 +309,12 @@ describe("Admin API", function() setup(function() spec_helper.drop_db() local fixtures = spec_helper.insert_fixtures { - api = {{ request_host="mockbin.com", upstream_url="http://mockbin.com" }}, - plugin = {{ name = "key-auth", config = { key_names = { "apikey" }}, __api = 1 }} + api = { + {request_host="mockbin.com", upstream_url="http://mockbin.com"} + }, + plugin = { + {name = "key-auth", config = {key_names = {"apikey"}}, __api = 1} + } } api = fixtures.api[1] plugin = fixtures.plugin[1] @@ -340,35 +322,30 @@ describe("Admin API", function() end) describe("GET", function() - it("should retrieve by id", function() local response, status = http_client.get(BASE_URL..plugin.id) assert.equal(200, status) local body = json.decode(response) assert.same(plugin, body) end) - end) describe("PATCH", function() - it("[SUCCESS] should update a plugin", function() - local response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"]={"key_updated"}}) + local response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"] = {"key_updated"}}) assert.equal(200, status) local body = json.decode(response) assert.same("key_updated", body.config.key_names[1]) - response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"]={"key_updated-json"}}, {["content-type"]="application/json"}) + response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"] = {"key_updated-json"}}, {["content-type"] = "application/json"}) assert.equal(200, status) body = json.decode(response) assert.same("key_updated-json", body.config.key_names[1]) end) - it("[FAILURE] should return proper errors", function() local _, status = http_client.patch(BASE_URL.."b6cca0aa-4537-11e5-af97-23a06d98af51", {}) assert.equal(404, status) end) - it("should not override a plugin's `config` if partial", function() -- This is delicate since a plugin's `config` is a text field in a DB like Cassandra local _, status = http_client.patch(BASE_URL..plugin.id, { @@ -385,22 +362,18 @@ describe("Admin API", function() assert.same({"key_set_null_test_updated"}, body.config.key_names) assert.equal(true, body.config.hide_credentials) end) - end) describe("DELETE", function() - it("[FAILURE] should return proper errors", function() local _, status = http_client.delete(BASE_URL.."b6cca0aa-4537-11e5-af97-23a06d98af51") assert.equal(404, status) end) - it("[SUCCESS] should delete a plugin configuration", function() local response, status = http_client.delete(BASE_URL..plugin.id) assert.equal(204, status) assert.falsy(response) end) - end) end) end) diff --git a/spec/integration/proxy/api_resolver_spec.lua b/spec/integration/proxy/resolver_spec.lua similarity index 73% rename from spec/integration/proxy/api_resolver_spec.lua rename to spec/integration/proxy/resolver_spec.lua index 90557dc9401b..597b0386ef4d 100644 --- a/spec/integration/proxy/api_resolver_spec.lua +++ b/spec/integration/proxy/resolver_spec.lua @@ -21,7 +21,6 @@ local function parse_cert(cert) end describe("Resolver", function() - setup(function() spec_helper.prepare_db() spec_helper.insert_fixtures { @@ -55,22 +54,6 @@ describe("Resolver", function() spec_helper.stop_kong() end) - describe("Test URI", function() - - it("should URL decode the URI with querystring", function() - local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", { hello = "world"}, {host = "mockbin-uri.com"}) - assert.equal(200, status) - assert.equal("http://mockbin.org/request/hello%2f?hello=world", cjson.decode(response).url) - end) - - it("should URL decode the URI without querystring", function() - local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", nil, {host = "mockbin-uri.com"}) - assert.equal(200, status) - assert.equal("http://mockbin.org/request/hello%2f", cjson.decode(response).url) - end) - - end) - describe("Inexistent API", function() it("should return Not Found when the API is not in Kong", function() local response, status, headers = http_client.get(spec_helper.STUB_GET_URL, nil, {host = "foo.com"}) @@ -171,18 +154,6 @@ describe("Resolver", function() assert.equal("/somerequest_path/status/200", body.request_path) assert.equal(404, status) end) - it("should proxy and strip the request_path if `strip_request_path` is true", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request") - assert.equal(200, status) - local body = cjson.decode(response) - assert.equal("http://mockbin.com/request", body.url) - end) - it("should proxy and strip the request_path if `strip_request_path` is true if request_path has pattern characters", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request") - assert.equal(200, status) - local body = cjson.decode(response) - assert.equal("http://mockbin.com/request", body.url) - end) it("should proxy when the request_path has a deep level", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/deep/request_path/status/200") assert.equal(200, status) @@ -191,33 +162,11 @@ describe("Resolver", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/mockbin?foo=bar") assert.equal(200, status) end) - it("should not strip if the `request_path` pattern is repeated in the request_uri", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request") - assert.equal(200, status) - local body = cjson.decode(response) - local upstream_url = body.log.entries[1].request.url - assert.equal("http://mockbin.com/har/of/request", upstream_url) - end) - it("should not add a trailing slash when strip_path is enabled", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash", { hello = "world"}) - assert.equal(200, status) - assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) - end) it("should not add a trailing slash when strip_path is disabled", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash2", { hello = "world"}) + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash2", {hello = "world"}) assert.equal(200, status) assert.equal("http://www.mockbin.org/request/test-trailing-slash2?hello=world", cjson.decode(response).url) end) - it("should not add a trailing slash when strip_path is enabled and upstream_url has no path", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash3/request", { hello = "world"}) - assert.equal(200, status) - assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) - end) - it("should not add a trailing slash when strip_path is enabled and upstream_url has single path", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash4/request", { hello = "world"}) - assert.equal(200, status) - assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) - end) end) it("should return the correct Server and Via headers when the request was proxied", function() @@ -240,7 +189,7 @@ describe("Resolver", function() end) end) - describe("Preseve Host", function() + describe("preserve_host", function() it("should not preserve the host (default behavior)", function() local response, status = http_client.get(PROXY_URL.."/get", nil, {host = "httpbin-nopreserve.com"}) assert.equal(200, status) @@ -255,5 +204,69 @@ describe("Resolver", function() assert.equal("httpbin-preserve.com", parsed_response.headers["Host"]) end) end) - + + describe("strip_path", function() + it("should strip the request_path if `strip_request_path` is true", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request") + assert.equal(200, status) + local body = cjson.decode(response) + assert.equal("http://mockbin.com/request", body.url) + end) + it("should strip the request_path if `strip_request_path` is true if `request_path` has pattern characters", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request") + assert.equal(200, status) + local body = cjson.decode(response) + assert.equal("http://mockbin.com/request", body.url) + end) + it("should not strip if the `request_path` pattern is repeated in the request_uri", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request") + assert.equal(200, status) + local body = cjson.decode(response) + local upstream_url = body.log.entries[1].request.url + assert.equal("http://mockbin.com/har/of/request", upstream_url) + end) + it("should not add a trailing slash when strip_path is enabled", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + it("should not add a trailing slash when strip_path is enabled and upstream_url has no path", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash3/request", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + it("should not add a trailing slash when strip_path is enabled and upstream_url has single path", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash4/request", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + end) + + describe("Percent-encoding", function() + it("should leave percent-encoded values in URI untouched", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2Fworld", {}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request/hello%2fworld", cjson.decode(response).url) + end) + it("should leave untouched percent-encoded values in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {foo = "abc%7Cdef%2c%20world"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?foo=abc%7cdef%2c%20world", cjson.decode(response).url) + end) + it("should leave untouched percent-encoded keys in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {["hello%20world"] = "foo"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?hello%20world=foo", cjson.decode(response).url) + end) + it("should percent-encoded keys in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {["hello world"] = "foo"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?hello%20world=foo", cjson.decode(response).url) + end) + it("should percent-encoded keys in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {foo = "abc|def, world"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?foo=abc%7cdef%2c%20world", cjson.decode(response).url) + end) + end) end) diff --git a/spec/plugins/acl/api_spec.lua b/spec/plugins/acl/api_spec.lua index 02d31b3b6ff6..13a4cc2a2ce3 100644 --- a/spec/plugins/acl/api_spec.lua +++ b/spec/plugins/acl/api_spec.lua @@ -41,7 +41,7 @@ describe("ACLs API", function() end) end) - + describe("PUT", function() it("[SUCCESS] should create and update", function() @@ -64,9 +64,9 @@ describe("ACLs API", function() end) end) - + end) - + describe("/consumers/:consumer/acl/:id", function() describe("GET", function() @@ -79,7 +79,7 @@ describe("ACLs API", function() end) end) - + describe("PATCH", function() it("[SUCCESS] should update an ACL association", function() @@ -96,7 +96,7 @@ describe("ACLs API", function() end) end) - + describe("DELETE", function() it("[FAILURE] should return proper errors", function() @@ -113,7 +113,7 @@ describe("ACLs API", function() end) end) - + end) - + end) diff --git a/spec/plugins/oauth2/api_spec.lua b/spec/plugins/oauth2/api_spec.lua index 707cd9be4c25..b49f4e2e0a28 100644 --- a/spec/plugins/oauth2/api_spec.lua +++ b/spec/plugins/oauth2/api_spec.lua @@ -18,7 +18,9 @@ describe("OAuth 2 Credentials API", function() setup(function() local fixtures = spec_helper.insert_fixtures { - consumer = {{ username = "bob" }} + consumer = { + {username = "bob"} + } } consumer = fixtures.consumer[1] BASE_URL = spec_helper.API_URL.."/consumers/bob/oauth2/" @@ -27,7 +29,7 @@ describe("OAuth 2 Credentials API", function() describe("POST", function() it("[SUCCESS] should create a oauth2 credential", function() - local response, status = http_client.post(BASE_URL, { name = "Test APP", redirect_uri = "http://google.com/" }) + local response, status = http_client.post(BASE_URL, {name = "Test APP", redirect_uri = "http://google.com/"}) assert.equal(201, status) credential = json.decode(response) assert.equal(consumer.id, credential.consumer_id) @@ -43,11 +45,11 @@ describe("OAuth 2 Credentials API", function() describe("PUT", function() setup(function() - spec_helper.get_env().dao_factory.keyauth_credentials:delete({id=credential.id}) + spec_helper.get_env().dao_factory.keyauth_credentials:delete({id = credential.id}) end) it("[SUCCESS] should create and update", function() - local response, status = http_client.put(BASE_URL, { redirect_uri = "http://google.com/", name = "Test APP" }) + local response, status = http_client.put(BASE_URL, {redirect_uri = "http://google.com/", name = "Test APP"}) assert.equal(201, status) credential = json.decode(response) assert.equal(consumer.id, credential.consumer_id) @@ -89,14 +91,14 @@ describe("OAuth 2 Credentials API", function() describe("PATCH", function() it("[SUCCESS] should update a credential", function() - local response, status = http_client.patch(BASE_URL..credential.id, { redirect_uri = "http://getkong.org/" }) + local response, status = http_client.patch(BASE_URL..credential.id, {redirect_uri = "http://getkong.org/"}) assert.equal(200, status) credential = json.decode(response) assert.equal("http://getkong.org/", credential.redirect_uri) end) it("[FAILURE] should return proper errors", function() - local response, status = http_client.patch(BASE_URL..credential.id, { redirect_uri = "" }) + local response, status = http_client.patch(BASE_URL..credential.id, {redirect_uri = ""}) assert.equal(400, status) assert.equal('{"redirect_uri":"redirect_uri is not a url"}\n', response) end) diff --git a/spec/plugins/request-transformer/access_spec.lua b/spec/plugins/request-transformer/access_spec.lua index bee26658b2b6..a142dd8dc837 100644 --- a/spec/plugins/request-transformer/access_spec.lua +++ b/spec/plugins/request-transformer/access_spec.lua @@ -11,8 +11,8 @@ describe("Request Transformer", function() spec_helper.prepare_db() spec_helper.insert_fixtures { api = { - { name = "tests-request-transformer-1", request_host = "test1.com", upstream_url = "http://mockbin.com" }, - { name = "tests-request-transformer-2", request_host = "test2.com", upstream_url = "http://httpbin.org" } + {name = "tests-request-transformer-1", request_host = "test1.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-2", request_host = "test2.com", upstream_url = "http://httpbin.org"} }, plugin = { { @@ -25,10 +25,10 @@ describe("Request Transformer", function() json = {"newjsonparam:newvalue"} }, remove = { - headers = { "x-to-remove" }, - querystring = { "toremovequery" }, - form = { "toremoveform" }, - json = { "toremovejson" } + headers = {"x-to-remove"}, + querystring = {"toremovequery"}, + form = {"toremoveform"}, + json = {"toremovejson"} } }, __api = 1 @@ -37,12 +37,12 @@ describe("Request Transformer", function() name = "request-transformer", config = { add = { - headers = { "host:mark" } + headers = {"host:mark"} } }, __api = 2 } - }, + } } spec_helper.start_kong() @@ -53,7 +53,6 @@ describe("Request Transformer", function() end) describe("Test adding parameters", function() - it("should add new headers", function() local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) @@ -61,37 +60,32 @@ describe("Request Transformer", function() assert.are.equal("true", body.headers["x-added"]) assert.are.equal("true", body.headers["x-added2"]) end) - it("should add new parameters on POST", function() local response, status = http_client.post(STUB_POST_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.are.equal("newvalue", body.postData.params["newformparam"]) end) - it("should add new parameters on POST when existing params exist", function() - local response, status = http_client.post(STUB_POST_URL, { hello = "world" }, {host = "test1.com"}) + local response, status = http_client.post(STUB_POST_URL, {hello = "world"}, {host = "test1.com"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.are.equal("world", body.postData.params["hello"]) assert.are.equal("newvalue", body.postData.params["newformparam"]) end) - it("should add new parameters on multipart POST", function() local response, status = http_client.post_multipart(STUB_POST_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.are.equal("newvalue", body.postData.params["newformparam"]) end) - it("should add new parameters on multipart POST when existing params exist", function() - local response, status = http_client.post_multipart(STUB_POST_URL, { hello = "world" }, {host = "test1.com"}) + local response, status = http_client.post_multipart(STUB_POST_URL, {hello = "world"}, {host = "test1.com"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.are.equal("world", body.postData.params["hello"]) assert.are.equal("newvalue", body.postData.params["newformparam"]) end) - it("should add new paramters on json POST", function() local response, status = http_client.post(STUB_POST_URL, {}, {host = "test1.com", ["content-type"] = "application/json"}) local raw = cjson.decode(response) @@ -99,7 +93,6 @@ describe("Request Transformer", function() assert.are.equal(200, status) assert.are.equal("newvalue", body["newjsonparam"]) end) - it("should add new paramters on json POST when existing params exist", function() local response, status = http_client.post(STUB_POST_URL, {hello = "world"}, {host = "test1.com", ["content-type"] = "application/json"}) local raw = cjson.decode(response) @@ -108,32 +101,27 @@ describe("Request Transformer", function() assert.are.equal("world", body["hello"]) assert.are.equal("newvalue", body["newjsonparam"]) end) - it("should add new parameters on GET", function() local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.are.equal("value", body.queryString["newparam"]) end) - it("should change the host header", function() local response, status = http_client.get(spec_helper.PROXY_URL.."/get", {}, {host = "test2.com"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.are.equal("mark", body.headers["Host"]) end) - end) describe("Test removing parameters", function() - it("should remove a header", function() local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com", ["x-to-remove"] = "true"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.falsy(body.headers["x-to-remove"]) end) - it("should remove parameters on POST", function() local response, status = http_client.post(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) local body = cjson.decode(response) @@ -141,7 +129,6 @@ describe("Request Transformer", function() assert.falsy(body.postData.params["toremoveform"]) assert.are.same("yes", body.postData.params["nottoremove"]) end) - it("should remove parameters on multipart POST", function() local response, status = http_client.post_multipart(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) local body = cjson.decode(response) @@ -149,7 +136,6 @@ describe("Request Transformer", function() assert.falsy(body.postData.params["toremoveform"]) assert.are.same("yes", body.postData.params["nottoremove"]) end) - it("should remove parameters on json POST", function() local response, status = http_client.post(STUB_POST_URL, {["toremovejson"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com", ["content-type"] = "application/json"}) local raw = cjson.decode(response) @@ -158,7 +144,6 @@ describe("Request Transformer", function() assert.falsy(body["toremovejson"]) assert.are.same("yes", body["nottoremove"]) end) - it("should remove parameters on GET", function() local response, status = http_client.get(STUB_GET_URL, {["toremovequery"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) local body = cjson.decode(response) @@ -166,7 +151,5 @@ describe("Request Transformer", function() assert.falsy(body.queryString["toremovequery"]) assert.are.equal("yes", body.queryString["nottoremove"]) end) - end) - end) diff --git a/spec/unit/core/resolver_spec.lua b/spec/unit/core/resolver_spec.lua index f49b3cd85a05..45a7089625ec 100644 --- a/spec/unit/core/resolver_spec.lua +++ b/spec/unit/core/resolver_spec.lua @@ -1,16 +1,24 @@ -local resolver_access = require "kong.core.resolver" +local resolver = require "kong.core.resolver" -- Stubs require "kong.tools.ngx_stub" + local APIS_FIXTURES = { + -- request_host {name = "mockbin", request_host = "mockbin.com", upstream_url = "http://mockbin.com"}, {name = "mockbin", request_host = "mockbin-auth.com", upstream_url = "http://mockbin.com"}, {name = "mockbin", request_host = "*.wildcard.com", upstream_url = "http://mockbin.com"}, {name = "mockbin", request_host = "wildcard.*", upstream_url = "http://mockbin.com"}, + -- request_path {name = "mockbin", request_path = "/mockbin", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_path = "/mockbin-with-dashes", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_path = "/some/deep/url", upstream_url = "http://mockbin.com"} + {name = "mockbin", request_path = "/mockbin-with-dashes", upstream_url = "http://mockbin.com/some/path"}, + {name = "mockbin", request_path = "/some/deep/url", upstream_url = "http://mockbin.com"}, + -- + {name = "mockbin", request_path = "/strip", upstream_url = "http://mockbin.com/some/path/", strip_request_path = true}, + {name = "mockbin", request_path = "/strip-me", upstream_url = "http://mockbin.com/", strip_request_path = true}, + {name = "preserve-host", request_path = "/preserve-host", request_host = "preserve-host.com", upstream_url = "http://mockbin.com", preserve_host = true} } + _G.dao = { apis = { find_all = function() @@ -21,10 +29,10 @@ _G.dao = { local apis_dics -describe("Resolver Access", function() +describe("Resolver", function() describe("load_apis_in_memory()", function() it("should retrieve all APIs in datastore and return them organized", function() - apis_dics = resolver_access.load_apis_in_memory() + apis_dics = resolver.load_apis_in_memory() assert.equal("table", type(apis_dics)) assert.truthy(apis_dics.by_dns) assert.truthy(apis_dics.request_path_arr) @@ -36,7 +44,7 @@ describe("Resolver Access", function() end) it("should return an array of APIs by request_path", function() assert.equal("table", type(apis_dics.request_path_arr)) - assert.equal(3, #apis_dics.request_path_arr) + assert.equal(6, #apis_dics.request_path_arr) for _, item in ipairs(apis_dics.request_path_arr) do assert.truthy(item.strip_request_path_pattern) assert.truthy(item.request_path) @@ -56,69 +64,215 @@ describe("Resolver Access", function() assert.equal("^wildcard%..+$", apis_dics.wildcard_dns_arr[2].pattern) end) end) - describe("find_api_by_request_path()", function() - it("should return nil when no matching API for that URI", function() - local api = resolver_access.find_api_by_request_path("/", apis_dics.request_path_arr) - assert.falsy(api) + describe("strip_request_path()", function() + it("should strip the api's request_path from the requested URI", function() + assert.equal("/status/200", resolver.strip_request_path("/mockbin/status/200", apis_dics.request_path_arr[1].strip_request_path_pattern)) + assert.equal("/status/200", resolver.strip_request_path("/mockbin-with-dashes/status/200", apis_dics.request_path_arr[2].strip_request_path_pattern)) + assert.equal("/", resolver.strip_request_path("/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) + assert.equal("/", resolver.strip_request_path("/mockbin/", apis_dics.request_path_arr[1].strip_request_path_pattern)) + end) + it("should only strip the first pattern", function() + assert.equal("/mockbin/status/200/mockbin", resolver.strip_request_path("/mockbin/mockbin/status/200/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) + end) + it("should not add final slash", function() + assert.equal("hello", resolver.strip_request_path("hello", apis_dics.request_path_arr[3].strip_request_path_pattern, true)) + assert.equal("/hello", resolver.strip_request_path("hello", apis_dics.request_path_arr[3].strip_request_path_pattern, false)) + end) + end) + + -- Note: ngx.var.request_uri always adds a trailing slash even with a request without any + -- `curl kong:8000` will result in ngx.var.request_uri being '/' + describe("execute()", function() + local DEFAULT_REQUEST_URI = "/" + + it("should find an API by the request's simple Host header", function() + local api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "mockbin.com"}) + assert.same(APIS_FIXTURES[1], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("mockbin.com", upstream_host) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "mockbin-auth.com"}) + assert.same(APIS_FIXTURES[2], api) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = {"example.com", "mockbin.com"}}) + assert.same(APIS_FIXTURES[1], api) end) - it("should return the API for a matching URI", function() - local api = resolver_access.find_api_by_request_path("/mockbin", apis_dics.request_path_arr) + it("should find an API by the request's wildcard Host header", function() + local api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "foobar.wildcard.com"}) + assert.same(APIS_FIXTURES[3], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("mockbin.com", upstream_host) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "something.wildcard.com"}) + assert.same(APIS_FIXTURES[3], api) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "wildcard.com"}) + assert.same(APIS_FIXTURES[4], api) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "wildcard.fr"}) + assert.same(APIS_FIXTURES[4], api) + end) + it("should find an API by the request's URI (path component)", function() + local api, upstream_url, upstream_host = resolver.execute("/mockbin", {}) assert.same(APIS_FIXTURES[5], api) + assert.equal("http://mockbin.com/mockbin", upstream_url) + assert.equal("mockbin.com", upstream_host) - api = resolver_access.find_api_by_request_path("/mockbin-with-dashes", apis_dics.request_path_arr) + api = resolver.execute("/mockbin-with-dashes", {}) assert.same(APIS_FIXTURES[6], api) - api = resolver_access.find_api_by_request_path("/mockbin-with-dashes/and/some/uri", apis_dics.request_path_arr) + api = resolver.execute("/some/deep/url", {}) + assert.same(APIS_FIXTURES[7], api) + + api = resolver.execute("/mockbin-with-dashes/and/some/uri", {}) assert.same(APIS_FIXTURES[6], api) + end) + it("should return a 404 HTTP response if no API was found", function() + local responses = require "kong.tools.responses" + spy.on(responses, "send_HTTP_NOT_FOUND") + finally(function() + responses.send_HTTP_NOT_FOUND:revert() + end) - api = resolver_access.find_api_by_request_path("/dashes-mockbin", apis_dics.request_path_arr) + -- non existant request_path + local api, upstream_url, upstream_host = resolver.execute("/inexistant-mockbin", {}) assert.falsy(api) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(1) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {}, + request_path = "/inexistant-mockbin" + }) + assert.equal(404, ngx.status) + ngx.status = nil - api = resolver_access.find_api_by_request_path("/some/deep/url", apis_dics.request_path_arr) - assert.same(APIS_FIXTURES[7], api) - end) - end) - describe("find_api_by_request_host()", function() - it("should return nil and a list of all the Host headers in the request when no API was found", function() - local api, all_hosts = resolver_access.find_api_by_request_host({ - Host = "foo.com", - ["X-Host-Override"] = {"bar.com", "hello.com"} - }, apis_dics) + -- non-existant Host + api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "inexistant.com"}) + assert.falsy(api) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(2) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {"inexistant.com"}, + request_path = "/" + }) + assert.equal(404, ngx.status) + ngx.status = nil + + -- non-existant request_path with many Host headers + api, upstream_url, upstream_host = resolver.execute("/some-path", { + ["Host"] = {"nowhere.com", "inexistant.com"}, + ["X-Host-Override"] = "nowhere.fr" + }) assert.falsy(api) - assert.same({"foo.com", "bar.com", "hello.com"}, all_hosts) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(3) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {"nowhere.com", "inexistant.com", "nowhere.fr"}, + request_path = "/some-path" + }) + assert.equal(404, ngx.status) + ngx.status = nil + + -- when a later part of the URI has a valid request_path + api, upstream_url, upstream_host = resolver.execute("/invalid-part/some-path", {}) + assert.falsy(api) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(4) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {}, + request_path = "/invalid-part/some-path" + }) + assert.equal(404, ngx.status) + ngx.status = nil end) - it("should return an API when one of the Host headers matches", function() - local api = resolver_access.find_api_by_request_host({Host = "mockbin.com"}, apis_dics) - assert.same(APIS_FIXTURES[1], api) + it("should strip_request_path", function() + local api = resolver.execute("/strip", {}) + assert.same(APIS_FIXTURES[8], api) - api = resolver_access.find_api_by_request_host({Host = "mockbin-auth.com"}, apis_dics) - assert.same(APIS_FIXTURES[2], api) + -- strip when contains pattern characters + api, upstream_url, upstream_host = resolver.execute("/strip-me/hello/world", {}) + assert.same(APIS_FIXTURES[9], api) + assert.equal("http://mockbin.com/hello/world", upstream_url) + assert.equal("mockbin.com", upstream_host) + + -- only strip first match of request_uri + api, upstream_url = resolver.execute("/strip-me/strip-me/hello/world", {}) + assert.same(APIS_FIXTURES[9], api) + assert.equal("http://mockbin.com/strip-me/hello/world", upstream_url) end) - it("should return an API when one of the Host headers matches a wildcard dns", function() - local api = resolver_access.find_api_by_request_host({Host = "wildcard.com"}, apis_dics) - assert.same(APIS_FIXTURES[4], api) - api = resolver_access.find_api_by_request_host({Host = "wildcard.fr"}, apis_dics) - assert.same(APIS_FIXTURES[4], api) + it("should preserve_host", function() + local api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "preserve-host.com"}) + assert.same(APIS_FIXTURES[10], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("preserve-host.com", upstream_host) - api = resolver_access.find_api_by_request_host({Host = "foobar.wildcard.com"}, apis_dics) - assert.same(APIS_FIXTURES[3], api) - api = resolver_access.find_api_by_request_host({Host = "barfoo.wildcard.com"}, apis_dics) - assert.same(APIS_FIXTURES[3], api) + api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, { + ["Host"] = {"inexistant.com", "preserve-host.com"}, + ["X-Host-Override"] = "hello.com" + }) + assert.same(APIS_FIXTURES[10], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("preserve-host.com", upstream_host) + + -- No host given to this request, we extract if from the configured upstream_url + api, upstream_url, upstream_host = resolver.execute("/preserve-host", {}) + assert.same(APIS_FIXTURES[10], api) + assert.equal("http://mockbin.com/preserve-host", upstream_url) + assert.equal("mockbin.com", upstream_host) end) - end) - describe("strip_request_path()", function() - it("should strip the api's request_path from the requested URI", function() - assert.equal("/status/200", resolver_access.strip_request_path("/mockbin/status/200", apis_dics.request_path_arr[1].strip_request_path_pattern)) - assert.equal("/status/200", resolver_access.strip_request_path("/mockbin-with-dashes/status/200", apis_dics.request_path_arr[2].strip_request_path_pattern)) - assert.equal("/", resolver_access.strip_request_path("/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) - assert.equal("/", resolver_access.strip_request_path("/mockbin/", apis_dics.request_path_arr[1].strip_request_path_pattern)) + it("should not decode percent-encoded values in URI", function() + -- they should be forwarded as-is + local api, upstream_url = resolver.execute("/mockbin/path%2Fwith%2Fencoded/values", {}) + assert.same(APIS_FIXTURES[5], api) + assert.equal("http://mockbin.com/mockbin/path%2Fwith%2Fencoded/values", upstream_url) + + api, upstream_url = resolver.execute("/strip-me/path%2Fwith%2Fencoded/values", {}) + assert.same(APIS_FIXTURES[9], api) + assert.equal("http://mockbin.com/path%2Fwith%2Fencoded/values", upstream_url) end) - it("should only strip the first pattern", function() - assert.equal("/mockbin/status/200/mockbin", resolver_access.strip_request_path("/mockbin/mockbin/status/200/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) + it("should not recognized request_path if percent-encoded", function() + local responses = require "kong.tools.responses" + spy.on(responses, "send_HTTP_NOT_FOUND") + finally(function() + responses.send_HTTP_NOT_FOUND:revert() + end) + + local api = resolver.execute("/some/deep%2Furl", {}) + assert.falsy(api) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(1) + assert.equal(404, ngx.status) + ngx.status = nil end) - it("should not add final slash", function() - assert.equal("hello", resolver_access.strip_request_path("hello", apis_dics.request_path_arr[3].strip_request_path_pattern, true)) - assert.equal("/hello", resolver_access.strip_request_path("hello", apis_dics.request_path_arr[3].strip_request_path_pattern, false)) + it("should have or not have a trailing slash depending on the request URI", function() + local api, upstream_url = resolver.execute("/strip/", {}) + assert.same(APIS_FIXTURES[8], api) + assert.equal("http://mockbin.com/some/path/", upstream_url) + + api, upstream_url = resolver.execute("/strip", {}) + assert.same(APIS_FIXTURES[8], api) + assert.equal("http://mockbin.com/some/path", upstream_url) + + api, upstream_url = resolver.execute("/mockbin-with-dashes", {}) + assert.same(APIS_FIXTURES[6], api) + assert.equal("http://mockbin.com/some/path/mockbin-with-dashes", upstream_url) + + api, upstream_url = resolver.execute("/mockbin-with-dashes/", {}) + assert.same(APIS_FIXTURES[6], api) + assert.equal("http://mockbin.com/some/path/mockbin-with-dashes/", upstream_url) + end) + it("should strip the querystring out of the URI", function() + -- it will be re-inserted by core.handler just before proxying, once all plugins have been run and eventually modified it + local api, upstream_url = resolver.execute("/?hello=world&foo=bar", {["Host"] = "mockbin.com"}) + assert.same(APIS_FIXTURES[1], api) + assert.equal("http://mockbin.com/", upstream_url) end) end) end) diff --git a/spec/unit/tools/responses_spec.lua b/spec/unit/tools/responses_spec.lua index 044e3343d9ec..6a8af13a913d 100644 --- a/spec/unit/tools/responses_spec.lua +++ b/spec/unit/tools/responses_spec.lua @@ -14,7 +14,7 @@ describe("Responses", function() ngx.header = {} -- Revert mocked functions for _, v in pairs(ngx) do - if type(v) == "table" and v.revert then + if type(v) == "table" and type(v.revert) == "function" then v:revert() end end diff --git a/spec/unit/tools/utils_spec.lua b/spec/unit/tools/utils_spec.lua index 92e997ca1eb8..84a28b3c4eb1 100644 --- a/spec/unit/tools/utils_spec.lua +++ b/spec/unit/tools/utils_spec.lua @@ -2,17 +2,87 @@ local utils = require "kong.tools.utils" describe("Utils", function() - describe("strings", function() - local first = utils.random_string() - assert.truthy(first) - assert.falsy(first:find("-")) - local second = utils.random_string() - assert.falsy(first == second) - end) + describe("string", function() + describe("random_string()", function() + it("should return a random string", function() + local first = utils.random_string() + assert.truthy(first) + assert.falsy(first:find("-")) + + local second = utils.random_string() + assert.not_equal(first, second) + end) + end) - describe("tables", function() - describe("#table_size()", function() + describe("encode_args()", function() + it("should encode a Lua table to a querystring", function() + local str = utils.encode_args { + foo = "bar", + hello = "world" + } + assert.equal("foo=bar&hello=world", str) + end) + it("should encode multi-value query args", function() + local str = utils.encode_args { + foo = {"bar", "zoo"}, + hello = "world" + } + assert.equal("foo=bar&foo=zoo&hello=world", str) + end) + it("should percent-encode given values", function() + local str = utils.encode_args { + encode = {"abc|def", ",$@|`"} + } + assert.equal("encode=abc%7cdef&encode=%2c%24%40%7c%60", str) + end) + it("should percent-encode given query args keys", function() + local str = utils.encode_args { + ["hello world"] = "foo" + } + assert.equal("hello%20world=foo", str) + end) + it("should support Lua numbers", function() + local str = utils.encode_args { + a = 1, + b = 2 + } + assert.equal("a=1&b=2", str) + end) + it("should support a boolean argument", function() + local str = utils.encode_args { + a = true, + b = 1 + } + assert.equal("a&b=1", str) + end) + it("should ignore nil and false values", function() + local str = utils.encode_args { + a = nil, + b = false + } + assert.equal("", str) + end) + it("should encode complex query args", function() + local str = utils.encode_args { + multiple = {"hello, world"}, + hello = "world", + ignore = false, + ["multiple values"] = true + } + assert.equal("hello=world&multiple=hello%2c%20world&multiple%20values", str) + end) + it("should not percent-encode if given a `raw` option", function() + -- this is useful for kong.tools.http_client + local str = utils.encode_args({ + ["hello world"] = "foo, bar" + }, true) + assert.equal("hello world=foo, bar", str) + end) + end) + end) + describe("table", function() + describe("table_size()", function() it("should return the size of a table", function() assert.are.same(0, utils.table_size(nil)) assert.are.same(0, utils.table_size({})) @@ -20,44 +90,36 @@ describe("Utils", function() assert.are.same(2, utils.table_size({ foo = "bar", bar = "baz" })) assert.are.same(2, utils.table_size({ "foo", "bar" })) end) - end) - describe("#table_contains()", function() - + describe("table_contains()", function() it("should return false if a value is not contained in a nil table", function() assert.False(utils.table_contains(nil, "foo")) end) - it("should return true if a value is contained in a table", function() local t = { foo = "hello", bar = "world" } assert.True(utils.table_contains(t, "hello")) end) - it("should return false if a value is not contained in a table", function() local t = { foo = "hello", bar = "world" } assert.False(utils.table_contains(t, "foo")) end) - end) - describe("#is_array()", function() - + describe("is_array()", function() it("should know when an array ", function() assert.True(utils.is_array({ "a", "b", "c", "d" })) assert.True(utils.is_array({ ["1"] = "a", ["2"] = "b", ["3"] = "c", ["4"] = "d" })) assert.False(utils.is_array({ "a", "b", "c", foo = "d" })) end) - end) - describe("#add_error()", function() + describe("add_error()", function() local add_error = utils.add_error it("should create a table if given `errors` is nil", function() assert.same({hello = "world"}, add_error(nil, "hello", "world")) end) - it("should add a key/value when the key does not exists", function() local errors = {hello = "world"} assert.same({ @@ -65,10 +127,8 @@ describe("Utils", function() foo = "bar" }, add_error(errors, "foo", "bar")) end) - it("should transform previous values to a list if the same key is given again", function() - local e = nil - + local e e = add_error(e, "key1", "value1") e = add_error(e, "key2", "value2") assert.same({key1 = "value1", key2 = "value2"}, e) @@ -82,10 +142,8 @@ describe("Utils", function() e = add_error(e, "key2", "value7") assert.same({key1 = {"value1", "value3", "value4", "value5", "value6"}, key2 = {"value2", "value7"}}, e) end) - it("should also list tables pushed as errors", function() - local e = nil - + local e e = add_error(e, "key1", "value1") e = add_error(e, "key2", "value2") e = add_error(e, "key1", "value3") @@ -100,11 +158,9 @@ describe("Utils", function() keyO = {{message = "some error"}, {message = "another"}} }, e) end) - end) - describe("#load_module_if_exists()", function() - + describe("load_module_if_exists()", function() it("should return false if the module does not exist", function() local loaded, mod assert.has_no.errors(function() @@ -113,7 +169,6 @@ describe("Utils", function() assert.False(loaded) assert.falsy(mod) end) - it("should throw an error if the module is invalid", function() local loaded, mod assert.has.errors(function() @@ -122,7 +177,6 @@ describe("Utils", function() assert.falsy(loaded) assert.falsy(mod) end) - it("should load a module if it was found and valid", function() local loaded, mod assert.has_no.errors(function() @@ -132,7 +186,6 @@ describe("Utils", function() assert.truthy(mod) assert.are.same("All your base are belong to us.", mod.exposed) end) - end) end) end)