diff --git a/kong/core/plugins_iterator.lua b/kong/core/plugins_iterator.lua index 265a36afb5af..4a4ac9f07ba3 100644 --- a/kong/core/plugins_iterator.lua +++ b/kong/core/plugins_iterator.lua @@ -41,7 +41,8 @@ local function load_plugin_configuration(api_id, consumer_id, plugin_name) load_plugin_into_memory, api_id, consumer_id, plugin_name) if err then - responses.send_HTTP_INTERNAL_SERVER_ERROR(err) + ngx.ctx.delay_response = false + return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) end if plugin ~= nil and plugin.enabled then return plugin.config or {} diff --git a/kong/init.lua b/kong/init.lua index 33a8ac297b71..c6d2e9b18511 100644 --- a/kong/init.lua +++ b/kong/init.lua @@ -352,10 +352,20 @@ function Kong.access() local ctx = ngx.ctx core.access.before(ctx) + ctx.delay_response = true + for plugin, plugin_conf in plugins_iterator(singletons.loaded_plugins, true) do - plugin.handler:access(plugin_conf) + if not ctx.delayed_response then + plugin.handler:access(plugin_conf) + end end + if ctx.delayed_response then + return responses.flush_delayed_response(ctx) + end + + ctx.delay_response = false + core.access.after(ctx) end diff --git a/kong/plugins/acl/handler.lua b/kong/plugins/acl/handler.lua index 761e38c7a404..70d924191a87 100644 --- a/kong/plugins/acl/handler.lua +++ b/kong/plugins/acl/handler.lua @@ -70,7 +70,7 @@ function ACLHandler:access(conf) local acls, err = singletons.cache:get(cache_key, nil, load_acls_into_memory, consumer_id) if err then - responses.send_HTTP_INTERNAL_SERVER_ERROR(err) + return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) end if not acls then acls = EMPTY diff --git a/kong/plugins/oauth2/access.lua b/kong/plugins/oauth2/access.lua index c2465b2e342f..c47ab39d1768 100644 --- a/kong/plugins/oauth2/access.lua +++ b/kong/plugins/oauth2/access.lua @@ -543,17 +543,14 @@ function _M.execute(conf) if ngx.req.get_method() == "POST" then local uri = ngx.var.uri - local from, _ = string_find(uri, "/oauth2/token", nil, true) - + local from = string_find(uri, "/oauth2/token", nil, true) if from then - issue_token(conf) - - else - from, _ = string_find(uri, "/oauth2/authorize", nil, true) + return issue_token(conf) + end - if from then - authorize(conf) - end + from = string_find(uri, "/oauth2/authorize", nil, true) + if from then + return authorize(conf) end end diff --git a/kong/tools/responses.lua b/kong/tools/responses.lua index 51867805c76f..50c2d278ca54 100644 --- a/kong/tools/responses.lua +++ b/kong/tools/responses.lua @@ -21,6 +21,8 @@ local cjson = require "cjson.safe" local meta = require "kong.meta" +local type = type + --local server_header = _KONG._NAME .. "/" .. _KONG._VERSION local server_header = meta._NAME .. "/" .. meta._VERSION @@ -102,6 +104,18 @@ local function send_response(status_code) -- @param content (Optional) The content to send as a response. -- @return ngx.exit (Exit current context) return function(content, headers) + local ctx = ngx.ctx + + if ctx.delay_response and not ctx.delayed_response then + ctx.delayed_response = { + status_code = status_code, + content = content, + headers = headers, + } + + return + end + if status_code == _M.status_codes.HTTP_INTERNAL_SERVER_ERROR then if content then ngx.log(ngx.ERR, tostring(content)) @@ -137,6 +151,19 @@ local function send_response(status_code) end end +function _M.flush_delayed_response(ctx) + ctx.delay_response = false + + if type(ctx.delayed_response_callback) == "function" then + ctx.delayed_response_callback(ctx) + return -- avoid tail call + end + + _M.send(ctx.delayed_response.status_code, + ctx.delayed_response.content, + ctx.delayed_response.headers) +end + -- Generate sugar methods (closures) for the most used HTTP status codes. for status_code_name, status_code in pairs(_M.status_codes) do _M["send_" .. status_code_name] = send_response(status_code) diff --git a/spec/01-unit/009-responses_spec.lua b/spec/01-unit/009-responses_spec.lua index 04db5ac221ff..7fe313ecf31f 100644 --- a/spec/01-unit/009-responses_spec.lua +++ b/spec/01-unit/009-responses_spec.lua @@ -61,7 +61,7 @@ describe("Response helpers", function() it("calls `ngx.log` if and only if a 500 status code was given", function() responses.send_HTTP_BAD_REQUEST() assert.stub(ngx.log).was_not_called() - + responses.send_HTTP_BAD_REQUEST("error") assert.stub(ngx.log).was_not_called() @@ -119,4 +119,70 @@ describe("Response helpers", function() assert.stub(ngx.exit).was.called_with(501) end) end) + + describe("delayed response", function() + it("does not call ngx.say/ngx.exit if `ctx.delayed_response = true`", function() + ngx.ctx.delay_response = true + + responses.send(401, "Unauthorized", { ["X-Hello"] = "world" }) + assert.stub(ngx.say).was_not_called() + assert.stub(ngx.exit).was_not_called() + assert.not_equal("world", ngx.header["X-Hello"]) + end) + + it("flush_delayed_response() sends delayed response's status/header/body", function() + ngx.ctx.delay_response = true + + responses.send(401, "Unauthorized", { ["X-Hello"] = "world" }) + responses.flush_delayed_response(ngx.ctx) + + assert.stub(ngx.say).was.called_with("{\"message\":\"Unauthorized\"}") + assert.stub(ngx.exit).was.called_with(401) + assert.equal("world", ngx.header["X-Hello"]) + assert.is_false(ngx.ctx.delay_response) + end) + + it("delayed response cannot be overriden", function() + ngx.ctx.delay_response = true + + responses.send(401, "Unauthorized") + responses.send(200, "OK") + responses.flush_delayed_response(ngx.ctx) + + assert.stub(ngx.say).was.called_with("{\"message\":\"Unauthorized\"}") + assert.stub(ngx.exit).was.called_with(401) + end) + + it("flush_delayed_response() use custom callback if set", function() + local s = spy.new(function(ctx) end) + + do + local old_type = _G.type + + -- luacheck: ignore + _G.type = function(v) + if v == s then + return "function" + end + + return old_type(v) + end + + finally(function() + _G.type = old_type + end) + end + + package.loaded["kong.tools.responses"] = nil + responses = require "kong.tools.responses" + + ngx.ctx.delay_response = true + ngx.ctx.delayed_response_callback = s + + responses.send(401, "Unauthorized", { ["X-Hello"] = "world" }) + responses.flush_delayed_response(ngx.ctx) + + assert.spy(s).was.called_with(ngx.ctx) + end) + end) end) diff --git a/spec/02-integration/05-proxy/03-plugins_triggering_spec.lua b/spec/02-integration/05-proxy/03-plugins_triggering_spec.lua index 2b2ebfc19167..1abe01896d3e 100644 --- a/spec/02-integration/05-proxy/03-plugins_triggering_spec.lua +++ b/spec/02-integration/05-proxy/03-plugins_triggering_spec.lua @@ -165,6 +165,132 @@ describe("Plugins triggering", function() assert.equal("5", res.headers["x-ratelimit-limit-hour"]) end) + describe("short-circuited requests", function() + local FILE_LOG_PATH = os.tmpname() + + setup(function() + if client then + client:close() + end + + helpers.stop_kong() + helpers.dao:truncate_tables() + + local api = assert(helpers.dao.apis:insert { + name = "example", + hosts = { "mock_upstream" }, + upstream_url = helpers.mock_upstream_url, + }) + + -- plugin able to short-circuit a request + assert(helpers.dao.plugins:insert { + name = "key-auth", + api_id = api.id, + }) + + -- response/body filter plugin + assert(helpers.dao.plugins:insert { + name = "dummy", + api_id = api.id, + config = { + append_body = "appended from body filtering", + } + }) + + -- log phase plugin + assert(helpers.dao.plugins:insert { + name = "file-log", + api_id = api.id, + config = { + path = FILE_LOG_PATH, + }, + }) + + assert(helpers.start_kong { + nginx_conf = "spec/fixtures/custom_nginx.template", + }) + + client = helpers.proxy_client() + end) + + teardown(function() + if client then + client:close() + end + + os.remove(FILE_LOG_PATH) + + helpers.stop_kong() + end) + + it("execute a log plugin", function() + local utils = require "kong.tools.utils" + local cjson = require "cjson" + local pl_path = require "pl.path" + local pl_file = require "pl.file" + local pl_stringx = require "pl.stringx" + + local uuid = utils.uuid() + + local res = assert(client:send { + method = "GET", + path = "/status/200", + headers = { + ["Host"] = "mock_upstream", + ["X-UUID"] = uuid, + -- /!\ no key credential + } + }) + assert.res_status(401, res) + + -- TEST: ensure that our logging plugin was executed and wrote + -- something to disk. + + helpers.wait_until(function() + return pl_path.exists(FILE_LOG_PATH) and pl_path.getsize(FILE_LOG_PATH) > 0 + end, 3) + + local log = pl_file.read(FILE_LOG_PATH) + local log_message = cjson.decode(pl_stringx.strip(log)) + assert.equal("127.0.0.1", log_message.client_ip) + assert.equal(uuid, log_message.request.headers["x-uuid"]) + end) + + it("execute a header_filter plugin", function() + local res = assert(client:send { + method = "GET", + path = "/status/200", + headers = { + ["Host"] = "mock_upstream", + } + }) + assert.res_status(401, res) + + -- TEST: ensure that the dummy plugin was executed by checking + -- that headers have been injected in the header_filter phase + -- Plugins such as CORS need to run on short-circuited requests + -- as well. + + assert.not_nil(res.headers["dummy-plugin"]) + end) + + it("execute a body_filter plugin", function() + local res = assert(client:send { + method = "GET", + path = "/status/200", + headers = { + ["Host"] = "mock_upstream", + } + }) + local body = assert.res_status(401, res) + + -- TEST: ensure that the dummy plugin was executed by checking + -- that the body filtering phase has run + + assert.matches("appended from body filtering", body, nil, true) + end) + end) + describe("anonymous reports execution", function() -- anonymous reports are implemented as a plugin which is being executed -- by the plugins runloop, but which doesn't have a schema diff --git a/spec/fixtures/custom_plugins/kong/plugins/dummy/handler.lua b/spec/fixtures/custom_plugins/kong/plugins/dummy/handler.lua index 6d7e208ac62e..d60b899dcea0 100644 --- a/spec/fixtures/custom_plugins/kong/plugins/dummy/handler.lua +++ b/spec/fixtures/custom_plugins/kong/plugins/dummy/handler.lua @@ -21,6 +21,19 @@ function DummyHandler:header_filter(conf) DummyHandler.super.header_filter(self) ngx.header["Dummy-Plugin"] = conf.resp_header_value + + if conf.append_body then + ngx.header["Content-Length"] = nil + end +end + + +function DummyHandler:body_filter(conf) + DummyHandler.super.body_filter(self) + + if conf.append_body and not ngx.arg[2] then + ngx.arg[1] = string.sub(ngx.arg[1], 1, -2) .. conf.append_body + end end diff --git a/spec/fixtures/custom_plugins/kong/plugins/dummy/schema.lua b/spec/fixtures/custom_plugins/kong/plugins/dummy/schema.lua index ccee388296a2..b193b332fe39 100644 --- a/spec/fixtures/custom_plugins/kong/plugins/dummy/schema.lua +++ b/spec/fixtures/custom_plugins/kong/plugins/dummy/schema.lua @@ -1,5 +1,6 @@ return { fields = { resp_header_value = { type = "string", default = "1" }, + append_body = { type = "string" }, } }