diff --git a/kong/dao/cassandra/base_dao.lua b/kong/dao/cassandra/base_dao.lua index 07661e5eaef9..6a40c4f38b4f 100644 --- a/kong/dao/cassandra/base_dao.lua +++ b/kong/dao/cassandra/base_dao.lua @@ -351,10 +351,11 @@ function BaseDao:insert(t) assert(t ~= nil, "Cannot insert a nil element") assert(type(t) == "table", "Entity to insert must be a table") - local ok, db_err, errors + local ok, db_err, errors, self_err -- Populate the entity with any default/overriden values and validate it - errors = validations.validate(t, self, { + ok, errors, self_err = validations.validate_entity(t, self._schema, { + dao = self._factory, dao_insert = function(field) if field.type == "id" then return uuid() @@ -363,13 +364,10 @@ function BaseDao:insert(t) end end }) - if errors then - return nil, errors - end - - ok, errors = validations.on_insert(t, self._schema, self._factory) - if not ok then - return nil, errors + if self_err then + return nil, self_err + elseif not ok then + return nil, DaoError(errors, error_types.SCHEMA) end ok, errors, db_err = self:check_unique_fields(t) @@ -440,7 +438,7 @@ function BaseDao:update(t, full) assert(t ~= nil, "Cannot update a nil element") assert(type(t) == "table", "Entity to update must be a table") - local ok, db_err, errors + local ok, db_err, errors, self_err -- Check if exists to prevent upsert local res, err = self:find_by_primary_key(t) @@ -455,9 +453,15 @@ function BaseDao:update(t, full) end -- Validate schema - errors = validations.validate(t, self, {partial_update = not full, full_update = full}) - if errors then - return nil, errors + ok, errors, self_err = validations.validate_entity(t, self._schema, { + partial_update = not full, + full_update = full, + dao = self._factory + }) + if self_err then + return nil, self_err + elseif not ok then + return nil, DaoError(errors, error_types.SCHEMA) end ok, errors, db_err = self:check_unique_fields(t, true) diff --git a/kong/dao/schemas/apis.lua b/kong/dao/schemas/apis.lua index 5e2d0b22ee02..67c66a1990b9 100644 --- a/kong/dao/schemas/apis.lua +++ b/kong/dao/schemas/apis.lua @@ -24,12 +24,30 @@ local function check_public_dns_and_path(value, api_t) return false, "At least a 'public_dns' or a 'path' must be specified" end - return true + -- Validate wildcard public_dns + if public_dns then + local _, count = public_dns:gsub("%*", "") + if count > 1 then + return false, "Only one wildcard is allowed: "..public_dns + elseif count > 0 then + local pos = public_dns:find("%*") + local valid + if pos == 1 then + valid = public_dns:match("^%*%.") ~= nil + elseif pos == string.len(public_dns) then + valid = public_dns:match(".%.%*$") ~= nil + end + + if not valid then + return false, "Invalid wildcard placement: "..public_dns + end + end + end end local function check_path(path, api_t) local valid, err = check_public_dns_and_path(path, api_t) - if not valid then + if valid == false then return false, err end diff --git a/kong/dao/schemas/plugins_configurations.lua b/kong/dao/schemas/plugins_configurations.lua index 08092220f48b..5a68c0f213f3 100644 --- a/kong/dao/schemas/plugins_configurations.lua +++ b/kong/dao/schemas/plugins_configurations.lua @@ -26,11 +26,11 @@ return { value = { type = "table", schema = load_value_schema }, enabled = { type = "boolean", default = true } }, - on_insert = function(plugin_t, dao, schema) + self_check = function(self, plugin_t, dao, is_update) -- Load the value schema - local value_schema, err = schema.fields.value.schema(plugin_t) + local value_schema, err = self.fields.value.schema(plugin_t) if err then - return false, err + return false, DaoError(err, constants.DATABASE_ERROR_TYPES.SCHEMA) end -- Check if the schema has a `no_consumer` field @@ -38,20 +38,20 @@ return { return false, DaoError("No consumer can be configured for that plugin", constants.DATABASE_ERROR_TYPES.SCHEMA) end - local res, err = dao.plugins_configurations:find_by_keys({ - name = plugin_t.name, - api_id = plugin_t.api_id, - consumer_id = plugin_t.consumer_id - }) + if not is_update then + local res, err = dao.plugins_configurations:find_by_keys({ + name = plugin_t.name, + api_id = plugin_t.api_id, + consumer_id = plugin_t.consumer_id + }) - if err then - return nil, DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) - end + if err then + return nil, DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) + end - if res and #res > 0 then - return false, DaoError("Plugin configuration already exists", constants.DATABASE_ERROR_TYPES.UNIQUE) - else - return true + if res and #res > 0 then + return false, DaoError("Plugin configuration already exists", constants.DATABASE_ERROR_TYPES.UNIQUE) + end end end } diff --git a/kong/dao/schemas_validation.lua b/kong/dao/schemas_validation.lua index 78ca8517f7dd..3caebaceafa9 100644 --- a/kong/dao/schemas_validation.lua +++ b/kong/dao/schemas_validation.lua @@ -1,8 +1,5 @@ local utils = require "kong.tools.utils" local stringy = require "stringy" -local DaoError = require "kong.dao.error" -local constants = require "kong.constants" -local error_types = constants.DATABASE_ERROR_TYPES local POSSIBLE_TYPES = { id = true, @@ -44,7 +41,7 @@ local _M = {} -- `is_update` For an entity update, check immutable fields. Set to true. -- @return `valid` Success of validation. True or false. -- @return `errors` A list of encountered errors during the validation. -function _M.validate_fields(t, schema, options) +function _M.validate_entity(t, schema, options) if not options then options = {} end local errors @@ -151,7 +148,7 @@ function _M.validate_fields(t, schema, options) if t[column] and type(t[column]) == "table" then -- Actually validating the sub-schema - local s_ok, s_errors = _M.validate_fields(t[column], sub_schema, options) + local s_ok, s_errors = _M.validate_entity(t[column], sub_schema, options) if not s_ok then for s_k, s_v in pairs(s_errors) do errors = utils.add_error(errors, column.."."..s_k, s_v) @@ -172,7 +169,7 @@ function _M.validate_fields(t, schema, options) -- [FUNC] Check field against a custom function -- only if there is no error on that field already. local ok, err, new_fields = v.func(t[column], t, column) - if not ok and err then + if ok == false and err then errors = utils.add_error(errors, column, err) elseif new_fields then for k, v in pairs(new_fields) do @@ -190,29 +187,14 @@ function _M.validate_fields(t, schema, options) end end - return errors == nil, errors -end - -function _M.on_insert(t, schema, dao) - if schema.on_insert and type(schema.on_insert) == "function" then - local valid, err = schema.on_insert(t, dao, schema) - if not valid or err then - return false, err - else - return true + if errors == nil and type(schema.self_check) == "function" then + local ok, err = schema.self_check(schema, t, options.dao, (options.partial_update or options.full_update)) + if ok == false then + return false, nil, err end - else - return true end -end -function _M.validate(t, dao, options) - local ok, errors - - ok, errors = _M.validate_fields(t, dao._schema, options) - if not ok then - return DaoError(errors, error_types.SCHEMA) - end + return errors == nil, errors end local digit = "[0-9a-f]" diff --git a/kong/resolver/access.lua b/kong/resolver/access.lua index 1f3712f526b2..75938dc92ea6 100644 --- a/kong/resolver/access.lua +++ b/kong/resolver/access.lua @@ -6,6 +6,49 @@ local responses = require "kong.tools.responses" local _M = {} +-- Take a public_dns and make it a pattern for wildcard matching. +-- Only do so if the public_dns actually has a wildcard. +local function create_wildcard_pattern(public_dns) + if string.find(public_dns, "*", 1, true) then + local pattern = string.gsub(public_dns, "%.", "%%.") + pattern = string.gsub(pattern, "*", ".+") + pattern = string.format("^%s$", pattern) + return pattern + end +end + +-- Load all APIs in memory. +-- Sort the data for faster lookup: dictionary per public_dns, host, +-- and an array of wildcard public_dns. +local function load_apis_in_memory() + local apis, err = dao.apis:find_all() + if err then + return nil, err + end + + -- build dictionnaries of public_dns:api and path:apis for efficient O(1) lookup. + -- we only do O(n) lookup for wildcard public_dns that are in an array. + local dns_dic, dns_wildcard, path_dic = {}, {}, {} + for _, api in ipairs(apis) do + if api.public_dns then + local pattern = create_wildcard_pattern(api.public_dns) + if pattern then + -- If the public_dns is a wildcard, we have a pattern and we can + -- store it in an array for later lookup. + table.insert(dns_wildcard, {pattern = pattern, api = api}) + else + -- Keep non-wildcard public_dns in a dictionary for faster lookup. + dns_dic[api.public_dns] = api + end + end + if api.path then + path_dic[api.path] = api + end + end + + return {by_dns = dns_dic, wildcard_dns = dns_wildcard, by_path = path_dic} +end + local function get_backend_url(api) local result = api.target_url @@ -37,7 +80,8 @@ end -- matching the API's `public_dns`, either from the `request_uri` matching the API's `path`. -- -- To perform this, we need to query _ALL_ APIs in memory. It is the only way to compare the `request_uri` --- as a regex to the values set in DB. We keep APIs in the database cache for a longer time than usual. +-- as a regex to the values set in DB, as well as matching wildcard dns. +-- We keep APIs in the database cache for a longer time than usual. -- @see https://github.com/Mashape/kong/issues/15 for an improvement on this. -- -- @param `request_uri` The URI for this request. @@ -49,31 +93,14 @@ end local function find_api(request_uri) local retrieved_api - -- retrieve all APIs - local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", function() - local apis, err = dao.apis:find_all() - if err then - return nil, err - end - - -- build dictionnaries of public_dns:api and path:apis for efficient lookup. - local dns_dic, path_dic = {}, {} - for _, api in ipairs(apis) do - if api.public_dns then - dns_dic[api.public_dns] = api - end - if api.path then - path_dic[api.path] = api - end - end - return {dns = dns_dic, path = path_dic} - end, 60) -- 60 seconds cache + -- Retrieve all APIs + local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", load_apis_in_memory, 60) -- 60 seconds cache, longer than usual if err then return err end - -- find by Host header + -- Find by Host header local all_hosts = {} for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do local hosts = ngx.req.get_headers()[header_name] @@ -85,9 +112,18 @@ local function find_api(request_uri) for _, host in ipairs(hosts) do host = unpack(stringy.split(host, ":")) table.insert(all_hosts, host) - if apis_dics.dns[host] then - retrieved_api = apis_dics.dns[host] - break + if apis_dics.by_dns[host] then + retrieved_api = apis_dics.by_dns[host] + --break + else + -- If the API was not found in the dictionary, maybe it is a wildcard public_dns. + -- In that case, we need to loop over all of them. + for _, wildcard_dns in ipairs(apis_dics.wildcard_dns) do + if string.match(host, wildcard_dns.pattern) then + retrieved_api = wildcard_dns.api + break + end + end end end end @@ -99,7 +135,7 @@ local function find_api(request_uri) end -- Otherwise, we look for it by path. We have to loop over all APIs and compare the requested URI. - for path, api in pairs(apis_dics.path) do + for path, api in pairs(apis_dics.by_path) do local m, err = ngx.re.match(request_uri, "^"..path) if err then ngx.log(ngx.ERR, "[resolver] error matching requested path: "..err) diff --git a/spec/integration/proxy/resolver_spec.lua b/spec/integration/proxy/resolver_spec.lua index 3b4bd4443b1f..22e27502afb2 100644 --- a/spec/integration/proxy/resolver_spec.lua +++ b/spec/integration/proxy/resolver_spec.lua @@ -25,11 +25,13 @@ describe("Resolver", function() spec_helper.prepare_db() spec_helper.insert_fixtures { api = { - { name = "tests host resolver 1", public_dns = "mockbin.com", target_url = "http://mockbin.com" }, - { name = "tests host resolver 2", public_dns = "mockbin-auth.com", target_url = "http://mockbin.com" }, - { name = "tests path resolver", target_url = "http://mockbin.com", path = "/status/" }, - { name = "tests stripped path resolver", target_url = "http://mockbin.com", path = "/mockbin/", strip_path = true }, - { name = "tests deep path resolver", target_url = "http://mockbin.com", path = "/deep/path/", strip_path = true } + {name = "tests host resolver 1", public_dns = "mockbin.com", target_url = "http://mockbin.com"}, + {name = "tests host resolver 2", public_dns = "mockbin-auth.com", target_url = "http://mockbin.com"}, + {name = "tests path resolver", target_url = "http://mockbin.com", path = "/status/"}, + {name = "tests stripped path resolver", target_url = "http://mockbin.com", path = "/mockbin/", strip_path = true}, + {name = "tests deep path resolver", target_url = "http://mockbin.com", path = "/deep/path/", strip_path = true}, + {name = "tests wildcard subdomain", target_url = "http://mockbin.com/status/200", public_dns = "*.wildcard.com"}, + {name = "tests wildcard subdomain 2", target_url = "http://mockbin.com/status/201", public_dns = "wildcard.*"} }, plugin_configuration = { { name = "keyauth", value = {key_names = {"apikey"} }, __api = 2 } @@ -47,8 +49,8 @@ describe("Resolver", function() it("should return Not Found when the API is not in Kong", function() local response, status = http_client.get(spec_helper.STUB_GET_URL, nil, { host = "foo.com" }) - assert.are.equal(404, status) - assert.are.equal('{"public_dns":["foo.com"],"message":"API not found with these values","path":"\\/request"}\n', response) + assert.equal(404, status) + assert.equal('{"public_dns":["foo.com"],"message":"API not found with these values","path":"\\/request"}\n', response) end) end) @@ -57,10 +59,10 @@ describe("Resolver", function() it("should work when calling SSL port", function() local response, status = http_client.get(STUB_GET_SSL_URL, nil, { host = "mockbin.com" }) - assert.are.equal(200, status) + assert.equal(200, status) assert.truthy(response) local parsed_response = cjson.decode(response) - assert.are.same("GET", parsed_response.method) + assert.same("GET", parsed_response.method) end) it("should work when manually triggering the handshake on default route", function() @@ -86,13 +88,13 @@ describe("Resolver", function() local cert = parse_cert(conn:getpeercertificate()) - assert.are.same(6, utils.table_size(cert)) - assert.are.same("Kong", cert.organizationName) - assert.are.same("IT", cert.organizationalUnitName) - assert.are.same("US", cert.countryName) - assert.are.same("California", cert.stateOrProvinceName) - assert.are.same("San Francisco", cert.localityName) - assert.are.same("localhost", cert.commonName) + assert.same(6, utils.table_size(cert)) + assert.same("Kong", cert.organizationName) + assert.same("IT", cert.organizationalUnitName) + assert.same("US", cert.countryName) + assert.same("California", cert.stateOrProvinceName) + assert.same("San Francisco", cert.localityName) + assert.same("localhost", cert.commonName) conn:close() end) @@ -103,70 +105,81 @@ describe("Resolver", function() describe("By Host", function() it("should proxy when the API is in Kong", function() - local _, status = http_client.get(STUB_GET_URL, nil, { host = "mockbin.com"}) - assert.are.equal(200, status) + local _, status = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com"}) + assert.equal(200, status) end) it("should proxy when the Host header is not trimmed", function() - local _, status = http_client.get(STUB_GET_URL, nil, { host = " mockbin.com "}) - assert.are.equal(200, status) + local _, status = http_client.get(STUB_GET_URL, nil, {host = " mockbin.com "}) + assert.equal(200, status) end) it("should proxy when the request has no Host header but the X-Host-Override header", function() - local _, status = http_client.get(STUB_GET_URL, nil, { ["X-Host-Override"] = "mockbin.com"}) - assert.are.equal(200, status) + local _, status = http_client.get(STUB_GET_URL, nil, {["X-Host-Override"] = "mockbin.com"}) + assert.equal(200, status) end) it("should proxy when the Host header contains a port", function() - local _, status = http_client.get(STUB_GET_URL, nil, { host = "mockbin.com:80"}) - assert.are.equal(200, status) + local _, status = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com:80"}) + assert.equal(200, status) end) + describe("with wildcard subdomain", function() + + it("should proxy when the public_dns is a wildcard subdomain", function() + local _, status = http_client.get(STUB_GET_URL, nil, {host = "subdomain.wildcard.com"}) + assert.equal(200, status) + + _, status = http_client.get(STUB_GET_URL, nil, {host = "wildcard.org"}) + assert.equal(201, status) + end) + + end) end) describe("By Path", function() it("should proxy when no Host is present but the request_uri matches the API's path", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/status/200") - assert.are.equal(200, status) + assert.equal(200, status) local _, status = http_client.get(spec_helper.PROXY_URL.."/status/301") - assert.are.equal(301, status) + assert.equal(301, status) end) it("should not proxy when the path does not match the start of the request_uri", function() local response, status = http_client.get(spec_helper.PROXY_URL.."/somepath/status/200") local body = cjson.decode(response) - assert.are.equal("API not found with these values", body.message) - assert.are.equal("/somepath/status/200", body.path) - assert.are.equal(404, status) + assert.equal("API not found with these values", body.message) + assert.equal("/somepath/status/200", body.path) + assert.equal(404, status) end) it("should proxy and strip the path if `strip_path` is true", function() local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request") - assert.are.equal(200, status) + assert.equal(200, status) local body = cjson.decode(response) - assert.are.equal("http://mockbin.com/request", body.url) + assert.equal("http://mockbin.com/request", body.url) end) it("should proxy when the path has a deep level", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/deep/path/status/200") - assert.are.equal(200, status) + assert.equal(200, status) end) end) it("should return the correct Server and Via headers when the request was proxied", function() local _, status, headers = http_client.get(STUB_GET_URL, nil, { host = "mockbin.com"}) - assert.are.equal(200, status) - assert.are.equal("cloudflare-nginx", headers.server) - assert.are.equal(constants.NAME.."/"..constants.VERSION, headers.via) + assert.equal(200, status) + assert.equal("cloudflare-nginx", headers.server) + assert.equal(constants.NAME.."/"..constants.VERSION, headers.via) end) it("should return the correct Server and no Via header when the request was NOT proxied", function() local _, status, headers = http_client.get(STUB_GET_URL, nil, { host = "mockbin-auth.com"}) - assert.are.equal(401, status) - assert.are.equal(constants.NAME.."/"..constants.VERSION, headers.server) + assert.equal(401, status) + assert.equal(constants.NAME.."/"..constants.VERSION, headers.server) assert.falsy(headers.via) end) diff --git a/spec/unit/dao/cassandra/base_dao_spec.lua b/spec/unit/dao/cassandra/base_dao_spec.lua index 0052a069d1a8..3101c1a1e1eb 100644 --- a/spec/unit/dao/cassandra/base_dao_spec.lua +++ b/spec/unit/dao/cassandra/base_dao_spec.lua @@ -161,7 +161,7 @@ describe("Cassandra", function() assert.are.same("consumer_id "..plugin_t.consumer_id.." does not exist", err.message.consumer_id) end) - it("should do insert checks for entities with `on_insert`", function() + it("should do insert checks for entities with `self_check`", function() local api, err = dao_factory.apis:insert(faker:fake_entity("api")) assert.falsy(err) assert.truthy(api.id) diff --git a/spec/unit/dao/entities_schemas_spec.lua b/spec/unit/dao/entities_schemas_spec.lua index c64dad441bed..c5e7efe91e8f 100644 --- a/spec/unit/dao/entities_schemas_spec.lua +++ b/spec/unit/dao/entities_schemas_spec.lua @@ -2,7 +2,7 @@ local api_schema = require "kong.dao.schemas.apis" local consumer_schema = require "kong.dao.schemas.consumers" local plugins_configurations_schema = require "kong.dao.schemas.plugins_configurations" local validations = require "kong.dao.schemas_validation" -local validate_fields = validations.validate_fields +local validate_entity = validations.validate_entity require "kong.tools.ngx_stub" @@ -26,7 +26,7 @@ describe("Entities Schemas", function() describe("APIs", function() it("should return error with wrong target_url", function() - local valid, errors = validate_fields({ + local valid, errors = validate_entity({ public_dns = "mockbin.com", target_url = "asdasd" }, api_schema) @@ -35,7 +35,7 @@ describe("Entities Schemas", function() end) it("should return error with wrong target_url protocol", function() - local valid, errors = validate_fields({ + local valid, errors = validate_entity({ public_dns = "mockbin.com", target_url = "wot://mockbin.com/" }, api_schema) @@ -44,7 +44,7 @@ describe("Entities Schemas", function() end) it("should validate without a path", function() - local valid, errors = validate_fields({ + local valid, errors = validate_entity({ public_dns = "mockbin.com", target_url = "http://mockbin.com" }, api_schema) @@ -53,7 +53,7 @@ describe("Entities Schemas", function() end) it("should validate with upper case protocol", function() - local valid, errors = validate_fields({ + local valid, errors = validate_entity({ public_dns = "mockbin.com", target_url = "HTTP://mockbin.com/world" }, api_schema) @@ -62,14 +62,14 @@ describe("Entities Schemas", function() end) it("should complain if missing `public_dns` and `path`", function() - local valid, errors = validate_fields({ + local valid, errors = validate_entity({ name = "mockbin" }, api_schema) assert.False(valid) assert.equal("At least a 'public_dns' or a 'path' must be specified", errors.path) assert.equal("At least a 'public_dns' or a 'path' must be specified", errors.public_dns) - local valid, errors = validate_fields({ + local valid, errors = validate_entity({ name = "mockbin", path = true }, api_schema) @@ -81,14 +81,54 @@ describe("Entities Schemas", function() it("should set the name from public_dns if not set", function() local t = { public_dns = "mockbin.com", target_url = "http://mockbin.com" } - local valid, errors = validate_fields(t, api_schema) + local valid, errors = validate_entity(t, api_schema) assert.falsy(errors) assert.True(valid) assert.equal("mockbin.com", t.name) end) + it("should accept valid wildcard public_dns", function() + local valid, errors = validate_entity({ + name = "mockbin", + public_dns = "*.mockbin.org", + target_url = "http://mockbin.com" + }, api_schema) + assert.True(valid) + assert.falsy(errors) + + valid, errors = validate_entity({ + name = "mockbin", + public_dns = "mockbin.*", + target_url = "http://mockbin.com" + }, api_schema) + assert.True(valid) + assert.falsy(errors) + end) + + it("should refuse invalid wildcard public_dns", function() + local api_t = { + name = "mockbin", + public_dns = "*.mockbin.*", + target_url = "http://mockbin.com" + } + + local valid, errors = validate_entity(api_t, api_schema) + assert.False(valid) + assert.equal("Only one wildcard is allowed: *.mockbin.*", errors.public_dns) + + api_t.public_dns = "*mockbin.com" + valid, errors = validate_entity(api_t, api_schema) + assert.False(valid) + assert.equal("Invalid wildcard placement: *mockbin.com", errors.public_dns) + + api_t.public_dns = "www.mockbin*" + valid, errors = validate_entity(api_t, api_schema) + assert.False(valid) + assert.equal("Invalid wildcard placement: www.mockbin*", errors.public_dns) + end) + it("should only accept alphanumeric `path`", function() - local valid, errors = validate_fields({ + local valid, errors = validate_entity({ name = "mockbin", path = "/[a-zA-Z]{3}", target_url = "http://mockbin.com" @@ -96,14 +136,14 @@ describe("Entities Schemas", function() assert.equal("path must only contain alphanumeric and '. -, _, ~, /' characters", errors.path) assert.False(valid) - valid = validate_fields({ + valid = validate_entity({ name = "mockbin", path = "/status/", target_url = "http://mockbin.com" }, api_schema) assert.True(valid) - valid = validate_fields({ + valid = validate_entity({ name = "mockbin", path = "/abcd~user-2", target_url = "http://mockbin.com" @@ -113,42 +153,42 @@ describe("Entities Schemas", function() it("should prefix a `path` with a slash and remove trailing slash", function() local api_t = { name = "mockbin", path = "status", target_url = "http://mockbin.com" } - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "/status" - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "status/" - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "/status/" - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/status", api_t.path) api_t.path = "/deep/nested/status/" - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) api_t.path = "deep/nested/status" - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) -- Strip all leading slashes api_t.path = "//deep/nested/status" - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) -- Strip all trailing slashes api_t.path = "/deep/nested/status//" - validate_fields(api_t, api_schema) + validate_entity(api_t, api_schema) assert.equal("/deep/nested/status", api_t.path) -- Error if invalid path api_t.path = "/deep//nested/status" - local _, errors = validate_fields(api_t, api_schema) + local _, errors = validate_entity(api_t, api_schema) assert.equal("path is invalid: /deep//nested/status", errors.path) end) @@ -157,17 +197,17 @@ describe("Entities Schemas", function() describe("Consumers", function() it("should require a `custom_id` or `username`", function() - local valid, errors = validate_fields({}, consumer_schema) + local valid, errors = validate_entity({}, consumer_schema) assert.False(valid) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.username) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.custom_id) - valid, errors = validate_fields({ username = "" }, consumer_schema) + valid, errors = validate_entity({ username = "" }, consumer_schema) assert.False(valid) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.username) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.custom_id) - valid, errors = validate_fields({ username = true }, consumer_schema) + valid, errors = validate_entity({ username = true }, consumer_schema) assert.False(valid) assert.equal("username is not a string", errors.username) assert.equal("At least a 'custom_id' or a 'username' must be specified", errors.custom_id) @@ -177,8 +217,16 @@ describe("Entities Schemas", function() describe("Plugins Configurations", function() + local dao_stub = { + plugins_configurations = { + find_by_keys = function() + return nil + end + } + } + it("should not validate if the plugin doesn't exist (not installed)", function() - local valid, errors = validate_fields({name = "world domination"}, plugins_configurations_schema) + local valid, errors = validate_entity({name = "world domination"}, plugins_configurations_schema) assert.False(valid) assert.equal("Plugin \"world domination\" not found", errors.value) end) @@ -186,19 +234,19 @@ describe("Entities Schemas", function() it("should validate a plugin configuration's `value` field", function() -- Success local plugin = {name = "keyauth", api_id = "stub", value = {key_names = {"x-kong-key"}}} - local valid = validate_fields(plugin, plugins_configurations_schema) + local valid = validate_entity(plugin, plugins_configurations_schema, {dao = dao_stub}) assert.True(valid) -- Failure plugin = {name = "ratelimiting", api_id = "stub", value = {period = "hello"}} - local valid, errors = validate_fields(plugin, plugins_configurations_schema) + local valid, errors = validate_entity(plugin, plugins_configurations_schema, {dao = dao_stub}) assert.False(valid) assert.equal("limit is required", errors["value.limit"]) assert.equal("\"hello\" is not allowed. Allowed values are: \"second\", \"minute\", \"hour\", \"day\", \"month\", \"year\"", errors["value.period"]) end) - describe("on_insert", function() + describe("self_check", function() it("should refuse `consumer_id` if specified in the value schema", function() local stub_value_schema = { no_consumer = true, @@ -211,19 +259,11 @@ describe("Entities Schemas", function() return stub_value_schema end - local valid, err = validations.on_insert({name = "stub", api_id = "0000", consumer_id = "0000", value = {string = "foo"}}, plugins_configurations_schema) + local valid, _, err = validate_entity({name = "stub", api_id = "0000", consumer_id = "0000", value = {string = "foo"}}, plugins_configurations_schema) assert.False(valid) assert.equal("No consumer can be configured for that plugin", err.message) - local dao_stub = { - plugins_configurations = { - find_by_keys = function() - return nil - end - } - } - - valid, err = validations.on_insert({name = "stub", api_id = "0000", value = {string = "foo"}}, plugins_configurations_schema, dao_stub) + valid, err = validate_entity({name = "stub", api_id = "0000", value = {string = "foo"}}, plugins_configurations_schema, {dao = dao_stub}) assert.True(valid) assert.falsy(err) end) diff --git a/spec/unit/dao/oauth2/oauth2_entities_spec.lua b/spec/unit/dao/oauth2/oauth2_entities_spec.lua index 1663bb3dd5f1..dbc1a049144c 100644 --- a/spec/unit/dao/oauth2/oauth2_entities_spec.lua +++ b/spec/unit/dao/oauth2/oauth2_entities_spec.lua @@ -1,4 +1,4 @@ -local validate_fields = require("kong.dao.schemas_validation").validate_fields +local validate_entity = require("kong.dao.schemas_validation").validate_entity local oauth2_schema = require "kong.plugins.oauth2.schema" require "kong.tools.ngx_stub" @@ -8,26 +8,26 @@ describe("OAuth2 Entities Schemas", function() describe("OAuth2 Configuration", function() it("should not require a `scopes` when `mandatory_scope` is false", function() - local valid, errors = validate_fields({ mandatory_scope = false }, oauth2_schema) + local valid, errors = validate_entity({ mandatory_scope = false }, oauth2_schema) assert.truthy(valid) assert.falsy(errors) end) it("should require a `scopes` when `mandatory_scope` is true", function() - local valid, errors = validate_fields({ mandatory_scope = true }, oauth2_schema) + local valid, errors = validate_entity({ mandatory_scope = true }, oauth2_schema) assert.falsy(valid) assert.equal("To set a mandatory scope you also need to create available scopes", errors.mandatory_scope) end) it("should pass when both `scopes` when `mandatory_scope` are passed", function() - local valid, errors = validate_fields({ mandatory_scope = true, scopes = { "email", "info" } }, oauth2_schema) + local valid, errors = validate_entity({ mandatory_scope = true, scopes = { "email", "info" } }, oauth2_schema) assert.truthy(valid) assert.falsy(errors) end) it("should autogenerate a `provision_key` when it is not being passed", function() local t = { mandatory_scope = true, scopes = { "email", "info" } } - local valid, errors = validate_fields(t, oauth2_schema) + local valid, errors = validate_entity(t, oauth2_schema) assert.truthy(valid) assert.falsy(errors) assert.truthy(t.provision_key) @@ -36,7 +36,7 @@ describe("OAuth2 Entities Schemas", function() it("should not autogenerate a `provision_key` when it is being passed", function() local t = { mandatory_scope = true, scopes = { "email", "info" }, provision_key = "hello" } - local valid, errors = validate_fields(t, oauth2_schema) + local valid, errors = validate_entity(t, oauth2_schema) assert.truthy(valid) assert.falsy(errors) assert.truthy(t.provision_key) diff --git a/spec/unit/schemas_spec.lua b/spec/unit/schemas_spec.lua index e7b86ca6ac4f..2960e722a757 100644 --- a/spec/unit/schemas_spec.lua +++ b/spec/unit/schemas_spec.lua @@ -1,5 +1,5 @@ local schemas = require "kong.dao.schemas_validation" -local validate_fields = schemas.validate_fields +local validate_entity = schemas.validate_entity require "kong.tools.ngx_stub" @@ -7,7 +7,7 @@ describe("Schemas", function() -- Ok kids, today we're gonna test a custom validation schema, -- grab a pair of glasses, this stuff can literally explode. - describe("#validate_fields()", function() + describe("#validate_entity()", function() local schema = { fields = { string = { type = "string", required = true, immutable = true}, @@ -39,7 +39,7 @@ describe("Schemas", function() it("should confirm a valid entity is valid", function() local values = {string = "mockbin entity", url = "mockbin.com"} - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.True(valid) end) @@ -48,7 +48,7 @@ describe("Schemas", function() it("should invalidate entity if required property is missing", function() local values = { url = "mockbin.com" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.False(valid) assert.truthy(err) assert.are.same("string is required", err.string) @@ -60,7 +60,7 @@ describe("Schemas", function() -- Failure local values = { string = "foo", table = "bar" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.False(valid) assert.truthy(err) assert.are.same("table is not a table", err.table) @@ -68,14 +68,14 @@ describe("Schemas", function() -- Success local values = { string = "foo", table = { foo = "bar" }} - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.True(valid) -- Failure local values = { string = 1, table = { foo = "bar" }} - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.False(valid) assert.truthy(err) assert.are.same("string is not a string", err.string) @@ -83,14 +83,14 @@ describe("Schemas", function() -- Success local values = { string = "foo", number = 10 } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.True(valid) -- Success local values = { string = "foo", number = "10" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("number", type(values.number)) @@ -98,7 +98,7 @@ describe("Schemas", function() -- Success local values = { string = "foo", boolean_val = true } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("boolean", type(values.boolean_val)) @@ -106,14 +106,14 @@ describe("Schemas", function() -- Success local values = { string = "foo", boolean_val = "true" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) -- Failure local values = { string = "foo", endpoint = "" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.equal("endpoint is not a url", err.endpoint) @@ -121,28 +121,28 @@ describe("Schemas", function() -- Failure local values = { string = "foo", endpoint = "asdasd" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) -- Failure local values = { string = "foo", endpoint = "http://google.com" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) -- Success local values = { string = "foo", endpoint = "http://google.com/" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.truthy(valid) assert.falsy(err) -- Success local values = { string = "foo", endpoint = "http://google.com/hello/?world=asd" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.truthy(valid) assert.falsy(err) end) @@ -150,7 +150,7 @@ describe("Schemas", function() it("should return error when an invalid boolean value is passed", function() local values = { string = "test", boolean_val = "ciao" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("boolean_val is not a boolean", err.boolean_val) @@ -159,7 +159,7 @@ describe("Schemas", function() it("should not return an error when a true boolean value is passed", function() local values = { string = "test", boolean_val = true } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) end) @@ -167,7 +167,7 @@ describe("Schemas", function() it("should not return an error when a false boolean value is passed", function() local values = { string = "test", boolean_val = false } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) end) @@ -181,7 +181,7 @@ describe("Schemas", function() local values = { id = "123" } - local valid, err = validate_fields(values, s) + local valid, err = validate_entity(values, s) assert.falsy(err) assert.truthy(valid) end) @@ -196,14 +196,14 @@ describe("Schemas", function() -- Success local values = { array = {"hello", "world"} } - local valid, err = validate_fields(values, s) + local valid, err = validate_entity(values, s) assert.True(valid) assert.falsy(err) -- Failure local values = { array = {hello="world"} } - local valid, err = validate_fields(values, s) + local valid, err = validate_entity(values, s) assert.False(valid) assert.truthy(err) assert.equal("array is not a array", err.array) @@ -213,7 +213,7 @@ describe("Schemas", function() it("should not return an error when a `number` is passed as a string", function() local values = { string = "test", number = "10" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.same("number", type(values.number)) @@ -222,7 +222,7 @@ describe("Schemas", function() it("should not return an error when a `boolean` is passed as a string", function() local values = { string = "test", boolean_val = "false" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.same("boolean", type(values.boolean_val)) @@ -238,7 +238,7 @@ describe("Schemas", function() -- It should also strip the resulting strings local values = { array = "hello, world" } - local valid, err = validate_fields(values, s) + local valid, err = validate_entity(values, s) assert.True(valid) assert.falsy(err) assert.same({"hello", "world"}, values.array) @@ -251,7 +251,7 @@ describe("Schemas", function() -- Variables local values = { string = "mockbin entity", url = "mockbin.com" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same(123456, values.date) @@ -259,7 +259,7 @@ describe("Schemas", function() -- Functions local values = { string = "mockbin entity", url = "mockbin.com" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("default", values.default) @@ -269,7 +269,7 @@ describe("Schemas", function() -- Variables local values = { string = "mockbin entity", url = "mockbin.com", date = 654321 } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same(654321, values.date) @@ -277,7 +277,7 @@ describe("Schemas", function() -- Functions local values = { string = "mockbin entity", url = "mockbin.com", default = "abcdef" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) assert.are.same("abcdef", values.default) @@ -289,7 +289,7 @@ describe("Schemas", function() it("should validate a field against a regex", function() local values = { string = "mockbin entity", url = "mockbin_!" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("url has an invalid value", err.url) @@ -301,14 +301,14 @@ describe("Schemas", function() -- Success local values = { string = "somestring", allowed = "hello" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) -- Failure local values = { string = "somestring", allowed = "hello123" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("\"hello123\" is not allowed. Allowed values are: \"hello\", \"world\"", err.allowed) @@ -320,14 +320,14 @@ describe("Schemas", function() -- Success local values = { string = "somestring", custom = true, default = "test_custom_func" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) -- Failure local values = { string = "somestring", custom = true, default = "not the default :O" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("Nah", err.custom) @@ -346,7 +346,7 @@ describe("Schemas", function() it("should call a given function when encountering a field with `dao_insert_value`", function() local values = {string = "hello", id = "0000"} - local valid, err = validate_fields(values, schema, {dao_insert = function(field) + local valid, err = validate_entity(values, schema, {dao_insert = function(field) if field.type == "id" then return "1234" elseif field.type == "timestamp" then @@ -363,7 +363,7 @@ describe("Schemas", function() it("should not raise any error if the function is not given", function() local values = { string = "hello", id = "0000" } - local valid, err = validate_fields(values, schema, { dao_insert = true }) -- invalid type + local valid, err = validate_entity(values, schema, { dao_insert = true }) -- invalid type assert.falsy(err) assert.True(valid) assert.equal("0000", values.id) @@ -375,7 +375,7 @@ describe("Schemas", function() it("should return error when unexpected values are included in the schema", function() local values = { string = "mockbin entity", url = "mockbin.com", unexpected = "abcdef" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) end) @@ -383,7 +383,7 @@ describe("Schemas", function() it("should be able to return multiple errors at once", function() local values = { url = "mockbin.com", unexpected = "abcdef" } - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(valid) assert.truthy(err) assert.are.same("string is required", err.string) @@ -399,7 +399,7 @@ describe("Schemas", function() } assert.has_no_errors(function() - local valid, err = validate_fields({}, schema) + local valid, err = validate_entity({}, schema) assert.False(valid) assert.are.same("property is required", err.property) end) @@ -407,7 +407,7 @@ describe("Schemas", function() describe("Sub-schemas", function() -- To check wether schema_from_function was called, we will simply use booleans because - -- busted's spy methods create tables and metatable magic, but the validate_fields() function + -- busted's spy methods create tables and metatable magic, but the validate_entity() function -- only callse v.schema if the type is a function. Which is not the case with a busted spy. local called, called_with local schema_from_function = function(t) @@ -441,7 +441,7 @@ describe("Schemas", function() -- Success local values = { some_required = "somestring", sub_schema = { sub_field_required = "sub value" }} - local valid, err = validate_fields(values, nested_schema) + local valid, err = validate_entity(values, nested_schema) assert.falsy(err) assert.truthy(valid) assert.are.same("abcd", values.sub_schema.sub_field_default) @@ -449,7 +449,7 @@ describe("Schemas", function() -- Failure local values = { some_required = "somestring", sub_schema = { sub_field_default = "" }} - local valid, err = validate_fields(values, nested_schema) + local valid, err = validate_entity(values, nested_schema) assert.truthy(err) assert.falsy(valid) assert.are.same("sub_field_required is required", err["sub_schema.sub_field_required"]) @@ -464,7 +464,7 @@ describe("Schemas", function() sub_sub_schema = { sub_sub_field_required = "test" } }} - local valid, err = validate_fields(values, nested_schema) + local valid, err = validate_entity(values, nested_schema) assert.falsy(err) assert.truthy(valid) @@ -474,7 +474,7 @@ describe("Schemas", function() sub_sub_schema = {} }} - local valid, err = validate_fields(values, nested_schema) + local valid, err = validate_entity(values, nested_schema) assert.truthy(err) assert.falsy(valid) assert.are.same("sub_sub_field_required is required", err["sub_schema.sub_sub_schema.sub_sub_field_required"]) @@ -487,7 +487,7 @@ describe("Schemas", function() sub_sub_schema = { sub_sub_field_required = "test" } }} - local valid, err = validate_fields(values, nested_schema) + local valid, err = validate_entity(values, nested_schema) assert.falsy(err) assert.truthy(valid) assert.True(called) @@ -502,7 +502,7 @@ describe("Schemas", function() sub_sub_schema = { sub_sub_field_required = "test" } }} - local valid, err = validate_fields(values, nested_schema) + local valid, err = validate_entity(values, nested_schema) assert.truthy(err) assert.falsy(valid) assert.are.same("Error loading the sub-sub-schema", err["sub_schema.sub_sub_schema"]) @@ -523,7 +523,7 @@ describe("Schemas", function() } local obj = {} - local valid, err = validate_fields(obj, schema) + local valid, err = validate_entity(obj, schema) assert.falsy(err) assert.True(valid) assert.are.same("hello", obj.value.some_property) @@ -537,7 +537,7 @@ describe("Schemas", function() } local obj = {} - local valid, err = validate_fields(obj, schema) + local valid, err = validate_entity(obj, schema) assert.truthy(err) assert.False(valid) assert.are.same("value.some_property is required", err.value) @@ -549,7 +549,7 @@ describe("Schemas", function() it("should ignore required properties and defaults if we are updating because the entity might be partial", function() local values = {} - local valid, err = validate_fields(values, schema, {partial_update = true}) + local valid, err = validate_entity(values, schema, {partial_update = true}) assert.falsy(err) assert.True(valid) assert.falsy(values.default) @@ -559,7 +559,7 @@ describe("Schemas", function() it("should still validate set properties", function() local values = { string = 123 } - local valid, err = validate_fields(values, schema, {partial_update = true}) + local valid, err = validate_entity(values, schema, {partial_update = true}) assert.False(valid) assert.equal("string is not a string", err.string) end) @@ -567,7 +567,7 @@ describe("Schemas", function() it("should ignore immutable fields if they are required", function() local values = { string = "somestring" } - local valid, err = validate_fields(values, schema, {partial_update = true}) + local valid, err = validate_entity(values, schema, {partial_update = true}) assert.falsy(err) assert.True(valid) end) @@ -576,12 +576,12 @@ describe("Schemas", function() -- Success local values = {string = "somestring", date = 1234} - local valid, err = validate_fields(values, schema) + local valid, err = validate_entity(values, schema) assert.falsy(err) assert.truthy(valid) -- Failure - local valid, err = validate_fields(values, schema, {partial_update = true}) + local valid, err = validate_entity(values, schema, {partial_update = true}) assert.False(valid) assert.truthy(err) assert.equal("date cannot be updated", err.date) @@ -592,7 +592,7 @@ describe("Schemas", function() it("should not ignore required properties and ignore defaults", function() local values = {} - local valid, err = validate_fields(values, schema, {full_update = true}) + local valid, err = validate_entity(values, schema, {full_update = true}) assert.False(valid) assert.truthy(err) assert.equal("string is required", err.string)