Skip to content

Commit

Permalink
Merge pull request #752 from Mashape/fix/resolver-query-encoding
Browse files Browse the repository at this point in the history
fix(resolver) percent-encode querystring before proxying
  • Loading branch information
thibaultcha committed Dec 2, 2015
2 parents c9b3a0e + 6274cc4 commit 51918e8
Show file tree
Hide file tree
Showing 13 changed files with 640 additions and 316 deletions.
17 changes: 11 additions & 6 deletions kong/core/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
--
-- @see https://github.com/openresty/lua-nginx-module#ngxctx

local url = require "socket.url"
local utils = require "kong.tools.utils"
local reports = require "kong.core.reports"
local stringy = require "stringy"
local resolver = require "kong.core.resolver"
local constants = require "kong.constants"
local certificate = require "kong.core.certificate"

local table_insert = table.insert
local type = type
local ipairs = ipairs
local math_floor = math.floor
local table_insert = table.insert

local MULT = 10^3
local function round(num)
Expand All @@ -43,7 +46,7 @@ return {
access = {
before = function()
ngx.ctx.KONG_ACCESS_START = ngx.now()
ngx.ctx.api, ngx.ctx.upstream_url = resolver.execute()
ngx.ctx.api, ngx.ctx.upstream_url, ngx.var.upstream_host = resolver.execute(ngx.var.request_uri, ngx.req.get_headers())
end,
-- Only executed if the `resolver` module found an API and allows nginx to proxy it.
after = function()
Expand All @@ -53,12 +56,14 @@ return {
ngx.ctx.KONG_PROXIED = true

-- Append any querystring parameters modified during plugins execution
local upstream_url = unpack(stringy.split(ngx.ctx.upstream_url, "?"))
if utils.table_size(ngx.req.get_uri_args()) > 0 then
upstream_url = upstream_url.."?"..ngx.encode_args(ngx.req.get_uri_args())
local upstream_url = ngx.ctx.upstream_url
local uri_args = ngx.req.get_uri_args()
if utils.table_size(uri_args) > 0 then
upstream_url = upstream_url.."?"..utils.encode_args(uri_args)
end

-- Set the `$upstream_url` variable for the `proxy_pass` nginx's directive.
-- Set the `$upstream_url` and `$upstream_host` variables for the `proxy_pass` nginx
-- directive in kong.yml.
ngx.var.upstream_url = upstream_url
end
},
Expand Down
54 changes: 29 additions & 25 deletions kong/core/resolver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ local cache = require "kong.tools.database_cache"
local stringy = require "stringy"
local constants = require "kong.constants"
local responses = require "kong.tools.responses"

local table_insert = table.insert
local string_match = string.match
local string_find = string.find
Expand Down Expand Up @@ -46,7 +47,7 @@ local function get_upstream_url(api)
return result
end

local function get_host_from_url(val)
local function get_host_from_upstream_url(val)
local parsed_url = url.parse(val)

local port
Expand Down Expand Up @@ -99,7 +100,7 @@ function _M.load_apis_in_memory()
end

function _M.find_api_by_request_host(req_headers, apis_dics)
local all_hosts = {}
local hosts_list = {}
for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do
local hosts = req_headers[header_name]
if hosts then
Expand All @@ -109,9 +110,9 @@ function _M.find_api_by_request_host(req_headers, apis_dics)
-- for all values of this header, try to find an API using the apis_by_dns dictionnary
for _, host in ipairs(hosts) do
host = unpack(stringy.split(host, ":"))
table_insert(all_hosts, host)
table_insert(hosts_list, host)
if apis_dics.by_dns[host] then
return apis_dics.by_dns[host]
return apis_dics.by_dns[host], host
else
-- If the API was not found in the dictionary, maybe it is a wildcard request_host.
-- In that case, we need to loop over all of them.
Expand All @@ -125,7 +126,7 @@ function _M.find_api_by_request_host(req_headers, apis_dics)
end
end

return nil, all_hosts
return nil, nil, hosts_list
end

-- To do so, we have to compare entire URI segments (delimited by "/").
Expand Down Expand Up @@ -180,13 +181,14 @@ end
-- We keep APIs in the database cache for a longer time than usual.
-- @see https://github.com/Mashape/kong/issues/15 for an improvement on this.
--
-- @param `uri` The URI for this request.
-- @return `err` Any error encountered during the retrieval.
-- @return `api` The retrieved API, if any.
-- @return `hosts` The list of headers values found in Host and X-Host-Override.
-- @param `uri` The URI for this request.
-- @return `err` Any error encountered during the retrieval.
-- @return `api` The retrieved API, if any.
-- @return `matched_host` The host that was matched for this API, if matched.
-- @return `hosts` The list of headers values found in Host and X-Host-Override.
-- @return `strip_request_path_pattern` If the API was retrieved by request_path, contain the pattern to strip it from the URI.
local function find_api(uri)
local api, all_hosts, strip_request_path_pattern
local function find_api(uri, headers)
local api, matched_host, hosts_list, strip_request_path_pattern

-- Retrieve all APIs
local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", _M.load_apis_in_memory, 60) -- 60 seconds cache, longer than usual
Expand All @@ -195,37 +197,37 @@ local function find_api(uri)
end

-- Find by Host header
api, all_hosts = _M.find_api_by_request_host(ngx.req.get_headers(), apis_dics)

api, matched_host, hosts_list = _M.find_api_by_request_host(headers, apis_dics)
-- If it was found by Host, return
if api then
return nil, api, all_hosts
return nil, api, matched_host, hosts_list
end

-- Otherwise, we look for it by request_path. We have to loop over all APIs and compare the requested URI.
api, strip_request_path_pattern = _M.find_api_by_request_path(uri, apis_dics.request_path_arr)

return nil, api, all_hosts, strip_request_path_pattern
return nil, api, nil, hosts_list, strip_request_path_pattern
end

local function url_has_path(url)
local _, count_slashes = string.gsub(url, "/", "")
local _, count_slashes = string_gsub(url, "/", "")
return count_slashes > 2
end

function _M.execute()
local uri = stringy.split(ngx.var.request_uri, "?")[1]
local err, api, hosts, strip_request_path_pattern = find_api(uri)
function _M.execute(request_uri, request_headers)
local uri = unpack(stringy.split(request_uri, "?"))
local err, api, matched_host, hosts_list, strip_request_path_pattern = find_api(uri, request_headers)
if err then
return responses.send_HTTP_INTERNAL_SERVER_ERROR(err)
elseif not api then
return responses.send_HTTP_NOT_FOUND {
message = "API not found with these values",
request_host = hosts,
request_host = hosts_list,
request_path = uri
}
end

local upstream_host
local upstream_url = get_upstream_url(api)

-- If API was retrieved by request_path and the request_path needs to be stripped
Expand All @@ -235,13 +237,15 @@ function _M.execute()

upstream_url = upstream_url..uri

-- Set the
if api.preserve_host then
ngx.var.upstream_host = ngx.req.get_headers()["host"]
else
ngx.var.upstream_host = get_host_from_url(upstream_url)
upstream_host = matched_host
end
return api, upstream_url

if upstream_host == nil then
upstream_host = get_host_from_upstream_url(upstream_url)
end

return api, upstream_url, upstream_host
end

return _M
4 changes: 2 additions & 2 deletions kong/tools/http_client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ local function with_body(method)
else
headers["content-type"] = "application/x-www-form-urlencoded"
if type(body) == "table" then
body = ngx.encode_args(body)
body = ngx.encode_args(body, true)
end
end

Expand All @@ -75,7 +75,7 @@ local function without_body(method)
if not headers then headers = {} end

if querystring then
url = string.format("%s?%s", url, ngx.encode_args(querystring))
url = string.format("%s?%s", url, ngx.encode_args(querystring, true))
end

return http_call {
Expand Down
131 changes: 98 additions & 33 deletions kong/tools/ngx_stub.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,102 @@
-- Monkeypatches the global `ngx` table.

local reg = require "rex_pcre"
local utils = require "kong.tools.utils"

-- DICT Proxy
-- https://github.com/bsm/fakengx/blob/master/fakengx.lua

local SharedDict = {}

local function set(data, key, value)
data[key] = {
value = value,
info = {expired = false}
}
end

function SharedDict:new()
return setmetatable({data = {}}, {__index = self})
end

function SharedDict:get(key)
return self.data[key] and self.data[key].value, nil
end

function SharedDict:set(key, value)
set(self.data, key, value)
return true, nil, false
end

SharedDict.safe_set = SharedDict.set

function SharedDict:add(key, value)
if self.data[key] ~= nil then
return false, "exists", false
end

set(self.data, key, value)
return true, nil, false
end

function SharedDict:replace(key, value)
if self.data[key] == nil then
return false, "not found", false
end

set(self.data, key, value)
return true, nil, false
end

function SharedDict:delete(key)
self.data[key] = nil
end

function SharedDict:incr(key, value)
if not self.data[key] then
return nil, "not found"
elseif type(self.data[key].value) ~= "number" then
return nil, "not a number"
end

self.data[key].value = self.data[key].value + value
return self.data[key].value, nil
end

function SharedDict:flush_all()
for _, item in pairs(self.data) do
item.info.expired = true
end
end

function SharedDict:flush_expired(n)
local data = self.data
local flushed = 0

for key, item in pairs(self.data) do
if item.info.expired then
data[key] = nil
flushed = flushed + 1
if n and flushed == n then
break
end
end
end

self.data = data

return flushed
end

local shared = {}
local shared_mt = {
__index = function(self, key)
if shared[key] == nil then
shared[key] = SharedDict:new()
end
return shared[key]
end
}

_G.ngx = {
req = {},
Expand All @@ -19,6 +115,7 @@ _G.ngx = {
timer = {
at = function() end
},
shared = setmetatable({}, shared_mt),
re = {
match = reg.match,
gsub = function(str, pattern, sub)
Expand All @@ -29,37 +126,5 @@ _G.ngx = {
encode_base64 = function(str)
return string.format("base64_%s", str)
end,
-- Builds a querystring from a table, separated by `&`
-- @param `tab` The key/value parameters
-- @param `key` The parent key if the value is multi-dimensional (optional)
-- @return `querystring` A string representing the built querystring
encode_args = function(tab, key)
local query = {}
local keys = {}

for k in pairs(tab) do
keys[#keys+1] = k
end

table.sort(keys)

for _, name in ipairs(keys) do
local value = tab[name]
if key then
name = string.format("%s[%s]", tostring(key), tostring(name))
end
if type(value) == "table" then
query[#query+1] = ngx.encode_args(value, name)
else
value = tostring(value)
if value ~= "" then
query[#query+1] = string.format("%s=%s", name, value)
else
query[#query+1] = name
end
end
end

return table.concat(query, "&")
end
encode_args = utils.encode_args
}
Loading

0 comments on commit 51918e8

Please sign in to comment.