Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(dao) better full update handling in DAO and API #820

Merged
merged 1 commit into from
Dec 23, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions kong/api/crud_helpers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ function _M.paginated_set(self, dao_collection)
if next_total > 0 then
next_url = self:build_url(self.req.parsed_url.path, {
port = self.req.parsed_url.port,
query = ngx.encode_args({
offset = ngx.encode_base64(data.next_page),
size = size
})
query = ngx.encode_args {
offset = ngx.encode_base64(data.next_page),
size = size
}
})
end

Expand All @@ -80,11 +80,22 @@ function _M.paginated_set(self, dao_collection)
-- This check is required otherwise the response is going to be a
-- JSON Object and not a JSON array. The reason is because an empty Lua array `{}`
-- will not be translated as an empty array by cjson, but as an empty object.
local result = #data == 0 and "{\"data\":[],\"total\":0}" or {data=data, ["next"]=next_url, total=total}
local result = #data == 0 and "{\"data\":[],\"total\":0}" or {data = data, ["next"] = next_url, total = total}

return responses.send_HTTP_OK(result, type(result) ~= "table")
end

function _M.get(params, dao_collection)
local rows, err = dao_collection:find_by_keys(params)
if err then
return app_helpers.yield_error(err)
elseif rows[1] == nil then
return responses.send_HTTP_NOT_FOUND()
else
return responses.send_HTTP_OK(rows[1])
end
end

function _M.put(params, dao_collection)
local res, new_entity, err

Expand Down Expand Up @@ -120,14 +131,12 @@ function _M.post(params, dao_collection, success)
end
end

function _M.patch(params, old_entity, dao_collection)
for k, v in pairs(params) do
old_entity[k] = v
end

local updated_entity, err = dao_collection:update(old_entity)
function _M.patch(params, dao_collection)
local updated_entity, err = dao_collection:update(params)
if err then
return app_helpers.yield_error(err)
elseif updated_entity == nil then
return responses.send_HTTP_NOT_FOUND()
else
return responses.send_HTTP_OK(updated_entity)
end
Expand Down
44 changes: 18 additions & 26 deletions kong/api/routes/apis.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local crud = require "kong.api.crud_helpers"
local syslog = require "kong.tools.syslog"
local constants = require "kong.constants"
local validations = require "kong.dao.schemas_validation"

