diff --git a/kong/plugins/rate-limiting/daos.lua b/kong/plugins/rate-limiting/daos.lua index 8e60b7d45b80..631c0d84176a 100644 --- a/kong/plugins/rate-limiting/daos.lua +++ b/kong/plugins/rate-limiting/daos.lua @@ -1,3 +1,5 @@ return { - tables = {"ratelimiting_metrics"} + tables = { + "ratelimiting_metrics" + }, } diff --git a/kong/plugins/rate-limiting/handler.lua b/kong/plugins/rate-limiting/handler.lua index ba2d959e2f02..833f43926cf5 100644 --- a/kong/plugins/rate-limiting/handler.lua +++ b/kong/plugins/rate-limiting/handler.lua @@ -1,49 +1,64 @@ -- Copyright (C) Kong Inc. - local policies = require "kong.plugins.rate-limiting.policies" -local timestamp = require "kong.tools.timestamp" -local responses = require "kong.tools.responses" local BasePlugin = require "kong.plugins.base_plugin" -local ngx_log = ngx.log + +local kong = kong +local ngx = ngx +local max = math.max +local time = ngx.time local pairs = pairs local tostring = tostring -local ngx_timer_at = ngx.timer.at +local timer_at = ngx.timer.at + local RATELIMIT_LIMIT = "X-RateLimit-Limit" local RATELIMIT_REMAINING = "X-RateLimit-Remaining" + local RateLimitingHandler = BasePlugin:extend() + RateLimitingHandler.PRIORITY = 901 -RateLimitingHandler.VERSION = "0.1.0" +RateLimitingHandler.VERSION = "0.2.0" + local function get_identifier(conf) local identifier - -- Consumer is identified by ip address or authenticated_credential id - if conf.limit_by == "consumer" then - identifier = ngx.ctx.authenticated_consumer and ngx.ctx.authenticated_consumer.id - if not identifier and ngx.ctx.authenticated_credential then -- Fallback on credential - identifier = ngx.ctx.authenticated_credential.id + if conf.limit_by == "consumer" or conf.limit_by == "credential" then + local shared_ctx = kong.ctx.shared + local ngx_ctx = ngx.ctx + + local consumer = shared_ctx.authenticated_consumer or + ngx_ctx.authenticated_consumer + + if conf.limit_by == "consumer" then + identifier = consumer and consumer.id + end + + if not identifier then + local credential = shared_ctx.authenticated_credential or + ngx_ctx.authenticated_credential + + identifier = credential and credential.id end - elseif conf.limit_by == "credential" then - identifier = ngx.ctx.authenticated_credential and ngx.ctx.authenticated_credential.id end if not identifier then - identifier = ngx.var.remote_addr + identifier = kong.client.get_forwarded_ip() end return identifier end + local function get_usage(conf, identifier, current_timestamp, limits) local usage = {} local stop - for name, limit in pairs(limits) do - local current_usage, err = policies[conf.policy].usage(conf, identifier, current_timestamp, name) + for period, limit in pairs(limits) do + local current_usage, err = policies[conf.policy].usage(conf, identifier, period, current_timestamp) if err then return nil, nil, err end @@ -52,30 +67,41 @@ local function get_usage(conf, identifier, current_timestamp, limits) local remaining = limit - current_usage -- Recording usage - usage[name] = { + usage[period] = { limit = limit, - remaining = remaining + remaining = remaining, } if remaining <= 0 then - stop = name + stop = period end end return usage, stop end + +local function increment(premature, conf, ...) + if premature then + return + end + + policies[conf.policy].increment(conf, ...) +end + + function RateLimitingHandler:new() RateLimitingHandler.super.new(self, "rate-limiting") end + function RateLimitingHandler:access(conf) RateLimitingHandler.super.access(self) - local current_timestamp = timestamp.get_utc() + + local current_timestamp = time() * 1000 -- Consumer is identified by ip address or authenticated_credential id local identifier = get_identifier(conf) - local policy = conf.policy local fault_tolerant = conf.fault_tolerant -- Load current metric for configured period @@ -85,45 +111,65 @@ function RateLimitingHandler:access(conf) hour = conf.hour, day = conf.day, month = conf.month, - year = conf.year + year = conf.year, } local usage, stop, err = get_usage(conf, identifier, current_timestamp, limits) if err then if fault_tolerant then - ngx_log(ngx.ERR, "failed to get usage: ", tostring(err)) + kong.log.err("failed to get usage: ", tostring(err)) else - return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) + kong.log.err(err) + return kong.response.exit(500, { message = "An unexpected error occurred" }) end end if usage then -- Adding headers if not conf.hide_client_headers then + local headers = {} for k, v in pairs(usage) do - ngx.header[RATELIMIT_LIMIT .. "-" .. k] = v.limit - ngx.header[RATELIMIT_REMAINING .. "-" .. k] = math.max(0, (stop == nil or stop == k) and v.remaining - 1 or v.remaining) -- -increment_value for this current request + if stop == nil or stop == k then + v.remaining = v.remaining - 1 + end + + headers[RATELIMIT_LIMIT .. "-" .. k] = v.limit + headers[RATELIMIT_REMAINING .. "-" .. k] = max(0, v.remaining) end + + kong.ctx.plugin.headers = headers end -- If limit is exceeded, terminate the request if stop then - return responses.send(429, "API rate limit exceeded") + return kong.response.exit(429, { message = "API rate limit exceeded" }) end end - local incr = function(premature, conf, limits, identifier, current_timestamp, value) - if premature then - return + kong.ctx.plugin.timer = function() + local ok, err = timer_at(0, increment, conf, limits, identifier, current_timestamp, 1) + if not ok then + kong.log.err("failed to create timer: ", err) end - policies[policy].increment(conf, limits, identifier, current_timestamp, value) end +end - -- Increment metrics for configured periods if the request goes through - local ok, err = ngx_timer_at(0, incr, conf, limits, identifier, current_timestamp, 1) - if not ok then - ngx_log(ngx.ERR, "failed to create timer: ", err) + +function RateLimitingHandler:header_filter(_) + RateLimitingHandler.super.header_filter(self) + + local headers = kong.ctx.plugin.headers + if headers then + kong.response.set_headers(headers) end end + +function RateLimitingHandler:log(_) + if kong.ctx.plugin.timer then + kong.ctx.plugin.timer() + end +end + + return RateLimitingHandler diff --git a/kong/plugins/rate-limiting/migrations/001_14_to_15.lua b/kong/plugins/rate-limiting/migrations/001_14_to_15.lua index 5099f11d5974..865272adf407 100644 --- a/kong/plugins/rate-limiting/migrations/001_14_to_15.lua +++ b/kong/plugins/rate-limiting/migrations/001_14_to_15.lua @@ -4,6 +4,15 @@ return { ALTER TABLE IF EXISTS ONLY "ratelimiting_metrics" ALTER "period_date" TYPE TIMESTAMP WITH TIME ZONE USING "period_date" AT TIME ZONE 'UTC'; ]], + teardown = function(connector) + assert(connector:connect_migrations()) + assert(connector:query([[ + DROP FUNCTION IF EXISTS "increment_rate_limits_api" (UUID, TEXT, TEXT, TIMESTAMP WITH TIME ZONE, INTEGER) CASCADE; + DROP FUNCTION IF EXISTS "increment_rate_limits" (UUID, TEXT, TEXT, TIMESTAMP WITHOUT TIME ZONE, INTEGER) CASCADE; + DROP FUNCTION IF EXISTS "increment_rate_limits" (UUID, TEXT, TEXT, TIMESTAMP WITH TIME ZONE, INTEGER) CASCADE; + DROP FUNCTION IF EXISTS "increment_rate_limits" (UUID, UUID, TEXT, TEXT, TIMESTAMP WITH TIME ZONE, INTEGER) CASCADE; + ]])) + end, }, cassandra = { diff --git a/kong/plugins/rate-limiting/policies/cluster.lua b/kong/plugins/rate-limiting/policies/cluster.lua index 440b8f9ab1fc..b321339a53de 100644 --- a/kong/plugins/rate-limiting/policies/cluster.lua +++ b/kong/plugins/rate-limiting/policies/cluster.lua @@ -1,209 +1,261 @@ local timestamp = require "kong.tools.timestamp" local cassandra = require "cassandra" + +local kong = kong local concat = table.concat local pairs = pairs +local floor = math.floor local fmt = string.format -local log = ngx.log -local ERR = ngx.ERR -local NULL_UUID = "00000000-0000-0000-0000-000000000000" +local EMPTY_UUID = "00000000-0000-0000-0000-000000000000" return { - ["cassandra"] = { - increment = function(connector, limits, route_id, service_id, identifier, current_timestamp, value) + cassandra = { + increment = function(connector, limits, identifier, current_timestamp, service_id, route_id, value) local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then local res, err = connector:query([[ UPDATE ratelimiting_metrics - SET value = value + ? - WHERE route_id = ? - AND service_id = ? - AND api_id = ? - AND identifier = ? - AND period_date = ? - AND period = ? + SET value = value + ? + WHERE identifier = ? + AND period = ? + AND period_date = ? + AND service_id = ? + AND route_id = ? + AND api_id = ? ]], { cassandra.counter(value), - cassandra.uuid(route_id), - cassandra.uuid(service_id), - cassandra.uuid(NULL_UUID), identifier, - cassandra.timestamp(period_date), period, + cassandra.timestamp(period_date), + cassandra.uuid(service_id), + cassandra.uuid(route_id), + cassandra.uuid(EMPTY_UUID), }) if not res then - log(ERR, "[rate-limiting] cluster policy: could not increment ", - "cassandra counter for period '", period, "': ", err) + kong.log.err("cluster policy: could not increment cassandra counter for period '", + period, "': ", err) end end end return true end, - increment_api = function(connector, limits, api_id, identifier, current_timestamp, value) + increment_api = function(connector, limits, identifier, current_timestamp, api_id, value) local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then local res, err = connector:query([[ UPDATE ratelimiting_metrics - SET value = value + ? - WHERE api_id = ? AND - route_id = ? AND - service_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? + SET value = value + ? + WHERE identifier = ? + AND period = ? + AND period_date = ? + AND service_id = ? + AND route_id = ? + AND api_id = ? ]], { cassandra.counter(value), - cassandra.uuid(api_id), - cassandra.uuid(NULL_UUID), - cassandra.uuid(NULL_UUID), identifier, - cassandra.timestamp(period_date), period, + cassandra.timestamp(period_date), + cassandra.uuid(EMPTY_UUID), + cassandra.uuid(EMPTY_UUID), + cassandra.uuid(api_id), + }) if not res then - log(ERR, "[rate-limiting] cluster policy: could not increment ", - "cassandra counter for period '", period, "': ", err) + kong.log.err("cluster policy: could not increment cassandra counter for period '", + period, "': ", err) end end end return true end, - find = function(connector, route_id, service_id, identifier, current_timestamp, period) + find = function(connector, identifier, period, current_timestamp, service_id, route_id) local periods = timestamp.get_timestamps(current_timestamp) local rows, err = connector:query([[ - SELECT * FROM ratelimiting_metrics - WHERE route_id = ? - AND service_id = ? - AND api_id = ? - AND identifier = ? - AND period_date = ? - AND period = ? + SELECT value + FROM ratelimiting_metrics + WHERE identifier = ? + AND period = ? + AND period_date = ? + AND service_id = ? + AND route_id = ? + AND api_id = ? ]], { - cassandra.uuid(route_id), - cassandra.uuid(service_id), - cassandra.uuid(NULL_UUID), identifier, - cassandra.timestamp(periods[period]), period, + cassandra.timestamp(periods[period]), + cassandra.uuid(service_id), + cassandra.uuid(route_id), + cassandra.uuid(EMPTY_UUID), }) - if not rows then return nil, err - elseif #rows <= 1 then return rows[1] - else return nil, "bad rows result" end + + if not rows then + return nil, err + end + + if #rows <= 1 then + return rows[1] + end + + return nil, "bad rows result" end, - find_api = function(connector, api_id, identifier, current_timestamp, period) + find_api = function(connector, identifier, period, current_timestamp, api_id) local periods = timestamp.get_timestamps(current_timestamp) local rows, err = connector:query([[ - SELECT * - FROM ratelimiting_metrics - WHERE api_id = ? AND - route_id = ? AND - service_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? + SELECT value + FROM ratelimiting_metrics + WHERE identifier = ? + AND period = ? + AND period_date = ? + AND service_id = ? + AND route_id = ? + AND api_id = ? ]], { - cassandra.uuid(api_id), - cassandra.uuid(NULL_UUID), - cassandra.uuid(NULL_UUID), identifier, - cassandra.timestamp(periods[period]), period, + cassandra.timestamp(periods[period]), + cassandra.uuid(EMPTY_UUID), + cassandra.uuid(EMPTY_UUID), + cassandra.uuid(api_id), }) - if not rows then return nil, err - elseif #rows <= 1 then return rows[1] - else return nil, "bad rows result" end + + if not rows then + return nil, err + end + + if #rows <= 1 then + return rows[1] + end + + return nil, "bad rows result" end, }, - ["postgres"] = { - increment = function(connector, limits, route_id, service_id, identifier, current_timestamp, value) - local buf = {} + postgres = { + increment = function(connector, limits, identifier, current_timestamp, service_id, route_id, value) + local buf = { "BEGIN" } + local len = 1 local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then - buf[#buf + 1] = fmt([[ - SELECT increment_rate_limits('%s', '%s', '%s', '%s', - to_timestamp('%s') at time zone 'UTC', %d) - ]], route_id, service_id, identifier, period, period_date/1000, value) + len = len + 1 + buf[len] = fmt([[ + INSERT INTO "ratelimiting_metrics" ("identifier", "period", "period_date", "service_id", "route_id", "api_id", "value") + VALUES ('%s', '%s', TO_TIMESTAMP('%s') AT TIME ZONE 'UTC', '%s', '%s', '%s', %d) + ON CONFLICT ("identifier", "period", "period_date", "service_id", "route_id", "api_id") DO UPDATE + SET "value" = "ratelimiting_metrics"."value" + EXCLUDED."value" + ]], identifier, period, floor(period_date / 1000), service_id, route_id, EMPTY_UUID, value) end end - local res, err = connector:query(concat(buf, ";")) - if not res then - return nil, err + if len > 1 then + local sql + if len == 2 then + sql = buf[2] + + else + buf[len + 1] = "COMMIT;" + sql = concat(buf, ";\n") + end + + local res, err = connector:query(sql) + if not res then + return nil, err + end end return true end, - increment_api = function(connector, limits, api_id, identifier, current_timestamp, value) - local buf = {} + increment_api = function(connector, limits, identifier, current_timestamp, api_id, value) + local buf = { "BEGIN" } + local len = 1 local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then - buf[#buf+1] = fmt([[ - SELECT increment_rate_limits_api('%s', '%s', '%s', to_timestamp('%s') - at time zone 'UTC', %d) - ]], api_id, identifier, period, period_date/1000, value) + len = len + 1 + buf[len] = fmt([[ + INSERT INTO "ratelimiting_metrics" ("identifier", "period", "period_date", "service_id", "route_id", "api_id", "value") + VALUES ('%s', '%s', TO_TIMESTAMP('%s') AT TIME ZONE 'UTC', '%s', '%s', '%s', %d) + ON CONFLICT ("identifier", "period", "period_date", "service_id", "route_id", "api_id") DO UPDATE + SET "value" = "ratelimiting_metrics"."value" + EXCLUDED."value" + ]], identifier, period, floor(period_date / 1000), EMPTY_UUID, EMPTY_UUID, api_id, value) end end - local res, err = connector:query(concat(buf, ";")) - if not res then - return nil, err + if len > 1 then + local sql + if len == 2 then + sql = buf[2] + + else + buf[len + 1] = "COMMIT;" + sql = concat(buf, ";\n") + end + + local res, err = connector:query(sql) + if not res then + return nil, err + end end return true end, - find = function(connector, route_id, service_id, identifier, current_timestamp, period) + find = function(connector, identifier, period, current_timestamp, service_id, route_id) local periods = timestamp.get_timestamps(current_timestamp) - local q = fmt([[ - SELECT *, extract(epoch from period_date)*1000 AS period_date - FROM ratelimiting_metrics - WHERE route_id = '%s' - AND service_id = '%s' - AND identifier = '%s' - AND period_date = to_timestamp('%s') at time zone 'UTC' - AND period = '%s' - ]], route_id, service_id, identifier, periods[period]/1000, period) - - local res, err = connector:query(q) - if not res then + local sql = fmt([[ + SELECT "value" + FROM "ratelimiting_metrics" + WHERE "identifier" = '%s' + AND "period" = '%s' + AND "period_date" = TO_TIMESTAMP('%s') AT TIME ZONE 'UTC' + AND "service_id" = '%s' + AND "route_id" = '%s' + AND "api_id" = '%s' + LIMIT 1; + ]], identifier, period, floor(periods[period] / 1000), service_id, route_id, EMPTY_UUID) + + local res, err = connector:query(sql) + if not res or err then return nil, err end return res[1] end, - find_api = function(connector, api_id, identifier, current_timestamp, period) + find_api = function(connector, identifier, period, current_timestamp, api_id) local periods = timestamp.get_timestamps(current_timestamp) - local q = fmt([[ - SELECT *, extract(epoch from period_date)*1000 AS period_date - FROM ratelimiting_metrics - WHERE api_id = '%s' AND - identifier = '%s' AND - period_date = to_timestamp('%s') at time zone 'UTC' AND - period = '%s' - ]], api_id, identifier, periods[period]/1000, period) - - local res, err = connector:query(q) - if not res then + local sql = fmt([[ + SELECT "value" + FROM "ratelimiting_metrics" + WHERE "identifier" = '%s' + AND "period" = '%s' + AND "period_date" = TO_TIMESTAMP('%s') AT TIME ZONE 'UTC' + AND "service_id" = '%s' + AND "route_id" = '%s' + AND "api_id" = '%s' + LIMIT 1; + ]], identifier, period, floor(periods[period] / 1000), EMPTY_UUID, EMPTY_UUID, api_id) + + local res, err = connector:query(sql) + if not res or err then return nil, err end return res[1] end, - } + }, } diff --git a/kong/plugins/rate-limiting/policies/init.lua b/kong/plugins/rate-limiting/policies/init.lua index 42dadce59879..ef01cb7b97e1 100644 --- a/kong/plugins/rate-limiting/policies/init.lua +++ b/kong/plugins/rate-limiting/policies/init.lua @@ -1,20 +1,21 @@ -local timestamp = require "kong.tools.timestamp" -local redis = require "resty.redis" local policy_cluster = require "kong.plugins.rate-limiting.policies.cluster" +local timestamp = require "kong.tools.timestamp" local reports = require "kong.reports" +local redis = require "resty.redis" -local ngx_log = ngx.log -local shm = ngx.shared.kong_rate_limiting_counters +local kong = kong local pairs = pairs +local null = ngx.null +local shm = ngx.shared.kong_rate_limiting_counters local fmt = string.format -local NULL_UUID = "00000000-0000-0000-0000-000000000000" +local EMPTY_UUID = "00000000-0000-0000-0000-000000000000" local function is_present(str) - return str and str ~= "" and str ~= ngx.null + return str and str ~= "" and str ~= null end @@ -23,45 +24,46 @@ local function get_ids(conf) local api_id = conf.api_id - if api_id and api_id ~= ngx.null then - return nil, nil, api_id + if api_id and api_id ~= null then + return EMPTY_UUID, EMPTY_UUID, api_id end - api_id = NULL_UUID + api_id = EMPTY_UUID - local route_id = conf.route_id local service_id = conf.service_id + local route_id = conf.route_id - if not route_id or route_id == ngx.null then - route_id = NULL_UUID + if not service_id or service_id == null then + service_id = EMPTY_UUID end - if not service_id or service_id == ngx.null then - service_id = NULL_UUID + if not route_id or route_id == null then + route_id = EMPTY_UUID end - return route_id, service_id, api_id + return service_id, route_id, api_id end -local get_local_key = function(conf, identifier, period_date, name) - local route_id, service_id, api_id = get_ids(conf) +local get_local_key = function(conf, identifier, period, period_date) + local service_id, route_id, api_id = get_ids(conf) - if api_id == NULL_UUID then - return fmt("ratelimit:%s:%s:%s:%s:%s", route_id, service_id, identifier, period_date, name) + if api_id == EMPTY_UUID then + return fmt("ratelimit:%s:%s:%s:%s:%s", route_id, service_id, identifier, + period_date, period) end - return fmt("ratelimit:%s:%s:%s:%s", api_id, identifier, period_date, name) + return fmt("ratelimit:%s:%s:%s:%s", api_id, identifier, period_date, period) end local EXPIRATIONS = { second = 1, minute = 60, - hour = 3600, - day = 86400, - month = 2592000, - year = 31536000, + hour = 3600, + day = 86400, + month = 2592000, + year = 31536000, } @@ -71,11 +73,10 @@ return { local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then - local cache_key = get_local_key(conf, identifier, period_date, period) + local cache_key = get_local_key(conf, identifier, period, period_date) local newval, err = shm:incr(cache_key, value, 0) if not newval then - ngx_log(ngx.ERR, "[rate-limiting] could not increment counter ", - "for period '", period, "': ", err) + kong.log.err("could not increment counter for period '", period, "': ", err) return nil, err end end @@ -83,59 +84,63 @@ return { return true end, - usage = function(conf, identifier, current_timestamp, name) + usage = function(conf, identifier, period, current_timestamp) local periods = timestamp.get_timestamps(current_timestamp) - local cache_key = get_local_key(conf, identifier, periods[name], name) + local cache_key = get_local_key(conf, identifier, period, periods[period]) local current_metric, err = shm:get(cache_key) if err then return nil, err end - return current_metric and current_metric or 0 + return current_metric or 0 end }, ["cluster"] = { increment = function(conf, limits, identifier, current_timestamp, value) local db = kong.db - local route_id, service_id, api_id = get_ids(conf) + local service_id, route_id, api_id = get_ids(conf) local policy = policy_cluster[db.strategy] local ok, err - if api_id == NULL_UUID then - ok, err = policy.increment(db.connector, limits, route_id, service_id, - identifier, current_timestamp, value) + if api_id == EMPTY_UUID then + ok, err = policy.increment(db.connector, limits, identifier, current_timestamp, + service_id, route_id, value) else - ok, err = policy.increment_api(db.connector, limits, api_id, identifier, - current_timestamp, value) + ok, err = policy.increment_api(db.connector, limits, identifier, + current_timestamp, api_id, value) end if not ok then - ngx_log(ngx.ERR, "[rate-limiting] cluster policy: could not increment ", - db.strategy, " counter: ", err) + kong.log.err("cluster policy: could not increment ", db.strategy, + " counter: ", err) end return ok, err end, - usage = function(conf, identifier, current_timestamp, name) + usage = function(conf, identifier, period, current_timestamp) local db = kong.db - local route_id, service_id, api_id = get_ids(conf) + local service_id, route_id, api_id = get_ids(conf) local policy = policy_cluster[db.strategy] local row, err - if api_id == NULL_UUID then - row, err = policy.find(db.connector, route_id, service_id, - identifier, current_timestamp, name) + if api_id == EMPTY_UUID then + row, err = policy.find(db.connector, identifier, period, + current_timestamp, service_id, route_id) else - row, err = policy.find_api(db.connector, api_id, identifier, - current_timestamp, name) + row, err = policy.find_api(db.connector, identifier, period, + current_timestamp, api_id) end if err then return nil, err end - return row and row.value or 0 + if row and row.value ~= null and row.value > 0 then + return row.value + end + + return 0 end }, ["redis"] = { @@ -144,20 +149,20 @@ return { red:set_timeout(conf.redis_timeout) local ok, err = red:connect(conf.redis_host, conf.redis_port) if not ok then - ngx_log(ngx.ERR, "failed to connect to Redis: ", err) + kong.log.err("failed to connect to Redis: ", err) return nil, err end local times, err = red:get_reused_times() if err then - ngx_log(ngx.ERR, "failed to get connect reused times: ", err) + kong.log.err("failed to get connect reused times: ", err) return nil, err end if times == 0 and is_present(conf.redis_password) then local ok, err = red:auth(conf.redis_password) if not ok then - ngx_log(ngx.ERR, "failed to auth Redis: ", err) + kong.log.err("failed to auth Redis: ", err) return nil, err end end @@ -173,7 +178,7 @@ return { local ok, err = red:select(conf.redis_database or 0) if not ok then - ngx_log(ngx.ERR, "failed to change Redis database: ", err) + kong.log.err("failed to change Redis database: ", err) return nil, err end end @@ -184,10 +189,10 @@ return { local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do if limits[period] then - local cache_key = get_local_key(conf, identifier, period_date, period) + local cache_key = get_local_key(conf, identifier, period, period_date) local exists, err = red:exists(cache_key) if err then - ngx_log(ngx.ERR, "failed to query Redis: ", err) + kong.log.err("failed to query Redis: ", err) return nil, err end @@ -209,36 +214,39 @@ return { local _, err = red:commit_pipeline() if err then - ngx_log(ngx.ERR, "failed to commit pipeline in Redis: ", err) + kong.log.err("failed to commit pipeline in Redis: ", err) return nil, err end + local ok, err = red:set_keepalive(10000, 100) if not ok then - ngx_log(ngx.ERR, "failed to set Redis keepalive: ", err) + kong.log.err("failed to set Redis keepalive: ", err) return nil, err end return true end, - usage = function(conf, identifier, current_timestamp, name) + usage = function(conf, identifier, period, current_timestamp) local red = redis:new() + red:set_timeout(conf.redis_timeout) + local ok, err = red:connect(conf.redis_host, conf.redis_port) if not ok then - ngx_log(ngx.ERR, "failed to connect to Redis: ", err) + kong.log.err("failed to connect to Redis: ", err) return nil, err end local times, err = red:get_reused_times() if err then - ngx_log(ngx.ERR, "failed to get connect reused times: ", err) + kong.log.err("failed to get connect reused times: ", err) return nil, err end if times == 0 and is_present(conf.redis_password) then local ok, err = red:auth(conf.redis_password) if not ok then - ngx_log(ngx.ERR, "failed to connect to Redis: ", err) + kong.log.err("failed to connect to Redis: ", err) return nil, err end end @@ -254,7 +262,7 @@ return { local ok, err = red:select(conf.redis_database or 0) if not ok then - ngx_log(ngx.ERR, "failed to change Redis database: ", err) + kong.log.err("failed to change Redis database: ", err) return nil, err end end @@ -262,19 +270,20 @@ return { reports.retrieve_redis_version(red) local periods = timestamp.get_timestamps(current_timestamp) - local cache_key = get_local_key(conf, identifier, periods[name], name) + local cache_key = get_local_key(conf, identifier, period, periods[period]) + local current_metric, err = red:get(cache_key) if err then return nil, err end - if current_metric == ngx.null then + if current_metric == null then current_metric = nil end local ok, err = red:set_keepalive(10000, 100) if not ok then - ngx_log(ngx.ERR, "failed to set Redis keepalive: ", err) + kong.log.err("failed to set Redis keepalive: ", err) end return current_metric or 0 diff --git a/spec-old-api/03-plugins/24-rate-limiting/02-policies_spec.lua b/spec-old-api/03-plugins/24-rate-limiting/02-policies_spec.lua index 088cea2ab193..a9ab234b5f3b 100644 --- a/spec-old-api/03-plugins/24-rate-limiting/02-policies_spec.lua +++ b/spec-old-api/03-plugins/24-rate-limiting/02-policies_spec.lua @@ -1,21 +1,27 @@ local uuid = require("kong.tools.utils").uuid local helpers = require "spec.helpers" -local policies = require "kong.plugins.rate-limiting.policies" local timestamp = require "kong.tools.timestamp" describe("Plugin: rate-limiting (policies)", function() describe("cluster", function() - local cluster_policy = policies.cluster - local api_id = uuid() local conf = { api_id = api_id } local identifier = uuid() local dao + local policies setup(function() local _, db _, db, dao = helpers.get_db_utils() - _G.kong = _G.kong or { db = db } + + if _G.kong then + _G.kong.db = db + else + _G.kong = { db = db } + end + + package.loaded["kong.plugins.rate-limiting.policies"] = nil + policies = require "kong.plugins.rate-limiting.policies" dao:truncate_tables() end) @@ -28,8 +34,8 @@ describe("Plugin: rate-limiting (policies)", function() local periods = timestamp.get_timestamps(current_timestamp) for period, period_date in pairs(periods) do - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, + current_timestamp)) assert.equal(0, metric) end end) @@ -48,22 +54,22 @@ describe("Plugin: rate-limiting (policies)", function() } -- First increment - assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) + assert(policies.cluster.increment(conf, limits, identifier, current_timestamp, 1)) -- First select for period, period_date in pairs(periods) do - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, + current_timestamp)) assert.equal(1, metric) end -- Second increment - assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) + assert(policies.cluster.increment(conf, limits, identifier, current_timestamp, 1)) -- Second select for period, period_date in pairs(periods) do - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, + current_timestamp)) assert.equal(2, metric) end @@ -72,7 +78,7 @@ describe("Plugin: rate-limiting (policies)", function() periods = timestamp.get_timestamps(current_timestamp) -- Third increment - assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) + assert(policies.cluster.increment(conf, limits, identifier, current_timestamp, 1)) -- Third select with 1 second delay for period, period_date in pairs(periods) do @@ -81,8 +87,8 @@ describe("Plugin: rate-limiting (policies)", function() expected_value = 1 end - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, + current_timestamp)) assert.equal(expected_value, metric) end end) diff --git a/spec/03-plugins/24-rate-limiting/02-policies_spec.lua b/spec/03-plugins/24-rate-limiting/02-policies_spec.lua index fd27794da54b..bb0c3300a4ec 100644 --- a/spec/03-plugins/24-rate-limiting/02-policies_spec.lua +++ b/spec/03-plugins/24-rate-limiting/02-policies_spec.lua @@ -1,24 +1,30 @@ local uuid = require("kong.tools.utils").uuid local helpers = require "spec.helpers" -local policies = require "kong.plugins.rate-limiting.policies" local timestamp = require "kong.tools.timestamp" for _, strategy in helpers.each_strategy() do describe("Plugin: rate-limiting (policies) [#" .. strategy .. "]", function() describe("cluster", function() - local cluster_policy = policies.cluster - local identifier = uuid() local conf = { route = { id = uuid() }, service = { id = uuid() } } local db local dao + local policies setup(function() local _ _, db, dao = helpers.get_db_utils(strategy) - _G.kong = _G.kong or { db = db } + + if _G.kong then + _G.kong.db = db + else + _G.kong = { db = db } + end + + package.loaded["kong.plugins.rate-limiting.policies"] = nil + policies = require "kong.plugins.rate-limiting.policies" end) after_each(function() @@ -31,8 +37,7 @@ for _, strategy in helpers.each_strategy() do local periods = timestamp.get_timestamps(current_timestamp) for period in pairs(periods) do - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, current_timestamp)) assert.equal(0, metric) end end) @@ -51,22 +56,20 @@ for _, strategy in helpers.each_strategy() do } -- First increment - assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) + assert(policies.cluster.increment(conf, limits, identifier, current_timestamp, 1)) -- First select for period in pairs(periods) do - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, current_timestamp)) assert.equal(1, metric) end -- Second increment - assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) + assert(policies.cluster.increment(conf, limits, identifier, current_timestamp, 1)) -- Second select for period in pairs(periods) do - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, current_timestamp)) assert.equal(2, metric) end @@ -75,7 +78,7 @@ for _, strategy in helpers.each_strategy() do periods = timestamp.get_timestamps(current_timestamp) -- Third increment - assert(cluster_policy.increment(conf, limits, identifier, current_timestamp, 1)) + assert(policies.cluster.increment(conf, limits, identifier, current_timestamp, 1)) -- Third select with 1 second delay for period in pairs(periods) do @@ -84,8 +87,7 @@ for _, strategy in helpers.each_strategy() do expected_value = 1 end - local metric = assert(cluster_policy.usage(conf, identifier, - current_timestamp, period)) + local metric = assert(policies.cluster.usage(conf, identifier, period, current_timestamp)) assert.equal(expected_value, metric) end end) diff --git a/spec/03-plugins/24-rate-limiting/04-access_spec.lua b/spec/03-plugins/24-rate-limiting/04-access_spec.lua index 40df225c8f76..538dd15470c2 100644 --- a/spec/03-plugins/24-rate-limiting/04-access_spec.lua +++ b/spec/03-plugins/24-rate-limiting/04-access_spec.lua @@ -9,12 +9,7 @@ local REDIS_PASSWORD = "" local REDIS_DATABASE = 1 -local SLEEP_TIME = 1 - - local fmt = string.format - - local proxy_client = helpers.proxy_client @@ -29,6 +24,27 @@ local function wait(second_offset) end +local function GET(url, opts, res_status) + local client = proxy_client() + local res, err = client:get(url, opts) + if not res then + client:close() + return nil, err + end + + local body, err = assert.res_status(res_status, res) + if not body then + return nil, err + end + + client:close() + + ngx.sleep(0.010) + + return res, body +end + + local function flush_redis() local redis = require "resty.redis" local red = redis:new() @@ -56,7 +72,7 @@ end for _, strategy in helpers.each_strategy() do - for _, policy in ipairs({"local", "cluster", "redis"}) do + for _, policy in ipairs({ "local", "cluster", "redis" }) do describe(fmt("#flaky Plugin: rate-limiting (access) with policy: %s [#%s]", policy, strategy), function() local bp local db @@ -233,7 +249,6 @@ for _, strategy in helpers.each_strategy() do })) end) - teardown(function() helpers.stop_kong() assert(db:truncate()) @@ -246,21 +261,19 @@ for _, strategy in helpers.each_strategy() do describe("Without authentication (IP address)", function() it("blocks if exceeding limit", function() for i = 1, 6 do - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "test1.com" }, - }) + }, 200) - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit - - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end -- Additonal request, while limit is 6/minute - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "test1.com" }, - }) + }, 429) + local body = assert.res_status(429, res) local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) @@ -268,32 +281,28 @@ for _, strategy in helpers.each_strategy() do it("counts against the same service register from different routes", function() for i = 1, 3 do - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "test-service1.com" }, - }) - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + }, 200) - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end for i = 4, 6 do - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "test-service2.com" }, - }) - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + }, 200) - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end -- Additonal request, while limit is 6/minute - local res = proxy_client():get("/status/200", { + local _, body = GET("/status/200", { headers = { Host = "test-service1.com" }, - }) - local body = assert.res_status(429, res) + }, 429) + local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) end) @@ -305,24 +314,21 @@ for _, strategy in helpers.each_strategy() do } for i = 1, 3 do - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "test2.com" }, - }) + }, 200) - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit - - assert.res_status(200, res) assert.are.same(limits.minute, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(limits.minute - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) assert.are.same(limits.hour, tonumber(res.headers["x-ratelimit-limit-hour"])) assert.are.same(limits.hour - i, tonumber(res.headers["x-ratelimit-remaining-hour"])) end - local res = proxy_client():get("/status/200", { + local res, body = GET("/status/200", { path = "/status/200", headers = { Host = "test2.com" }, - }) - local body = assert.res_status(429, res) + }, 429) + local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) assert.are.equal(2, tonumber(res.headers["x-ratelimit-remaining-hour"])) @@ -333,70 +339,61 @@ for _, strategy in helpers.each_strategy() do describe("API-specific plugin", function() it("blocks if exceeding limit", function() for i = 1, 6 do - local res = proxy_client():get("/status/200?apikey=apikey123", { + local res = GET("/status/200?apikey=apikey123", { headers = { Host = "test3.com" }, - }) + }, 200) - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit - - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end -- Third query, while limit is 2/minute - local res = proxy_client():get("/status/200?apikey=apikey123", { + local res = GET("/status/200?apikey=apikey123", { headers = { Host = "test3.com" }, - }) + }, 429) + local body = assert.res_status(429, res) local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) -- Using a different key of the same consumer works - local res = proxy_client():get("/status/200?apikey=apikey333", { + GET("/status/200?apikey=apikey333", { headers = { Host = "test3.com" }, - }) - assert.res_status(200, res) + }, 200) end) end) describe("Plugin customized for specific consumer and route", function() it("blocks if exceeding limit", function() for i = 1, 8 do - local res = proxy_client():get("/status/200?apikey=apikey122", { + local res = GET("/status/200?apikey=apikey122", { headers = { Host = "test3.com" }, - }) + }, 200) - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit - - assert.res_status(200, res) assert.are.same(8, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(8 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end - local res = proxy_client():get("/status/200?apikey=apikey122", { + local _, body = GET("/status/200?apikey=apikey122", { headers = { Host = "test3.com" }, - }) - local body = assert.res_status(429, res) + }, 429) + local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) end) it("blocks if the only rate-limiting plugin existing is per consumer and not per API", function() for i = 1, 6 do - local res = proxy_client():get("/status/200?apikey=apikey122", { + local res = GET("/status/200?apikey=apikey122", { headers = { Host = "test4.com" }, - }) - - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + }, 200) - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end - local res = proxy_client():get("/status/200?apikey=apikey122", { + local _, body = GET("/status/200?apikey=apikey122", { headers = { Host = "test4.com" }, - }) - local body = assert.res_status(429, res) + }, 429) + local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) end) @@ -405,11 +402,10 @@ for _, strategy in helpers.each_strategy() do describe("Config with hide_client_headers", function() it("does not send rate-limit headers when hide_client_headers==true", function() - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "test5.com" }, - }) + }, 200) - assert.res_status(200, res) assert.is_nil(res.headers["x-ratelimit-limit-minute"]) assert.is_nil(res.headers["x-ratelimit-remaining-minute"]) end) @@ -455,10 +451,10 @@ for _, strategy in helpers.each_strategy() do end) it("does not work if an error occurs", function() - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "failtest1.com" }, - }) - assert.res_status(200, res) + }, 200) + assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(5, tonumber(res.headers["x-ratelimit-remaining-minute"])) @@ -466,18 +462,21 @@ for _, strategy in helpers.each_strategy() do assert(dao.db:drop_table("ratelimiting_metrics")) -- Make another request - local res = proxy_client():get("/status/200", { + local _, body = GET("/status/200", { headers = { Host = "failtest1.com" }, - }) - local body = assert.res_status(500, res) + }, 500) + local json = cjson.decode(body) assert.same({ message = "An unexpected error occurred" }, json) + + db:reset() + bp, db, dao = helpers.get_db_utils(strategy) end) it("keeps working if an error occurs", function() - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "failtest2.com" }, - }) - assert.res_status(200, res) + }, 200) + assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(5, tonumber(res.headers["x-ratelimit-remaining-minute"])) @@ -485,12 +484,15 @@ for _, strategy in helpers.each_strategy() do assert(dao.db:drop_table("ratelimiting_metrics")) -- Make another request - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "failtest2.com" }, - }) - assert.res_status(200, res) + }, 200) + assert.falsy(res.headers["x-ratelimit-limit-minute"]) assert.falsy(res.headers["x-ratelimit-remaining-minute"]) + + db:reset() + bp, db, dao = helpers.get_db_utils(strategy) end) end) @@ -500,6 +502,9 @@ for _, strategy in helpers.each_strategy() do before_each(function() helpers.kill_all() + assert(db:truncate()) + dao:truncate_tables() + local service1 = bp.services:insert() local route1 = bp.routes:insert { @@ -533,21 +538,25 @@ for _, strategy in helpers.each_strategy() do })) end) + teardown(function() + helpers.kill_all() + assert(db:truncate()) + end) + it("does not work if an error occurs", function() -- Make another request - local res = proxy_client():get("/status/200", { + local _, body = GET("/status/200", { headers = { Host = "failtest3.com" }, - }) - local body = assert.res_status(500, res) + }, 500) + local json = cjson.decode(body) assert.same({ message = "An unexpected error occurred" }, json) end) it("keeps working if an error occurs", function() - -- Make another request - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "failtest4.com" }, - }) - assert.res_status(200, res) + }, 200) + assert.falsy(res.headers["x-ratelimit-limit-minute"]) assert.falsy(res.headers["x-ratelimit-remaining-minute"]) end) @@ -586,25 +595,19 @@ for _, strategy in helpers.each_strategy() do end) describe("expires a counter", function() - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "expire1.com" }, - }) + }, 200) - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit - - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(5, tonumber(res.headers["x-ratelimit-remaining-minute"])) ngx.sleep(61) -- Wait for counter to expire - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = "expire1.com" } - }) - - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + }, 200) - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(5, tonumber(res.headers["x-ratelimit-remaining-minute"])) end) @@ -663,21 +666,19 @@ for _, strategy in helpers.each_strategy() do it("blocks when the consumer exceeds their quota, no matter what service/route used", function() for i = 1, 6 do - local res = proxy_client():get("/status/200?apikey=apikey125", { + local res = GET("/status/200?apikey=apikey125", { headers = { Host = fmt("test%d.com", i) }, - }) - - ngx.sleep(SLEEP_TIME) -- Wait for async timer to increment the limit + }, 200) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end -- Additonal request, while limit is 6/minute - local res = proxy_client():get("/status/200?apikey=apikey125", { + local _, body = GET("/status/200?apikey=apikey125", { headers = { Host = "test1.com" }, - }) - local body = assert.res_status(429, res) + }, 429) + local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) end) @@ -722,27 +723,22 @@ for _, strategy in helpers.each_strategy() do it("blocks if exceeding limit", function() for i = 1, 6 do - local res = proxy_client():get("/status/200", { + local res = GET("/status/200", { headers = { Host = fmt("test%d.com", i) }, - }) + }, 200) - assert.res_status(200, res) assert.are.same(6, tonumber(res.headers["x-ratelimit-limit-minute"])) assert.are.same(6 - i, tonumber(res.headers["x-ratelimit-remaining-minute"])) end - ngx.sleep(SLEEP_TIME) - -- Additonal request, while limit is 6/minute - local res = proxy_client():get("/status/200", { + local _, body = GET("/status/200", { headers = { Host = "test1.com" }, - }) - local body = assert.res_status(429, res) + }, 429) + local json = cjson.decode(body) assert.same({ message = "API rate limit exceeded" }, json) end) end) end end - -