From 18ebfa0cce993d7881d9a0d04f68b1cb3831d208 Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Fri, 9 Sep 2016 17:27:28 -0700 Subject: [PATCH 1/7] feat(db) implement lua-cassandra 1.0.0 support This makes the DAO functional with lua-cassandra 1.0.0. Only Cassandra 2.x is still supported (more updates required for 3.x). We also retrieve the cluster release_version to avoid nodes with different major versions, and thus prevent migration errors (due to different queries being performed for different major versions) * move DB implementations to `kong.dao.db` * updated rockspec tests to properly evaluate `init` modules (typical Lua idiom) * Get rid of classic.lua and use our own DB implementation helper module, it's pretty cool * instanciating a DAO Factory now returns `factory, err` * instanciating a Cassandra DB in the CLI triggers a cluster refresh (testing the connection and retrieving cluster info at the same time) * as long as https://github.com/openresty/resty-cli/pull/12 is not included in an OpenResty release, this implements a mocked `ngx.shared.DICT` object in `kong.globalpatches`. --- bin/busted | 2 + bin/kong | 2 + kong-0.9.3-0.rockspec | 12 +- kong/cmd/cluster.lua | 4 +- kong/cmd/migrations.lua | 2 +- kong/cmd/quit.lua | 3 +- kong/cmd/reload.lua | 4 +- kong/cmd/start.lua | 2 +- kong/cmd/stop.lua | 3 +- kong/core/globalpatches.lua | 108 +++++++++- kong/dao/base_db.lua | 48 ----- kong/dao/dao.lua | 46 ++-- .../{cassandra_db.lua => db/cassandra.lua} | 204 +++++++++++------- kong/dao/db/init.lua | 43 ++++ kong/dao/{postgres_db.lua => db/postgres.lua} | 71 +++--- kong/dao/errors.lua | 4 +- kong/dao/factory.lua | 72 ++++--- kong/kong.lua | 7 +- kong/templates/nginx_kong.lua | 3 +- spec/01-unit/01-rockspec_meta_spec.lua | 2 +- spec/01-unit/13-db/01-cassandra_spec.lua | 80 +++++++ .../02-integration/02-dao/01-factory_spec.lua | 7 +- .../02-dao/02-migrations_spec.lua | 15 +- spec/02-integration/02-dao/03-crud_spec.lua | 11 +- .../02-dao/04-constraints_spec.lua | 2 +- .../02-dao/05-use_cases_spec.lua | 2 +- .../02-dao/06-plugins_daos_spec.lua | 6 +- spec/02-integration/02-dao/07-ttl_spec.lua | 2 +- spec/helpers.lua | 2 +- 29 files changed, 502 insertions(+), 267 deletions(-) delete mode 100644 kong/dao/base_db.lua rename kong/dao/{cassandra_db.lua => db/cassandra.lua} (72%) create mode 100644 kong/dao/db/init.lua rename kong/dao/{postgres_db.lua => db/postgres.lua} (88%) create mode 100644 spec/01-unit/13-db/01-cassandra_spec.lua diff --git a/bin/busted b/bin/busted index 2eb6765b4018..991b418e5680 100755 --- a/bin/busted +++ b/bin/busted @@ -5,5 +5,7 @@ require("kong.core.globalpatches")({ rbusted = true }) +package.path = "?/init.lua;"..package.path + -- Busted command-line runner require 'busted.runner'({ standalone = false }) diff --git a/bin/kong b/bin/kong index 48b017a6a719..296417e2f147 100755 --- a/bin/kong +++ b/bin/kong @@ -4,4 +4,6 @@ require("kong.core.globalpatches")({ cli = true }) +package.path = "?/init.lua;"..package.path + require("kong.cmd.init")(arg) diff --git a/kong-0.9.3-0.rockspec b/kong-0.9.3-0.rockspec index b043bec30f14..dfc50a2f7ba3 100644 --- a/kong-0.9.3-0.rockspec +++ b/kong-0.9.3-0.rockspec @@ -20,7 +20,7 @@ dependencies = { "multipart == 0.4", "version == 0.2", "lapis == 1.5.1", - "lua-cassandra == 0.5.4", + "lua-cassandra == 1.0.0", "pgmoon-mashape == 2.0.1", "luatz == 0.3", "lua_system_constants == 0.1.1", @@ -45,8 +45,8 @@ build = { ["kong.vendor.classic"] = "kong/vendor/classic.lua", + ["kong.cmd"] = "kong/cmd/init.lua", ["kong.cmd.roar"] = "kong/cmd/roar.lua", - ["kong.cmd.init"] = "kong/cmd/init.lua", ["kong.cmd.stop"] = "kong/cmd/stop.lua", ["kong.cmd.quit"] = "kong/cmd/quit.lua", ["kong.cmd.start"] = "kong/cmd/start.lua", @@ -65,7 +65,7 @@ build = { ["kong.cmd.utils.prefix_handler"] = "kong/cmd/utils/prefix_handler.lua", ["kong.cmd.utils.dnsmasq_signals"] = "kong/cmd/utils/dnsmasq_signals.lua", - ["kong.api.init"] = "kong/api/init.lua", + ["kong.api"] = "kong/api/init.lua", ["kong.api.api_helpers"] = "kong/api/api_helpers.lua", ["kong.api.crud_helpers"] = "kong/api/crud_helpers.lua", ["kong.api.routes.kong"] = "kong/api/routes/kong.lua", @@ -98,9 +98,9 @@ build = { ["kong.dao.schemas.nodes"] = "kong/dao/schemas/nodes.lua", ["kong.dao.schemas.consumers"] = "kong/dao/schemas/consumers.lua", ["kong.dao.schemas.plugins"] = "kong/dao/schemas/plugins.lua", - ["kong.dao.base_db"] = "kong/dao/base_db.lua", - ["kong.dao.cassandra_db"] = "kong/dao/cassandra_db.lua", - ["kong.dao.postgres_db"] = "kong/dao/postgres_db.lua", + ["kong.dao.db"] = "kong/dao/db/init.lua", + ["kong.dao.db.cassandra"] = "kong/dao/db/cassandra.lua", + ["kong.dao.db.postgres"] = "kong/dao/db/postgres.lua", ["kong.dao.dao"] = "kong/dao/dao.lua", ["kong.dao.factory"] = "kong/dao/factory.lua", ["kong.dao.model_factory"] = "kong/dao/model_factory.lua", diff --git a/kong/cmd/cluster.lua b/kong/cmd/cluster.lua index cb3a1d687041..cb7e05649003 100644 --- a/kong/cmd/cluster.lua +++ b/kong/cmd/cluster.lua @@ -7,7 +7,7 @@ local conf_loader = require "kong.conf_loader" local function execute(args) if args.command == "keygen" then local conf = assert(conf_loader(args.conf)) - local dao = DAOFactory(conf) + local dao = assert(DAOFactory(conf)) local serf = Serf.new(conf, dao) print(assert(serf:keygen())) return @@ -21,7 +21,7 @@ local function execute(args) assert(pl_path.exists(default_conf.prefix), "no such prefix: "..default_conf.prefix) local conf = assert(conf_loader(default_conf.kong_conf)) - local dao = DAOFactory(conf) + local dao = assert(DAOFactory.new(conf)) local serf = Serf.new(conf, dao) if args.command == "members" then diff --git a/kong/cmd/migrations.lua b/kong/cmd/migrations.lua index faaa44acfc00..b6d2ffe1c1f1 100644 --- a/kong/cmd/migrations.lua +++ b/kong/cmd/migrations.lua @@ -28,7 +28,7 @@ end local function execute(args) local conf = assert(conf_loader(args.conf)) - local dao = DAOFactory(conf, conf.plugins) + local dao = assert(DAOFactory.new(conf, conf.plugins)) if args.command == "up" then assert(dao:run_migrations()) diff --git a/kong/cmd/quit.lua b/kong/cmd/quit.lua index b8c81a51fc41..2a32b59ea36e 100644 --- a/kong/cmd/quit.lua +++ b/kong/cmd/quit.lua @@ -37,7 +37,8 @@ local function execute(args) assert(nginx_signals.stop(conf)) end - assert(serf_signals.stop(conf, DAOFactory(conf))) + local dao = assert(DAOFactory.new(conf)) + assert(serf_signals.stop(conf, dao)) if conf.dnsmasq then assert(dnsmasq_signals.stop(conf)) diff --git a/kong/cmd/reload.lua b/kong/cmd/reload.lua index 79ec190a6bf2..ea29e461e5c3 100644 --- a/kong/cmd/reload.lua +++ b/kong/cmd/reload.lua @@ -25,7 +25,9 @@ local function execute(args) if conf.dnsmasq then assert(dnsmasq_signals.start(conf)) end - assert(serf_signals.start(conf, DAOFactory(conf))) + + local dao = assert(DAOFactory.new(conf)) + assert(serf_signals.start(conf, dao)) assert(nginx_signals.reload(conf)) log("Kong reloaded") end diff --git a/kong/cmd/start.lua b/kong/cmd/start.lua index 9e520bd38226..863d9f8b2944 100644 --- a/kong/cmd/start.lua +++ b/kong/cmd/start.lua @@ -15,8 +15,8 @@ local function execute(args) assert(not kill.is_running(conf.nginx_pid), "Kong is already running in "..conf.prefix) - local dao = DAOFactory(conf) local err + local dao = assert(DAOFactory.new(conf)) xpcall(function() assert(prefix_handler.prepare_prefix(conf, args.nginx_conf)) assert(dao:run_migrations()) diff --git a/kong/cmd/stop.lua b/kong/cmd/stop.lua index b912049de501..9710ab95c35c 100644 --- a/kong/cmd/stop.lua +++ b/kong/cmd/stop.lua @@ -18,8 +18,9 @@ local function execute(args) -- load /kong.conf containing running node's config local conf = assert(conf_loader(default_conf.kong_conf)) + local dao = assert(DAOFactory.new(conf)) assert(nginx_signals.stop(conf)) - assert(serf_signals.stop(conf, DAOFactory(conf))) + assert(serf_signals.stop(conf, dao)) if conf.dnsmasq then assert(dnsmasq_signals.stop(conf)) end diff --git a/kong/core/globalpatches.lua b/kong/core/globalpatches.lua index f63c0223e3c2..e453c05499cd 100644 --- a/kong/core/globalpatches.lua +++ b/kong/core/globalpatches.lua @@ -13,6 +13,113 @@ return function(options) local socket = require(namespace .. ".socket") socket.force_luasocket(ngx.get_phase(), true) end + + do + -- ngx.shared.DICT proxy + -- https://github.com/bsm/fakengx/blob/master/fakengx.lua + -- with minor fixes and addtions such as exptime + -- + -- See https://github.com/openresty/resty-cli/pull/12 + -- for a definitive solution ot using shms in CLI + local SharedDict = {} + local function set(data, key, value) + data[key] = { + value = value, + info = {expired = false} + } + end + function SharedDict:new() + return setmetatable({data = {}}, {__index = self}) + end + function SharedDict:get(key) + return self.data[key] and self.data[key].value, nil + end + function SharedDict:set(key, value) + set(self.data, key, value) + return true, nil, false + end + SharedDict.safe_set = SharedDict.set + function SharedDict:add(key, value, exptime) + if self.data[key] ~= nil then + return false, "exists", false + end + + if exptime then + ngx.timer.at(exptime, function() + self.data[key] = nil + end) + end + + set(self.data, key, value) + return true, nil, false + end + function SharedDict:replace(key, value) + if self.data[key] == nil then + return false, "not found", false + end + set(self.data, key, value) + return true, nil, false + end + function SharedDict:delete(key) + self.data[key] = nil + return true + end + function SharedDict:incr(key, value) + if not self.data[key] then + return nil, "not found" + elseif type(self.data[key].value) ~= "number" then + return nil, "not a number" + end + self.data[key].value = self.data[key].value + value + return self.data[key].value, nil + end + function SharedDict:flush_all() + for _, item in pairs(self.data) do + item.info.expired = true + end + end + function SharedDict:flush_expired(n) + local data = self.data + local flushed = 0 + + for key, item in pairs(self.data) do + if item.info.expired then + data[key] = nil + flushed = flushed + 1 + if n and flushed == n then + break + end + end + end + self.data = data + return flushed + end + function SharedDict:get_keys(n) + n = n or 1024 + local i = 0 + local keys = {} + for k in pairs(self.data) do + keys[#keys+1] = k + i = i + 1 + if n ~= 0 and i == n then + break + end + end + return keys + end + + -- hack + _G.ngx.shared = setmetatable({}, { + __index = function(self, key) + local shm = rawget(self, key) + if not shm then + shm = SharedDict:new() + rawset(self, key, SharedDict:new()) + end + return shm + end + }) + end end if options.rbusted then @@ -93,6 +200,5 @@ return function(options) return seed end end - end end diff --git a/kong/dao/base_db.lua b/kong/dao/base_db.lua deleted file mode 100644 index 376b8f6503bc..000000000000 --- a/kong/dao/base_db.lua +++ /dev/null @@ -1,48 +0,0 @@ -local Object = require "kong.vendor.classic" -local utils = require "kong.tools.utils" - -local BaseDB = Object:extend() - -function BaseDB:new(db_type, conn_opts) - self.options = conn_opts - self.db_type = db_type -end - -function BaseDB:init() - -- to be implemented in child - -- called by init_by_worker for DB specific initialization -end - -function BaseDB:_get_conn_options() - return utils.shallow_copy(self.options) -end - -function BaseDB:query(query) - -- to be implemented in child -end - -function BaseDB:insert(model) - -- to be implemented in child -end - -function BaseDB:find() - -- to be implemented in child -end - -function BaseDB:find_all() - -- to be implemented in child -end - -function BaseDB:count() - -- to be implemented in child -end - -function BaseDB:update() - -- to be implemented in child -end - -function BaseDB:delete() - -- to be implemented in child -end - -return BaseDB diff --git a/kong/dao/dao.lua b/kong/dao/dao.lua index 04954a769cf1..f5f6fd9f794b 100644 --- a/kong/dao/dao.lua +++ b/kong/dao/dao.lua @@ -46,9 +46,13 @@ local function check_utf8(tbl, arg_n) end end -local function ret_error(db_type, res, err, ...) +local function ret_error(db_name, res, err, ...) if type(err) == "table" then - err.db_type = db_type + err.db_name = db_name + elseif type(err) == "string" then + local e = Errors.db(err) + e.db_name = db_name + err = tostring(e) end return res, err, ... @@ -113,7 +117,7 @@ function DAO:insert(tbl, options) local model = self.model_mt(tbl) local ok, err = model:validate {dao = self} if not ok then - return ret_error(self.db.db_type, nil, err) + return ret_error(self.db.name, nil, err) end for col, field in pairs(model.__schema.fields) do @@ -129,7 +133,7 @@ function DAO:insert(tbl, options) if not err and not options.quiet then event(self, event_types.ENTITY_CREATED, self.table, self.schema, res) end - return ret_error(self.db.db_type, res, err) + return ret_error(self.db.name, res, err) end --- Find a row. @@ -148,10 +152,10 @@ function DAO:find(tbl) local primary_keys, _, _, err = model:extract_keys() if err then - return ret_error(self.db.db_type, nil, Errors.schema(err)) + return ret_error(self.db.name, nil, Errors.schema(err)) end - return ret_error(self.db.db_type, self.db:find(self.table, self.schema, primary_keys)) + return ret_error(self.db.name, self.db:find(self.table, self.schema, primary_keys)) end --- Find all rows. @@ -167,11 +171,11 @@ function DAO:find_all(tbl) local ok, err = schemas_validation.is_schema_subset(tbl, self.schema) if not ok then - return ret_error(self.db.db_type, nil, Errors.schema(err)) + return ret_error(self.db.name, nil, Errors.schema(err)) end end - return ret_error(self.db.db_type, self.db:find_all(self.table, tbl, self.schema)) + return ret_error(self.db.name, self.db:find_all(self.table, tbl, self.schema)) end --- Find a paginated set of rows. @@ -187,7 +191,7 @@ function DAO:find_page(tbl, page_offset, page_size) check_not_empty(tbl, 1) local ok, err = schemas_validation.is_schema_subset(tbl, self.schema) if not ok then - return ret_error(self.db.db_type, nil, Errors.schema(err)) + return ret_error(self.db.name, nil, Errors.schema(err)) end end @@ -197,7 +201,7 @@ function DAO:find_page(tbl, page_offset, page_size) check_arg(page_size, 3, "number") - return ret_error(self.db.db_type, self.db:find_page(self.table, tbl, page_offset, page_size, self.schema)) + return ret_error(self.db.name, self.db:find_page(self.table, tbl, page_offset, page_size, self.schema)) end --- Count the number of rows. @@ -211,7 +215,7 @@ function DAO:count(tbl) check_not_empty(tbl, 1) local ok, err = schemas_validation.is_schema_subset(tbl, self.schema) if not ok then - return ret_error(self.db.db_type, nil, Errors.schema(err)) + return ret_error(self.db.name, nil, Errors.schema(err)) end end @@ -219,7 +223,7 @@ function DAO:count(tbl) tbl = nil end - return ret_error(self.db.db_type, self.db:count(self.table, tbl, self.schema)) + return ret_error(self.db.name, self.db:count(self.table, tbl, self.schema)) end local function fix(old, new, schema) @@ -271,17 +275,17 @@ function DAO:update(tbl, filter_keys, options) local model = self.model_mt(tbl) local ok, err = model:validate {dao = self, update = true, full_update = options.full} if not ok then - return ret_error(self.db.db_type, nil, err) + return ret_error(self.db.name, nil, err) end local primary_keys, values, nils, err = model:extract_keys() if err then - return ret_error(self.db.db_type, nil, Errors.schema(err)) + return ret_error(self.db.name, nil, Errors.schema(err)) end local old, err = self.db:find(self.table, self.schema, primary_keys) if err then - return ret_error(self.db.db_type, nil, err) + return ret_error(self.db.name, nil, err) elseif old == nil then return end @@ -292,7 +296,7 @@ function DAO:update(tbl, filter_keys, options) local res, err = self.db:update(self.table, self.schema, self.constraints, primary_keys, values, nils, options.full, model, options) if err then - return ret_error(self.db.db_type, nil, err) + return ret_error(self.db.name, nil, err) elseif res then if not options.quiet then event(self, event_types.ENTITY_UPDATED, self.table, self.schema, old) @@ -321,7 +325,7 @@ function DAO:delete(tbl, options) local primary_keys, _, _, err = model:extract_keys() if err then - return ret_error(self.db.db_type, nil, Errors.schema(err)) + return ret_error(self.db.name, nil, Errors.schema(err)) end -- Find associated entities @@ -331,7 +335,7 @@ function DAO:delete(tbl, options) local f_fetch_keys = {[cascade.f_col] = tbl[cascade.col]} local rows, err = self.db:find_all(cascade.table, f_fetch_keys, cascade.schema) if err then - return ret_error(self.db.db_type, nil, err) + return ret_error(self.db.name, nil, err) end associated_entites[cascade.table] = { schema = cascade.schema, @@ -351,11 +355,11 @@ function DAO:delete(tbl, options) end end end - return ret_error(self.db.db_type, row, err) + return ret_error(self.db.name, row, err) end function DAO:truncate() - return ret_error(self.db.db_type, self.db:truncate_table(self.table)) + return ret_error(self.db.name, self.db:truncate_table(self.table)) end -return DAO \ No newline at end of file +return DAO diff --git a/kong/dao/cassandra_db.lua b/kong/dao/db/cassandra.lua similarity index 72% rename from kong/dao/cassandra_db.lua rename to kong/dao/db/cassandra.lua index 9fad5f9a59fd..fad239a8295c 100644 --- a/kong/dao/cassandra_db.lua +++ b/kong/dao/db/cassandra.lua @@ -1,23 +1,14 @@ +local cassandra = require "cassandra" +local Cluster = require "resty.cassandra.cluster" local timestamp = require "kong.tools.timestamp" local Errors = require "kong.dao.errors" -local BaseDB = require "kong.dao.base_db" local utils = require "kong.tools.utils" +local cjson = require "cjson" local uuid = utils.uuid -local cassandra +local _M = require("kong.dao.db").new_db("cassandra") -if ngx.IS_CLI then - local ngx_stub = _G.ngx - _G.ngx = nil - cassandra = require "cassandra" - _G.ngx = ngx_stub -else - cassandra = require "cassandra" -end - -local CassandraDB = BaseDB:extend() - -CassandraDB.dao_insert_values = { +_M.dao_insert_values = { id = function() return uuid() end, @@ -26,42 +17,105 @@ CassandraDB.dao_insert_values = { end } -function CassandraDB:new(kong_config) - local conn_opts = { +function _M.new(kong_config) + local self = _M.super.new() + + local query_opts = { + consistency = cassandra.consistencies[kong_config.cassandra_consistency:lower()], + prepared = true + } + + local cluster_options = { shm = "cassandra", - prepared_shm = "cassandra_prepared", contact_points = kong_config.cassandra_contact_points, + default_port = kong_config.cassandra_port, keyspace = kong_config.cassandra_keyspace, - protocol_options = { - default_port = kong_config.cassandra_port - }, - query_options = { - prepare = true, - consistency = cassandra.consistencies[kong_config.cassandra_consistency:lower()] - }, - socket_options = { - connect_timeout = kong_config.cassandra_timeout, - read_timeout = kong_config.cassandra_timeout, - }, - ssl_options = { - enabled = kong_config.cassandra_ssl, - verify = kong_config.cassandra_ssl_verify, - ca = kong_config.lua_ssl_trusted_certificate - } + connect_timeout = kong_config.cassandra_timeout, + read_timeout = kong_config.cassandra_timeout, + ssl = kong_config.cassandra_ssl, + verify = kong_config.cassandra_ssl_verify } if kong_config.cassandra_username and kong_config.cassandra_password then - conn_opts.auth = cassandra.auth.PlainTextProvider(kong_config.cassandra_username, - kong_config.cassandra_password) + cluster_options.auth = cassandra.auth_providers.plain_text( + kong_config.cassandra_username, + kong_config.cassandra_password + ) + end + + local cluster, err = Cluster.new(cluster_options) + if not cluster then return nil, err end + + self.cluster = cluster + self.query_options = query_opts + self.cluster_options = cluster_options + + if ngx.RESTY_CLI then + -- we must manually call our init phase (usually called from `init_by_lua`) + -- to refresh the cluster. + local ok, err = self:init() + if not ok then return nil, err end + end + + return self +end + +local function extract_major(release_version) + return string.match(release_version, "^(%d+)%.%d+%.?%d*$") +end + +local function cluster_release_version(peers) + local first_release_version + local ok = true + + for i = 1, #peers do + local release_version = peers[i].release_version + if not release_version then + return nil, 'no release_version for peer '..peers[i].host + end + + local major_version = extract_major(release_version) + if i == 1 then + first_release_version = major_version + elseif major_version ~= first_release_version then + ok = false + break + end + end + + if not ok then + local err_t = {"different major versions detected (only all of 2.x or 3.x supported):"} + for i = 1, #peers do + err_t[#err_t+1] = string.format("%s (%s)", peers[i].host, peers[i].release_version) + end + + return nil, table.concat(err_t, " ") end - CassandraDB.super.new(self, "cassandra", conn_opts) + return tonumber(first_release_version) +end + +_M.extract_major = extract_major +_M.cluster_release_version = cluster_release_version + +function _M:init() + local ok, err = self.cluster:refresh() + if not ok then return nil, err end + + local peers, err = self.cluster:get_peers() + if err then return nil, err + elseif not peers then return nil, 'no peers in shm' end + + self.release_version, err = cluster_release_version(peers) + if not self.release_version then return nil, err end + + return true end -function CassandraDB:infos() +function _M:infos() return { desc = "keyspace", - name = self:_get_conn_options().keyspace + name = self.cluster_options.keyspace } end @@ -75,19 +129,17 @@ local function serialize_arg(field, value) elseif field.type == "timestamp" then return cassandra.timestamp(value) elseif field.type == "table" or field.type == "array" then - local json = require "cjson" - return json.encode(value) + return cjson.encode(value) else return value end end local function deserialize_rows(rows, schema) - local json = require "cjson" for i, row in ipairs(rows) do for col, value in pairs(row) do if schema.fields[col].type == "table" or schema.fields[col].type == "array" then - rows[i][col] = json.decode(value) + rows[i][col] = cjson.decode(value) end end end @@ -170,23 +222,17 @@ local function check_foreign_constaints(self, values, constraints) return Errors.foreign(errors) end -function CassandraDB:query(query, args, opts, schema, no_keyspace) - CassandraDB.super.query(self, query, args) - - local conn_opts = self:_get_conn_options() +function _M:query(query, args, options, schema, no_keyspace) + local opts = self:clone_query_options(options) + local coordinator_opts = {} if no_keyspace then - conn_opts.keyspace = nil - end - - local session, err = cassandra.spawn_session(conn_opts) - if err then - return nil, Errors.db(tostring(err)) + -- defaults to the system keyspace, always present + coordinator_opts.keyspace = "system" end - local res, err = session:execute(query, args, opts) - session:set_keep_alive() - if err then - return nil, Errors.db(tostring(err)) + local res, err = self.cluster:execute(query, args, opts, coordinator_opts) + if not res then + return nil, Errors.db(err) end if schema ~= nil and res.type == "ROWS" then @@ -196,7 +242,7 @@ function CassandraDB:query(query, args, opts, schema, no_keyspace) return res end -function CassandraDB:insert(table_name, schema, model, constraints, options) +function _M:insert(table_name, schema, model, constraints, options) local err = check_unique_constraints(self, table_name, constraints, model) if err then return nil, err @@ -235,7 +281,7 @@ function CassandraDB:insert(table_name, schema, model, constraints, options) return row end -function CassandraDB:find(table_name, schema, filter_keys) +function _M:find(table_name, schema, filter_keys) local where, args = get_where(schema, filter_keys) local query = get_select_query(table_name, where) local rows, err = self:query(query, args, nil, schema) @@ -246,24 +292,20 @@ function CassandraDB:find(table_name, schema, filter_keys) end end -function CassandraDB:find_all(table_name, tbl, schema) - local conn_opts = self:_get_conn_options() - local session, err = cassandra.spawn_session(conn_opts) - if err then - return nil, Errors.db(tostring(err)) - end - +function _M:find_all(table_name, tbl, schema) + local opts = self:clone_query_options() local where, args if tbl ~= nil then where, args = get_where(schema, tbl) end + local err local query = get_select_query(table_name, where) - local res_rows, err = {}, nil + local res_rows = {} - for rows, page_err in session:execute(query, args, {auto_paging = true}) do + for rows, page_err in self.cluster:iterate(query, args, opts) do if page_err then - err = Errors.db(tostring(page_err)) + err = Errors.db(page_err) res_rows = nil break end @@ -275,12 +317,10 @@ function CassandraDB:find_all(table_name, tbl, schema) end end - session:set_keep_alive() - return res_rows, err end -function CassandraDB:find_page(table_name, tbl, paging_state, page_size, schema) +function _M:find_page(table_name, tbl, paging_state, page_size, schema) local where, args if tbl ~= nil then where, args = get_where(schema, tbl) @@ -301,7 +341,7 @@ function CassandraDB:find_page(table_name, tbl, paging_state, page_size, schema) end end -function CassandraDB:count(table_name, tbl, schema) +function _M:count(table_name, tbl, schema) local where, args if tbl ~= nil then where, args = get_where(schema, tbl) @@ -316,7 +356,7 @@ function CassandraDB:count(table_name, tbl, schema) end end -function CassandraDB:update(table_name, schema, constraints, filter_keys, values, nils, full, model, options) +function _M:update(table_name, schema, constraints, filter_keys, values, nils, full, model, options) -- must check unique constaints manually too local err = check_unique_constraints(self, table_name, constraints, values, filter_keys, true) if err then @@ -403,7 +443,7 @@ local function cascade_delete(self, primary_keys, constraints) end end -function CassandraDB:delete(table_name, schema, primary_keys, constraints) +function _M:delete(table_name, schema, primary_keys, constraints) local row, err = self:find(table_name, schema, primary_keys) if err or row == nil then return nil, err @@ -425,7 +465,7 @@ end -- Migrations -function CassandraDB:queries(queries, no_keyspace) +function _M:queries(queries, no_keyspace) for _, query in ipairs(utils.split(queries, ";")) do if utils.strip(query) ~= "" then local err = select(2, self:query(query, nil, nil, nil, no_keyspace)) @@ -436,19 +476,19 @@ function CassandraDB:queries(queries, no_keyspace) end end -function CassandraDB:drop_table(table_name) +function _M:drop_table(table_name) return select(2, self:query("DROP TABLE "..table_name)) end -function CassandraDB:truncate_table(table_name) +function _M:truncate_table(table_name) return select(2, self:query("TRUNCATE "..table_name)) end -function CassandraDB:current_migrations() +function _M:current_migrations() -- Check if keyspace exists local rows, err = self:query([[ SELECT * FROM system.schema_keyspaces WHERE keyspace_name = ? - ]], {self.options.keyspace}, nil, nil, true) + ]], {self.cluster_options.keyspace}, nil, nil, true) if err then return nil, err elseif #rows == 0 then @@ -460,7 +500,7 @@ function CassandraDB:current_migrations() SELECT COUNT(*) FROM system.schema_columnfamilies WHERE keyspace_name = ? AND columnfamily_name = ? ]], { - self.options.keyspace, + self.cluster_options.keyspace, "schema_migrations" }) if err then @@ -474,7 +514,7 @@ function CassandraDB:current_migrations() end end -function CassandraDB:record_migration(id, name) +function _M:record_migration(id, name) return select(2, self:query([[ UPDATE schema_migrations SET migrations = migrations + ? WHERE id = ? ]], { @@ -483,4 +523,4 @@ function CassandraDB:record_migration(id, name) })) end -return CassandraDB +return _M diff --git a/kong/dao/db/init.lua b/kong/dao/db/init.lua new file mode 100644 index 000000000000..c4220ad38237 --- /dev/null +++ b/kong/dao/db/init.lua @@ -0,0 +1,43 @@ +local utils = require "kong.tools.utils" + +local _M = {} + +function _M.new_db(name) + local db_mt = { + name = name, + init = function() error('init() not implemented') end, + infos = function() error('infos() not implemented') end, + query = function() error('query() not implemented') end, + insert = function() error('insert() not implemented') end, + find = function() error('find() not implemented') end, + find_all = function() error('find_all() not implemented') end, + count = function() error('count() not implemented') end, + update = function() error('update() not implemented') end, + delete = function() error('delete() not implemented') end, + queries = function() error('queries() not implemented') end, + drop_table = function() error('drop_table() not implemented') end, + truncate_table = function() error('truncate_table() not implemented') end, + current_migrations = function() error('current_migrations() not implemented') end, + record_migration = function() error('record_migration() not implemented') end, + clone_query_options = function(self, options) + options = options or {} + local opts = utils.shallow_copy(self.query_options) + for k, v in pairs(options) do + opts[k] = v + end + return opts + end + } + + db_mt.__index = db_mt + + db_mt.super = { + new = function() + return setmetatable({}, db_mt) + end + } + + return setmetatable(db_mt, {__index = db_mt.super}) +end + +return _M diff --git a/kong/dao/postgres_db.lua b/kong/dao/db/postgres.lua similarity index 88% rename from kong/dao/postgres_db.lua rename to kong/dao/db/postgres.lua index ec4642d1252d..ed3e63b5a64c 100644 --- a/kong/dao/postgres_db.lua +++ b/kong/dao/db/postgres.lua @@ -1,23 +1,24 @@ local pgmoon = require "pgmoon-mashape" -local BaseDB = require "kong.dao.base_db" local Errors = require "kong.dao.errors" local utils = require "kong.tools.utils" local uuid = utils.uuid local TTL_CLEANUP_INTERVAL = 60 -- 1 minute -local PostgresDB = BaseDB:extend() +local _M = require("kong.dao.db").new_db("postgres") -PostgresDB.dao_insert_values = { +_M.dao_insert_values = { id = function() return uuid() end } -PostgresDB.additional_tables = {"ttls"} +_M.additional_tables = {"ttls"} -function PostgresDB:new(kong_config) - local conn_opts = { +function _M.new(kong_config) + local self = _M.super.new() + + self.query_options = { host = kong_config.pg_host, port = kong_config.pg_port, user = kong_config.pg_user, @@ -28,14 +29,14 @@ function PostgresDB:new(kong_config) cafile = kong_config.lua_ssl_trusted_certificate } - PostgresDB.super.new(self, "postgres", conn_opts) + return self end -- TTL clean up timer functions local function do_clean_ttl(premature, postgres) if premature then return end - + local ok, err = postgres:clear_expired_ttl() if not ok then ngx.log(ngx.ERR, "failed to cleanup TTLs: ", err) @@ -46,7 +47,7 @@ local function do_clean_ttl(premature, postgres) end end -function PostgresDB:start_ttl_timer() +function _M:start_ttl_timer() if ngx then local ok, err = ngx.timer.at(TTL_CLEANUP_INTERVAL, do_clean_ttl, self) if not ok then @@ -56,14 +57,14 @@ function PostgresDB:start_ttl_timer() end end -function PostgresDB:init() +function _M:init() self:start_ttl_timer() end -function PostgresDB:infos() +function _M:infos() return { desc = "database", - name = self:_get_conn_options().database + name = self:clone_query_options().database } end @@ -134,10 +135,8 @@ end -- Querying -function PostgresDB:query(query, schema) - PostgresDB.super.query(self, query) - - local conn_opts = self:_get_conn_options() +function _M:query(query, schema) + local conn_opts = self:clone_query_options() local pg = pgmoon.new(conn_opts) local ok, err = pg:connect() if not ok then @@ -160,7 +159,7 @@ function PostgresDB:query(query, schema) return res end -function PostgresDB:retrieve_primary_key_type(schema, table_name) +function _M:retrieve_primary_key_type(schema, table_name) if schema.primary_key and #schema.primary_key == 1 then if not self.column_types then self.column_types = {} end @@ -181,7 +180,7 @@ function PostgresDB:retrieve_primary_key_type(schema, table_name) end end -function PostgresDB:get_select_query(select_clause, schema, table, where, offset, limit) +function _M:get_select_query(select_clause, schema, table, where, offset, limit) local query local join_ttl = schema.primary_key and #schema.primary_key == 1 @@ -206,7 +205,7 @@ function PostgresDB:get_select_query(select_clause, schema, table, where, offset return query end -function PostgresDB:deserialize_rows(rows, schema) +function _M:deserialize_rows(rows, schema) if schema then local json = require "cjson" for i, row in ipairs(rows) do @@ -220,7 +219,7 @@ function PostgresDB:deserialize_rows(rows, schema) end end -function PostgresDB:deserialize_timestamps(row, schema) +function _M:deserialize_timestamps(row, schema) local result = row for k, v in pairs(schema.fields) do if v.type == "timestamp" and result[k] then @@ -236,7 +235,7 @@ function PostgresDB:deserialize_timestamps(row, schema) return result end -function PostgresDB:serialize_timestamps(tbl, schema) +function _M:serialize_timestamps(tbl, schema) local result = tbl for k, v in pairs(schema.fields) do if v.type == "timestamp" and result[k] then @@ -252,7 +251,7 @@ function PostgresDB:serialize_timestamps(tbl, schema) return result end -function PostgresDB:ttl(tbl, table_name, schema, ttl) +function _M:ttl(tbl, table_name, schema, ttl) if not schema.primary_key or #schema.primary_key ~= 1 then return false, "Cannot set a TTL if the entity has no primary key, or has more than one primary key" end @@ -280,7 +279,7 @@ function PostgresDB:ttl(tbl, table_name, schema, ttl) end -- Delete old expired TTL entities -function PostgresDB:clear_expired_ttl() +function _M:clear_expired_ttl() local query = "SELECT * FROM ttls WHERE expire_at < CURRENT_TIMESTAMP(0) at time zone 'utc'" local res, err = self:query(query) if err then @@ -303,7 +302,7 @@ function PostgresDB:clear_expired_ttl() return true end -function PostgresDB:insert(table_name, schema, model, _, options) +function _M:insert(table_name, schema, model, _, options) local values, err = self:serialize_timestamps(model, schema) if err then return nil, err @@ -340,7 +339,7 @@ function PostgresDB:insert(table_name, schema, model, _, options) end end -function PostgresDB:find(table_name, schema, primary_keys) +function _M:find(table_name, schema, primary_keys) local where = get_where(primary_keys) local query = self:get_select_query(get_select_fields(schema), schema, table_name, where) local rows, err = self:query(query, schema) @@ -351,7 +350,7 @@ function PostgresDB:find(table_name, schema, primary_keys) end end -function PostgresDB:find_all(table_name, tbl, schema) +function _M:find_all(table_name, tbl, schema) local where if tbl ~= nil then where = get_where(tbl) @@ -361,7 +360,7 @@ function PostgresDB:find_all(table_name, tbl, schema) return self:query(query, schema) end -function PostgresDB:find_page(table_name, tbl, page, page_size, schema) +function _M:find_page(table_name, tbl, page, page_size, schema) if page == nil then page = 1 end @@ -389,7 +388,7 @@ function PostgresDB:find_page(table_name, tbl, page, page_size, schema) return rows, nil, (next_page <= total_pages and next_page or nil) end -function PostgresDB:count(table_name, tbl, schema) +function _M:count(table_name, tbl, schema) local where if tbl ~= nil then where = get_where(tbl) @@ -404,7 +403,7 @@ function PostgresDB:count(table_name, tbl, schema) end end -function PostgresDB:update(table_name, schema, _, filter_keys, values, nils, full, _, options) +function _M:update(table_name, schema, _, filter_keys, values, nils, full, _, options) local args = {} local values, err = self:serialize_timestamps(values, schema) if err then @@ -446,7 +445,7 @@ function PostgresDB:update(table_name, schema, _, filter_keys, values, nils, ful end end -function PostgresDB:delete(table_name, schema, primary_keys) +function _M:delete(table_name, schema, primary_keys) local where = get_where(primary_keys) local query = string.format("DELETE FROM %s WHERE %s RETURNING *", table_name, where) @@ -462,7 +461,7 @@ end -- Migrations -function PostgresDB:queries(queries) +function _M:queries(queries) if utils.strip(queries) ~= "" then local err = select(2, self:query(queries)) if err then @@ -471,15 +470,15 @@ function PostgresDB:queries(queries) end end -function PostgresDB:drop_table(table_name) +function _M:drop_table(table_name) return select(2, self:query("DROP TABLE "..table_name.." CASCADE")) end -function PostgresDB:truncate_table(table_name) +function _M:truncate_table(table_name) return select(2, self:query("TRUNCATE "..table_name.." CASCADE")) end -function PostgresDB:current_migrations() +function _M:current_migrations() -- Check if schema_migrations table exists local rows, err = self:query "SELECT to_regclass('schema_migrations')" if err then @@ -493,7 +492,7 @@ function PostgresDB:current_migrations() end end -function PostgresDB:record_migration(id, name) +function _M:record_migration(id, name) return select(2, self:query { [[ CREATE OR REPLACE FUNCTION upsert_schema_migrations(identifier text, migration_name varchar) RETURNS VOID AS $$ @@ -510,4 +509,4 @@ function PostgresDB:record_migration(id, name) }) end -return PostgresDB +return _M diff --git a/kong/dao/errors.lua b/kong/dao/errors.lua index 69b9c5fe09f4..bb232472b5bf 100644 --- a/kong/dao/errors.lua +++ b/kong/dao/errors.lua @@ -8,8 +8,8 @@ local fmt = string.format local error_mt = {} function error_mt.__tostring(t) - if t.db_type then - return fmt("[%s error] %s", t.db_type, tostring(t.message)) + if t.db_name then + return fmt("[%s error] %s", t.db_name, tostring(t.message)) end return tostring(t.message) diff --git a/kong/dao/factory.lua b/kong/dao/factory.lua index bb39eb2b7fc8..9b152d49155b 100644 --- a/kong/dao/factory.lua +++ b/kong/dao/factory.lua @@ -1,25 +1,24 @@ local DAO = require "kong.dao.dao" local utils = require "kong.tools.utils" -local Object = require "kong.vendor.classic" local ModelFactory = require "kong.dao.model_factory" local CORE_MODELS = {"apis", "consumers", "plugins", "nodes"} local _db -- returns db errors as strings, including the initial `nil` -local function ret_error_string(db_type, res, err) - res, err = DAO.ret_error(db_type, res, err) +local function ret_error_string(db_name, res, err) + res, err = DAO.ret_error(db_name, res, err) return res, tostring(err) end -local Factory = Object:extend() +local _M = {} -function Factory:__index(key) +function _M:__index(key) local daos = rawget(self, "daos") if daos and daos[key] then return daos[key] else - return Factory[key] + return _M[key] end end @@ -78,25 +77,30 @@ local function load_daos(self, schemas, constraints, events_handler) end end -function Factory:new(kong_config, events_handler) - self.db_type = kong_config.database - self.daos = {} - self.kong_config = kong_config - self.plugin_names = kong_config.plugins or {} +function _M.new(kong_config, events_handler) + local factory = { + db_type = kong_config.database, + daos = {}, + kong_config = kong_config, + plugin_names = kong_config.plugins or {} + } - local schemas = {} - local DB = require("kong.dao."..self.db_type.."_db") - _db = DB(kong_config) + local DB = require("kong.dao.db."..factory.db_type) + local db, err = DB.new(kong_config) + if not db then return ret_error_string(factory.db_type, nil, err) end + + _db = db -- avoid setting a previous upvalue to `nil` in case `DB.new()` fails + local schemas = {} for _, m_name in ipairs(CORE_MODELS) do schemas[m_name] = require("kong.dao.schemas."..m_name) end - for plugin_name in pairs(self.plugin_names) do - local has_dao, plugin_daos = utils.load_module_if_exists("kong.plugins."..plugin_name..".dao."..self.db_type) + for plugin_name in pairs(factory.plugin_names) do + local has_dao, plugin_daos = utils.load_module_if_exists("kong.plugins."..plugin_name..".dao."..factory.db_type) if has_dao then for k, v in pairs(plugin_daos) do - self.daos[k] = v(kong_config) + factory.daos[k] = v(kong_config) end end @@ -110,20 +114,22 @@ function Factory:new(kong_config, events_handler) local constraints = build_constraints(schemas) - load_daos(self, schemas, constraints, events_handler) + load_daos(factory, schemas, constraints, events_handler) + + return setmetatable(factory, _M) end -function Factory:init() +function _M:init() return _db:init() end -- Migrations -function Factory:infos() +function _M:infos() return _db:infos() end -function Factory:drop_schema() +function _M:drop_schema() for _, dao in pairs(self.daos) do _db:drop_table(dao.table) end @@ -137,11 +143,11 @@ function Factory:drop_schema() _db:drop_table("schema_migrations") end -function Factory:truncate_table(dao_name) +function _M:truncate_table(dao_name) _db:truncate_table(self.daos[dao_name].table) end -function Factory:truncate_tables() +function _M:truncate_tables() for _, dao in pairs(self.daos) do _db:truncate_table(dao.table) end @@ -153,7 +159,7 @@ function Factory:truncate_tables() end end -function Factory:migrations_modules() +function _M:migrations_modules() local migrations = { core = require("kong.dao.migrations."..self.db_type) } @@ -168,11 +174,9 @@ function Factory:migrations_modules() return migrations end -function Factory:current_migrations() +function _M:current_migrations() local rows, err = _db:current_migrations() - if err then - return nil, err - end + if err then return ret_error_string(_db.name, nil, err) end local cur_migrations = {} for _, row in ipairs(rows) do @@ -218,7 +222,7 @@ local function migrate(self, identifier, migrations_modules, cur_migrations, on_ if on_success then on_success(identifier, migration.name, _db:infos()) end - end +end return true, nil, #to_run end @@ -235,7 +239,7 @@ local function default_on_success(identifier, migration_name, db_infos) identifier, migration_name) end -function Factory:run_migrations(on_migrate, on_success) +function _M:run_migrations(on_migrate, on_success) on_migrate = on_migrate or default_on_migrate on_success = on_success or default_on_success @@ -245,15 +249,15 @@ function Factory:run_migrations(on_migrate, on_success) local migrations_modules = self:migrations_modules() local cur_migrations, err = self:current_migrations() - if err then return ret_error_string(_db.db_type, nil, err) end + if err then return ret_error_string(_db.name, nil, err) end local ok, err, migrations_ran = migrate(self, "core", migrations_modules, cur_migrations, on_migrate, on_success) - if not ok then return ret_error_string(_db.db_type, nil, err) end + if not ok then return ret_error_string(_db.name, nil, err) end for identifier in pairs(migrations_modules) do if identifier ~= "core" then local ok, err, n_ran = migrate(self, identifier, migrations_modules, cur_migrations, on_migrate, on_success) - if not ok then return ret_error_string(_db.db_type, nil, err) + if not ok then return ret_error_string(_db.name, nil, err) else migrations_ran = migrations_ran + n_ran end @@ -269,4 +273,4 @@ function Factory:run_migrations(on_migrate, on_success) return true end -return Factory +return _M diff --git a/kong/kong.lua b/kong/kong.lua index bc862c89b8b7..d52336180015 100644 --- a/kong/kong.lua +++ b/kong/kong.lua @@ -120,7 +120,7 @@ function Kong.init() local config = assert(conf_loader(conf_path)) local events = Events() -- retrieve node plugins - local dao = DAOFactory(config, events) -- instanciate long-lived DAO + local dao = assert(DAOFactory.new(config, events)) -- instanciate long-lived DAO assert(dao:run_migrations()) -- migrating in case embedded in custom nginx -- populate singletons @@ -142,7 +142,10 @@ function Kong.init_worker() core.init_worker.before() - singletons.dao:init() -- Executes any initialization by the DB + local ok, err = singletons.dao:init() -- Executes any initialization by the DB + if not ok then + ngx.log(ngx.ERR, err) + end for _, plugin in ipairs(singletons.loaded_plugins) do plugin.handler:init_worker() diff --git a/kong/templates/nginx_kong.lua b/kong/templates/nginx_kong.lua index cea3da821613..d34bb4ce0a70 100644 --- a/kong/templates/nginx_kong.lua +++ b/kong/templates/nginx_kong.lua @@ -39,8 +39,7 @@ lua_shared_dict reports_locks 100k; lua_shared_dict cluster_locks 100k; lua_shared_dict cluster_autojoin_locks 100k; lua_shared_dict cache_locks 100k; -lua_shared_dict cassandra 1m; -lua_shared_dict cassandra_prepared 5m; +lua_shared_dict cassandra 5m; lua_socket_log_errors off; > if lua_ssl_trusted_certificate then lua_ssl_trusted_certificate '${{LUA_SSL_TRUSTED_CERTIFICATE}}'; diff --git a/spec/01-unit/01-rockspec_meta_spec.lua b/spec/01-unit/01-rockspec_meta_spec.lua index fd4de0b497fe..6e753fa82e02 100644 --- a/spec/01-unit/01-rockspec_meta_spec.lua +++ b/spec/01-unit/01-rockspec_meta_spec.lua @@ -67,7 +67,7 @@ describe("rockspec/meta", function() it("all modules named as their path", function() for mod_name, mod_path in pairs(rock.build.modules) do if mod_name ~= "kong" then - mod_path = mod_path:gsub("%.lua", ""):gsub("/", '.') + mod_path = mod_path:gsub("%.lua", ""):gsub("/", '.'):gsub("%.init", "") assert(mod_name == mod_path, mod_path.." has different name ("..mod_name..")") end end diff --git a/spec/01-unit/13-db/01-cassandra_spec.lua b/spec/01-unit/13-db/01-cassandra_spec.lua new file mode 100644 index 000000000000..5f1bceff1d6c --- /dev/null +++ b/spec/01-unit/13-db/01-cassandra_spec.lua @@ -0,0 +1,80 @@ +local cassandra_db = require "kong.dao.db.cassandra" + +describe("cassandra_db", function() + describe("extract_major()", function() + it("extract major version digit", function() + assert.equal("3", cassandra_db.extract_major("3.7")) + assert.equal("3", cassandra_db.extract_major("3.7.12")) + assert.equal("2", cassandra_db.extract_major("2.1.14")) + assert.equal("2", cassandra_db.extract_major("2.10")) + assert.equal("10", cassandra_db.extract_major("10.0")) + end) + end) + + describe("cluster_release_version()", function() + it("extracts major release_version from available peers", function() + local release_version = assert(cassandra_db.cluster_release_version { + { + host = "127.0.0.1", + release_version = "3.7", + }, + { + host = "127.0.0.2", + release_version = "3.7", + }, + { + host = "127.0.0.3", + release_version = "3.1.2", + } + }) + assert.equal(3, release_version) + + local release_version = assert(cassandra_db.cluster_release_version { + { + host = "127.0.0.1", + release_version = "2.14", + }, + { + host = "127.0.0.2", + release_version = "2.11.14", + }, + { + host = "127.0.0.3", + release_version = "2.2.4", + } + }) + assert.equal(2, release_version) + end) + it("errors with different major versions", function() + local release_version, err = cassandra_db.cluster_release_version { + { + host = "127.0.0.1", + release_version = "3.7", + }, + { + host = "127.0.0.2", + release_version = "3.7", + }, + { + host = "127.0.0.3", + release_version = "2.11.14", + } + } + assert.is_nil(release_version) + assert.equal("different major versions detected (only all of 2.x or 3.x supported): 127.0.0.1 (3.7) 127.0.0.2 (3.7) 127.0.0.3 (2.11.14)", err) + end) + it("errors if a peer is missing release_version", function() + local release_version, err = cassandra_db.cluster_release_version { + { + host = "127.0.0.1", + release_version = "3.7", + }, + { + host = "127.0.0.2" + } + } + assert.is_nil(release_version) + assert.equal("no release_version for peer 127.0.0.2", err) + end) + end) +end) diff --git a/spec/02-integration/02-dao/01-factory_spec.lua b/spec/02-integration/02-dao/01-factory_spec.lua index 37efbcdda95d..093b5ff713cb 100644 --- a/spec/02-integration/02-dao/01-factory_spec.lua +++ b/spec/02-integration/02-dao/01-factory_spec.lua @@ -2,19 +2,18 @@ local helpers = require "spec.02-integration.02-dao.helpers" local Factory = require "kong.dao.factory" helpers.for_each_dao(function(kong_conf) - describe("Model Factory with DB: #"..kong_conf.database, function() + describe("DAO Factory with DB: #"..kong_conf.database, function() it("should be instanciable", function() local factory assert.has_no_errors(function() - factory = Factory(kong_conf) + factory = assert(Factory.new(kong_conf)) end) - assert.True(factory:is(Factory)) assert.is_table(factory.daos) assert.equal(kong_conf.database, factory.db_type) end) it("should have shorthands to access the underlying daos", function() - local factory = Factory(kong_conf) + local factory = assert(Factory.new(kong_conf)) assert.equal(factory.daos.apis, factory.apis) assert.equal(factory.daos.consumers, factory.consumers) assert.equal(factory.daos.plugins, factory.plugins) diff --git a/spec/02-integration/02-dao/02-migrations_spec.lua b/spec/02-integration/02-dao/02-migrations_spec.lua index 8a366f3519f8..c48734a78716 100644 --- a/spec/02-integration/02-dao/02-migrations_spec.lua +++ b/spec/02-integration/02-dao/02-migrations_spec.lua @@ -7,11 +7,11 @@ helpers.for_each_dao(function(kong_config) describe("Model migrations with DB: #"..kong_config.database, function() local factory setup(function() - local f = Factory(kong_config) + local f = assert(Factory.new(kong_config)) f:drop_schema() end) before_each(function() - factory = Factory(kong_config) + factory = assert(Factory.new(kong_config)) end) describe("current_migrations()", function() @@ -25,7 +25,7 @@ helpers.for_each_dao(function(kong_config) local invalid_conf = utils.shallow_copy(kong_config) invalid_conf.cassandra_keyspace = "_inexistent_" - local xfactory = Factory(invalid_conf) + local xfactory = assert(Factory.new(invalid_conf)) local cur_migrations, err = xfactory:current_migrations() assert.is_nil(err) assert.same({}, cur_migrations) @@ -110,11 +110,10 @@ helpers.for_each_dao(function(kong_config) kong_config.cassandra_port = 3333 kong_config.cassandra_timeout = 1000 - local fact = Factory(kong_config) - - local apis, err = fact:run_migrations() - assert.matches("["..kong_config.database.." error]", err, nil, true) - assert.is_nil(apis) + assert.error_matches(function() + local fact = assert(Factory.new(kong_config)) + assert(fact:run_migrations()) + end, "["..kong_config.database.." error]", nil, true) end) end) end) diff --git a/spec/02-integration/02-dao/03-crud_spec.lua b/spec/02-integration/02-dao/03-crud_spec.lua index 374085337edf..901e4c87e000 100644 --- a/spec/02-integration/02-dao/03-crud_spec.lua +++ b/spec/02-integration/02-dao/03-crud_spec.lua @@ -34,7 +34,7 @@ helpers.for_each_dao(function(kong_config) describe("Model (CRUD) with DB: #"..kong_config.database, function() local factory, apis, oauth2_credentials setup(function() - factory = Factory(kong_config) + factory = assert(Factory.new(kong_config)) apis = factory.apis -- DAO used for testing arrays @@ -706,11 +706,10 @@ helpers.for_each_dao(function(kong_config) kong_config.cassandra_port = 3333 kong_config.cassandra_timeout = 1000 - local fact = Factory(kong_config) - - local apis, err = fact.apis:find_all() - assert.matches("["..kong_config.database.." error]", err, nil, true) - assert.is_nil(apis) + assert.error_matches(function() + local fact = assert(Factory.new(kong_config)) + assert(fact.apis:find_all()) + end, "["..kong_config.database.." error]", nil, true) end) end) end) -- describe diff --git a/spec/02-integration/02-dao/04-constraints_spec.lua b/spec/02-integration/02-dao/04-constraints_spec.lua index 9180addf69a7..543aa78d1af5 100644 --- a/spec/02-integration/02-dao/04-constraints_spec.lua +++ b/spec/02-integration/02-dao/04-constraints_spec.lua @@ -19,7 +19,7 @@ helpers.for_each_dao(function(kong_config) local plugin_fixture, api_fixture local factory, apis, plugins setup(function() - factory = Factory(kong_config) + factory = assert(Factory.new(kong_config)) apis = factory.apis plugins = factory.plugins assert(factory:run_migrations()) diff --git a/spec/02-integration/02-dao/05-use_cases_spec.lua b/spec/02-integration/02-dao/05-use_cases_spec.lua index 92ace5dfb01c..69f3cf92058c 100644 --- a/spec/02-integration/02-dao/05-use_cases_spec.lua +++ b/spec/02-integration/02-dao/05-use_cases_spec.lua @@ -5,7 +5,7 @@ helpers.for_each_dao(function(kong_config) describe("Real use-cases with DB: #"..kong_config.database, function() local factory setup(function() - factory = Factory(kong_config) + factory = assert(Factory.new(kong_config)) assert(factory:run_migrations()) factory:truncate_tables() diff --git a/spec/02-integration/02-dao/06-plugins_daos_spec.lua b/spec/02-integration/02-dao/06-plugins_daos_spec.lua index dc5e88cc9f76..4c4fdddde598 100644 --- a/spec/02-integration/02-dao/06-plugins_daos_spec.lua +++ b/spec/02-integration/02-dao/06-plugins_daos_spec.lua @@ -4,7 +4,7 @@ local Factory = require "kong.dao.factory" helpers.for_each_dao(function(kong_config) describe("Plugins DAOs with DB: #"..kong_config.database, function() it("load plugins DAOs", function() - local factory = Factory(kong_config) + local factory = assert(Factory.new(kong_config)) assert.truthy(factory.keyauth_credentials) assert.truthy(factory.basicauth_credentials) assert.truthy(factory.acls) @@ -18,7 +18,7 @@ helpers.for_each_dao(function(kong_config) describe("plugins migrations", function() local factory setup(function() - factory = Factory(kong_config) + factory = assert(Factory.new(kong_config)) end) it("migrations_modules()", function() local migrations = factory:migrations_modules() @@ -35,7 +35,7 @@ helpers.for_each_dao(function(kong_config) describe("custom DBs", function() it("loads rate-limiting custom DB", function() - local factory = Factory(kong_config) + local factory = assert(Factory.new(kong_config)) assert.truthy(factory.ratelimiting_metrics) end) end) diff --git a/spec/02-integration/02-dao/07-ttl_spec.lua b/spec/02-integration/02-dao/07-ttl_spec.lua index be927423a279..f691a0fe5a1b 100644 --- a/spec/02-integration/02-dao/07-ttl_spec.lua +++ b/spec/02-integration/02-dao/07-ttl_spec.lua @@ -6,7 +6,7 @@ helpers.for_each_dao(function(kong_config) describe("TTL with #"..kong_config.database, function() local factory setup(function() - factory = Factory(kong_config) + factory = assert(Factory.new(kong_config)) assert(factory:run_migrations()) factory:truncate_tables() diff --git a/spec/helpers.lua b/spec/helpers.lua index cca24ea409e1..10aa2158ed39 100644 --- a/spec/helpers.lua +++ b/spec/helpers.lua @@ -25,7 +25,7 @@ log.set_lvl(log.levels.quiet) -- disable stdout logs in tests -- Conf and DAO --------------- local conf = assert(conf_loader(TEST_CONF_PATH)) -local dao = DAOFactory(conf) +local dao = assert(DAOFactory.new(conf)) -- make sure migrations are up-to-date --assert(dao:run_migrations()) From a0e0b005705b3b0a58598872bac9d37548d8d8fd Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Fri, 30 Sep 2016 11:22:04 +0200 Subject: [PATCH 2/7] feat(dao) support for Cassandra 3.x - new architecture for db modules, this gets rid of `classic.lua` and uses a custom base DB module `kong.dao.db.init` - support for Cassandra 3.x system querying - we detect which major Cassandra version is in use before executing queries - switch `cassandra.unset` for `cassandra.null` which has the same meaning for protocol v3 but the one we intend for protocol v4 (set an existing column to NULL) - distinguish `init()` and `init_worker()` methods for dbs (Postgres needs init_worker, Cassandra needs init context (must run before migrations to detect major version) - we do not support clusters made of nodes with different major versions (not sure if this is possible, so we should prevent it) - add exptime support in shm mock in `globalpatches` --- kong-0.9.3-0.rockspec | 10 +- kong/cmd/cluster.lua | 2 +- kong/dao/db/cassandra.lua | 54 +++++++---- kong/dao/db/init.lua | 3 +- kong/dao/db/postgres.lua | 2 +- kong/dao/errors.lua | 4 +- kong/dao/factory.lua | 1 - kong/kong.lua | 6 +- kong/plugins/rate-limiting/dao/cassandra.lua | 70 -------------- kong/plugins/rate-limiting/dao/postgres.lua | 50 ---------- .../rate-limiting/policies/cluster.lua | 92 +++++++++++++++++++ .../{policies.lua => policies/init.lua} | 21 +++-- .../response-ratelimiting/dao/cassandra.lua | 70 -------------- .../response-ratelimiting/dao/postgres.lua | 50 ---------- .../policies/cluster.lua | 91 ++++++++++++++++++ .../{policies.lua => policies/init.lua} | 29 +++--- .../01-cmd/02-start_stop_spec.lua | 2 +- .../02-dao/00-unit_error_spec.lua | 2 +- .../02-dao/06-plugins_daos_spec.lua | 7 -- spec/02-integration/02-dao/07-ttl_spec.lua | 4 +- 20 files changed, 265 insertions(+), 305 deletions(-) delete mode 100644 kong/plugins/rate-limiting/dao/cassandra.lua delete mode 100644 kong/plugins/rate-limiting/dao/postgres.lua create mode 100644 kong/plugins/rate-limiting/policies/cluster.lua rename kong/plugins/rate-limiting/{policies.lua => policies/init.lua} (85%) delete mode 100644 kong/plugins/response-ratelimiting/dao/cassandra.lua delete mode 100644 kong/plugins/response-ratelimiting/dao/postgres.lua create mode 100644 kong/plugins/response-ratelimiting/policies/cluster.lua rename kong/plugins/response-ratelimiting/{policies.lua => policies/init.lua} (81%) diff --git a/kong-0.9.3-0.rockspec b/kong-0.9.3-0.rockspec index dfc50a2f7ba3..14c6d8e302e4 100644 --- a/kong-0.9.3-0.rockspec +++ b/kong-0.9.3-0.rockspec @@ -166,9 +166,8 @@ build = { ["kong.plugins.rate-limiting.migrations.postgres"] = "kong/plugins/rate-limiting/migrations/postgres.lua", ["kong.plugins.rate-limiting.handler"] = "kong/plugins/rate-limiting/handler.lua", ["kong.plugins.rate-limiting.schema"] = "kong/plugins/rate-limiting/schema.lua", - ["kong.plugins.rate-limiting.policies"] = "kong/plugins/rate-limiting/policies.lua", - ["kong.plugins.rate-limiting.dao.cassandra"] = "kong/plugins/rate-limiting/dao/cassandra.lua", - ["kong.plugins.rate-limiting.dao.postgres"] = "kong/plugins/rate-limiting/dao/postgres.lua", + ["kong.plugins.rate-limiting.policies"] = "kong/plugins/rate-limiting/policies/init.lua", + ["kong.plugins.rate-limiting.policies.cluster"] = "kong/plugins/rate-limiting/policies/cluster.lua", ["kong.plugins.response-ratelimiting.migrations.cassandra"] = "kong/plugins/response-ratelimiting/migrations/cassandra.lua", ["kong.plugins.response-ratelimiting.migrations.postgres"] = "kong/plugins/response-ratelimiting/migrations/postgres.lua", @@ -177,9 +176,8 @@ build = { ["kong.plugins.response-ratelimiting.header_filter"] = "kong/plugins/response-ratelimiting/header_filter.lua", ["kong.plugins.response-ratelimiting.log"] = "kong/plugins/response-ratelimiting/log.lua", ["kong.plugins.response-ratelimiting.schema"] = "kong/plugins/response-ratelimiting/schema.lua", - ["kong.plugins.response-ratelimiting.policies"] = "kong/plugins/response-ratelimiting/policies.lua", - ["kong.plugins.response-ratelimiting.dao.cassandra"] = "kong/plugins/response-ratelimiting/dao/cassandra.lua", - ["kong.plugins.response-ratelimiting.dao.postgres"] = "kong/plugins/response-ratelimiting/dao/postgres.lua", + ["kong.plugins.response-ratelimiting.policies"] = "kong/plugins/response-ratelimiting/policies/init.lua", + ["kong.plugins.response-ratelimiting.policies.cluster"] = "kong/plugins/response-ratelimiting/policies/cluster.lua", ["kong.plugins.request-size-limiting.handler"] = "kong/plugins/request-size-limiting/handler.lua", ["kong.plugins.request-size-limiting.schema"] = "kong/plugins/request-size-limiting/schema.lua", diff --git a/kong/cmd/cluster.lua b/kong/cmd/cluster.lua index cb7e05649003..ca04ebd14bf9 100644 --- a/kong/cmd/cluster.lua +++ b/kong/cmd/cluster.lua @@ -7,7 +7,7 @@ local conf_loader = require "kong.conf_loader" local function execute(args) if args.command == "keygen" then local conf = assert(conf_loader(args.conf)) - local dao = assert(DAOFactory(conf)) + local dao = assert(DAOFactory.new(conf)) local serf = Serf.new(conf, dao) print(assert(serf:keygen())) return diff --git a/kong/dao/db/cassandra.lua b/kong/dao/db/cassandra.lua index fad239a8295c..a090b8a41e68 100644 --- a/kong/dao/db/cassandra.lua +++ b/kong/dao/db/cassandra.lua @@ -8,6 +8,10 @@ local uuid = utils.uuid local _M = require("kong.dao.db").new_db("cassandra") +-- expose cassandra binding serializers +-- ex: cassandra.uuid('') +_M.cassandra = cassandra + _M.dao_insert_values = { id = function() return uuid() @@ -123,7 +127,7 @@ end local function serialize_arg(field, value) if value == nil then - return cassandra.unset + return cassandra.null elseif field.type == "id" then return cassandra.uuid(value) elseif field.type == "timestamp" then @@ -485,30 +489,42 @@ function _M:truncate_table(table_name) end function _M:current_migrations() - -- Check if keyspace exists - local rows, err = self:query([[ - SELECT * FROM system.schema_keyspaces WHERE keyspace_name = ? - ]], {self.cluster_options.keyspace}, nil, nil, true) - if err then - return nil, err - elseif #rows == 0 then - return {} + local q_keyspace_exists, q_migrations_table_exists + + assert(self.release_version, "release_version not set for Cassandra cluster") + + if self.release_version == 3 then + q_keyspace_exists = "SELECT * FROM system_schema.keyspaces WHERE keyspace_name = ?" + q_migrations_table_exists = [[ + SELECT COUNT(*) FROM system_schema.tables + WHERE keyspace_name = ? AND table_name = ? + ]] + else + q_keyspace_exists = "SELECT * FROM system.schema_keyspaces WHERE keyspace_name = ?" + q_migrations_table_exists = [[ + SELECT COUNT(*) FROM system.schema_columnfamilies + WHERE keyspace_name = ? AND columnfamily_name = ? + ]] end - -- Check if schema_migrations table exists first - rows, err = self:query([[ - SELECT COUNT(*) FROM system.schema_columnfamilies - WHERE keyspace_name = ? AND columnfamily_name = ? - ]], { + -- Check if keyspace exists + local rows, err = self:query(q_keyspace_exists, { + self.cluster_options.keyspace + }, {prepared = false}, nil, true) + if err then return nil, err + elseif #rows == 0 then return {} end + + -- Check if schema_migrations table exists + rows, err = self:query(q_migrations_table_exists, { self.cluster_options.keyspace, "schema_migrations" - }) - if err then - return nil, err - end + }, {prepared = false}) + if err then return nil, err end if rows[1].count > 0 then - return self:query "SELECT * FROM schema_migrations" + return self:query("SELECT * FROM schema_migrations", nil, { + prepared = false + }) else return {} end diff --git a/kong/dao/db/init.lua b/kong/dao/db/init.lua index c4220ad38237..2bd172fbb5a0 100644 --- a/kong/dao/db/init.lua +++ b/kong/dao/db/init.lua @@ -5,7 +5,8 @@ local _M = {} function _M.new_db(name) local db_mt = { name = name, - init = function() error('init() not implemented') end, + init = function() return true end, + init_worker = function() return true end, infos = function() error('infos() not implemented') end, query = function() error('query() not implemented') end, insert = function() error('insert() not implemented') end, diff --git a/kong/dao/db/postgres.lua b/kong/dao/db/postgres.lua index ed3e63b5a64c..434f40ce4ec2 100644 --- a/kong/dao/db/postgres.lua +++ b/kong/dao/db/postgres.lua @@ -57,7 +57,7 @@ function _M:start_ttl_timer() end end -function _M:init() +function _M:init_worker() self:start_ttl_timer() end diff --git a/kong/dao/errors.lua b/kong/dao/errors.lua index bb232472b5bf..18359764bb92 100644 --- a/kong/dao/errors.lua +++ b/kong/dao/errors.lua @@ -48,7 +48,7 @@ local serializers = { } local function build_error(err_type) - return function(err, db_type) + return function(err, db_name) if err == nil then return elseif getmetatable(err) == error_mt then @@ -56,7 +56,7 @@ local function build_error(err_type) end local err_obj = { - db_type = db_type, + db_name = db_name, [err_type] = true } diff --git a/kong/dao/factory.lua b/kong/dao/factory.lua index 9b152d49155b..610006da3f44 100644 --- a/kong/dao/factory.lua +++ b/kong/dao/factory.lua @@ -113,7 +113,6 @@ function _M.new(kong_config, events_handler) end local constraints = build_constraints(schemas) - load_daos(factory, schemas, constraints, events_handler) return setmetatable(factory, _M) diff --git a/kong/kong.lua b/kong/kong.lua index d52336180015..09587640bdb7 100644 --- a/kong/kong.lua +++ b/kong/kong.lua @@ -121,6 +121,7 @@ function Kong.init() local events = Events() -- retrieve node plugins local dao = assert(DAOFactory.new(config, events)) -- instanciate long-lived DAO + assert(dao:init()) assert(dao:run_migrations()) -- migrating in case embedded in custom nginx -- populate singletons @@ -142,10 +143,7 @@ function Kong.init_worker() core.init_worker.before() - local ok, err = singletons.dao:init() -- Executes any initialization by the DB - if not ok then - ngx.log(ngx.ERR, err) - end + singletons.dao:init_worker() for _, plugin in ipairs(singletons.loaded_plugins) do plugin.handler:init_worker() diff --git a/kong/plugins/rate-limiting/dao/cassandra.lua b/kong/plugins/rate-limiting/dao/cassandra.lua deleted file mode 100644 index 8660ea727a1f..000000000000 --- a/kong/plugins/rate-limiting/dao/cassandra.lua +++ /dev/null @@ -1,70 +0,0 @@ -local CassandraDB = require "kong.dao.cassandra_db" -local cassandra = require "cassandra" -local timestamp = require "kong.tools.timestamp" - -local _M = CassandraDB:extend() - -_M.table = "ratelimiting_metrics" -_M.schema = require("kong.plugins.response-ratelimiting.schema") - -function _M:increment(api_id, identifier, current_timestamp, value) - local periods = timestamp.get_timestamps(current_timestamp) - local options = self:_get_conn_options() - local session, err = cassandra.spawn_session(options) - if err then - ngx.log(ngx.ERR, "[rate-limiting] could not spawn session to Cassandra: "..tostring(err)) - return nil, err - end - - local ok = true - for period, period_date in pairs(periods) do - local res, err = session:execute([[ - UPDATE ratelimiting_metrics SET value = value + ? WHERE - api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? - ]], { - cassandra.counter(value), - cassandra.uuid(api_id), - identifier, - cassandra.timestamp(period_date), - period - }) - if not res then - ok = false - ngx.log(ngx.ERR, "[rate-limiting] could not increment counter for period '"..period.."': "..tostring(err)) - end - end - - session:set_keep_alive() - - return ok -end - -function _M:find(api_id, identifier, current_timestamp, period) - local periods = timestamp.get_timestamps(current_timestamp) - local rows, err = self:query([[ - SELECT * FROM ratelimiting_metrics WHERE - api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? - ]], { - cassandra.uuid(api_id), - identifier, - cassandra.timestamp(periods[period]), - period - }) - if err then - return nil, err - elseif #rows > 0 then - return rows[1] - end -end - -function _M:count() - return _M.super.count(self, _M.table, nil, _M.schema) -end - -return {ratelimiting_metrics = _M} diff --git a/kong/plugins/rate-limiting/dao/postgres.lua b/kong/plugins/rate-limiting/dao/postgres.lua deleted file mode 100644 index ad6d3c95e243..000000000000 --- a/kong/plugins/rate-limiting/dao/postgres.lua +++ /dev/null @@ -1,50 +0,0 @@ -local PostgresDB = require "kong.dao.postgres_db" -local timestamp = require "kong.tools.timestamp" -local fmt = string.format -local concat = table.concat - -local _M = PostgresDB:extend() - -_M.table = "ratelimiting_metrics" -_M.schema = require("kong.plugins.response-ratelimiting.schema") - -function _M:increment(api_id, identifier, current_timestamp, value) - local buf = {} - local periods = timestamp.get_timestamps(current_timestamp) - for period, period_date in pairs(periods) do - buf[#buf + 1] = fmt("SELECT increment_rate_limits('%s', '%s', '%s', to_timestamp('%s') at time zone 'UTC', %d)", - api_id, identifier, period, period_date/1000, value) - end - - local queries = concat(buf, ";") - - local res, err = self:query(queries) - if not res then - return false, err - end - return true -end - -function _M:find(api_id, identifier, current_timestamp, period) - local periods = timestamp.get_timestamps(current_timestamp) - - local q = fmt([[SELECT *, extract(epoch from period_date)*1000 AS period_date FROM ratelimiting_metrics WHERE - api_id = '%s' AND - identifier = '%s' AND - period_date = to_timestamp('%s') at time zone 'UTC' AND - period = '%s' - ]], api_id, identifier, periods[period]/1000, period) - - local res, err = self:query(q) - if not res or err then - return nil, err - end - - return res[1] -end - -function _M:count() - return _M.super.count(self, _M.table, nil, _M.schema) -end - -return {ratelimiting_metrics = _M} diff --git a/kong/plugins/rate-limiting/policies/cluster.lua b/kong/plugins/rate-limiting/policies/cluster.lua new file mode 100644 index 000000000000..282888d10640 --- /dev/null +++ b/kong/plugins/rate-limiting/policies/cluster.lua @@ -0,0 +1,92 @@ +local timestamp = require "kong.tools.timestamp" + +local concat = table.concat +local pairs = pairs +local fmt = string.format +local log = ngx.log +local ERR = ngx.ERR + +return { + ["cassandra"] = { + increment = function(db, api_id, identifier, current_timestamp, value) + local periods = timestamp.get_timestamps(current_timestamp) + + for period, period_date in pairs(periods) do + local res, err = db:query([[ + UPDATE ratelimiting_metrics + SET value = value + ? + WHERE api_id = ? AND + identifier = ? AND + period_date = ? AND + period = ? + ]], { + db.cassandra.counter(value), + db.cassandra.uuid(api_id), + identifier, + db.cassandra.timestamp(period_date), + period, + }) + if not res then + log(ERR, "[rate-limiting] cluster policy: could not increment ", + "cassandra counter for period '", period, "': ", err) + end + end + + return true + end, + find = function(db, api_id, identifier, current_timestamp, period) + local periods = timestamp.get_timestamps(current_timestamp) + + local rows, err = db:query([[ + SELECT * + FROM ratelimiting_metrics + WHERE api_id = ? AND + identifier = ? AND + period_date = ? AND + period = ? + ]], { + db.cassandra.uuid(api_id), + identifier, + db.cassandra.timestamp(periods[period]), + period, + }) + if not rows then return nil, err + elseif #rows > 0 then return rows[1] end + end, + }, + ["postgres"] = { + increment = function(db, api_id, identifier, current_timestamp, value) + local buf = {} + local periods = timestamp.get_timestamps(current_timestamp) + + for period, period_date in pairs(periods) do + buf[#buf+1] = fmt([[ + SELECT increment_rate_limits('%s', '%s', '%s', to_timestamp('%s') + at time zone 'UTC', %d) + ]], api_id, identifier, period, period_date/1000, value) + end + + local res, err = db:query(concat(buf, ";")) + if not res then return nil, err end + + return true + end, + find = function(db, api_id, identifier, current_timestamp, period) + local periods = timestamp.get_timestamps(current_timestamp) + + local q = fmt([[ + SELECT *, extract(epoch from period_date)*1000 AS period_date + FROM ratelimiting_metrics + WHERE api_id = '%s' AND + identifier = '%s' AND + period_date = to_timestamp('%s') at time zone 'UTC' AND + period = '%s' + ]], api_id, identifier, periods[period]/1000, period) + + local res, err = db:query(q) + if not res or err then return nil, err end + + return res[1] + end, + } +} diff --git a/kong/plugins/rate-limiting/policies.lua b/kong/plugins/rate-limiting/policies/init.lua similarity index 85% rename from kong/plugins/rate-limiting/policies.lua rename to kong/plugins/rate-limiting/policies/init.lua index 22b8f56530af..89d336e8f8cb 100644 --- a/kong/plugins/rate-limiting/policies.lua +++ b/kong/plugins/rate-limiting/policies/init.lua @@ -2,6 +2,7 @@ local singletons = require "kong.singletons" local timestamp = require "kong.tools.timestamp" local cache = require "kong.tools.database_cache" local redis = require "resty.redis" +local policy_cluster = require "kong.plugins.rate-limiting.policies.cluster" local ngx_log = ngx.log local pairs = pairs @@ -48,17 +49,21 @@ return { }, ["cluster"] = { increment = function(conf, api_id, identifier, current_timestamp, value) - local _, stmt_err = singletons.dao.ratelimiting_metrics:increment(api_id, identifier, current_timestamp, value) - if stmt_err then - ngx_log(ngx.ERR, "failed to increment: ", tostring(stmt_err)) + local db = singletons.dao.db + local ok, err = policy_cluster[db.name].increment(db, api_id, identifier, + current_timestamp, value) + if not ok then + ngx_log(ngx.ERR, "[rate-limiting] cluster policy: could not increment ", + db.name, " counter: ", err) end end, usage = function(conf, api_id, identifier, current_timestamp, name) - local current_metric, err = singletons.dao.ratelimiting_metrics:find(api_id, identifier, current_timestamp, name) - if err then - return nil, err - end - return current_metric and current_metric.value or 0 + local db = singletons.dao.db + local rows, err = policy_cluster[db.name].find(db, api_id, identifier, + current_timestamp, name) + if not rows then return nil, err end + + return rows and rows.value or 0 end }, ["redis"] = { diff --git a/kong/plugins/response-ratelimiting/dao/cassandra.lua b/kong/plugins/response-ratelimiting/dao/cassandra.lua deleted file mode 100644 index cb5568b0f512..000000000000 --- a/kong/plugins/response-ratelimiting/dao/cassandra.lua +++ /dev/null @@ -1,70 +0,0 @@ -local CassandraDB = require "kong.dao.cassandra_db" -local cassandra = require "cassandra" -local timestamp = require "kong.tools.timestamp" - -local _M = CassandraDB:extend() - -_M.table = "response_ratelimiting_metrics" -_M.schema = require("kong.plugins.response-ratelimiting.schema") - -function _M:increment(api_id, identifier, current_timestamp, value, name) - local periods = timestamp.get_timestamps(current_timestamp) - local options = self:_get_conn_options() - local session, err = cassandra.spawn_session(options) - if err then - ngx.log(ngx.ERR, "[response-ratelimiting] could not spawn session to Cassandra: "..tostring(err)) - return nil, err - end - - local ok = true - for period, period_date in pairs(periods) do - local res, err = session:execute([[ - UPDATE response_ratelimiting_metrics SET value = value + ? WHERE - api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? - ]], { - cassandra.counter(value), - cassandra.uuid(api_id), - identifier, - cassandra.timestamp(period_date), - name.."_"..period - }) - if not res then - ok = false - ngx.log(ngx.ERR, "[response-ratelimiting] could not increment counter for period '"..period.."': "..tostring(err)) - end - end - - session:set_keep_alive() - - return ok -end - -function _M:find(api_id, identifier, current_timestamp, period, name) - local periods = timestamp.get_timestamps(current_timestamp) - local rows, err = self:query([[ - SELECT * FROM response_ratelimiting_metrics WHERE - api_id = ? AND - identifier = ? AND - period_date = ? AND - period = ? - ]], { - cassandra.uuid(api_id), - identifier, - cassandra.timestamp(periods[period]), - name.."_"..period - }) - if err then - return nil, err - elseif #rows > 0 then - return rows[1] - end -end - -function _M:count() - return _M.super.count(self, _M.table, nil, _M.schema) -end - -return {response_ratelimiting_metrics = _M} diff --git a/kong/plugins/response-ratelimiting/dao/postgres.lua b/kong/plugins/response-ratelimiting/dao/postgres.lua deleted file mode 100644 index e532e54c070b..000000000000 --- a/kong/plugins/response-ratelimiting/dao/postgres.lua +++ /dev/null @@ -1,50 +0,0 @@ -local PostgresDB = require "kong.dao.postgres_db" -local timestamp = require "kong.tools.timestamp" -local fmt = string.format -local concat = table.concat - -local _M = PostgresDB:extend() - -_M.table = "response_ratelimiting_metrics" -_M.schema = require("kong.plugins.response-ratelimiting.schema") - -function _M:increment(api_id, identifier, current_timestamp, value, name) - local buf = {} - local periods = timestamp.get_timestamps(current_timestamp) - for period, period_date in pairs(periods) do - buf[#buf + 1] = fmt("SELECT increment_response_rate_limits('%s', '%s', '%s', to_timestamp('%s') at time zone 'UTC', %d)", - api_id, identifier, name.."_"..period, period_date/1000, value) - end - - local queries = concat(buf, ";") - - local res, err = self:query(queries) - if not res then - return false, err - end - return true -end - -function _M:find(api_id, identifier, current_timestamp, period, name) - local periods = timestamp.get_timestamps(current_timestamp) - - local q = fmt([[SELECT *, extract(epoch from period_date)*1000 AS period_date FROM response_ratelimiting_metrics WHERE - api_id = '%s' AND - identifier = '%s' AND - period_date = to_timestamp('%s') at time zone 'UTC' AND - period = '%s' - ]], api_id, identifier, periods[period]/1000, name.."_"..period) - - local res, err = self:query(q) - if not res or err then - return nil, err - end - - return res[1] -end - -function _M:count() - return _M.super.count(self, _M.table, nil, _M.schema) -end - -return {response_ratelimiting_metrics = _M} diff --git a/kong/plugins/response-ratelimiting/policies/cluster.lua b/kong/plugins/response-ratelimiting/policies/cluster.lua new file mode 100644 index 000000000000..33c2d2b4a9bb --- /dev/null +++ b/kong/plugins/response-ratelimiting/policies/cluster.lua @@ -0,0 +1,91 @@ +local timestamp = require "kong.tools.timestamp" + +local concat = table.concat +local pairs = pairs +local fmt = string.format +local log = ngx.log +local ERR = ngx.ERR + +return { + ["cassandra"] = { + increment = function(db, api_id, identifier, current_timestamp, value, name) + local periods = timestamp.get_timestamps(current_timestamp) + + for period, period_date in pairs(periods) do + local res, err = db:query([[ + UPDATE response_ratelimiting_metrics + SET value = value + ? + WHERE api_id = ? AND + identifier = ? AND + period_date = ? AND + period = ? + ]], { + db.cassandra.counter(value), + db.cassandra.uuid(api_id), + identifier, + db.cassandra.timestamp(period_date), + name.."_"..period, + }) + if not res then + log(ERR, "[response-ratelimiting] cluster policy: could not increment ", + "cassandra counter for period '", period, "': ", err) + end + end + + return true + end, + find = function(db, api_id, identifier, current_timestamp, period, name) + local periods = timestamp.get_timestamps(current_timestamp) + + local rows, err = db:query([[ + SELECT * FROM response_ratelimiting_metrics + WHERE api_id = ? AND + identifier = ? AND + period_date = ? AND + period = ? + ]], { + db.cassandra.uuid(api_id), + identifier, + db.cassandra.timestamp(periods[period]), + name.."_"..period, + }) + if not rows then return nil, err + elseif #rows > 0 then return rows[1] end + end, + }, + ["postgres"] = { + increment = function(db, api_id, identifier, current_timestamp, value, name) + local buf = {} + local periods = timestamp.get_timestamps(current_timestamp) + + for period, period_date in pairs(periods) do + buf[#buf+1] = fmt([[ + SELECT increment_response_rate_limits('%s', '%s', '%s_%s', to_timestamp('%s') + at time zone 'UTC', %d) + ]], api_id, identifier, name, period, period_date/1000, value) + end + + local res, err = db:query(concat(buf, ";")) + if not res then return nil, err end + + return true + end, + find = function(db, api_id, identifier, current_timestamp, period, name) + local periods = timestamp.get_timestamps(current_timestamp) + + local q = fmt([[ + SELECT *, extract(epoch from period_date)*1000 AS period_date + FROM response_ratelimiting_metrics + WHERE api_id = '%s' AND + identifier = '%s' AND + period_date = to_timestamp('%s') at time zone 'UTC' AND + period = '%s_%s' + ]], api_id, identifier, periods[period]/1000, name, period) + + local res, err = db:query(q) + if not res or err then return nil, err end + + return res[1] + end, + } +} diff --git a/kong/plugins/response-ratelimiting/policies.lua b/kong/plugins/response-ratelimiting/policies/init.lua similarity index 81% rename from kong/plugins/response-ratelimiting/policies.lua rename to kong/plugins/response-ratelimiting/policies/init.lua index a726d86631f2..39b6636cb08f 100644 --- a/kong/plugins/response-ratelimiting/policies.lua +++ b/kong/plugins/response-ratelimiting/policies/init.lua @@ -2,6 +2,7 @@ local singletons = require "kong.singletons" local timestamp = require "kong.tools.timestamp" local cache = require "kong.tools.database_cache" local redis = require "resty.redis" +local policy_cluster = require "kong.plugins.response-ratelimiting.policies.cluster" local ngx_log = ngx.log local pairs = pairs @@ -32,7 +33,7 @@ return { local _, err = cache.incr(cache_key, value) if err then - ngx_log("[rate-limiting] could not increment counter for period '"..period.."': "..tostring(err)) + ngx_log("[response-ratelimiting] could not increment counter for period '"..period.."': "..tostring(err)) end end end, @@ -48,17 +49,23 @@ return { }, ["cluster"] = { increment = function(conf, api_id, identifier, current_timestamp, value, name) - local _, stmt_err = singletons.dao.response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, value, name) - if stmt_err then - ngx_log(ngx.ERR, tostring(stmt_err)) + local db = singletons.dao.db + local ok, err = policy_cluster[db.name].increment(db, api_id, identifier, + current_timestamp, value, + name) + if not ok then + ngx_log(ngx.ERR, "[response-ratelimiting] cluster policy: could not increment ", + db.name, " counter: ", err) end end, usage = function(conf, api_id, identifier, current_timestamp, period, name) - local current_metric, err = singletons.dao.response_ratelimiting_metrics:find(api_id, identifier, current_timestamp, period, name) - if err then - return nil, err - end - return current_metric and current_metric.value or 0 + local db = singletons.dao.db + local rows, err = policy_cluster[db.name].find(db, api_id, identifier, + current_timestamp, period, + name) + if not rows then return nil, err end + + return rows and rows.value or 0 end }, ["redis"] = { @@ -93,7 +100,7 @@ return { if not exists or exists == 0 then red:expire(cache_key, EXPIRATIONS[period]) end - + local _, err = red:commit_pipeline() if err then ngx_log(ngx.ERR, "failed to commit pipeline in Redis: ", err) @@ -138,4 +145,4 @@ return { return current_metric and current_metric or 0 end } -} \ No newline at end of file +} diff --git a/spec/02-integration/01-cmd/02-start_stop_spec.lua b/spec/02-integration/01-cmd/02-start_stop_spec.lua index ba58712d535b..37c70662eb6b 100644 --- a/spec/02-integration/01-cmd/02-start_stop_spec.lua +++ b/spec/02-integration/01-cmd/02-start_stop_spec.lua @@ -106,7 +106,7 @@ describe("kong start/stop", function() -- it in its resolver directive. As such and until supported by resty-cli, -- we must force the use of LuaSocket in our CLI to resolve localhost. it("resolves cassandra hostname", function() - assert(helpers.kong_exec("start --conf "..helpers.test_conf_path, { + assert(helpers.kong_exec("start --vv --conf "..helpers.test_conf_path, { cassandra_contact_points = "localhost", database = "cassandra" })) diff --git a/spec/02-integration/02-dao/00-unit_error_spec.lua b/spec/02-integration/02-dao/00-unit_error_spec.lua index 9cb91dd93bda..1440461b07e4 100644 --- a/spec/02-integration/02-dao/00-unit_error_spec.lua +++ b/spec/02-integration/02-dao/00-unit_error_spec.lua @@ -8,7 +8,7 @@ describe("Errors", function() assert.True(err.unique) assert.equal("already exists with value 'foo'", err.tbl.name) assert.equal("already exists with value 'bar'", err.tbl.unique_field) - assert.equal("name=already exists with value 'foo' unique_field=already exists with value 'bar'", err.message) + assert.matches("name=already exists with value 'foo' unique_field=already exists with value 'bar'", err.message) assert.equal(err.message, tostring(err)) end) end) diff --git a/spec/02-integration/02-dao/06-plugins_daos_spec.lua b/spec/02-integration/02-dao/06-plugins_daos_spec.lua index 4c4fdddde598..0cc3c74d5655 100644 --- a/spec/02-integration/02-dao/06-plugins_daos_spec.lua +++ b/spec/02-integration/02-dao/06-plugins_daos_spec.lua @@ -32,12 +32,5 @@ helpers.for_each_dao(function(kong_config) assert.is_table(migrations["response-ratelimiting"]) end) end) - - describe("custom DBs", function() - it("loads rate-limiting custom DB", function() - local factory = assert(Factory.new(kong_config)) - assert.truthy(factory.ratelimiting_metrics) - end) - end) end) end) diff --git a/spec/02-integration/02-dao/07-ttl_spec.lua b/spec/02-integration/02-dao/07-ttl_spec.lua index f691a0fe5a1b..44953f9f8e31 100644 --- a/spec/02-integration/02-dao/07-ttl_spec.lua +++ b/spec/02-integration/02-dao/07-ttl_spec.lua @@ -68,8 +68,8 @@ helpers.for_each_dao(function(kong_config) if kong_config.database == "postgres" then it("clears old entities", function() - local DB = require "kong.dao.postgres_db" - local _db = DB(kong_config) + local DB = require "kong.dao.db.postgres" + local _db = DB.new(kong_config) for i = 1, 4 do local _, err = factory.apis:insert({ From adc77ff7a94393ded67981eade5bda4957760940 Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Mon, 3 Oct 2016 15:57:19 +0200 Subject: [PATCH 3/7] style(db) improve Cassandra DB code style - cleanup some bad patterns like return values - trying to limit col width to 80 chars - caching some global lookups --- kong/dao/db/cassandra.lua | 401 +++++++++++++++++++------------------- kong/dao/db/postgres.lua | 20 +- kong/dao/factory.lua | 8 +- 3 files changed, 219 insertions(+), 210 deletions(-) diff --git a/kong/dao/db/cassandra.lua b/kong/dao/db/cassandra.lua index a090b8a41e68..fbd51496e158 100644 --- a/kong/dao/db/cassandra.lua +++ b/kong/dao/db/cassandra.lua @@ -1,10 +1,17 @@ +local timestamp = require "kong.tools.timestamp" local cassandra = require "cassandra" local Cluster = require "resty.cassandra.cluster" -local timestamp = require "kong.tools.timestamp" local Errors = require "kong.dao.errors" local utils = require "kong.tools.utils" local cjson = require "cjson" + +local tonumber = tonumber +local concat = table.concat +local match = string.match +local fmt = string.format local uuid = utils.uuid +local pairs = pairs +local ipairs = ipairs local _M = require("kong.dao.db").new_db("cassandra") @@ -65,12 +72,12 @@ function _M.new(kong_config) end local function extract_major(release_version) - return string.match(release_version, "^(%d+)%.%d+%.?%d*$") + return match(release_version, "^(%d+)%.%d+%.?%d*$") end local function cluster_release_version(peers) local first_release_version - local ok = true + local mismatch for i = 1, #peers do local release_version = peers[i].release_version @@ -82,18 +89,20 @@ local function cluster_release_version(peers) if i == 1 then first_release_version = major_version elseif major_version ~= first_release_version then - ok = false + mismatch = true break end end - if not ok then - local err_t = {"different major versions detected (only all of 2.x or 3.x supported):"} + if mismatch then + local err_t = { + "different major versions detected (only all of 2.x or 3.x supported):" + } for i = 1, #peers do - err_t[#err_t+1] = string.format("%s (%s)", peers[i].host, peers[i].release_version) + err_t[#err_t+1] = fmt("%s (%s)", peers[i].host, peers[i].release_version) end - return nil, table.concat(err_t, " ") + return nil, concat(err_t, " ") end return tonumber(first_release_version) @@ -123,7 +132,37 @@ function _M:infos() } end --- Formatting +local function deserialize_rows(rows, schema) + for i, row in ipairs(rows) do + for col, value in pairs(row) do + local t = schema.fields[col].type + if t == "table" or t == "array" then + rows[i][col] = cjson.decode(value) + end + end + end +end + +function _M:query(query, args, options, schema, no_keyspace) + local opts = self:clone_query_options(options) + local coordinator_opts = {} + if no_keyspace then + -- defaults to the system keyspace, always present + coordinator_opts.keyspace = "system" + end + + local res, err = self.cluster:execute(query, args, opts, coordinator_opts) + if not res then return nil, Errors.db(err) end + + if schema ~= nil and res.type == "ROWS" then + deserialize_rows(res, schema) + end + + return res +end + +--- Query building +-- @section query_building local function serialize_arg(field, value) if value == nil then @@ -139,39 +178,33 @@ local function serialize_arg(field, value) end end -local function deserialize_rows(rows, schema) - for i, row in ipairs(rows) do - for col, value in pairs(row) do - if schema.fields[col].type == "table" or schema.fields[col].type == "array" then - rows[i][col] = cjson.decode(value) - end - end - end -end - local function get_where(schema, filter_keys, args) args = args or {} - local fields = schema.fields local where = {} + local fields = schema.fields for col, value in pairs(filter_keys) do - where[#where + 1] = col.." = ?" - args[#args + 1] = serialize_arg(fields[col], value) + where[#where+1] = col .. " = ?" + args[#args+1] = serialize_arg(fields[col], value) end - return table.concat(where, " AND "), args + return concat(where, " AND "), args end -local function get_select_query(table_name, where, select_clause) - local query = string.format("SELECT %s FROM %s", select_clause or "*", table_name) - if where ~= nil then - query = query.." WHERE "..where.." ALLOW FILTERING" +local function select_query(table_name, where, select_clause) + select_clause = select_clause or "*" + + local query = fmt("SELECT %s FROM %s", select_clause, table_name) + + if where then + query = query .. " WHERE " .. where .. " ALLOW FILTERING" end return query end --- Querying +-- @section querying local function check_unique_constraints(self, table_name, constraints, values, primary_keys, update) local errors @@ -180,12 +213,12 @@ local function check_unique_constraints(self, table_name, constraints, values, p -- Only check constraints if value is non-null if values[col] ~= nil then local where, args = get_where(constraint.schema, {[col] = values[col]}) - local query = get_select_query(table_name, where) + local query = select_query(table_name, where) local rows, err = self:query(query, args, nil, constraint.schema) - if err then - return err + if err then return nil, err elseif #rows > 0 then - -- if in update, it's fine if the retrieved row is the same as the one updated + -- if in update, it's fine if the retrieved row is + -- the same as the one being updated if update then local same_row = true for col, val in pairs(primary_keys) do @@ -205,106 +238,79 @@ local function check_unique_constraints(self, table_name, constraints, values, p end end - return Errors.unique(errors) + return errors == nil, Errors.unique(errors) end local function check_foreign_constaints(self, values, constraints) local errors for col, constraint in pairs(constraints.foreign) do - -- Only check foreign keys if value is non-null, if must not be null, field should be required + -- Only check foreign keys if value is non-null, + -- if must not be null, field should be required if values[col] ~= nil then - local res, err = self:find(constraint.table, constraint.schema, {[constraint.col] = values[col]}) - if err then - return err - elseif res == nil then + local res, err = self:find(constraint.table, constraint.schema, { + [constraint.col] = values[col] + }) + if err then return nil, err + elseif not res then errors = utils.add_error(errors, col, values[col]) end end end - return Errors.foreign(errors) -end - -function _M:query(query, args, options, schema, no_keyspace) - local opts = self:clone_query_options(options) - local coordinator_opts = {} - if no_keyspace then - -- defaults to the system keyspace, always present - coordinator_opts.keyspace = "system" - end - - local res, err = self.cluster:execute(query, args, opts, coordinator_opts) - if not res then - return nil, Errors.db(err) - end - - if schema ~= nil and res.type == "ROWS" then - deserialize_rows(res, schema) - end - - return res + return errors == nil, Errors.foreign(errors) end function _M:insert(table_name, schema, model, constraints, options) - local err = check_unique_constraints(self, table_name, constraints, model) - if err then - return nil, err - end + options = options or {} - err = check_foreign_constaints(self, model, constraints) - if err then - return nil, err - end + local ok, err = check_unique_constraints(self, table_name, constraints, model) + if not ok then return nil, err end + + ok, err = check_foreign_constaints(self, model, constraints) + if not ok then return nil, err end local cols, binds, args = {}, {}, {} + for col, value in pairs(model) do local field = schema.fields[col] - cols[#cols + 1] = col - args[#args + 1] = serialize_arg(field, value) - binds[#binds + 1] = "?" + cols[#cols+1] = col + args[#args+1] = serialize_arg(field, value) + binds[#binds+1] = "?" end - cols = table.concat(cols, ", ") - binds = table.concat(binds, ", ") + local query = fmt("INSERT INTO %s(%s) VALUES(%s)%s", + table_name, + concat(cols, ", "), + concat(binds, ", "), + options.ttl and fmt(" USING TTL %d", options.ttl) or "") - local query = string.format("INSERT INTO %s(%s) VALUES(%s)%s", - table_name, cols, binds, (options and options.ttl) and string.format(" USING TTL %d", options.ttl) or "") - local err = select(2, self:query(query, args)) - if err then - return nil, err - end + local res, err = self:query(query, args) + if not res then return nil, err end local primary_keys = model:extract_keys() - local row, err = self:find(table_name, schema, primary_keys) - if err then - return nil, err - end - - return row + return self:find(table_name, schema, primary_keys) end function _M:find(table_name, schema, filter_keys) local where, args = get_where(schema, filter_keys) - local query = get_select_query(table_name, where) + local query = select_query(table_name, where) local rows, err = self:query(query, args, nil, schema) - if err then - return nil, err - elseif #rows > 0 then - return rows[1] - end + if not rows then return nil, err + elseif #rows <= 1 then return rows[1] + else return nil, "bad rows result" end end function _M:find_all(table_name, tbl, schema) local opts = self:clone_query_options() local where, args - if tbl ~= nil then + if tbl then where, args = get_where(schema, tbl) end local err - local query = get_select_query(table_name, where) + local query = select_query(table_name, where) local res_rows = {} for rows, page_err in self.cluster:iterate(query, args, opts) do @@ -313,11 +319,13 @@ function _M:find_all(table_name, tbl, schema) res_rows = nil break end - if schema ~= nil then + + if schema then deserialize_rows(rows, schema) end + for _, row in ipairs(rows) do - res_rows[#res_rows + 1] = row + res_rows[#res_rows+1] = row end end @@ -326,80 +334,81 @@ end function _M:find_page(table_name, tbl, paging_state, page_size, schema) local where, args - if tbl ~= nil then + if tbl then where, args = get_where(schema, tbl) end - local query = get_select_query(table_name, where) + local query = select_query(table_name, where) local rows, err = self:query(query, args, {page_size = page_size, paging_state = paging_state}, schema) - if err then - return nil, err - elseif rows ~= nil then - local paging_state - if rows.meta and rows.meta.has_more_pages then - paging_state = rows.meta.paging_state - end - rows.meta = nil - rows.type = nil - return rows, nil, paging_state + if not rows then return nil, err end + + local paging_state + if rows.meta and rows.meta.has_more_pages then + paging_state = rows.meta.paging_state end + + rows.meta = nil + rows.type = nil + + return rows, nil, paging_state end function _M:count(table_name, tbl, schema) local where, args - if tbl ~= nil then + if tbl then where, args = get_where(schema, tbl) end - local query = get_select_query(table_name, where, "COUNT(*)") + local query = select_query(table_name, where, "COUNT(*)") local res, err = self:query(query, args) - if err then - return nil, err - elseif res and #res > 0 then - return res[1].count - end + if not res then return nil, err + elseif #res == 1 then return res[1].count + else return "bad rows result" end end function _M:update(table_name, schema, constraints, filter_keys, values, nils, full, model, options) + options = options or {} + -- must check unique constaints manually too - local err = check_unique_constraints(self, table_name, constraints, values, filter_keys, true) - if err then - return nil, err - end - err = check_foreign_constaints(self, values, constraints) - if err then - return nil, err - end + local ok, err = check_unique_constraints(self, table_name, constraints, values, filter_keys, true) + if not ok then return nil, err end - -- Cassandra TTL on update is per-column and not per-row, and TTLs cannot be updated on primary keys. - -- Not only that, but TTL on other rows can only be incremented, and not decremented. Because of all - -- of these limitations, the only way to make this happen is to do an upsert operation. - -- This implementation can be changed once Cassandra closes this issue: https://issues.apache.org/jira/browse/CASSANDRA-9312 - if options and options.ttl then - if schema.primary_key and #schema.primary_key == 1 and filter_keys[schema.primary_key[1]] then + ok, err = check_foreign_constaints(self, values, constraints) + if not ok then return nil, err end + + -- Cassandra TTLs on update is per-column and not per-row, + -- and TTLs cannot be updated on primary keys. + -- TTLs can also only be incremented and not decremented. + -- Because of these limitations, the current implementation + -- is to use an upsert operation. + -- See: https://issues.apache.org/jira/browse/CASSANDRA-9312 + if options.ttl then + if schema.primary_key and + #schema.primary_key == 1 and + filter_keys[schema.primary_key[1]] then local row, err = self:find(table_name, schema, filter_keys) - if err then - return nil, err + if err then return nil, err elseif row then for k, v in pairs(row) do if not values[k] then model[k] = v -- Populate the model to be used later for the insert end end - - -- Insert without any contraint check, since the check has already been executed + -- insert without any constraint check, since the check has already been executed return self:insert(table_name, schema, model, {unique={}, foreign={}}, options) end else - return nil, "Cannot update TTL on entities that have more than one primary_key" + return nil, "cannot update TTL on entities that have more than one primary_key" end end - local sets, args, where = {}, {} + local where + local sets, args = {}, {} + for col, value in pairs(values) do local field = schema.fields[col] - sets[#sets + 1] = col.." = ?" - args[#args + 1] = serialize_arg(field, value) + sets[#sets+1] = col .. " = ?" + args[#args+1] = serialize_arg(field, value) end -- unset nil fields if asked for @@ -410,82 +419,74 @@ function _M:update(table_name, schema, constraints, filter_keys, values, nils, f end end - sets = table.concat(sets, ", ") - where, args = get_where(schema, filter_keys, args) - local query = string.format("UPDATE %s%s SET %s WHERE %s", - table_name, (options and options.ttl) and string.format(" USING TTL %d", options.ttl) or "", sets, where) - local res, err = self:query(query, args) - if err then - return nil, err - elseif res and res.type == "VOID" then - return self:find(table_name, schema, filter_keys) - end -end - -local function cascade_delete(self, primary_keys, constraints) - if constraints.cascade == nil then return end - for f_entity, cascade in pairs(constraints.cascade) do - local tbl = {[cascade.f_col] = primary_keys[cascade.col]} - local rows, err = self:find_all(cascade.table, tbl, cascade.schema) - if err then - return nil, err - end + local query = fmt("UPDATE %s%s SET %s WHERE %s", + table_name, + options.ttl and fmt(" USING TTL %d", options.ttl) or "", + concat(sets, ", "), + where) - for _, row in ipairs(rows) do - local primary_keys_to_delete = {} - for _, primary_key in ipairs(cascade.schema.primary_key) do - primary_keys_to_delete[primary_key] = row[primary_key] - end - - local ok, err = self:delete(cascade.table, cascade.schema, primary_keys_to_delete) - if not ok then - return nil, err - end - end + local res, err = self:query(query, args) + if not res then return nil, err + elseif res.type == "VOID" then + return self:find(table_name, schema, filter_keys) end end function _M:delete(table_name, schema, primary_keys, constraints) local row, err = self:find(table_name, schema, primary_keys) - if err or row == nil then - return nil, err - end + if not row or err then return nil, err end local where, args = get_where(schema, primary_keys) - local query = string.format("DELETE FROM %s WHERE %s", - table_name, where) + local query = fmt("DELETE FROM %s WHERE %s", table_name, where) local res, err = self:query(query, args) - if err then - return nil, err - elseif res and res.type == "VOID" then - if constraints ~= nil then - cascade_delete(self, primary_keys, constraints) + if not res then return nil, err + elseif res.type == "VOID" then + if constraints and constraints.cascade then + for f_entity, cascade in pairs(constraints.cascade) do + local tbl = {[cascade.f_col] = primary_keys[cascade.col]} + local rows, err = self:find_all(cascade.table, tbl, cascade.schema) + if not rows then return nil, err end + + for _, row in ipairs(rows) do + local primary_keys_to_delete = {} + for _, primary_key in ipairs(cascade.schema.primary_key) do + primary_keys_to_delete[primary_key] = row[primary_key] + end + + local ok, err = self:delete(cascade.table, cascade.schema, primary_keys_to_delete) + if not ok then return nil, err end + end + end end return row end end --- Migrations +--- Migrations +-- @section migrations function _M:queries(queries, no_keyspace) for _, query in ipairs(utils.split(queries, ";")) do - if utils.strip(query) ~= "" then - local err = select(2, self:query(query, nil, nil, nil, no_keyspace)) - if err then - return err - end + query = utils.strip(query) + if query ~= "" then + local res, err = self:query(query, nil, nil, nil, no_keyspace) + if not res then return err end end end end function _M:drop_table(table_name) - return select(2, self:query("DROP TABLE "..table_name)) + local res, err = self:query("DROP TABLE "..table_name) + if not res then return nil, err end + return true end function _M:truncate_table(table_name) - return select(2, self:query("TRUNCATE "..table_name)) + local res, err = self:query("TRUNCATE "..table_name) + if not res then return nil, err end + return true end function _M:current_migrations() @@ -494,13 +495,19 @@ function _M:current_migrations() assert(self.release_version, "release_version not set for Cassandra cluster") if self.release_version == 3 then - q_keyspace_exists = "SELECT * FROM system_schema.keyspaces WHERE keyspace_name = ?" + q_keyspace_exists = [[ + SELECT * FROM system_schema.keyspaces + WHERE keyspace_name = ? + ]] q_migrations_table_exists = [[ SELECT COUNT(*) FROM system_schema.tables WHERE keyspace_name = ? AND table_name = ? ]] else - q_keyspace_exists = "SELECT * FROM system.schema_keyspaces WHERE keyspace_name = ?" + q_keyspace_exists = [[ + SELECT * FROM system.schema_keyspaces + WHERE keyspace_name = ? + ]] q_migrations_table_exists = [[ SELECT COUNT(*) FROM system.schema_columnfamilies WHERE keyspace_name = ? AND columnfamily_name = ? @@ -511,7 +518,7 @@ function _M:current_migrations() local rows, err = self:query(q_keyspace_exists, { self.cluster_options.keyspace }, {prepared = false}, nil, true) - if err then return nil, err + if not rows then return nil, err elseif #rows == 0 then return {} end -- Check if schema_migrations table exists @@ -519,24 +526,22 @@ function _M:current_migrations() self.cluster_options.keyspace, "schema_migrations" }, {prepared = false}) - if err then return nil, err end - - if rows[1].count > 0 then + if not rows then return nil, err + elseif rows[1] and rows[1].count > 0 then return self:query("SELECT * FROM schema_migrations", nil, { prepared = false }) - else - return {} - end + else return {} end end function _M:record_migration(id, name) - return select(2, self:query([[ - UPDATE schema_migrations SET migrations = migrations + ? WHERE id = ? - ]], { - cassandra.list {name}, - id - })) + local res, err = self:query([[ + UPDATE schema_migrations + SET migrations = migrations + ? + WHERE id = ? + ]], {cassandra.list({name}), id}) + if not res then return nil, err end + return true end return _M diff --git a/kong/dao/db/postgres.lua b/kong/dao/db/postgres.lua index 434f40ce4ec2..5730c24b70ac 100644 --- a/kong/dao/db/postgres.lua +++ b/kong/dao/db/postgres.lua @@ -463,19 +463,21 @@ end function _M:queries(queries) if utils.strip(queries) ~= "" then - local err = select(2, self:query(queries)) - if err then - return err - end + local res, err = self:query(queries) + if not res then return err end end end function _M:drop_table(table_name) - return select(2, self:query("DROP TABLE "..table_name.." CASCADE")) + local res, err = self:query("DROP TABLE "..table_name.." CASCADE") + if not res then return nil, err end + return true end function _M:truncate_table(table_name) - return select(2, self:query("TRUNCATE "..table_name.." CASCADE")) + local res, err = self:query("TRUNCATE "..table_name.." CASCADE") + if not res then return nil, err end + return true end function _M:current_migrations() @@ -493,7 +495,7 @@ function _M:current_migrations() end function _M:record_migration(id, name) - return select(2, self:query { + local res, err = self:query{ [[ CREATE OR REPLACE FUNCTION upsert_schema_migrations(identifier text, migration_name varchar) RETURNS VOID AS $$ DECLARE @@ -506,7 +508,9 @@ function _M:record_migration(id, name) $$ LANGUAGE 'plpgsql'; ]], string.format("SELECT upsert_schema_migrations('%s', %s)", id, escape_literal(name)) - }) + } + if not res then return nil, err end + return true end return _M diff --git a/kong/dao/factory.lua b/kong/dao/factory.lua index 610006da3f44..63c1214c8953 100644 --- a/kong/dao/factory.lua +++ b/kong/dao/factory.lua @@ -209,13 +209,13 @@ local function migrate(self, identifier, migrations_modules, cur_migrations, on_ end if err then - return false, string.format("Error during migration %s: %s", migration.name, err) + return nil, string.format("Error during migration %s: %s", migration.name, err) end -- record success - err = _db:record_migration(identifier, migration.name) - if err then - return false, string.format("Error recording migration %s: %s", migration.name, err) + local ok, err = _db:record_migration(identifier, migration.name) + if not ok then + return nil, string.format("Error recording migration %s: %s", migration.name, err) end if on_success then From 7c94af19e9f09eef03a159f5cea5a8e30215776c Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Mon, 3 Oct 2016 22:51:48 +0200 Subject: [PATCH 4/7] style(db) improve Postgres DB code style - cleanup some bad patterns like single return values or such - caching global lookups --- kong/dao/db/cassandra.lua | 2 +- kong/dao/db/postgres.lua | 528 +++++++++++++++++++------------------- 2 files changed, 266 insertions(+), 264 deletions(-) diff --git a/kong/dao/db/cassandra.lua b/kong/dao/db/cassandra.lua index fbd51496e158..d438d49e70b9 100644 --- a/kong/dao/db/cassandra.lua +++ b/kong/dao/db/cassandra.lua @@ -154,7 +154,7 @@ function _M:query(query, args, options, schema, no_keyspace) local res, err = self.cluster:execute(query, args, opts, coordinator_opts) if not res then return nil, Errors.db(err) end - if schema ~= nil and res.type == "ROWS" then + if schema and res.type == "ROWS" then deserialize_rows(res, schema) end diff --git a/kong/dao/db/postgres.lua b/kong/dao/db/postgres.lua index 5730c24b70ac..5387db2928ae 100644 --- a/kong/dao/db/postgres.lua +++ b/kong/dao/db/postgres.lua @@ -1,10 +1,29 @@ local pgmoon = require "pgmoon-mashape" local Errors = require "kong.dao.errors" local utils = require "kong.tools.utils" +local cjson = require "cjson" + +local get_phase = ngx.get_phase +local timer_at = ngx.timer.at +local tostring = tostring +local ngx_log = ngx.log +local concat = table.concat +local ipairs = ipairs +local pairs = pairs +local match = string.match +local type = type +local find = string.find local uuid = utils.uuid +local ceil = math.ceil +local fmt = string.format +local ERR = ngx.ERR local TTL_CLEANUP_INTERVAL = 60 -- 1 minute +local function log(lvl, ...) + return ngx_log(lvl, "[postgres] ", ...) +end + local _M = require("kong.dao.db").new_db("postgres") _M.dao_insert_values = { @@ -32,47 +51,124 @@ function _M.new(kong_config) return self end --- TTL clean up timer functions +function _M:infos() + return { + desc = "database", + name = self:clone_query_options().database + } +end -local function do_clean_ttl(premature, postgres) - if premature then return end +local do_clean_ttl - local ok, err = postgres:clear_expired_ttl() - if not ok then - ngx.log(ngx.ERR, "failed to cleanup TTLs: ", err) - end - local ok, err = ngx.timer.at(TTL_CLEANUP_INTERVAL, do_clean_ttl, postgres) +function _M:init_worker() + local ok, err = timer_at(TTL_CLEANUP_INTERVAL, do_clean_ttl, self) if not ok then - ngx.log(ngx.ERR, "failed to create timer: ", err) + log(ERR, "could not create TTL timer: ", err) end end -function _M:start_ttl_timer() - if ngx then - local ok, err = ngx.timer.at(TTL_CLEANUP_INTERVAL, do_clean_ttl, self) - if not ok then - ngx.log(ngx.ERR, "failed to create timer: ", err) +--- TTL utils +-- @section ttl_utils + +local cached_columns_types = {} + +local function retrieve_primary_key_type(self, schema, table_name) + local col_type = cached_columns_types[table_name] + + if not col_type then + local query = fmt([[ + SELECT data_type + FROM information_schema.columns + WHERE table_name = '%s' + and column_name = '%s' + LIMIT 1]], table_name, schema.primary_key[1]) + + local res, err = self:query(query) + if not res then return nil, err + elseif #res > 0 then + col_type = res[1].data_type + cached_columns_types[table_name] = col_type end - self.timer_started = true end + + return col_type end -function _M:init_worker() - self:start_ttl_timer() +local function ttl(self, tbl, table_name, schema, ttl) + if not schema.primary_key or #schema.primary_key ~= 1 then + return nil, "cannot set a TTL if the entity has no primary key, or has more than one primary key" + end + + local primary_key_type, err = retrieve_primary_key_type(self, schema, table_name) + if not primary_key_type then return nil, err end + + -- get current server time + local query = [[ + SELECT extract(epoch from now() at time zone 'utc')::bigint*1000 as timestamp; + ]] + local res, err = self:query(query) + if not res then return nil, err end + + -- the expiration is always based on the current time + local expire_at = res[1].timestamp + (ttl * 1000) + + local query = fmt([[ + SELECT upsert_ttl('%s', %s, '%s', '%s', to_timestamp(%d/1000) at time zone 'UTC') + ]], tbl[schema.primary_key[1]], + primary_key_type == "uuid" and "'"..tbl[schema.primary_key[1]].."'" or "NULL", + schema.primary_key[1], table_name, expire_at) + local res, err = self:query(query) + if not res then return nil, err end + return true end -function _M:infos() - return { - desc = "database", - name = self:clone_query_options().database - } +local function clear_expired_ttl(self) + local query = [[ + SELECT * FROM ttls WHERE expire_at < CURRENT_TIMESTAMP(0) at time zone 'utc' + ]] + local res, err = self:query(query) + if not res then return nil, err end + + for _, v in ipairs(res) do + local delete_entity_query = fmt("DELETE FROM %s WHERE %s='%s'", v.table_name, + v.primary_key_name, v.primary_key_value) + local res, err = self:query(delete_entity_query) + if not res then return nil, err end + + local delete_ttl_query = fmt([[ + DELETE FROM ttls + WHERE primary_key_value='%s' + AND table_name='%s']], v.primary_key_value, v.table_name) + res, err = self:query(delete_ttl_query) + if not res then return nil, err end + end + + return true +end + +-- for tests +_M.clear_expired_ttl = clear_expired_ttl + +do_clean_ttl = function(premature, self) + if premature then return end + + local ok, err = clear_expired_ttl(self) + if not ok then + log(ERR, "could not cleanup TTLs: ", err) + end + + ok, err = timer_at(TTL_CLEANUP_INTERVAL, do_clean_ttl, self) + if not ok then + log(ERR, "could not create TTL timer: ", err) + end end --- Formatting +--- Query building +-- @section query_building -- @see pgmoon local function escape_identifier(ident) - return '"'..(tostring(ident):gsub('"', '""'))..'"' + return '"' .. (tostring(ident):gsub('"', '""')) .. '"' end -- @see pgmoon @@ -81,37 +177,82 @@ local function escape_literal(val, field) if t_val == "number" then return tostring(val) elseif t_val == "string" then - return "'"..tostring((val:gsub("'", "''"))).."'" + return "'" .. tostring((val:gsub("'", "''"))) .. "'" elseif t_val == "boolean" then return val and "TRUE" or "FALSE" elseif t_val == "table" and field and (field.type == "table" or field.type == "array") then - local json = require "cjson" - return escape_literal(json.encode(val)) + return escape_literal(cjson.encode(val)) end - error("don't know how to escape value: "..tostring(val)) + error("don't know how to escape value: " .. tostring(val)) end local function get_where(tbl) local where = {} for col, value in pairs(tbl) do - where[#where + 1] = string.format("%s = %s", - escape_identifier(col), - escape_literal(value)) + where[#where+1] = fmt("%s = %s", + escape_identifier(col), + escape_literal(value)) + end + + return concat(where, " AND ") +end + +local function get_select_fields(schema) + local fields = {} + for k, v in pairs(schema.fields) do + if v.type == "timestamp" then + fields[#fields+1] = fmt("extract(epoch from %s)::bigint*1000 as %s", k, k) + else + fields[#fields+1] = '"' .. k .. '"' + end + end + return concat(fields, ", ") +end + +local function select_query(self, select_clause, schema, table, where, offset, limit) + local query + + local join_ttl = schema.primary_key and #schema.primary_key == 1 + if join_ttl then + local primary_key_type, err = retrieve_primary_key_type(self, schema, table) + if not primary_key_type then return nil, err end + + query = fmt([[ + SELECT %s FROM %s + LEFT OUTER JOIN ttls ON (%s.%s = ttls.primary_%s_value) + WHERE (ttls.primary_key_value IS NULL + OR (ttls.table_name = '%s' AND expire_at > CURRENT_TIMESTAMP(0) at time zone 'utc')) + ]], select_clause, table, table, schema.primary_key[1], + primary_key_type == "uuid" and "uuid" or "key", table) + else + query = fmt("SELECT %s FROM %s", select_clause, table) end - return table.concat(where, " AND ") + if where then + query = query .. (join_ttl and " AND " or " WHERE ") .. where + end + if limit then + query = query .. " LIMIT " .. limit + end + if offset and offset > 0 then + query = query .. " OFFSET " .. offset + end + return query end +--- Querying +-- @section querying + local function parse_error(err_str) local err - if string.find(err_str, "Key .* already exists") then - local col, value = string.match(err_str, "%((.+)%)=%((.+)%)") + if find(err_str, "Key .* already exists") then + local col, value = match(err_str, "%((.+)%)=%((.+)%)") if col then err = Errors.unique {[col] = value} end - elseif string.find(err_str, "violates foreign key constraint") then - local col, value = string.match(err_str, "%((.+)%)=%((.+)%)") + elseif find(err_str, "violates foreign key constraint") then + local col, value = match(err_str, "%((.+)%)=%((.+)%)") if col then err = Errors.foreign {[col] = value} end @@ -120,113 +261,47 @@ local function parse_error(err_str) return err or Errors.db(err_str) end -local function get_select_fields(schema) - local fields = {} - local timestamp_fields = {} - for k, v in pairs(schema.fields) do - if v.type == "timestamp" then - table.insert(timestamp_fields, string.format("extract(epoch from %s)::bigint*1000 as %s", k, k)) - else - table.insert(fields, "\""..k.."\"") +local function deserialize_rows(rows, schema) + for i, row in ipairs(rows) do + for col, value in pairs(row) do + if type(value) == "string" and schema.fields[col] and + (schema.fields[col].type == "table" or schema.fields[col].type == "array") then + rows[i][col] = cjson.decode(value) + end end end - return table.concat(fields, ",")..(#timestamp_fields > 0 and ","..table.concat(timestamp_fields, ",") or "") end --- Querying - function _M:query(query, schema) local conn_opts = self:clone_query_options() local pg = pgmoon.new(conn_opts) local ok, err = pg:connect() - if not ok then - return nil, Errors.db(err) - end + if not ok then return nil, Errors.db(err) end local res, err = pg:query(query) - if ngx and ngx.get_phase() ~= "init" then + if get_phase() ~= "init" then pg:keepalive() else pg:disconnect() end - if res == nil then - return nil, parse_error(err) - elseif schema ~= nil then - self:deserialize_rows(res, schema) + if not res then return nil, parse_error(err) + elseif schema then + deserialize_rows(res, schema) end return res end -function _M:retrieve_primary_key_type(schema, table_name) - if schema.primary_key and #schema.primary_key == 1 then - if not self.column_types then self.column_types = {} end - - local result = self.column_types[table_name] - if not result then - local query = string.format("SELECT data_type FROM information_schema.columns WHERE table_name = '%s' and column_name = '%s' LIMIT 1", - table_name, schema.primary_key[1]) - local res, err = self:query(query) - if err then - return nil, err - elseif #res > 0 then - result = res[1].data_type - self.column_types[table_name] = result - end - end - - return result - end -end - -function _M:get_select_query(select_clause, schema, table, where, offset, limit) - local query - - local join_ttl = schema.primary_key and #schema.primary_key == 1 - if join_ttl then - local primary_key_type = self:retrieve_primary_key_type(schema, table) - query = string.format([[SELECT %s FROM %s LEFT OUTER JOIN ttls ON (%s.%s = ttls.primary_%s_value) WHERE - (ttls.primary_key_value IS NULL OR (ttls.table_name = '%s' AND expire_at > CURRENT_TIMESTAMP(0) at time zone 'utc'))]], - select_clause, table, table, schema.primary_key[1], primary_key_type == "uuid" and "uuid" or "key", table) - else - query = string.format("SELECT %s FROM %s", select_clause, table) - end - - if where ~= nil then - query = query..(join_ttl and " AND " or " WHERE ")..where - end - if limit ~= nil then - query = query.." LIMIT "..limit - end - if offset ~= nil and offset > 0 then - query = query.." OFFSET "..offset - end - return query -end - -function _M:deserialize_rows(rows, schema) - if schema then - local json = require "cjson" - for i, row in ipairs(rows) do - for col, value in pairs(row) do - if type(value) == "string" and schema.fields[col] and - (schema.fields[col].type == "table" or schema.fields[col].type == "array") then - rows[i][col] = json.decode(value) - end - end - end - end -end - -function _M:deserialize_timestamps(row, schema) +local function deserialize_timestamps(self, row, schema) local result = row for k, v in pairs(schema.fields) do if v.type == "timestamp" and result[k] then - local query = string.format("SELECT extract(epoch from timestamp '%s')::bigint*1000 as %s;", result[k], k) + local query = fmt([[ + SELECT extract(epoch from timestamp '%s')::bigint*1000 as %s; + ]], result[k], k) local res, err = self:query(query) - if err then - return nil, err + if not res then return nil, err elseif #res > 0 then result[k] = res[1][k] end @@ -235,15 +310,16 @@ function _M:deserialize_timestamps(row, schema) return result end -function _M:serialize_timestamps(tbl, schema) +local function serialize_timestamps(self, tbl, schema) local result = tbl for k, v in pairs(schema.fields) do if v.type == "timestamp" and result[k] then - local query = string.format("SELECT to_timestamp(%d/1000) at time zone 'UTC' as %s;", result[k], k) + local query = fmt([[ + SELECT to_timestamp(%d/1000) at time zone 'UTC' as %s; + ]], result[k], k) local res, err = self:query(query) - if err then - return nil, err - elseif #res > 0 then + if not res then return nil, err + elseif #res <= 1 then result[k] = res[1][k] end end @@ -251,88 +327,32 @@ function _M:serialize_timestamps(tbl, schema) return result end -function _M:ttl(tbl, table_name, schema, ttl) - if not schema.primary_key or #schema.primary_key ~= 1 then - return false, "Cannot set a TTL if the entity has no primary key, or has more than one primary key" - end - - local primary_key_type = self:retrieve_primary_key_type(schema, table_name) - - -- Get current server time - local query = "SELECT extract(epoch from now() at time zone 'utc')::bigint*1000 as timestamp;" - local res, err = self:query(query) - if err then - return false, err - end - - -- The expiration is always based on the current time - local expire_at = res[1].timestamp + (ttl * 1000) - - local query = string.format("SELECT upsert_ttl('%s', %s, '%s', '%s', to_timestamp(%d/1000) at time zone 'UTC')", - tbl[schema.primary_key[1]], primary_key_type == "uuid" and "'"..tbl[schema.primary_key[1]].."'" or "NULL", - schema.primary_key[1], table_name, expire_at) - local _, err = self:query(query) - if err then - return false, err - end - return true -end - --- Delete old expired TTL entities -function _M:clear_expired_ttl() - local query = "SELECT * FROM ttls WHERE expire_at < CURRENT_TIMESTAMP(0) at time zone 'utc'" - local res, err = self:query(query) - if err then - return false, err - end - - for _, v in ipairs(res) do - local delete_entity_query = string.format("DELETE FROM %s WHERE %s='%s'", v.table_name, v.primary_key_name, v.primary_key_value) - local _, err = self:query(delete_entity_query) - if err then - return false, err - end - local delete_ttl_query = string.format("DELETE FROM ttls WHERE primary_key_value='%s' AND table_name='%s'", v.primary_key_value, v.table_name) - local _, err = self:query(delete_ttl_query) - if err then - return false, err - end - end - - return true -end - function _M:insert(table_name, schema, model, _, options) - local values, err = self:serialize_timestamps(model, schema) - if err then - return nil, err - end + options = options or {} + + local values, err = serialize_timestamps(self, model, schema) + if err then return nil, err end local cols, args = {}, {} for col, value in pairs(values) do - cols[#cols + 1] = escape_identifier(col) - args[#args + 1] = escape_literal(value, schema.fields[col]) + cols[#cols+1] = escape_identifier(col) + args[#args+1] = escape_literal(value, schema.fields[col]) end - cols = table.concat(cols, ", ") - args = table.concat(args, ", ") - - local query = string.format("INSERT INTO %s(%s) VALUES(%s) RETURNING *", - table_name, cols, args) + local query = fmt("INSERT INTO %s(%s) VALUES(%s) RETURNING *", + table_name, + concat(cols, ", "), + concat(args, ", ")) local res, err = self:query(query, schema) - if err then - return nil, err + if not res then return nil, err elseif #res > 0 then - local res, err = self:deserialize_timestamps(res[1], schema) - if err then - return nil, err + res, err = deserialize_timestamps(self, res[1], schema) + if err then return nil, err else -- Handle options - if options and options.ttl then - local _, err = self:ttl(res, table_name, schema, options.ttl) - if err then - return nil, err - end + if options.ttl then + local ok, err = ttl(self, res, table_name, schema, options.ttl) + if not ok then return nil, err end end return res end @@ -341,48 +361,40 @@ end function _M:find(table_name, schema, primary_keys) local where = get_where(primary_keys) - local query = self:get_select_query(get_select_fields(schema), schema, table_name, where) + local query = select_query(self, get_select_fields(schema), schema, table_name, where) local rows, err = self:query(query, schema) - if err then - return nil, err - elseif rows and #rows > 0 then - return rows[1] - end + if not rows then return nil, err + elseif #rows <= 1 then return rows[1] + else return nil, "bad rows result" end end function _M:find_all(table_name, tbl, schema) local where - if tbl ~= nil then + if tbl then where = get_where(tbl) end - local query = self:get_select_query(get_select_fields(schema), schema, table_name, where) + local query = select_query(self, get_select_fields(schema), schema, table_name, where) return self:query(query, schema) end function _M:find_page(table_name, tbl, page, page_size, schema) - if page == nil then - page = 1 - end + page = page or 1 local total_count, err = self:count(table_name, tbl, schema) - if err then - return nil, err - end + if not total_count then return nil, err end - local total_pages = math.ceil(total_count/page_size) + local total_pages = ceil(total_count/page_size) local offset = page_size * (page - 1) local where - if tbl ~= nil then + if tbl then where = get_where(tbl) end - local query = self:get_select_query(get_select_fields(schema), schema, table_name, where, offset, page_size) + local query = select_query(self, get_select_fields(schema), schema, table_name, where, offset, page_size) local rows, err = self:query(query, schema) - if err then - return nil, err - end + if not rows then return nil, err end local next_page = page + 1 return rows, nil, (next_page <= total_pages and next_page or nil) @@ -390,76 +402,68 @@ end function _M:count(table_name, tbl, schema) local where - if tbl ~= nil then + if tbl then where = get_where(tbl) end - local query = self:get_select_query("COUNT(*)", schema, table_name, where) + local query = select_query(self, "COUNT(*)", schema, table_name, where) local res, err = self:query(query) - if err then - return nil, err - elseif res and #res > 0 then - return res[1].count - end + if not res then return nil, err + elseif #res <= 1 then return res[1].count + else return nil, "bad rows result" end end function _M:update(table_name, schema, _, filter_keys, values, nils, full, _, options) + options = options or {} + local args = {} - local values, err = self:serialize_timestamps(values, schema) - if err then - return nil, err - end + local values, err = serialize_timestamps(self, values, schema) + if not values then return nil, err end + for col, value in pairs(values) do - args[#args + 1] = string.format("%s = %s", - escape_identifier(col), escape_literal(value, schema.fields[col])) + args[#args+1] = fmt("%s = %s", + escape_identifier(col), + escape_literal(value, schema.fields[col])) end if full then for col in pairs(nils) do - args[#args + 1] = escape_identifier(col).." = NULL" + args[#args+1] = escape_identifier(col) .. " = NULL" end end - args = table.concat(args, ", ") - local where = get_where(filter_keys) - local query = string.format("UPDATE %s SET %s WHERE %s RETURNING *", - table_name, args, where) + local query = fmt("UPDATE %s SET %s WHERE %s RETURNING *", + table_name, + concat(args, ", "), + where) + local res, err = self:query(query, schema) - if err then - return nil, err - elseif res and res.affected_rows == 1 then - local res, err = self:deserialize_timestamps(res[1], schema) - if err then - return nil, err - else - -- Handle options - if options and options.ttl then - local _, err = self:ttl(res, table_name, schema, options.ttl) - if err then - return nil, err - end - end - return res + if not res then return nil, err + elseif res.affected_rows == 1 then + res, err = deserialize_timestamps(self, res[1], schema) + if not res then return nil, err + elseif options.ttl then + local ok, err = ttl(self, res, table_name, schema, options.ttl) + if not ok then return nil, err end end + return res end end function _M:delete(table_name, schema, primary_keys) local where = get_where(primary_keys) - local query = string.format("DELETE FROM %s WHERE %s RETURNING *", - table_name, where) + local query = fmt("DELETE FROM %s WHERE %s RETURNING *", + table_name, where) local res, err = self:query(query, schema) - if err then - return nil, err - end - - if res and res.affected_rows == 1 then - return self:deserialize_timestamps(res[1], schema) + if not res then return nil, err + elseif res.affected_rows == 1 then + return deserialize_timestamps(self, res[1], schema) end end --- Migrations +--- Migrations +-- @section migrations function _M:queries(queries) if utils.strip(queries) ~= "" then @@ -481,11 +485,9 @@ function _M:truncate_table(table_name) end function _M:current_migrations() - -- Check if schema_migrations table exists + -- check if schema_migrations table exists local rows, err = self:query "SELECT to_regclass('schema_migrations')" - if err then - return nil, err - end + if not rows then return nil, err end if #rows > 0 and rows[1].to_regclass == "schema_migrations" then return self:query "SELECT * FROM schema_migrations" @@ -507,7 +509,7 @@ function _M:record_migration(id, name) END; $$ LANGUAGE 'plpgsql'; ]], - string.format("SELECT upsert_schema_migrations('%s', %s)", id, escape_literal(name)) + fmt("SELECT upsert_schema_migrations('%s', %s)", id, escape_literal(name)) } if not res then return nil, err end return true From fa67cc2ae92f5bfce78a7b2b779b45c1e596922d Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Tue, 4 Oct 2016 12:40:25 +0200 Subject: [PATCH 5/7] feat(db) make the Cassandra LB policy configurable Allows the use of `DCAwareRoundRobin` load balancing policy for Cassandra, allowing to distribute C* queries accross a multi-dc cluster while always trying to hit the local DC first. --- kong.conf.default | 19 ++++++++++++++++--- kong/conf_loader.lua | 18 +++++++++++++----- kong/dao/db/cassandra.lua | 12 ++++++++++++ kong/templates/kong_defaults.lua | 10 ++++++---- spec/01-unit/02-conf_loader_spec.lua | 7 +++++++ 5 files changed, 54 insertions(+), 12 deletions(-) diff --git a/kong.conf.default b/kong.conf.default index 03b529a8210d..e9c5e075d1c7 100644 --- a/kong.conf.default +++ b/kong.conf.default @@ -129,9 +129,6 @@ #cassandra_keyspace = kong # The keyspace to use in your cluster. -#cassandra_consistency = ONE # Consistency setting to use when reading/ - # writing to the Cassandra cluster. - #cassandra_timeout = 5000 # Defines the timeout (in ms), for reading # and writing. @@ -149,6 +146,22 @@ #cassandra_password = kong # Password when using the # `PasswordAuthenticator` scheme. +#cassandra_consistency = ONE # Consistency setting to use when reading/ + # writing to the Cassandra cluster. + +#cassandra_lb_policy = RoundRobin # Load balancing policy to use when + # distributing queries across your Cassandra + # cluster. + # Accepted values are `RoundRobin` and + # `DCAwareRoundRobin`. + # Prefer the later if and only if you are + # using a multi-datacenter cluster. + +#cassandra_local_datacenter = # When using the `DCAwareRoundRobin` load + # balancing policy, you must specify the name + # of the local (closest) datacenter for this + # Kong node. + #cassandra_repl_strategy = SimpleStrategy # When migrating for the first time, # Kong will use this setting to # create your keyspace. diff --git a/kong/conf_loader.lua b/kong/conf_loader.lua index d35bd5c82d8c..cc04e2725b23 100644 --- a/kong/conf_loader.lua +++ b/kong/conf_loader.lua @@ -63,14 +63,16 @@ local CONF_INFERENCES = { cassandra_contact_points = {typ = "array"}, cassandra_port = {typ = "number"}, - cassandra_repl_strategy = {enum = {"SimpleStrategy", "NetworkTopologyStrategy"}}, - cassandra_repl_factor = {typ = "number"}, - cassandra_data_centers = {typ = "array"}, - cassandra_consistency = {enum = {"ALL", "EACH_QUORUM", "QUORUM", "LOCAL_QUORUM", "ONE", - "TWO", "THREE", "LOCAL_ONE"}}, -- no ANY: this is R/W cassandra_timeout = {typ = "number"}, cassandra_ssl = {typ = "boolean"}, cassandra_ssl_verify = {typ = "boolean"}, + cassandra_consistency = {enum = {"ALL", "EACH_QUORUM", "QUORUM", "LOCAL_QUORUM", "ONE", + "TWO", "THREE", "LOCAL_ONE"}}, -- no ANY: this is R/W + cassandra_lb_policy = {enum = {"RoundRobin", "DCAwareRoundRobin"}}, + cassandra_local_datacenter = {typ = "string"}, + cassandra_repl_strategy = {enum = {"SimpleStrategy", "NetworkTopologyStrategy"}}, + cassandra_repl_factor = {typ = "number"}, + cassandra_data_centers = {typ = "array"}, cluster_profile = {enum = {"local", "lan", "wan"}}, cluster_ttl_on_failure = {typ = "number"}, @@ -162,6 +164,12 @@ local function check_and_infer(conf) -- custom validations --------------------- + if conf.cassandra_lb_policy == "DCAwareRoundRobin" and + not conf.cassandra_local_datacenter then + errors[#errors+1] = "must specify 'cassandra_local_datacenter' when ".. + "DCAwareRoundRobin policy is in use" + end + if conf.ssl then if conf.ssl_cert and not conf.ssl_cert_key then errors[#errors+1] = "ssl_cert_key must be specified" diff --git a/kong/dao/db/cassandra.lua b/kong/dao/db/cassandra.lua index d438d49e70b9..ffcd460af83a 100644 --- a/kong/dao/db/cassandra.lua +++ b/kong/dao/db/cassandra.lua @@ -47,6 +47,10 @@ function _M.new(kong_config) verify = kong_config.cassandra_ssl_verify } + -- + -- cluster options from Kong config + -- + if kong_config.cassandra_username and kong_config.cassandra_password then cluster_options.auth = cassandra.auth_providers.plain_text( kong_config.cassandra_username, @@ -54,6 +58,14 @@ function _M.new(kong_config) ) end + if kong_config.cassandra_lb_policy == "RoundRobin" then + local policy = require("resty.cassandra.policies.lb.rr") + cluster_options.lb_policy = policy.new() + elseif kong_config.cassandra_lb_policy == "DCAwareRoundRobin" then + local policy = require("resty.cassandra.policies.lb.dc_rr") + cluster_options.lb_policy = policy.new(kong_config.cassandra_local_cluster) + end + local cluster, err = Cluster.new(cluster_options) if not cluster then return nil, err end diff --git a/kong/templates/kong_defaults.lua b/kong/templates/kong_defaults.lua index 5519c6516612..76eb4ac24855 100644 --- a/kong/templates/kong_defaults.lua +++ b/kong/templates/kong_defaults.lua @@ -26,15 +26,17 @@ pg_ssl_verify = off cassandra_contact_points = 127.0.0.1 cassandra_port = 9042 cassandra_keyspace = kong -cassandra_repl_strategy = SimpleStrategy -cassandra_repl_factor = 1 -cassandra_data_centers = dc1:2,dc2:3 -cassandra_consistency = ONE cassandra_timeout = 5000 cassandra_ssl = off cassandra_ssl_verify = off cassandra_username = kong cassandra_password = NONE +cassandra_consistency = ONE +cassandra_lb_policy = RoundRobin +cassandra_local_datacenter = NONE +cassandra_repl_strategy = SimpleStrategy +cassandra_repl_factor = 1 +cassandra_data_centers = dc1:2,dc2:3 cluster_listen = 0.0.0.0:7946 cluster_listen_rpc = 127.0.0.1:7373 diff --git a/spec/01-unit/02-conf_loader_spec.lua b/spec/01-unit/02-conf_loader_spec.lua index 8dee6d1c90e3..9611fc471ae8 100644 --- a/spec/01-unit/02-conf_loader_spec.lua +++ b/spec/01-unit/02-conf_loader_spec.lua @@ -323,6 +323,13 @@ describe("Configuration loader", function() local conf = assert(conf_loader(helpers.test_conf_path)) assert.equal("postgres", conf.database) end) + it("requires cassandra_local_datacenter if DCAwareRoundRobin is in use", function() + local conf, err = conf_loader(nil, { + cassandra_lb_policy = "DCAwareRoundRobin" + }) + assert.is_nil(conf) + assert.equal("must specify 'cassandra_local_datacenter' when DCAwareRoundRobin policy is in use", err) + end) end) describe("errors", function() From e9b9613e5332763187cb25cbc4c86f53f1e9cb21 Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Tue, 4 Oct 2016 14:42:49 +0200 Subject: [PATCH 6/7] fix(db) polishing after merge with next - make db instance accessible from Factory for `init_worker()` call - update to new CLI flag - use the new no_keyspace option from coordinator_options - correcly instanciate Factory in new quiet_spec tests - replace left over `unset` - fix rate-limiting and response-ratelimiting tests (renamed dao_spec to policies_spec) --- kong-0.9.3-0.rockspec | 2 + kong/dao/db/cassandra.lua | 7 +- kong/dao/db/postgres.lua | 1 + kong/dao/factory.lua | 95 ++++++++++-------- kong/kong.lua | 5 +- kong/plugins/rate-limiting/daos.lua | 3 + .../rate-limiting/policies/cluster.lua | 5 +- kong/plugins/rate-limiting/policies/init.lua | 8 +- kong/plugins/response-ratelimiting/daos.lua | 3 + .../policies/cluster.lua | 5 +- .../response-ratelimiting/policies/init.lua | 4 +- spec/01-unit/13-db/01-init_spec.lua | 21 ++++ ...ssandra_spec.lua => 02-cassandra_spec.lua} | 0 spec/02-integration/02-dao/08-quiet_spec.lua | 2 +- .../98-rate-limiting/02-daos_spec.lua | 96 ------------------- .../98-rate-limiting/02-policies_spec.lua | 78 +++++++++++++++ .../98-rate-limiting/04-access_spec.lua | 6 +- .../02-daos_spec.lua | 96 ------------------- .../02-policies_spec.lua | 81 ++++++++++++++++ .../04-access_spec.lua | 6 +- 20 files changed, 270 insertions(+), 254 deletions(-) create mode 100644 kong/plugins/rate-limiting/daos.lua create mode 100644 kong/plugins/response-ratelimiting/daos.lua create mode 100644 spec/01-unit/13-db/01-init_spec.lua rename spec/01-unit/13-db/{01-cassandra_spec.lua => 02-cassandra_spec.lua} (100%) delete mode 100644 spec/03-plugins/98-rate-limiting/02-daos_spec.lua create mode 100644 spec/03-plugins/98-rate-limiting/02-policies_spec.lua delete mode 100644 spec/03-plugins/98-response-rate-limiting/02-daos_spec.lua create mode 100644 spec/03-plugins/98-response-rate-limiting/02-policies_spec.lua diff --git a/kong-0.9.3-0.rockspec b/kong-0.9.3-0.rockspec index 14c6d8e302e4..44a1f5892c93 100644 --- a/kong-0.9.3-0.rockspec +++ b/kong-0.9.3-0.rockspec @@ -166,6 +166,7 @@ build = { ["kong.plugins.rate-limiting.migrations.postgres"] = "kong/plugins/rate-limiting/migrations/postgres.lua", ["kong.plugins.rate-limiting.handler"] = "kong/plugins/rate-limiting/handler.lua", ["kong.plugins.rate-limiting.schema"] = "kong/plugins/rate-limiting/schema.lua", + ["kong.plugins.rate-limiting.daos"] = "kong/plugins/rate-limiting/daos.lua", ["kong.plugins.rate-limiting.policies"] = "kong/plugins/rate-limiting/policies/init.lua", ["kong.plugins.rate-limiting.policies.cluster"] = "kong/plugins/rate-limiting/policies/cluster.lua", @@ -176,6 +177,7 @@ build = { ["kong.plugins.response-ratelimiting.header_filter"] = "kong/plugins/response-ratelimiting/header_filter.lua", ["kong.plugins.response-ratelimiting.log"] = "kong/plugins/response-ratelimiting/log.lua", ["kong.plugins.response-ratelimiting.schema"] = "kong/plugins/response-ratelimiting/schema.lua", + ["kong.plugins.response-ratelimiting.daos"] = "kong/plugins/response-ratelimiting/daos.lua", ["kong.plugins.response-ratelimiting.policies"] = "kong/plugins/response-ratelimiting/policies/init.lua", ["kong.plugins.response-ratelimiting.policies.cluster"] = "kong/plugins/response-ratelimiting/policies/cluster.lua", diff --git a/kong/dao/db/cassandra.lua b/kong/dao/db/cassandra.lua index ffcd460af83a..3a4535a74117 100644 --- a/kong/dao/db/cassandra.lua +++ b/kong/dao/db/cassandra.lua @@ -73,7 +73,7 @@ function _M.new(kong_config) self.query_options = query_opts self.cluster_options = cluster_options - if ngx.RESTY_CLI then + if ngx.IS_CLI then -- we must manually call our init phase (usually called from `init_by_lua`) -- to refresh the cluster. local ok, err = self:init() @@ -159,8 +159,7 @@ function _M:query(query, args, options, schema, no_keyspace) local opts = self:clone_query_options(options) local coordinator_opts = {} if no_keyspace then - -- defaults to the system keyspace, always present - coordinator_opts.keyspace = "system" + coordinator_opts.no_keyspace = true end local res, err = self.cluster:execute(query, args, opts, coordinator_opts) @@ -427,7 +426,7 @@ function _M:update(table_name, schema, constraints, filter_keys, values, nils, f if full then for col in pairs(nils) do sets[#sets + 1] = col.." = ?" - args[#args + 1] = cassandra.unset + args[#args + 1] = cassandra.null end end diff --git a/kong/dao/db/postgres.lua b/kong/dao/db/postgres.lua index 5387db2928ae..dbd4aa4518be 100644 --- a/kong/dao/db/postgres.lua +++ b/kong/dao/db/postgres.lua @@ -65,6 +65,7 @@ function _M:init_worker() if not ok then log(ERR, "could not create TTL timer: ", err) end + return true end --- TTL utils diff --git a/kong/dao/factory.lua b/kong/dao/factory.lua index 63c1214c8953..f52dd1fa7207 100644 --- a/kong/dao/factory.lua +++ b/kong/dao/factory.lua @@ -3,7 +3,6 @@ local utils = require "kong.tools.utils" local ModelFactory = require "kong.dao.model_factory" local CORE_MODELS = {"apis", "consumers", "plugins", "nodes"} -local _db -- returns db errors as strings, including the initial `nil` local function ret_error_string(db_name, res, err) @@ -73,87 +72,103 @@ local function load_daos(self, schemas, constraints, events_handler) end for m_name, schema in pairs(schemas) do - self.daos[m_name] = DAO(_db, ModelFactory(schema), schema, constraints[m_name], events_handler) + self.daos[m_name] = DAO(self.db, ModelFactory(schema), schema, constraints[m_name], events_handler) end end function _M.new(kong_config, events_handler) - local factory = { + local self = { db_type = kong_config.database, daos = {}, + additional_tables = {}, kong_config = kong_config, plugin_names = kong_config.plugins or {} } - local DB = require("kong.dao.db."..factory.db_type) + local DB = require("kong.dao.db."..self.db_type) local db, err = DB.new(kong_config) - if not db then return ret_error_string(factory.db_type, nil, err) end + if not db then return ret_error_string(self.db_type, nil, err) end - _db = db -- avoid setting a previous upvalue to `nil` in case `DB.new()` fails + self.db = db local schemas = {} for _, m_name in ipairs(CORE_MODELS) do schemas[m_name] = require("kong.dao.schemas."..m_name) end - for plugin_name in pairs(factory.plugin_names) do - local has_dao, plugin_daos = utils.load_module_if_exists("kong.plugins."..plugin_name..".dao."..factory.db_type) - if has_dao then - for k, v in pairs(plugin_daos) do - factory.daos[k] = v(kong_config) - end - end - + for plugin_name in pairs(self.plugin_names) do local has_schema, plugin_schemas = utils.load_module_if_exists("kong.plugins."..plugin_name..".daos") if has_schema then - for k, v in pairs(plugin_schemas) do - schemas[k] = v + if plugin_schemas.tables then + for _, v in ipairs(plugin_schemas.tables) do + table.insert(self.additional_tables, v) + end + else + for k, v in pairs(plugin_schemas) do + schemas[k] = v + end end end end local constraints = build_constraints(schemas) - load_daos(factory, schemas, constraints, events_handler) + load_daos(self, schemas, constraints, events_handler) - return setmetatable(factory, _M) + return setmetatable(self, _M) end function _M:init() - return _db:init() + return self.db:init() +end + +function _M:init_worker() + return self.db:init_worker() end -- Migrations function _M:infos() - return _db:infos() + return self.db:infos() end function _M:drop_schema() for _, dao in pairs(self.daos) do - _db:drop_table(dao.table) + self.db:drop_table(dao.table) + end + + if self.additional_tables then + for _, v in ipairs(self.additional_tables) do + self.db:drop_table(v) + end end - if _db.additional_tables then - for _, v in ipairs(_db.additional_tables) do - _db:drop_table(v) + if self.db.additional_tables then + for _, v in ipairs(self.db.additional_tables) do + self.db:drop_table(v) end end - _db:drop_table("schema_migrations") + self.db:drop_table("schema_migrations") end function _M:truncate_table(dao_name) - _db:truncate_table(self.daos[dao_name].table) + self.db:truncate_table(self.daos[dao_name].table) end function _M:truncate_tables() for _, dao in pairs(self.daos) do - _db:truncate_table(dao.table) + self.db:truncate_table(dao.table) + end + + if self.db.additional_tables then + for _, v in ipairs(self.db.additional_tables) do + self.db:truncate_table(v) + end end - if _db.additional_tables then - for _, v in ipairs(_db.additional_tables) do - _db:truncate_table(v) + if self.additional_tables then + for _, v in ipairs(self.additional_tables) do + self.db:truncate_table(v) end end end @@ -174,8 +189,8 @@ function _M:migrations_modules() end function _M:current_migrations() - local rows, err = _db:current_migrations() - if err then return ret_error_string(_db.name, nil, err) end + local rows, err = self.db:current_migrations() + if err then return ret_error_string(self.db.name, nil, err) end local cur_migrations = {} for _, row in ipairs(rows) do @@ -196,16 +211,16 @@ local function migrate(self, identifier, migrations_modules, cur_migrations, on_ if #to_run > 0 and on_migrate then -- we have some migrations to run - on_migrate(identifier, _db:infos()) + on_migrate(identifier, self.db:infos()) end for _, migration in ipairs(to_run) do local err local mig_type = type(migration.up) if mig_type == "string" then - err = _db:queries(migration.up) + err = self.db:queries(migration.up) elseif mig_type == "function" then - err = migration.up(_db, self.kong_config, self) + err = migration.up(self.db, self.kong_config, self) end if err then @@ -213,13 +228,13 @@ local function migrate(self, identifier, migrations_modules, cur_migrations, on_ end -- record success - local ok, err = _db:record_migration(identifier, migration.name) + local ok, err = self.db:record_migration(identifier, migration.name) if not ok then return nil, string.format("Error recording migration %s: %s", migration.name, err) end if on_success then - on_success(identifier, migration.name, _db:infos()) + on_success(identifier, migration.name, self.db:infos()) end end @@ -248,15 +263,15 @@ function _M:run_migrations(on_migrate, on_success) local migrations_modules = self:migrations_modules() local cur_migrations, err = self:current_migrations() - if err then return ret_error_string(_db.name, nil, err) end + if err then return ret_error_string(self.db.name, nil, err) end local ok, err, migrations_ran = migrate(self, "core", migrations_modules, cur_migrations, on_migrate, on_success) - if not ok then return ret_error_string(_db.name, nil, err) end + if not ok then return ret_error_string(self.db.name, nil, err) end for identifier in pairs(migrations_modules) do if identifier ~= "core" then local ok, err, n_ran = migrate(self, identifier, migrations_modules, cur_migrations, on_migrate, on_success) - if not ok then return ret_error_string(_db.name, nil, err) + if not ok then return ret_error_string(self.db.name, nil, err) else migrations_ran = migrations_ran + n_ran end diff --git a/kong/kong.lua b/kong/kong.lua index 09587640bdb7..f6b28f6fc2f4 100644 --- a/kong/kong.lua +++ b/kong/kong.lua @@ -143,7 +143,10 @@ function Kong.init_worker() core.init_worker.before() - singletons.dao:init_worker() + local ok, err = singletons.dao:init_worker() + if not ok then + ngx.log(ngx.ERR, "could not init DB: ", err) + end for _, plugin in ipairs(singletons.loaded_plugins) do plugin.handler:init_worker() diff --git a/kong/plugins/rate-limiting/daos.lua b/kong/plugins/rate-limiting/daos.lua new file mode 100644 index 000000000000..8e60b7d45b80 --- /dev/null +++ b/kong/plugins/rate-limiting/daos.lua @@ -0,0 +1,3 @@ +return { + tables = {"ratelimiting_metrics"} +} diff --git a/kong/plugins/rate-limiting/policies/cluster.lua b/kong/plugins/rate-limiting/policies/cluster.lua index 282888d10640..8c9cff8b28a3 100644 --- a/kong/plugins/rate-limiting/policies/cluster.lua +++ b/kong/plugins/rate-limiting/policies/cluster.lua @@ -50,8 +50,9 @@ return { db.cassandra.timestamp(periods[period]), period, }) - if not rows then return nil, err - elseif #rows > 0 then return rows[1] end + if not rows then return nil, err + elseif #rows <= 1 then return rows[1] + else return nil, "bad rows result" end end, }, ["postgres"] = { diff --git a/kong/plugins/rate-limiting/policies/init.lua b/kong/plugins/rate-limiting/policies/init.lua index 89d336e8f8cb..8ec3b1b48dae 100644 --- a/kong/plugins/rate-limiting/policies/init.lua +++ b/kong/plugins/rate-limiting/policies/init.lua @@ -56,14 +56,16 @@ return { ngx_log(ngx.ERR, "[rate-limiting] cluster policy: could not increment ", db.name, " counter: ", err) end + + return ok, err end, usage = function(conf, api_id, identifier, current_timestamp, name) local db = singletons.dao.db - local rows, err = policy_cluster[db.name].find(db, api_id, identifier, + local row, err = policy_cluster[db.name].find(db, api_id, identifier, current_timestamp, name) - if not rows then return nil, err end + if err then return nil, err end - return rows and rows.value or 0 + return row and row.value or 0 end }, ["redis"] = { diff --git a/kong/plugins/response-ratelimiting/daos.lua b/kong/plugins/response-ratelimiting/daos.lua new file mode 100644 index 000000000000..5dc35a5b2e23 --- /dev/null +++ b/kong/plugins/response-ratelimiting/daos.lua @@ -0,0 +1,3 @@ +return { + tables = {"response_ratelimiting_metrics"} +} diff --git a/kong/plugins/response-ratelimiting/policies/cluster.lua b/kong/plugins/response-ratelimiting/policies/cluster.lua index 33c2d2b4a9bb..11a76b5c3b45 100644 --- a/kong/plugins/response-ratelimiting/policies/cluster.lua +++ b/kong/plugins/response-ratelimiting/policies/cluster.lua @@ -49,8 +49,9 @@ return { db.cassandra.timestamp(periods[period]), name.."_"..period, }) - if not rows then return nil, err - elseif #rows > 0 then return rows[1] end + if not rows then return nil, err + elseif #rows <= 1 then return rows[1] + else return nil, "bad rows result" end end, }, ["postgres"] = { diff --git a/kong/plugins/response-ratelimiting/policies/init.lua b/kong/plugins/response-ratelimiting/policies/init.lua index 39b6636cb08f..f3641a11b307 100644 --- a/kong/plugins/response-ratelimiting/policies/init.lua +++ b/kong/plugins/response-ratelimiting/policies/init.lua @@ -57,13 +57,15 @@ return { ngx_log(ngx.ERR, "[response-ratelimiting] cluster policy: could not increment ", db.name, " counter: ", err) end + + return ok, err end, usage = function(conf, api_id, identifier, current_timestamp, period, name) local db = singletons.dao.db local rows, err = policy_cluster[db.name].find(db, api_id, identifier, current_timestamp, period, name) - if not rows then return nil, err end + if err then return nil, err end return rows and rows.value or 0 end diff --git a/spec/01-unit/13-db/01-init_spec.lua b/spec/01-unit/13-db/01-init_spec.lua new file mode 100644 index 000000000000..35d043bd1b9d --- /dev/null +++ b/spec/01-unit/13-db/01-init_spec.lua @@ -0,0 +1,21 @@ +local db = require "kong.dao.db" + +describe("kong.dao.db.init", function() + it("has __index set to the init module so we can call base functions", function() + local my_db_module = db.new_db("cassandra") + + function my_db_module.new() + local self = my_db_module.super.new() + self.foo = "bar" + return self + end + + local my_db = my_db_module.new() + assert.equal("bar", my_db.foo) + + assert.has_no_error(function() + my_db:init() + my_db:init_worker() + end) + end) +end) diff --git a/spec/01-unit/13-db/01-cassandra_spec.lua b/spec/01-unit/13-db/02-cassandra_spec.lua similarity index 100% rename from spec/01-unit/13-db/01-cassandra_spec.lua rename to spec/01-unit/13-db/02-cassandra_spec.lua diff --git a/spec/02-integration/02-dao/08-quiet_spec.lua b/spec/02-integration/02-dao/08-quiet_spec.lua index b55ba3a99d43..12de0c6a81da 100644 --- a/spec/02-integration/02-dao/08-quiet_spec.lua +++ b/spec/02-integration/02-dao/08-quiet_spec.lua @@ -11,7 +11,7 @@ helpers.for_each_dao(function(kong_config) describe("Quiet with #"..kong_config.database, function() local factory setup(function() - factory = Factory(kong_config, events) + factory = Factory.new(kong_config, events) assert(factory:run_migrations()) factory:truncate_tables() diff --git a/spec/03-plugins/98-rate-limiting/02-daos_spec.lua b/spec/03-plugins/98-rate-limiting/02-daos_spec.lua deleted file mode 100644 index c6f58ccb3337..000000000000 --- a/spec/03-plugins/98-rate-limiting/02-daos_spec.lua +++ /dev/null @@ -1,96 +0,0 @@ -local uuid = require("kong.tools.utils").uuid -local helpers = require "spec.helpers" -local timestamp = require "kong.tools.timestamp" - -local ratelimiting_metrics = helpers.dao.ratelimiting_metrics - -describe("Plugin: rate-limiting (DAO)", function() - local api_id = uuid() - local identifier = uuid() - - setup(function() - helpers.dao:truncate_tables() - end) - after_each(function() - helpers.dao:truncate_tables() - end) - - it("should return nil when rate-limiting metrics are not existing", function() - local current_timestamp = 1424217600 - local periods = timestamp.get_timestamps(current_timestamp) - -- Very first select should return nil - for period, period_date in pairs(periods) do - local metric, err = ratelimiting_metrics:find(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.same(nil, metric) - end - end) - - it("should increment rate-limiting metrics with the given period", function() - local current_timestamp = 1424217600 - local periods = timestamp.get_timestamps(current_timestamp) - - -- First increment - local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) - assert.falsy(err) - assert.True(ok) - - -- First select - for period, period_date in pairs(periods) do - local metric, err = ratelimiting_metrics:find(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.same({ - api_id = api_id, - identifier = identifier, - period = period, - period_date = period_date, - value = 1 -- The important part - }, metric) - end - - -- Second increment - local ok = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) - assert.True(ok) - - -- Second select - for period, period_date in pairs(periods) do - local metric, err = ratelimiting_metrics:find(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.same({ - api_id = api_id, - identifier = identifier, - period = period, - period_date = period_date, - value = 2 -- The important part - }, metric) - end - - -- 1 second delay - current_timestamp = 1424217601 - periods = timestamp.get_timestamps(current_timestamp) - - -- Third increment - local ok = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) - assert.True(ok) - - -- Third select with 1 second delay - for period, period_date in pairs(periods) do - - local expected_value = 3 - - if period == "second" then - expected_value = 1 - end - - local metric, err = ratelimiting_metrics:find(api_id, identifier, current_timestamp, period) - assert.falsy(err) - assert.same({ - api_id = api_id, - identifier = identifier, - period = period, - period_date = period_date, - value = expected_value -- The important part - }, metric) - end - end) -end) -- describe rate limiting metrics diff --git a/spec/03-plugins/98-rate-limiting/02-policies_spec.lua b/spec/03-plugins/98-rate-limiting/02-policies_spec.lua new file mode 100644 index 000000000000..c78ac2651690 --- /dev/null +++ b/spec/03-plugins/98-rate-limiting/02-policies_spec.lua @@ -0,0 +1,78 @@ +local uuid = require("kong.tools.utils").uuid +local helpers = require "spec.helpers" +local policies = require "kong.plugins.rate-limiting.policies" +local timestamp = require "kong.tools.timestamp" + +describe("Plugin: rate-limiting (policies)", function() + describe("cluster", function() + local cluster_policy = policies.cluster + + local api_id = uuid() + local identifier = uuid() + + setup(function() + local singletons = require "kong.singletons" + singletons.dao = helpers.dao + + helpers.dao:truncate_tables() + end) + after_each(function() + helpers.dao:truncate_tables() + end) + + it("returns 0 when rate-limiting metrics don't exist yet", function() + local current_timestamp = 1424217600 + local periods = timestamp.get_timestamps(current_timestamp) + + for period, period_date in pairs(periods) do + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period)) + assert.equal(0, metric) + end + end) + + it("increments rate-limiting metrics with the given period", function() + local current_timestamp = 1424217600 + local periods = timestamp.get_timestamps(current_timestamp) + + -- First increment + assert(cluster_policy.increment(nil, api_id, identifier, current_timestamp, 1)) + + -- First select + for period, period_date in pairs(periods) do + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period)) + assert.equal(1, metric) + end + + -- Second increment + assert(cluster_policy.increment(nil, api_id, identifier, current_timestamp, 1)) + + -- Second select + for period, period_date in pairs(periods) do + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period)) + assert.equal(2, metric) + end + + -- 1 second delay + current_timestamp = 1424217601 + periods = timestamp.get_timestamps(current_timestamp) + + -- Third increment + assert(cluster_policy.increment(nil, api_id, identifier, current_timestamp, 1)) + + -- Third select with 1 second delay + for period, period_date in pairs(periods) do + local expected_value = 3 + if period == "second" then + expected_value = 1 + end + + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period)) + assert.equal(expected_value, metric) + end + end) + end) +end) diff --git a/spec/03-plugins/98-rate-limiting/04-access_spec.lua b/spec/03-plugins/98-rate-limiting/04-access_spec.lua index 06295f849760..6f087dccdf46 100644 --- a/spec/03-plugins/98-rate-limiting/04-access_spec.lua +++ b/spec/03-plugins/98-rate-limiting/04-access_spec.lua @@ -391,8 +391,7 @@ for i, policy in ipairs({"local", "cluster", "redis"}) do assert.are.same(5, tonumber(res.headers["x-ratelimit-remaining-minute"])) -- Simulate an error on the database - local err = helpers.dao.ratelimiting_metrics:drop_table(helpers.dao.ratelimiting_metrics.table) - assert.falsy(err) + assert(helpers.dao.db:drop_table("ratelimiting_metrics")) -- Make another request local res = assert(helpers.proxy_client():send { @@ -418,8 +417,7 @@ for i, policy in ipairs({"local", "cluster", "redis"}) do assert.are.same(5, tonumber(res.headers["x-ratelimit-remaining-minute"])) -- Simulate an error on the database - local err = helpers.dao.ratelimiting_metrics:drop_table(helpers.dao.ratelimiting_metrics.table) - assert.falsy(err) + assert(helpers.dao.db:drop_table("ratelimiting_metrics")) -- Make another request local res = assert(helpers.proxy_client():send { diff --git a/spec/03-plugins/98-response-rate-limiting/02-daos_spec.lua b/spec/03-plugins/98-response-rate-limiting/02-daos_spec.lua deleted file mode 100644 index e9213138812e..000000000000 --- a/spec/03-plugins/98-response-rate-limiting/02-daos_spec.lua +++ /dev/null @@ -1,96 +0,0 @@ -local uuid = require("kong.tools.utils").uuid -local helpers = require "spec.helpers" -local timestamp = require "kong.tools.timestamp" - -local response_ratelimiting_metrics = helpers.dao.response_ratelimiting_metrics - -describe("Rate Limiting Metrics", function() - local api_id = uuid() - local identifier = uuid() - - setup(function() - helpers.dao:truncate_tables() - end) - - after_each(function() - helpers.dao:truncate_tables() - end) - - it("should return nil when ratelimiting metrics are not existing", function() - local current_timestamp = 1424217600 - local periods = timestamp.get_timestamps(current_timestamp) - -- Very first select should return nil - for period, period_date in pairs(periods) do - local metric, err = response_ratelimiting_metrics:find(api_id, identifier, current_timestamp, period, "video") - assert.falsy(err) - assert.are.same(nil, metric) - end - end) - - it("should increment ratelimiting metrics with the given period", function() - local current_timestamp = 1424217600 - local periods = timestamp.get_timestamps(current_timestamp) - - -- First increment - local ok = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") - assert.True(ok) - - -- First select - for period, period_date in pairs(periods) do - local metric, err = response_ratelimiting_metrics:find(api_id, identifier, current_timestamp, period, "video") - assert.falsy(err) - assert.same({ - api_id = api_id, - identifier = identifier, - period = "video_"..period, - period_date = period_date, - value = 1 -- The important part - }, metric) - end - - -- Second increment - local ok = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") - assert.True(ok) - - -- Second select - for period, period_date in pairs(periods) do - local metric, err = response_ratelimiting_metrics:find(api_id, identifier, current_timestamp, period, "video") - assert.falsy(err) - assert.same({ - api_id = api_id, - identifier = identifier, - period = "video_"..period, - period_date = period_date, - value = 2 -- The important part - }, metric) - end - - -- 1 second delay - current_timestamp = 1424217601 - periods = timestamp.get_timestamps(current_timestamp) - - -- Third increment - local ok = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") - assert.True(ok) - - -- Third select with 1 second delay - for period, period_date in pairs(periods) do - - local expected_value = 3 - - if period == "second" then - expected_value = 1 - end - - local metric, err = response_ratelimiting_metrics:find(api_id, identifier, current_timestamp, period, "video") - assert.falsy(err) - assert.same({ - api_id = api_id, - identifier = identifier, - period = "video_"..period, - period_date = period_date, - value = expected_value -- The important part - }, metric) - end - end) -end) diff --git a/spec/03-plugins/98-response-rate-limiting/02-policies_spec.lua b/spec/03-plugins/98-response-rate-limiting/02-policies_spec.lua new file mode 100644 index 000000000000..d7329752adb7 --- /dev/null +++ b/spec/03-plugins/98-response-rate-limiting/02-policies_spec.lua @@ -0,0 +1,81 @@ +local uuid = require("kong.tools.utils").uuid +local helpers = require "spec.helpers" +local policies = require "kong.plugins.rate-limiting.policies" +local timestamp = require "kong.tools.timestamp" + +describe("Plugin: response-ratelimiting (policies)", function() + describe("cluster", function() + local cluster_policy = policies.cluster + + local api_id = uuid() + local identifier = uuid() + + setup(function() + local singletons = require "kong.singletons" + singletons.dao = helpers.dao + + helpers.dao:truncate_tables() + end) + + after_each(function() + helpers.dao:truncate_tables() + end) + + it("should return nil when ratelimiting metrics are not existing", function() + local current_timestamp = 1424217600 + local periods = timestamp.get_timestamps(current_timestamp) + + for period, period_date in pairs(periods) do + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period, "video")) + assert.equal(0, metric) + end + end) + + it("should increment ratelimiting metrics with the given period", function() + local current_timestamp = 1424217600 + local periods = timestamp.get_timestamps(current_timestamp) + + -- First increment + assert(cluster_policy.increment(nil, api_id, identifier, current_timestamp, 1, "video")) + + -- First select + for period, period_date in pairs(periods) do + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period, "video")) + assert.equal(1, metric) + end + + -- Second increment + assert(cluster_policy.increment(nil, api_id, identifier, current_timestamp, 1, "video")) + + -- Second select + for period, period_date in pairs(periods) do + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period, "video")) + assert.equal(2, metric) + end + + -- 1 second delay + current_timestamp = 1424217601 + periods = timestamp.get_timestamps(current_timestamp) + + -- Third increment + assert(cluster_policy.increment(nil, api_id, identifier, current_timestamp, 1, "video")) + + -- Third select with 1 second delay + for period, period_date in pairs(periods) do + + local expected_value = 3 + + if period == "second" then + expected_value = 1 + end + + local metric = assert(cluster_policy.usage(nil, api_id, identifier, + current_timestamp, period, "video")) + assert.equal(expected_value, metric) + end + end) + end) +end) diff --git a/spec/03-plugins/98-response-rate-limiting/04-access_spec.lua b/spec/03-plugins/98-response-rate-limiting/04-access_spec.lua index 2f033afb8296..cd39ee55f7f7 100644 --- a/spec/03-plugins/98-response-rate-limiting/04-access_spec.lua +++ b/spec/03-plugins/98-response-rate-limiting/04-access_spec.lua @@ -488,8 +488,7 @@ for i, policy in ipairs({"local", "cluster", "redis"}) do assert.equal(5, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) -- Simulate an error on the database - local err = helpers.dao.response_ratelimiting_metrics:drop_table(helpers.dao.response_ratelimiting_metrics.table) - assert.falsy(err) + assert(helpers.dao.db:drop_table("response_ratelimiting_metrics")) -- Make another request local res = assert(helpers.proxy_client():send { @@ -516,8 +515,7 @@ for i, policy in ipairs({"local", "cluster", "redis"}) do assert.equal(5, tonumber(res.headers["x-ratelimit-remaining-video-minute"])) -- Simulate an error on the database - local err = helpers.dao.response_ratelimiting_metrics:drop_table(helpers.dao.response_ratelimiting_metrics.table) - assert.falsy(err) + assert(helpers.dao.db:drop_table("response_ratelimiting_metrics")) -- Make another request local res = assert(helpers.proxy_client():send { From 3dc9a0ec86bbc2b032b29bf74857e5389845eb85 Mon Sep 17 00:00:00 2001 From: Thibault Charbonnier Date: Fri, 7 Oct 2016 18:56:54 +0200 Subject: [PATCH 7/7] chore(ci) test under C* 3.9 and use dev lua-cassandra --- .ci/setup_env.sh | 2 ++ .travis.yml | 10 ++++++---- kong-0.9.3-0.rockspec | 2 +- kong/dao/db/cassandra.lua | 9 ++++++++- spec/kong_tests.conf | 1 + 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/.ci/setup_env.sh b/.ci/setup_env.sh index 7be06febc178..993b6769dbd3 100755 --- a/.ci/setup_env.sh +++ b/.ci/setup_env.sh @@ -90,6 +90,8 @@ export PATH=$PATH:$OPENRESTY_INSTALL/nginx/sbin:$OPENRESTY_INSTALL/bin:$LUAROCKS eval `luarocks path` +luarocks purge --tree=$LUAROCKS_INSTALL + # ------------------------------------- # Install ccm & setup Cassandra cluster # ------------------------------------- diff --git a/.travis.yml b/.travis.yml index 996b0bb6c99b..62d073b78d34 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,9 @@ sudo: false -language: c +language: java -compiler: - - gcc +jdk: + - oraclejdk8 notifications: email: false @@ -22,7 +22,7 @@ env: - SERF=0.7.0 - LUAROCKS=2.4.0 - OPENSSL=1.0.2h - - CASSANDRA=2.2.7 + - CASSANDRA=2.2.8 - OPENRESTY_BASE=1.9.15.1 - OPENRESTY_LATEST=1.11.2.1 - OPENRESTY=$OPENRESTY_BASE @@ -37,8 +37,10 @@ env: OPENRESTY=$OPENRESTY_BASE - TEST_SUITE=integration OPENRESTY=$OPENRESTY_LATEST + CASSANDRA=3.9 - TEST_SUITE=plugins OPENRESTY=$OPENRESTY_LATEST + CASSANDRA=3.9 before_install: - source .ci/setup_env.sh diff --git a/kong-0.9.3-0.rockspec b/kong-0.9.3-0.rockspec index 44a1f5892c93..a94a426b2ad0 100644 --- a/kong-0.9.3-0.rockspec +++ b/kong-0.9.3-0.rockspec @@ -20,7 +20,7 @@ dependencies = { "multipart == 0.4", "version == 0.2", "lapis == 1.5.1", - "lua-cassandra == 1.0.0", + "lua-cassandra == dev-0", "pgmoon-mashape == 2.0.1", "luatz == 0.3", "lua_system_constants == 0.1.1", diff --git a/kong/dao/db/cassandra.lua b/kong/dao/db/cassandra.lua index 3a4535a74117..72997ded6091 100644 --- a/kong/dao/db/cassandra.lua +++ b/kong/dao/db/cassandra.lua @@ -44,9 +44,16 @@ function _M.new(kong_config) connect_timeout = kong_config.cassandra_timeout, read_timeout = kong_config.cassandra_timeout, ssl = kong_config.cassandra_ssl, - verify = kong_config.cassandra_ssl_verify + verify = kong_config.cassandra_ssl_verify, + lock_timeout = 30, + silent = ngx.IS_CLI } + if ngx.IS_CLI then + local policy = require("resty.cassandra.policies.reconnection.const") + cluster_options.reconn_policy = policy.new(100) + end + -- -- cluster options from Kong config -- diff --git a/spec/kong_tests.conf b/spec/kong_tests.conf index 060c5cff2190..52d4c1944c87 100644 --- a/spec/kong_tests.conf +++ b/spec/kong_tests.conf @@ -15,6 +15,7 @@ pg_host = 127.0.0.1 pg_port = 5432 pg_database = kong_tests cassandra_keyspace = kong_tests +cassandra_timeout = 10000 anonymous_reports = off lua_package_path = ?/init.lua;./kong/?.lua