return {
["/apis/"] = {
Expand All @@ -18,18 +19,25 @@ return {
},

["/apis/:name_or_id"] = {
before = crud.find_api_by_name_or_id,
before = function(self, dao_factory)
if validations.is_valid_uuid(self.params.name_or_id) then
self.params.id = self.params.name_or_id
else
self.params.name = self.params.name_or_id
end
self.params.name_or_id = nil
end,

GET = function(self, dao_factory, helpers)
return helpers.responses.send_HTTP_OK(self.api)
crud.get(self.params, dao_factory.apis)
end,

PATCH = function(self, dao_factory)
crud.patch(self.params, self.api, dao_factory.apis)
crud.patch(self.params, dao_factory.apis)
end,

DELETE = function(self, dao_factory)
crud.delete(self.api, dao_factory.apis)
crud.delete(self.params, dao_factory.apis)
end
},

Expand Down Expand Up @@ -57,38 +65,22 @@ return {
end
},

["/apis/:name_or_id/plugins/:plugin_id"] = {
["/apis/:name_or_id/plugins/:id"] = {
before = function(self, dao_factory, helpers)
crud.find_api_by_name_or_id(self, dao_factory, helpers)
self.params.api_id = self.api.id

local fetch_keys = {
api_id = self.api.id,
id = self.params.plugin_id
}
self.params.plugin_id = nil

local data, err = dao_factory.plugins:find_by_keys(fetch_keys)
if err then
return helpers.yield_error(err)
end

self.plugin = data[1]
if not self.plugin then
return helpers.responses.send_HTTP_NOT_FOUND()
end
end,

GET = function(self, dao_factory, helpers)
return helpers.responses.send_HTTP_OK(self.plugin)
GET = function(self, dao_factory)
crud.get(self.params, dao_factory.plugins)
end,

PATCH = function(self, dao_factory, helpers)
crud.patch(self.params, self.plugin, dao_factory.plugins)
PATCH = function(self, dao_factory)
crud.patch(self.params, dao_factory.plugins)
end,

DELETE = function(self, dao_factory)
crud.delete(self.plugin, dao_factory.plugins)
crud.delete(self.params, dao_factory.plugins)
end
}
}
16 changes: 11 additions & 5 deletions kong/api/routes/consumers.lua
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
local validations = require "kong.dao.schemas_validation"
local crud = require "kong.api.crud_helpers"

return {
Expand All @@ -16,20 +17,25 @@ return {
},

["/consumers/:username_or_id"] = {
before = function(self, dao_factory, helpers)
crud.find_consumer_by_username_or_id(self, dao_factory, helpers)
before = function(self, dao_factory)
if validations.is_valid_uuid(self.params.username_or_id) then
self.params.id = self.params.username_or_id
else
self.params.username = self.params.username_or_id
end
self.params.username_or_id = nil
end,

GET = function(self, dao_factory, helpers)
return helpers.responses.send_HTTP_OK(self.consumer)
crud.get(self.params, dao_factory.consumers)
end,

PATCH = function(self, dao_factory, helpers)
crud.patch(self.params, self.consumer, dao_factory.consumers)
crud.patch(self.params, dao_factory.consumers)
end,

DELETE = function(self, dao_factory, helpers)
crud.delete(self.consumer, dao_factory.consumers)
crud.delete(self.params, dao_factory.consumers)
end
}
}
89 changes: 49 additions & 40 deletions kong/dao/cassandra/base_dao.lua
Original file line number Diff line number Diff line change
Expand Up @@ -176,94 +176,103 @@ local function extract_primary_key(t, primary_key, clustering_key)
end

---
-- When updating a row that has a json-as-text column (ex: plugin.config),
-- Complete a partial entity given to `update`.
-- Also handle updating a row that has a json-as-text column (ex: plugin.config),
-- we want to avoid overriding it with a partial value.
-- Ex: config.key_name + config.hide_credential, if we update only one field,
-- the other should be preserved. Of course this only applies in partial update.
local function fix_tables(t, old_t, schema)
for k, v in pairs(schema.fields) do
if t[k] ~= nil and v.schema then
local s = type(v.schema) == "function" and v.schema(t) or v.schema
for s_k, s_v in pairs(s.fields) do
if not t[k][s_k] and old_t[k] then
t[k][s_k] = old_t[k][s_k]
-- the other should be preserved. Of course this only applies in partial updates.
local function complete_partial_entity(new_t, old_t, schema)
for field_key, field_rules in pairs(schema.fields) do
if new_t[field_key] == nil and old_t[field_key] ~= nil then
new_t[field_key] = old_t[field_key]
elseif new_t[field_key] ~= nil and field_rules.schema ~= nil then
-- Retrieve the field's schema
local field_schema = type(field_rules.schema) == "function" and field_rules.schema(old_t) or field_rules.schema

-- Replace each of those subfields with the value from the already inserted entity
-- if not present in the new_fields
for s_k, s_v in pairs(field_schema.fields) do
if new_t[field_key][s_k] == nil and old_t[field_key] ~= nil then
new_t[field_key][s_k] = old_t[field_key][s_k]
end
end
fix_tables(t[k], old_t[k], s)

-- Recursive call for sub fields
complete_partial_entity(new_t[field_key], old_t[field_key], field_schema)
end
end
end

---
-- Update an entity: find the row with the given PRIMARY KEY and update the other values
--- Update an entity.
-- Find the row with the given PRIMARY KEY and update its colums with values
-- from the given table and return the complete entity, with updates taken into account.
-- If asked, the update can be "full", just like an HTTP PUT method would expect to work:
-- any schema field that is not included in the given argument will be set to CQL `null` (unset).
-- Performs schema validation, 'UNIQUE' and 'FOREIGN' checks.
-- @see check_unique_fields
-- @see check_foreign_fields
-- @param[type=table] t A table representing the entity to update. It **must** contain the entity's PRIMARY KEY (can be composite).
-- @param[type=boolean] full If **true**, set to NULL any column not in the `t` parameter, such as a PUT query would do for example.
-- @treturn table Updated entity or nil.
-- @treturn table Error if any during the execution.
-- @param[type=table] t A table representing the entity to update. It must contain the entity's PRIMARY KEY (can be composite).
-- @param[type=boolean] full If true, set to CQL `null` any column not in the `t` argument, such as a PUT query would do for example.
-- @treturn table `result`: Updated entity or nil.
-- @treturn table `error`: Error if any during the execution.
function BaseDao:update(t, full)
assert(t ~= nil, "Cannot update a nil element")
assert(type(t) == "table", "Entity to update must be a table")

local ok, db_err, errors, self_err

-- Check if exists to prevent upsert
local res, err = self:find_by_primary_key(t)
if err then
return false, err
elseif not res then
return false
-- Check if the entity exists to prevent upsert and retrieve its old values
local entity, err = self:find_by_primary_key(t)
if entity == nil or err then
return nil, err
end

if not full then
fix_tables(t, res, self._schema)
complete_partial_entity(t, entity, self._schema)
end

-- Validate schema
ok, errors, self_err = validations.validate_entity(t, self._schema, {
partial_update = not full,
local ok, errors, err = validations.validate_entity(t, self._schema, {
update = true,
old_t = entity,
full_update = full,
dao = self._factory
})
if self_err then
return nil, self_err
if err then
return nil, err
elseif not ok then
return nil, DaoError(errors, error_types.SCHEMA)
end

ok, errors, db_err = self:check_unique_fields(t, true)
if db_err then
return nil, DaoError(db_err, error_types.DATABASE)
ok, errors, err = self:check_unique_fields(t, true)
if err then
return nil, DaoError(err, error_types.DATABASE)
elseif not ok then
return nil, DaoError(errors, error_types.UNIQUE)
end

ok, errors, db_err = self:check_foreign_fields(t)
if db_err then
return nil, DaoError(db_err, error_types.DATABASE)
ok, errors, err = self:check_foreign_fields(t)
if err then
return nil, DaoError(err, error_types.DATABASE)
elseif not ok then
return nil, DaoError(errors, error_types.FOREIGN)
end

-- Extract primary key from the entity
local t_primary_key, t_no_primary_key = extract_primary_key(t, self._primary_key, self._clustering_key)

-- If full, add `null` values to the SET part of the query for nil columns
-- If full, add CQL `null` to the SET part of the query for nil columns
if full then
for k, v in pairs(self._schema.fields) do
if not t[k] and not v.immutable then
if t[k] == nil and not v.immutable then
t_no_primary_key[k] = cassandra.unset
end
end
end

local update_q, columns = query_builder.update(self._table, t_no_primary_key, t_primary_key)

local _, stmt_err = self:build_args_and_execute(update_q, columns, self:_marshall(t))
if stmt_err then
return nil, stmt_err
local _, err = self:build_args_and_execute(update_q, columns, self:_marshall(t))
if err then
return nil, err
else
return self:_unmarshall(t)
end
Expand Down
6 changes: 2 additions & 4 deletions kong/dao/schemas/apis.lua
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ local function check_name(name)
if err then
ngx.log(ngx.ERR, err)
return
end

if m then
elseif m then
return false, "name must only contain alphanumeric and '., -, _, ~' characters"
end
end
Expand All @@ -120,7 +118,7 @@ return {
primary_key = {"id"},
fields = {
id = {type = "id", dao_insert_value = true},
created_at = {type = "timestamp", dao_insert_value = true},
created_at = {type = "timestamp", immutable = true, dao_insert_value = true},
name = {type = "string", unique = true, queryable = true, default = default_name, func = check_name},
request_host = {type = "string", unique = true, queryable = true, func = check_request_host_and_path,
regex = "([a-zA-Z0-9-]+(\\.[a-zA-Z0-9-]+)*)"},
Expand Down
2 changes: 1 addition & 1 deletion kong/dao/schemas/consumers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ return {
primary_key = {"id"},
fields = {
id = { type = "id", dao_insert_value = true },
created_at = { type = "timestamp", dao_insert_value = true },
created_at = { type = "timestamp", immutable = true, dao_insert_value = true },
custom_id = { type = "string", unique = true, queryable = true, func = check_custom_id_and_username },
username = { type = "string", unique = true, queryable = true, func = check_custom_id_and_username }
}
Expand Down
1 change: 1 addition & 0 deletions kong/dao/schemas/plugins.lua
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ return {
},
created_at = {
type = "timestamp",
immutable = true,
dao_insert_value = true
},
api_id = {
Expand Down
Loading