Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(resolver) percent-encode querystring before proxying #752

Merged
merged 3 commits into from
Dec 2, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions kong/core/handler.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@
--
-- @see https://github.com/openresty/lua-nginx-module#ngxctx

local url = require "socket.url"
local utils = require "kong.tools.utils"
local reports = require "kong.core.reports"
local stringy = require "stringy"
local resolver = require "kong.core.resolver"
local constants = require "kong.constants"
local certificate = require "kong.core.certificate"

local table_insert = table.insert
local type = type
local ipairs = ipairs
local math_floor = math.floor
local table_insert = table.insert

local MULT = 10^3
local function round(num)
Expand All @@ -43,7 +46,7 @@ return {
access = {
before = function()
ngx.ctx.KONG_ACCESS_START = ngx.now()
ngx.ctx.api, ngx.ctx.upstream_url = resolver.execute()
ngx.ctx.api, ngx.ctx.upstream_url, ngx.var.upstream_host = resolver.execute(ngx.var.request_uri, ngx.req.get_headers())
end,
-- Only executed if the `resolver` module found an API and allows nginx to proxy it.
after = function()
Expand All @@ -53,12 +56,14 @@ return {
ngx.ctx.KONG_PROXIED = true

-- Append any querystring parameters modified during plugins execution
local upstream_url = unpack(stringy.split(ngx.ctx.upstream_url, "?"))
if utils.table_size(ngx.req.get_uri_args()) > 0 then
upstream_url = upstream_url.."?"..ngx.encode_args(ngx.req.get_uri_args())
local upstream_url = ngx.ctx.upstream_url
local uri_args = ngx.req.get_uri_args()
if utils.table_size(uri_args) > 0 then
upstream_url = upstream_url.."?"..utils.encode_args(uri_args)
end

-- Set the `$upstream_url` variable for the `proxy_pass` nginx's directive.
-- Set the `$upstream_url` and `$upstream_host` variables for the `proxy_pass` nginx
-- directive in kong.yml.
ngx.var.upstream_url = upstream_url
end
},
Expand Down
54 changes: 29 additions & 25 deletions kong/core/resolver.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ local cache = require "kong.tools.database_cache"
local stringy = require "stringy"
local constants = require "kong.constants"
local responses = require "kong.tools.responses"

local table_insert = table.insert
local string_match = string.match
local string_find = string.find
Expand Down Expand Up @@ -46,7 +47,7 @@ local function get_upstream_url(api)
return result
end

local function get_host_from_url(val)
local function get_host_from_upstream_url(val)
local parsed_url = url.parse(val)

local port
Expand Down Expand Up @@ -99,7 +100,7 @@ function _M.load_apis_in_memory()
end

function _M.find_api_by_request_host(req_headers, apis_dics)
local all_hosts = {}
local hosts_list = {}
for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do
local hosts = req_headers[header_name]
if hosts then
Expand All @@ -109,9 +110,9 @@ function _M.find_api_by_request_host(req_headers, apis_dics)
-- for all values of this header, try to find an API using the apis_by_dns dictionnary
for _, host in ipairs(hosts) do
host = unpack(stringy.split(host, ":"))
table_insert(all_hosts, host)
table_insert(hosts_list, host)
if apis_dics.by_dns[host] then
return apis_dics.by_dns[host]
return apis_dics.by_dns[host], host
else
-- If the API was not found in the dictionary, maybe it is a wildcard request_host.
-- In that case, we need to loop over all of them.
Expand All @@ -125,7 +126,7 @@ function _M.find_api_by_request_host(req_headers, apis_dics)
end
end

return nil, all_hosts
return nil, nil, hosts_list
end

-- To do so, we have to compare entire URI segments (delimited by "/").
Expand Down Expand Up @@ -180,13 +181,14 @@ end
-- We keep APIs in the database cache for a longer time than usual.
-- @see https://github.com/Mashape/kong/issues/15 for an improvement on this.
--
-- @param `uri` The URI for this request.
-- @return `err` Any error encountered during the retrieval.
-- @return `api` The retrieved API, if any.
-- @return `hosts` The list of headers values found in Host and X-Host-Override.
-- @param `uri` The URI for this request.
-- @return `err` Any error encountered during the retrieval.
-- @return `api` The retrieved API, if any.
-- @return `matched_host` The host that was matched for this API, if matched.
-- @return `hosts` The list of headers values found in Host and X-Host-Override.
-- @return `strip_request_path_pattern` If the API was retrieved by request_path, contain the pattern to strip it from the URI.
local function find_api(uri)
local api, all_hosts, strip_request_path_pattern
local function find_api(uri, headers)
local api, matched_host, hosts_list, strip_request_path_pattern

-- Retrieve all APIs
local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", _M.load_apis_in_memory, 60) -- 60 seconds cache, longer than usual
Expand All @@ -195,37 +197,37 @@ local function find_api(uri)
end

-- Find by Host header
api, all_hosts = _M.find_api_by_request_host(ngx.req.get_headers(), apis_dics)

api, matched_host, hosts_list = _M.find_api_by_request_host(headers, apis_dics)
-- If it was found by Host, return
if api then
return nil, api, all_hosts
return nil, api, matched_host, hosts_list
end

-- Otherwise, we look for it by request_path. We have to loop over all APIs and compare the requested URI.
api, strip_request_path_pattern = _M.find_api_by_request_path(uri, apis_dics.request_path_arr)

return nil, api, all_hosts, strip_request_path_pattern
return nil, api, nil, hosts_list, strip_request_path_pattern
end

local function url_has_path(url)
local _, count_slashes = string.gsub(url, "/", "")
local _, count_slashes = string_gsub(url, "/", "")
return count_slashes > 2
end

function _M.execute()
local uri = stringy.split(ngx.var.request_uri, "?")[1]
local err, api, hosts, strip_request_path_pattern = find_api(uri)
function _M.execute(request_uri, request_headers)
local uri = unpack(stringy.split(request_uri, "?"))
local err, api, matched_host, hosts_list, strip_request_path_pattern = find_api(uri, request_headers)
if err then
return responses.send_HTTP_INTERNAL_SERVER_ERROR(err)
elseif not api then
return responses.send_HTTP_NOT_FOUND {
message = "API not found with these values",
request_host = hosts,
request_host = hosts_list,
request_path = uri
}
end

local upstream_host
local upstream_url = get_upstream_url(api)

-- If API was retrieved by request_path and the request_path needs to be stripped
Expand All @@ -235,13 +237,15 @@ function _M.execute()

upstream_url = upstream_url..uri

-- Set the
if api.preserve_host then
ngx.var.upstream_host = ngx.req.get_headers()["host"]
else
ngx.var.upstream_host = get_host_from_url(upstream_url)
upstream_host = matched_host
end
return api, upstream_url

if upstream_host == nil then
upstream_host = get_host_from_upstream_url(upstream_url)
end

return api, upstream_url, upstream_host
end

return _M
4 changes: 2 additions & 2 deletions kong/tools/http_client.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ local function with_body(method)
else
headers["content-type"] = "application/x-www-form-urlencoded"
if type(body) == "table" then
body = ngx.encode_args(body)
body = ngx.encode_args(body, true)
end
end

Expand All @@ -75,7 +75,7 @@ local function without_body(method)
if not headers then headers = {} end

if querystring then
url = string.format("%s?%s", url, ngx.encode_args(querystring))
url = string.format("%s?%s", url, ngx.encode_args(querystring, true))
end

return http_call {
Expand Down
131 changes: 98 additions & 33 deletions kong/tools/ngx_stub.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,102 @@
-- Monkeypatches the global `ngx` table.

local reg = require "rex_pcre"
local utils = require "kong.tools.utils"

-- DICT Proxy
-- https://github.com/bsm/fakengx/blob/master/fakengx.lua

local SharedDict = {}

local function set(data, key, value)
data[key] = {
value = value,
info = {expired = false}
}
end

function SharedDict:new()
return setmetatable({data = {}}, {__index = self})
end

function SharedDict:get(key)
return self.data[key] and self.data[key].value, nil
end

function SharedDict:set(key, value)
set(self.data, key, value)
return true, nil, false
end

SharedDict.safe_set = SharedDict.set

function SharedDict:add(key, value)
if self.data[key] ~= nil then
return false, "exists", false
end

set(self.data, key, value)
return true, nil, false
end

function SharedDict:replace(key, value)
if self.data[key] == nil then
return false, "not found", false
end

set(self.data, key, value)
return true, nil, false
end

function SharedDict:delete(key)
self.data[key] = nil
end

function SharedDict:incr(key, value)
if not self.data[key] then
return nil, "not found"
elseif type(self.data[key].value) ~= "number" then
return nil, "not a number"
end

self.data[key].value = self.data[key].value + value
return self.data[key].value, nil
end

function SharedDict:flush_all()
for _, item in pairs(self.data) do
item.info.expired = true
end
end

function SharedDict:flush_expired(n)
local data = self.data
local flushed = 0

for key, item in pairs(self.data) do
if item.info.expired then
data[key] = nil
flushed = flushed + 1
if n and flushed == n then
break
end
end
end

self.data = data

return flushed
end

local shared = {}
local shared_mt = {
__index = function(self, key)
if shared[key] == nil then
shared[key] = SharedDict:new()
end
return shared[key]
end
}

_G.ngx = {
req = {},
Expand All @@ -19,6 +115,7 @@ _G.ngx = {
timer = {
at = function() end
},
shared = setmetatable({}, shared_mt),
re = {
match = reg.match,
gsub = function(str, pattern, sub)
Expand All @@ -29,37 +126,5 @@ _G.ngx = {
encode_base64 = function(str)
return string.format("base64_%s", str)
end,
-- Builds a querystring from a table, separated by `&`
-- @param `tab` The key/value parameters
-- @param `key` The parent key if the value is multi-dimensional (optional)
-- @return `querystring` A string representing the built querystring
encode_args = function(tab, key)
local query = {}
local keys = {}

for k in pairs(tab) do
keys[#keys+1] = k
end

table.sort(keys)

for _, name in ipairs(keys) do
local value = tab[name]
if key then
name = string.format("%s[%s]", tostring(key), tostring(name))
end
if type(value) == "table" then
query[#query+1] = ngx.encode_args(value, name)
else
value = tostring(value)
if value ~= "" then
query[#query+1] = string.format("%s=%s", name, value)
else
query[#query+1] = name
end
end
end

return table.concat(query, "&")
end
encode_args = utils.encode_args
}
Loading