diff --git a/.ci/setup_cassandra.sh b/.ci/setup_cassandra.sh index de57890592a6..fffd68be04bb 100644 --- a/.ci/setup_cassandra.sh +++ b/.ci/setup_cassandra.sh @@ -2,6 +2,18 @@ CASSANDRA_BASE=apache-cassandra-$CASSANDRA_VERSION -sudo rm -rf /var/lib/cassandra/* -curl http://apache.arvixe.com/cassandra/$CASSANDRA_VERSION/$CASSANDRA_BASE-bin.tar.gz | tar xz +n=0 +until [ $n -ge 5 ] +do + sudo rm -rf /var/lib/cassandra/* + curl http://archive.apache.org/dist/cassandra/$CASSANDRA_VERSION/$CASSANDRA_BASE-bin.tar.gz | tar xz && break + n=$[$n+1] + sleep 5 +done + +if [[ ! -f $CASSANDRA_BASE/bin/cassandra ]] ; then + echo 'Failed downloading and unpacking cassandra. Aborting.' + exit 1 +fi + sudo sh $CASSANDRA_BASE/bin/cassandra diff --git a/.ci/setup_kong.sh b/.ci/setup_kong.sh index 06b074144877..87785ca6b6fd 100644 --- a/.ci/setup_kong.sh +++ b/.ci/setup_kong.sh @@ -5,10 +5,10 @@ KONG_VERSION=0.5.0 sudo apt-get update # Installing dependencies required to build development rocks -sudo apt-get install wget curl tar make gcc unzip git liblua5.1-0-dev +sudo apt-get install wget curl tar make gcc unzip git # Installing dependencies required for Kong -sudo apt-get install sudo netcat lua5.1 openssl libpcre3 dnsmasq +sudo apt-get install sudo netcat openssl libpcre3 dnsmasq uuid-dev # Installing Kong and its dependencies sudo apt-get install lsb-release @@ -18,4 +18,4 @@ curl -L -o $KONG_FILE https://github.com/Mashape/kong/releases/download/$KONG_VE sudo dpkg -i $KONG_FILE sudo luarocks remove kong --force -sudo rm -rf /etc/kong \ No newline at end of file +sudo rm -rf /etc/kong diff --git a/.luacheckrc b/.luacheckrc index b5652c1fca3d..1eb3eaf1a7cd 100644 --- a/.luacheckrc +++ b/.luacheckrc @@ -1,6 +1,6 @@ redefined = false unused_args = false -globals = {"ngx", "dao", "app", "configuration"} +globals = {"ngx", "dao", "app", "configuration", "process_id"} files["kong/"] = { std = "luajit" diff --git a/.travis.yml b/.travis.yml index 28ebb51b02d3..044a0bc1c98d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,14 +2,10 @@ language: erlang env: global: - - CASSANDRA_VERSION=2.1.10 + - CASSANDRA_VERSION=2.1.11 matrix: - LUA=lua5.1 -branches: - only: - - master - before_install: - "bash .ci/setup_kong.sh" - "bash .ci/setup_cassandra.sh" diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f1ef02adced..ac8b479406a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ ## [Unreleased][unreleased] +### Breaking changes + +- Drop the Lua 5.1 dependency which was only used for Kong's CLI. The CLI now runs against LuaJIT, which is consistent with other Kong components (Luarocks and OpenResty) already relying on LuaJIT. Make sure the LuaJIT interpreter is included in your `$PATH`. [#799](https://github.com/Mashape/kong/pull/799) + +### Added + +- A new `total` field in API responses, that counts the total number of entities in the response body. [#635](https://github.com/Mashape/kong/pull/635) +- Dnsmasq is now optional. You can specify a custom DNS resolver address that Kong will use when resolving hostnames. This can be configured in `kong.yml`. [#625](https://github.com/Mashape/kong/pull/635) + +### Changed + +- Disable access logs for `/status` endpoint. +- The `/status` endpoint now includes `database` statistics, while the previous stats have been moved to a `server` field. [#635](https://github.com/Mashape/kong/pull/635) + +### Fixed + +- In the Admin API responses, the `next` link is not being displayed anymore if there are no more entities to be returned. [#635](https://github.com/Mashape/kong/pull/635) + ## [0.5.4] - 2015/12/03 ### Fixed @@ -14,8 +32,8 @@ ### Fixed - Avoids additional URL encoding when proxying to an upstream service. [#691](https://github.com/Mashape/kong/pull/691) -- Fixing potential timing comparison bug in HMAC plugin. [#704](https://github.com/Mashape/kong/pull/704) -- Fixed a missing "env" statement in the Nginx configuration. [#706](https://github.com/Mashape/kong/pull/706) +- Potential timing comparison bug in HMAC plugin. [#704](https://github.com/Mashape/kong/pull/704) +- A missing "env" statement in the Nginx configuration. [#706](https://github.com/Mashape/kong/pull/706) ### Added @@ -73,7 +91,6 @@ Several breaking changes are introduced. You will have to slightly change your c - `strip_path` -> `strip_request_path` - `target_url` -> `upstream_url` - `plugins_configurations` have been renamed to `plugins`, and their `value` property has been renamed to `config` to avoid confusions. [#513](https://github.com/Mashape/kong/issues/513) ->>>>>>> dbocs(changelog) 0.5.0 changes - The database schema has been updated to handle the separation of plugins outside of the core repository. - The Key authentication and Basic authentication plugins routes have changed: diff --git a/bin/kong b/bin/kong index d34526c1fba8..589b197fbc8a 100755 --- a/bin/kong +++ b/bin/kong @@ -1,4 +1,4 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit -- Kong CLI entry-point (bin/kong). -- @@ -49,7 +49,5 @@ elseif not commands[cmd] then os.exit(1) end -require "kong.tools.ngx_stub" - -- Load and execute desired command require(commands[cmd]) diff --git a/circle.yml b/circle.yml index c7508cb25b12..0560db0364c1 100644 --- a/circle.yml +++ b/circle.yml @@ -10,6 +10,9 @@ dependencies: - sudo make dev test: override: - - busted -o spec/busted-print.lua --coverage spec/ + - busted -v -o spec/busted-print.lua --coverage : + parallel: true + files: + - spec/**/*_spec.lua post: - make lint diff --git a/config.ld b/config.ld deleted file mode 100644 index 8484b69582f9..000000000000 --- a/config.ld +++ /dev/null @@ -1,21 +0,0 @@ --- LDoc configuration file --- See: https://github.com/stevedonovan/LDoc --- --- LDoc installation; --- luarocks install ldoc --- --- Generate the documentation from the Kong codebase; --- ldoc . --- - -project='Kong' -title='Kong by Mashape' -description='Kong manages Microservices & APIs in a secure and extensible way' -format='discount' -file={'./kong/', './bin/kong'} -dir='doc' -readme='readme.md' -sort=true -sort_modules=true -not_luadoc=true -all=false diff --git a/kong-0.5.4-1.rockspec b/kong-0.5.4-1.rockspec index ce2f5a46d76c..1a60a78d3bd7 100644 --- a/kong-0.5.4-1.rockspec +++ b/kong-0.5.4-1.rockspec @@ -14,12 +14,13 @@ dependencies = { "lua ~> 5.1", "luasec ~> 0.5-2", - "uuid ~> 0.2-1", + "lua_uuid ~> 0.1-8", + "lua_system_constants ~> 0.1-3", "luatz ~> 0.3-1", - "yaml ~> 1.1.1-1", - "lapis ~> 1.1.0-1", + "yaml ~> 1.1.2-1", + "lapis ~> 1.3.1-1", "stringy ~> 0.4-1", - "lua-cassandra ~> 0.3.6-0", + "lua-cassandra ~> 0.4.1-0", "multipart ~> 0.2-1", "lua-path ~> 0.2.3-1", "lua-cjson ~> 2.1.0-1", @@ -27,10 +28,11 @@ dependencies = { "lbase64 ~> 20120820-1", "lua-resty-iputils ~> 0.2.0-1", - "luasocket ~> 2.0.2-5", + "luasocket ~> 2.0.2-6", "lrexlib-pcre ~> 2.7.2-1", "lua-llthreads2 ~> 0.1.3-1", - "luacrypto >= 0.3.2-1" + "luacrypto >= 0.3.2-1", + "luasyslog >= 1.0.0-2" } build = { type = "builtin", @@ -44,7 +46,7 @@ build = { ["kong.constants"] = "kong/constants.lua", - ["kong.cli.utils"] = "kong/cli/utils/utils.lua", + ["kong.cli.utils"] = "kong/cli/utils.lua", ["kong.cli.utils.dnsmasq"] = "kong/cli/utils/dnsmasq.lua", ["kong.cli.utils.ssl"] = "kong/cli/utils/ssl.lua", ["kong.cli.utils.signal"] = "kong/cli/utils/signal.lua", @@ -70,15 +72,15 @@ build = { ["kong.tools.migrations"] = "kong/tools/migrations.lua", ["kong.tools.http_client"] = "kong/tools/http_client.lua", ["kong.tools.database_cache"] = "kong/tools/database_cache.lua", + ["kong.tools.config_defaults"] = "kong/tools/config_defaults.lua", + ["kong.tools.config_loader"] = "kong/tools/config_loader.lua", + ["kong.tools.dao_loader"] = "kong/tools/dao_loader.lua", - ["kong.resolver.handler"] = "kong/resolver/handler.lua", - ["kong.resolver.access"] = "kong/resolver/access.lua", - ["kong.resolver.header_filter"] = "kong/resolver/header_filter.lua", - ["kong.resolver.certificate"] = "kong/resolver/certificate.lua", - - ["kong.reports.handler"] = "kong/reports/handler.lua", - ["kong.reports.init_worker"] = "kong/reports/init_worker.lua", - ["kong.reports.log"] = "kong/reports/log.lua", + ["kong.core.handler"] = "kong/core/handler.lua", + ["kong.core.certificate"] = "kong/core/certificate.lua", + ["kong.core.resolver"] = "kong/core/resolver.lua", + ["kong.core.plugins_iterator"] = "kong/core/plugins_iterator.lua", + ["kong.core.reports"] = "kong/core/reports.lua", ["kong.dao.cassandra.schema.migrations"] = "kong/dao/cassandra/schema/migrations.lua", ["kong.dao.error"] = "kong/dao/error.lua", @@ -220,7 +222,15 @@ build = { ["kong.plugins.hmac-auth.access"] = "kong/plugins/hmac-auth/access.lua", ["kong.plugins.hmac-auth.schema"] = "kong/plugins/hmac-auth/schema.lua", ["kong.plugins.hmac-auth.api"] = "kong/plugins/hmac-auth/api.lua", - ["kong.plugins.hmac-auth.daos"] = "kong/plugins/hmac-auth/daos.lua" + ["kong.plugins.hmac-auth.daos"] = "kong/plugins/hmac-auth/daos.lua", + + ["kong.plugins.syslog.handler"] = "kong/plugins/syslog/handler.lua", + ["kong.plugins.syslog.log"] = "kong/plugins/syslog/log.lua", + ["kong.plugins.syslog.schema"] = "kong/plugins/syslog/schema.lua", + + ["kong.plugins.loggly.handler"] = "kong/plugins/loggly/handler.lua", + ["kong.plugins.loggly.log"] = "kong/plugins/loggly/log.lua", + ["kong.plugins.loggly.schema"] = "kong/plugins/loggly/schema.lua" }, install = { conf = { "kong.yml" }, diff --git a/kong.yml b/kong.yml index c86d7f9e6afd..5dae1d966c85 100644 --- a/kong.yml +++ b/kong.yml @@ -1,70 +1,135 @@ -## Available plugins on this server -plugins_available: - - ssl - - jwt - - acl - - cors - - oauth2 - - tcp-log - - udp-log - - file-log - - http-log - - key-auth - - hmac-auth - - basic-auth - - ip-restriction - - mashape-analytics - - request-transformer - - response-transformer - - request-size-limiting - - rate-limiting - - response-ratelimiting - -## The Kong working directory -## (Make sure you have read and write permissions) -nginx_working_dir: /usr/local/kong/ - -## Port configuration -proxy_port: 8000 -proxy_ssl_port: 8443 -admin_api_port: 8001 - -## Secondary port configuration -dnsmasq_port: 8053 - -## Specify the DAO to use -database: cassandra - -## Databases configuration -databases_available: - cassandra: - properties: - contact_points: - - "localhost:9042" - timeout: 1000 - keyspace: kong - keepalive: 60000 # in milliseconds - # ssl: false - # ssl_verify: false - # ssl_certificate: "/path/to/cluster-ca-certificate.pem" - # user: cassandra - # password: cassandra - -## Cassandra cache configuration -database_cache_expiration: 5 # in seconds - -## SSL Settings -## (Uncomment the two properties below to set your own certificate) +###### +## Kong configuration file. All commented values are default values. +## Uncomment and update a value to configure Kong to your needs. +## +## Lines starting with `##` are comments. +## Lines starting with `#` are properties that can be updated. +## Beware of YAML formatting for nested properties. + +###### +## Plugins that this node needs to execute. +## By default, Kong will try to execute all installed plugins on every request. +## If you are sure to only use a few plugins, uncomment and update this property to contain +## only those. +## Custom plugins also need to be added to this list. +# plugins_available: +# - ssl +# - jwt +# - ... + +###### +## The Kong working directory. Equivalent to nginx's prefix path. +## This is where this running nginx instance will keep server files including logs. +## Make sure it has the appropriate permissions. +# nginx_working_dir: /usr/local/kong/ + +###### +## Port which Kong proxies HTTP requests through, consumers will make requests against this port +## so make sure it is publicly available. +# proxy_port: 8000 + +###### +## Same as proxy_port, but for HTTPS requests. +# proxy_ssl_port: 8443 + +###### +## Specify how Kong performs DNS resolution (in the `dns_resolvers_available` property) you want to use. +## Options are: "dnsmasq" (You will need dnsmasq to be installed) or "server". +# dns_resolver: dnsmasq + +###### +## DNS resolvers configuration. Specify a DNS server or the port on which you want +## dnsmasq to run. +# dns_resolvers_available: + # server: + # address: "8.8.8.8:53" + # dnsmasq: + # port: 8053 + +###### +## Port on which the admin API will listen to. The admin API is a private API which lets you +## manage your Kong infrastructure. It needs to be secured appropriatly. +# admin_api_port: 8001 + +###### +## Specify which database to use from the databases_available property. +# database: cassandra + +###### +## Databases configuration. +# databases_available: + # cassandra: + ###### + ## Contact points to your Cassandra cluster. + # contact_points: + # - "localhost:9042" + + ###### + ## Name of the keyspace used by Kong. Will be created if it does not exist. + # keyspace: kong + + ###### + ## Keyspace options. Set those before running Kong or any migration. + ## Those settings will be used to create a keyspace with the desired options + ## when first running the migrations. + ## See http://docs.datastax.com/en/cql/3.1/cql/cql_reference/create_keyspace_r.html + ###### + ## The name of the replica placement strategy class for the keyspace. + ## Can be "SimpleStrategy" or "NetworkTopologyStrategy". + # replication_strategy: SimpleStrategy + ###### + ## For SimpleStrategy only. + ## The number of replicas of data on multiple nodes. + # replication_factor: 1 + ###### + ## For NetworkTopologyStrategy only. + ## The number of replicas of data on multiple nodes in each data center. + # data_centers: + # dc1: 2 + # dc2: 3 + + ##### + ## Client-to-node TLS options. + ## `enabled`: if true, will connect to your Cassandra instance using TLS. + ## `verify`: if true, will verify the server certificate using the given CA file. + ## `certificate_authority`: an absolute path to the trusted CA certificate in PEM format used to verify the server certificate. + ## For additional SSL settings, see the ngx_lua `lua_ssl_*` directives. + # ssl: + # enabled: false + # verify: false + # certificate_authority: "/path/to/cluster-ca-certificate.pem" + + ###### + ## Cluster authentication options. Provide a user and a password here if your cluster uses the + ## PasswordAuthenticator scheme. + # user: cassandra + # password: cassandra + +###### +## Time (in seconds) for which entities from the database (APIs, plugins configurations...) +## are cached by Kong. Increase this value if you want to lower the number of requests made +## to your database. +# database_cache_expiration: 5 + +###### +## SSL certificates to use. # ssl_cert_path: /path/to/certificate.pem # ssl_key_path: /path/to/certificate.key -## Sends anonymous error reports -send_anonymous_reports: true +###### +## Sending anonymous error reports helps Kong developers to understand how it performs. +# send_anonymous_reports: true -## In-memory cache size (MB) -memory_cache_size: 128 +###### +## Size (in MB) of the Lua cache. This value may not be smaller than 32MB. +# memory_cache_size: 128 -## Nginx configuration +###### +## The nginx configuration file which allows Kong to run. +## The placeholders will be computed and this property will be written as a file +## by Kong at `/nginx.conf` during startup. +## This file can tweaked to some extent, but many directives are necessary for Kong to work. +## /!\ BE CAREFUL nginx: | worker_processes auto; error_log logs/error.log error; @@ -122,6 +187,8 @@ nginx: | lua_max_pending_timers 16384; lua_shared_dict locks 100k; lua_shared_dict cache {{memory_cache_size}}m; + lua_shared_dict cassandra 1m; + lua_shared_dict cassandra_prepared 5m; lua_socket_log_errors off; {{lua_ssl_trusted_certificate}} @@ -151,18 +218,19 @@ nginx: | default_type 'text/plain'; # These properties will be used later by proxy_pass - set $backend_host nil; - set $backend_url nil; + set $upstream_host nil; + set $upstream_url nil; # Authenticate the user and load the API info access_by_lua 'kong.exec_plugins_access()'; + # Proxy the request # Proxy the request proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; - proxy_set_header Host $backend_host; - proxy_pass $backend_url; + proxy_set_header Host $upstream_host; + proxy_pass $upstream_url; proxy_pass_header Server; # Add additional response headers @@ -192,6 +260,9 @@ nginx: | server { listen {{admin_api_port}}; + client_max_body_size 10m; + client_body_buffer_size 10m; + location / { default_type application/json; content_by_lua ' @@ -208,6 +279,7 @@ nginx: | location /nginx_status { internal; + access_log off; stub_status; } diff --git a/kong/api/crud_helpers.lua b/kong/api/crud_helpers.lua index 9a394e4d220a..1ac78665832d 100644 --- a/kong/api/crud_helpers.lua +++ b/kong/api/crud_helpers.lua @@ -51,22 +51,36 @@ function _M.paginated_set(self, dao_collection) return app_helpers.yield_error(err) end + local total, err = dao_collection:count_by_keys(self.params) + if err then + return app_helpers.yield_error(err) + end + local next_url if data.next_page then - next_url = self:build_url(self.req.parsed_url.path, { - port = self.req.parsed_url.port, - query = ngx.encode_args({ - offset = ngx.encode_base64(data.next_page), - size = size - }) - }) + -- Parse next URL, if there are no elements then don't append it + local next_total, err = dao_collection:count_by_keys(self.params, data.next_page) + if err then + return app_helpers.yield_error(err) + end + + if next_total > 0 then + next_url = self:build_url(self.req.parsed_url.path, { + port = self.req.parsed_url.port, + query = ngx.encode_args({ + offset = ngx.encode_base64(data.next_page), + size = size + }) + }) + end + data.next_page = nil end -- This check is required otherwise the response is going to be a -- JSON Object and not a JSON array. The reason is because an empty Lua array `{}` -- will not be translated as an empty array by cjson, but as an empty object. - local result = #data == 0 and "{\"data\":[]}" or {data=data, ["next"]=next_url} + local result = #data == 0 and "{\"data\":[],\"total\":0}" or {data=data, ["next"]=next_url, total=total} return responses.send_HTTP_OK(result, type(result) ~= "table") end diff --git a/kong/api/routes/kong.lua b/kong/api/routes/kong.lua index 878b7af7a517..fbd5e65557f0 100644 --- a/kong/api/routes/kong.lua +++ b/kong/api/routes/kong.lua @@ -25,7 +25,21 @@ return { GET = function(self, dao, helpers) local res = ngx.location.capture("/nginx_status") if res.status == 200 then - return helpers.responses.send_HTTP_OK(route_helpers.parse_status(res.body)) + + local status_response = { + server = route_helpers.parse_status(res.body), + database = {} + } + + for k, v in pairs(dao.daos) do + local count, err = v:count_by_keys() + if err then + return helpers.responses.send_HTTP_INTERNAL_SERVER_ERROR(err) + end + status_response.database[k] = count + end + + return helpers.responses.send_HTTP_OK(status_response) else return helpers.responses.send_HTTP_INTERNAL_SERVER_ERROR(res.body) end diff --git a/kong/cli/config.lua b/kong/cli/config.lua index a3f5dfe0fcd4..6d749284f620 100644 --- a/kong/cli/config.lua +++ b/kong/cli/config.lua @@ -1,8 +1,9 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local constants = require "kong.constants" local cutils = require "kong.cli.utils" local IO = require "kong.tools.io" +local yaml = require "yaml" local args = require("lapp")(string.format([[ For development purposes only. @@ -20,40 +21,81 @@ local CONFIG_FILENAME = string.format("kong%s.yml", args.env ~= "" and "_"..args local config_path = cutils.get_kong_config_path(args.config) local config_content = IO.read_file(config_path) +local default_config = yaml.load(config_content) +local env = args.env:upper() local DEFAULT_ENV_VALUES = { TEST = { - ["nginx_working_dir: /usr/local/kong/"] = "nginx_working_dir: nginx_tmp", - ["send_anonymous_reports: true"] = "send_anonymous_reports: false", - ["keyspace: kong"] = "keyspace: kong_tests", - ["lua_package_path ';;'"] = "lua_package_path './kong/?.lua;;'", - ["error_log logs/error.log error"] = "error_log logs/error.log debug", - ["proxy_port: 8000"] = "proxy_port: 8100", - ["proxy_ssl_port: 8443"] = "proxy_ssl_port: 8543", - ["admin_api_port: 8001"] = "admin_api_port: 8101", - ["dnsmasq_port: 8053"] = "dnsmasq_port: 8153", - ["access_log off"] = "access_log on" + yaml = { + ["nginx_working_dir"] = "nginx_tmp", + ["send_anonymous_reports"] = false, + ["proxy_port"] = 8100, + ["proxy_ssl_port"] = 8543, + ["admin_api_port"] = 8101, + ["dnsmasq_port"] = 8153, + ["databases_available"] = { + ["cassandra"] = { + ["keyspace"] = "kong_tests" + } + } + }, + nginx = { + ["error_log logs/error.log error"] = "error_log logs/error.log debug", + ["lua_package_path ';;'"] = "lua_package_path './kong/?.lua;;'", + ["access_log off"] = "access_log on" + } }, DEVELOPMENT = { - ["nginx_working_dir: /usr/local/kong/"] = "nginx_working_dir: nginx_tmp", - ["send_anonymous_reports: true"] = "send_anonymous_reports: false", - ["keyspace: kong"] = "keyspace: kong_development", - ["lua_package_path ';;'"] = "lua_package_path './kong/?.lua;;'", - ["error_log logs/error.log error"] = "error_log logs/error.log debug", - ["lua_code_cache on"] = "lua_code_cache off", - ["access_log off"] = "access_log on" + yaml = { + ["databases_available"] = { + ["cassandra"] = { + ["keyspace"] = "kong_development" + } + } + }, + nginx = { + ["nginx_working_dir: /usr/local/kong/"] = "nginx_working_dir: nginx_tmp", + ["send_anonymous_reports: true"] = "send_anonymous_reports: false", + ["lua_package_path ';;'"] = "lua_package_path './kong/?.lua;;'", + ["error_log logs/error.log error"] = "error_log logs/error.log debug", + ["lua_code_cache on"] = "lua_code_cache off", + ["access_log off"] = "access_log on" + } } } --- Create a new default kong config for given environment -if DEFAULT_ENV_VALUES[args.env:upper()] then - -- If we know the environment we can override some of the variables - for k, v in pairs(DEFAULT_ENV_VALUES[args.env:upper()]) do - config_content = config_content:gsub(k, v) - end +if not DEFAULT_ENV_VALUES[args.env:upper()] then + cutils.error_exit(string.format("Unregistered environment '%s'", args.env:upper())) end -local ok, err = IO.write_to_file(IO.path:join(args.output, CONFIG_FILENAME), config_content) +-- Create the new configuration as a new blank object +local new_config = {} + +-- Populate with overriden values +for k, v in pairs(DEFAULT_ENV_VALUES[env].yaml) do + new_config[k] = v +end + +-- Dump into a string +local new_config_content = yaml.dump(new_config) + +-- Replace nginx directives +local nginx_config = default_config.nginx +for k, v in pairs(DEFAULT_ENV_VALUES[env].nginx) do + nginx_config = nginx_config:gsub(k, v) +end + +-- Indent nginx configuration +nginx_config = nginx_config:gsub("[^\r\n]+", function(line) + return " "..line +end) + +-- Manually add the string (can't do that before yaml.dump as it messes the formatting) +new_config_content = new_config_content..[[ +nginx: | +]]..nginx_config + +local ok, err = IO.write_to_file(IO.path:join(args.output, CONFIG_FILENAME), new_config_content) if not ok then - error(err) + cutils.error_exit(err) end diff --git a/kong/cli/db.lua b/kong/cli/db.lua index ab13315cf25f..8862d4baf62e 100644 --- a/kong/cli/db.lua +++ b/kong/cli/db.lua @@ -1,9 +1,10 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local Faker = require "kong.tools.faker" local constants = require "kong.constants" local cutils = require "kong.cli.utils" -local IO = require "kong.tools.io" +local config = require "kong.tools.config_loader" +local dao = require "kong.tools.dao_loader" local lapp = require("lapp") local args = lapp(string.format([[ @@ -29,7 +30,8 @@ if args.command == "db" then end local config_path = cutils.get_kong_config_path(args.config) -local _, dao_factory = IO.load_configuration_and_dao(config_path) +local config = config.load(config_path) +local dao_factory = dao.load(config) if args.command == "seed" then diff --git a/kong/cli/migrations.lua b/kong/cli/migrations.lua index 9c236e1b4f79..63f570fa8077 100644 --- a/kong/cli/migrations.lua +++ b/kong/cli/migrations.lua @@ -1,11 +1,12 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local Migrations = require "kong.tools.migrations" local constants = require "kong.constants" local cutils = require "kong.cli.utils" local utils = require "kong.tools.utils" local input = require "kong.cli.utils.input" -local IO = require "kong.tools.io" +local config = require "kong.tools.config_loader" +local dao = require "kong.tools.dao_loader" local lapp = require "lapp" local args = lapp(string.format([[ Kong datastore migrations. @@ -28,7 +29,8 @@ if args.command == "migrations" then end local config_path = cutils.get_kong_config_path(args.config) -local configuration, dao_factory = IO.load_configuration_and_dao(config_path) +local configuration = config.load(config_path) +local dao_factory = dao.load(configuration) local migrations = Migrations(dao_factory, configuration) local kind = args.type @@ -47,7 +49,7 @@ if args.command == "list" then elseif migrations then cutils.logger:info(string.format( "Executed migrations for keyspace %s (%s):", - cutils.colors.yellow(dao_factory._properties.keyspace), + cutils.colors.yellow(dao_factory.properties.keyspace), dao_factory.type )) @@ -61,7 +63,7 @@ if args.command == "list" then cutils.logger:info(string.format( "No migrations have been run yet for %s on keyspace: %s", cutils.colors.yellow(dao_factory.type), - cutils.colors.yellow(dao_factory._properties.keyspace) + cutils.colors.yellow(dao_factory.properties.keyspace) )) end @@ -71,7 +73,7 @@ elseif args.command == "up" then cutils.logger:info(string.format( "Migrating %s on keyspace \"%s\" (%s)", cutils.colors.yellow(identifier), - cutils.colors.yellow(dao_factory._properties.keyspace), + cutils.colors.yellow(dao_factory.properties.keyspace), dao_factory.type )) end @@ -108,7 +110,7 @@ elseif args.command == "down" then cutils.logger:info(string.format( "Rollbacking %s in keyspace \"%s\" (%s)", cutils.colors.yellow(identifier), - cutils.colors.yellow(dao_factory._properties.keyspace), + cutils.colors.yellow(dao_factory.properties.keyspace), dao_factory.type )) end @@ -128,7 +130,7 @@ elseif args.command == "down" then elseif args.command == "reset" then - local keyspace = dao_factory._properties.keyspace + local keyspace = dao_factory.properties.keyspace cutils.logger:info(string.format( "Resetting \"%s\" keyspace (%s)", diff --git a/kong/cli/quit.lua b/kong/cli/quit.lua index 9393e80d58b7..957487d5c5f4 100644 --- a/kong/cli/quit.lua +++ b/kong/cli/quit.lua @@ -1,4 +1,4 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local constants = require "kong.constants" local cutils = require "kong.cli.utils" diff --git a/kong/cli/reload.lua b/kong/cli/reload.lua index 21c10d8159e7..100133dee117 100644 --- a/kong/cli/reload.lua +++ b/kong/cli/reload.lua @@ -1,4 +1,4 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local constants = require "kong.constants" local cutils = require "kong.cli.utils" diff --git a/kong/cli/restart.lua b/kong/cli/restart.lua index c347cb3c6381..cb316dff0cd1 100644 --- a/kong/cli/restart.lua +++ b/kong/cli/restart.lua @@ -1,4 +1,4 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local constants = require "kong.constants" local cutils = require "kong.cli.utils" diff --git a/kong/cli/start.lua b/kong/cli/start.lua index eabbb601fa68..1f6755d317b5 100755 --- a/kong/cli/start.lua +++ b/kong/cli/start.lua @@ -1,4 +1,4 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local constants = require "kong.constants" local cutils = require "kong.cli.utils" diff --git a/kong/cli/stop.lua b/kong/cli/stop.lua index 2efe38b712b4..e7f7183a15b4 100755 --- a/kong/cli/stop.lua +++ b/kong/cli/stop.lua @@ -1,4 +1,4 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local constants = require "kong.constants" local cutils = require "kong.cli.utils" diff --git a/kong/cli/utils/utils.lua b/kong/cli/utils.lua similarity index 86% rename from kong/cli/utils/utils.lua rename to kong/cli/utils.lua index 0dd1e7abf391..88d5c349b94c 100644 --- a/kong/cli/utils/utils.lua +++ b/kong/cli/utils.lua @@ -119,12 +119,19 @@ local function get_kong_config_path(args_config) return config_path end --- Checks if a port is open on localhost +-- Checks if a port is available to bind a server to on localhost -- @param `port` The port to check --- @return `open` True if open, false otherwise -local function is_port_open(port) - local _, code = IO.os_execute("nc -w 5 -z 127.0.0.1 "..tostring(port)) - return code == 0 +-- @return `open` Truthy if available, falsy + error otherwise +local function is_port_bindable(port) + local server, success, err + server = require("socket").tcp() + server:setoption('reuseaddr', true) + success, err = server:bind("*", port) + if success then + success, err = server:listen() + end + server:close() + return success, err end return { @@ -133,5 +140,5 @@ return { get_kong_infos = get_kong_infos, get_kong_config_path = get_kong_config_path, get_luarocks_install_dir = get_luarocks_install_dir, - is_port_open = is_port_open + is_port_bindable = is_port_bindable } diff --git a/kong/cli/utils/dnsmasq.lua b/kong/cli/utils/dnsmasq.lua index b06b17be7b81..eb6ddfdbd870 100644 --- a/kong/cli/utils/dnsmasq.lua +++ b/kong/cli/utils/dnsmasq.lua @@ -13,7 +13,7 @@ function _M.stop(kong_config) end end -function _M.start(kong_config) +function _M.start(nginx_working_dir, dnsmasq_port) local cmd = IO.cmd_exists("dnsmasq") and "dnsmasq" if not cmd then -- Load dnsmasq given the PATH settings @@ -32,8 +32,8 @@ function _M.start(kong_config) end -- Start the dnsmasq daemon - local file_pid = kong_config.nginx_working_dir..(stringy.endswith(kong_config.nginx_working_dir, "/") and "" or "/")..constants.CLI.DNSMASQ_PID - local res, code = IO.os_execute(cmd.." -p "..kong_config.dnsmasq_port.." --pid-file="..file_pid.." -N -o") + local file_pid = nginx_working_dir..(stringy.endswith(nginx_working_dir, "/") and "" or "/")..constants.CLI.DNSMASQ_PID + local res, code = IO.os_execute(cmd.." -p "..dnsmasq_port.." --pid-file="..file_pid.." -N -o") if code ~= 0 then cutils.logger:error_exit(res) else diff --git a/kong/cli/utils/signal.lua b/kong/cli/utils/signal.lua index cab609366141..56249d8636d7 100644 --- a/kong/cli/utils/signal.lua +++ b/kong/cli/utils/signal.lua @@ -9,11 +9,11 @@ local constants = require "kong.constants" local syslog = require "kong.tools.syslog" local socket = require "socket" local dnsmasq = require "kong.cli.utils.dnsmasq" +local config = require "kong.tools.config_loader" +local dao = require "kong.tools.dao_loader" -- Cache config path, parsed config and DAO factory -local kong_config_path -local kong_config -local dao_factory +local kong_config_path, kong_config -- Retrieve the desired Kong config file, parse it and provides a DAO factory -- Will cache them for future retrieval @@ -28,9 +28,9 @@ local function get_kong_config(args_config) cutils.logger:info("Using configuration: "..kong_config_path) end if not kong_config then - kong_config, dao_factory = IO.load_configuration_and_dao(kong_config_path) + kong_config = config.load(kong_config_path) end - return kong_config, kong_config_path, dao_factory + return kong_config, kong_config_path end -- Check if an executable (typically `nginx`) is a distribution of openresty @@ -77,33 +77,20 @@ local function prepare_nginx_working_dir(args_config) if err then cutils.logger:error_exit(err) end + -- Create logs files os.execute("touch "..IO.path:join(kong_config.nginx_working_dir, "logs", "error.log")) os.execute("touch "..IO.path:join(kong_config.nginx_working_dir, "logs", "access.log")) + -- Create SSL folder if needed local _, err = IO.path:mkdir(IO.path:join(kong_config.nginx_working_dir, "ssl")) if err then cutils.logger:error_exit(err) end - -- TODO: this is NOT the place to do this. - -- @see https://github.com/Mashape/kong/issues/92 for configuration validation/defaults - -- @see https://github.com/Mashape/kong/issues/217 for a better configuration file - - -- Check memory cache - if kong_config.memory_cache_size then - if tonumber(kong_config.memory_cache_size) == nil then - cutils.logger:error_exit("Invalid \"memory_cache_size\" setting") - elseif tonumber(kong_config.memory_cache_size) < 32 then - cutils.logger:error_exit("Invalid \"memory_cache_size\" setting: needs to be at least 32") - end - else - kong_config.memory_cache_size = 128 -- Default value - cutils.logger:warn("Setting \"memory_cache_size\" to default 128MB") - end ssl.prepare_ssl(kong_config) local ssl_cert_path, ssl_key_path = ssl.get_ssl_cert_and_key(kong_config) - local trusted_ssl_cert_path = kong_config.databases_available[kong_config.database].properties.ssl_certificate -- DAO ssl cert + local trusted_ssl_cert_path = kong_config.dao_config.ssl_certificate -- DAO ssl cert -- Extract nginx config from kong config, replace any needed value local nginx_config = kong_config.nginx @@ -111,7 +98,7 @@ local function prepare_nginx_working_dir(args_config) proxy_port = kong_config.proxy_port, proxy_ssl_port = kong_config.proxy_ssl_port, admin_api_port = kong_config.admin_api_port, - dns_resolver = "127.0.0.1:"..kong_config.dnsmasq_port, + dns_resolver = kong_config.dns_resolver.address, memory_cache_size = kong_config.memory_cache_size, ssl_cert = ssl_cert_path, ssl_key = ssl_key_path, @@ -161,7 +148,8 @@ end -- Prepare the database keyspace if needed (run schema migrations) -- @param args_config Path to the desired configuration (usually from the --config CLI argument) local function prepare_database(args_config) - local kong_config, _, dao_factory = get_kong_config(args_config) + local kong_config = get_kong_config(args_config) + local dao_factory = dao.load(kong_config) local migrations = require("kong.tools.migrations")(dao_factory, kong_config) local keyspace_exists, err = dao_factory.migrations:keyspace_exists() @@ -175,7 +163,7 @@ local function prepare_database(args_config) cutils.logger:info(string.format( "Migrating %s on keyspace \"%s\" (%s)", cutils.colors.yellow(identifier), - cutils.colors.yellow(dao_factory._properties.keyspace), + cutils.colors.yellow(dao_factory.properties.keyspace), dao_factory.type )) end @@ -213,7 +201,7 @@ _M.QUIT = QUIT function _M.prepare_kong(args_config, signal) local kong_config = get_kong_config(args_config) - local dao_config = kong_config.databases_available[kong_config.database].properties + local dao_config = kong_config.dao_config local printable_mt = require "kong.tools.printable" setmetatable(dao_config, printable_mt) @@ -223,14 +211,14 @@ function _M.prepare_kong(args_config, signal) Proxy HTTP port....%s Proxy HTTPS port...%s Admin API port.....%s - dnsmasq port.......%s + DNS resolver.......%s Database...........%s %s ]], constants.VERSION, kong_config.proxy_port, kong_config.proxy_ssl_port, kong_config.admin_api_port, - kong_config.dnsmasq_port, + kong_config.dns_resolver.address, kong_config.database, tostring(dao_config))) @@ -239,9 +227,25 @@ function _M.prepare_kong(args_config, signal) prepare_nginx_working_dir(args_config, signal) end -local function check_port(port) - if cutils.is_port_open(port) then - cutils.logger:error_exit("Port "..tostring(port).." is already being used by another process.") +-- Checks whether a port is available. Exits the application if not available. +-- @param port The port to check +-- @param name Functional name the port is used for (display name) +-- @param timeout (optional) Timeout in seconds after which a failure is logged +-- and application exit is performed, if not provided then it will fail at once without retries. +local function check_port(port, name, timeout) + local expire = socket.gettime() + (timeout or 0) + local msg = tostring(port) .. (name and " ("..tostring(name)..")") + local warned + while not cutils.is_port_bindable(port) do + if expire <= socket.gettime() then + cutils.logger:error_exit("Port "..msg.." is being blocked by another process.") + else + if not warned then + cutils.logger:warn("Port "..msg.." is unavailable, retrying for "..tostring(timeout).." seconds") + warned = true + end + end + socket.sleep(0.5) end end @@ -253,6 +257,7 @@ end -- @return A boolean: true for success, false otherwise function _M.send_signal(args_config, signal) -- Make sure nginx is there and is openresty + local port_timeout = 1 -- OPT: make timeout configurable (note: this is a blocking timeout!) local nginx_path = find_nginx() if not nginx_path then cutils.logger:error_exit(string.format("Kong cannot find an 'nginx' executable.\nMake sure it is in your $PATH or in one of the following directories:\n%s", table.concat(NGINX_SEARCH_PATHS, "\n"))) @@ -262,9 +267,13 @@ function _M.send_signal(args_config, signal) if not signal then signal = START end if signal == START then - local ports = { kong_config.proxy_port, kong_config.proxy_ssl_port, kong_config.admin_api_port } - for _,port in ipairs(ports) do - check_port(port) + local ports = { + ["Kong proxy"] = kong_config.proxy_port, + ["Kong proxy ssl"] = kong_config.proxy_ssl_port, + ["Kong admin api"] = kong_config.admin_api_port + } + for name, port in pairs(ports) do + check_port(port, name, port_timeout) end end @@ -280,8 +289,11 @@ function _M.send_signal(args_config, signal) -- dnsmasq start/stop if signal == START then dnsmasq.stop(kong_config) - check_port(kong_config.dnsmasq_port) - dnsmasq.start(kong_config) + if kong_config.dns_resolver.dnsmasq then + local dnsmasq_port = kong_config.dns_resolver.port + check_port(dnsmasq_port, "dnsmasq", port_timeout) + dnsmasq.start(kong_config.nginx_working_dir, dnsmasq_port) + end elseif signal == STOP or signal == QUIT then dnsmasq.stop(kong_config) end @@ -290,7 +302,7 @@ function _M.send_signal(args_config, signal) if signal == START or signal == RESTART or signal == RELOAD then local res, code = IO.os_execute("ulimit -n") if code == 0 and tonumber(res) < 4096 then - cutils.logger:warn("ulimit is currently set to \""..res.."\". For better performance set it to at least \"4096\" using \"ulimit -n\"") + cutils.logger:warn('ulimit is currently set to "'..res..'". For better performance set it to at least "4096" using "ulimit -n"') end end diff --git a/kong/cli/version.lua b/kong/cli/version.lua index 2d39a0afd6fe..10b5450c5628 100644 --- a/kong/cli/version.lua +++ b/kong/cli/version.lua @@ -1,4 +1,4 @@ -#!/usr/bin/env lua +#!/usr/bin/env luajit local cutils = require "kong.cli.utils" local constants = require "kong.constants" diff --git a/kong/constants.lua b/kong/constants.lua index 4901d899d951..deca9a8c142b 100644 --- a/kong/constants.lua +++ b/kong/constants.lua @@ -33,8 +33,8 @@ return { -- Non standard headers, specific to Kong HEADERS = { HOST_OVERRIDE = "X-Host-Override", - PROXY_TIME = "X-Kong-Proxy-Time", - API_TIME = "X-Kong-Api-Time", + PROXY_LATENCY = "X-Kong-Proxy-Latency", + UPSTREAM_LATENCY = "X-Kong-Upstream-Latency", CONSUMER_ID = "X-Consumer-ID", CONSUMER_CUSTOM_ID = "X-Consumer-Custom-ID", CONSUMER_USERNAME = "X-Consumer-Username", diff --git a/kong/resolver/certificate.lua b/kong/core/certificate.lua similarity index 81% rename from kong/resolver/certificate.lua rename to kong/core/certificate.lua index 6e954999a286..6b512efa0c3c 100644 --- a/kong/resolver/certificate.lua +++ b/kong/core/certificate.lua @@ -9,7 +9,7 @@ local function find_api(hosts) local sanitized_host = stringy.split(host, ":")[1] retrieved_api, err = cache.get_or_set(cache.api_key(sanitized_host), function() - local apis, err = dao.apis:find_by_keys { request_host = sanitized_host } + local apis, err = dao.apis:find_by_keys {request_host = sanitized_host} if err then return nil, err elseif apis and #apis == 1 then @@ -23,14 +23,16 @@ local function find_api(hosts) end end -function _M.execute(conf) +function _M.execute() local ssl = require "ngx.ssl" local server_name = ssl.server_name() if server_name then -- Only support SNI requests local api, err = find_api({server_name}) - if not err and api then - ngx.ctx.api = api + if err then + ngx.log(ngx.ERR, tostring(err)) end + + return api end end diff --git a/kong/core/handler.lua b/kong/core/handler.lua new file mode 100644 index 000000000000..f7c060c5253c --- /dev/null +++ b/kong/core/handler.lua @@ -0,0 +1,97 @@ +-- Kong core +-- +-- This consists of events than need to +-- be ran at the very beginning and very end of the lua-nginx-module contexts. +-- It mainly carries information related to a request from one context to the next one, +-- through the `ngx.ctx` table. +-- +-- In the `access_by_lua` phase, it is responsible for retrieving the API being proxied by +-- a Consumer. Then it is responsible for loading the plugins to execute on this request. +-- +-- In other phases, we create different variables and timers. +-- Variables: +-- `plugins_to_execute`: an array of plugin to be executed for this request. +-- Timers: +-- `KONG__STARTED_AT`: time at which a given context is started to be executed by all Kong plugins. +-- `KONG__ENDED_AT`: time at which all plugins have been executed by Kong for this context. +-- `KONG__TIME`: time taken by Kong to execute all the plugins for this context +-- +-- @see https://github.com/openresty/lua-nginx-module#ngxctx + +local utils = require "kong.tools.utils" +local reports = require "kong.core.reports" +local resolver = require "kong.core.resolver" +local constants = require "kong.constants" +local certificate = require "kong.core.certificate" + +local ngx_now = ngx.now + +local function get_now() + return ngx_now() * 1000 -- time is kept in seconds with millisecond resolution. +end + +return { + init_worker = function() + reports.init_worker() + end, + certificate = function() + ngx.ctx.api = certificate.execute() + end, + access = { + before = function() + ngx.ctx.KONG_ACCESS_START = get_now() + ngx.ctx.api, ngx.ctx.upstream_url, ngx.var.upstream_host = resolver.execute(ngx.var.request_uri, ngx.req.get_headers()) + end, + -- Only executed if the `resolver` module found an API and allows nginx to proxy it. + after = function() + -- Append any querystring parameters modified during plugins execution + local upstream_url = ngx.ctx.upstream_url + local uri_args = ngx.req.get_uri_args() + if utils.table_size(uri_args) > 0 then + upstream_url = upstream_url.."?"..utils.encode_args(uri_args) + end + + -- Set the `$upstream_url` and `$upstream_host` variables for the `proxy_pass` nginx + -- directive in kong.yml. + ngx.var.upstream_url = upstream_url + + local now = get_now() + ngx.ctx.KONG_ACCESS_TIME = now - ngx.ctx.KONG_ACCESS_START -- time spent in Kong's access_by_lua + ngx.ctx.KONG_ACCESS_ENDED_AT = now + -- time spent in Kong before sending the reqeust to upstream + ngx.ctx.KONG_PROXY_LATENCY = now - ngx.req.start_time() * 1000 -- ngx.req.start_time() is kept in seconds with millisecond resolution. + ngx.ctx.KONG_PROXIED = true + end + }, + header_filter = { + before = function() + if ngx.ctx.KONG_PROXIED then + local now = get_now() + ngx.ctx.KONG_WAITING_TIME = now - ngx.ctx.KONG_ACCESS_ENDED_AT -- time spent waiting for a response from upstream + ngx.ctx.KONG_HEADER_FILTER_STARTED_AT = now + end + end, + after = function() + if ngx.ctx.KONG_PROXIED then + ngx.header[constants.HEADERS.UPSTREAM_LATENCY] = ngx.ctx.KONG_WAITING_TIME + ngx.header[constants.HEADERS.PROXY_LATENCY] = ngx.ctx.KONG_PROXY_LATENCY + ngx.header["Via"] = constants.NAME.."/"..constants.VERSION + else + ngx.header["Server"] = constants.NAME.."/"..constants.VERSION + end + end + }, + body_filter = { + after = function() + if ngx.arg[2] and ngx.ctx.KONG_PROXIED then + -- time spent receiving the response (header_filter + body_filter) + -- we could uyse $upstream_response_time but we need to distinguish the waiting time + -- from the receiving time in our logging plugins (especially ALF serializer). + ngx.ctx.KONG_RECEIVE_TIME = get_now() - ngx.ctx.KONG_HEADER_FILTER_STARTED_AT + end + end + }, + log = function() + reports.log() + end +} diff --git a/kong/core/plugins_iterator.lua b/kong/core/plugins_iterator.lua new file mode 100644 index 000000000000..bb3a542c0bd1 --- /dev/null +++ b/kong/core/plugins_iterator.lua @@ -0,0 +1,118 @@ +local cache = require "kong.tools.database_cache" +local constants = require "kong.constants" +local responses = require "kong.tools.responses" + +local table_remove = table.remove +local table_insert = table.insert +local ipairs = ipairs + +--- Load the configuration for a plugin entry in the DB. +-- Given an API, a Consumer and a plugin name, retrieve the plugin's configuration if it exists. +-- Results are cached in ngx.dict +-- @param[type=string] api_id ID of the API being proxied. +-- @param[type=string] consumer_id ID of the Consumer making the request (if any). +-- @param[type=stirng] plugin_name Name of the plugin being tested for. +-- @treturn table Plugin retrieved from the cache or database. +local function load_plugin_configuration(api_id, consumer_id, plugin_name) + local cache_key = cache.plugin_key(plugin_name, api_id, consumer_id) + + local plugin = cache.get_or_set(cache_key, function() + local rows, err = dao.plugins:find_by_keys { + api_id = api_id, + consumer_id = consumer_id ~= nil and consumer_id or constants.DATABASE_NULL_ID, + name = plugin_name + } + if err then + return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) + end + + if #rows > 0 then + return table_remove(rows, 1) + else + -- insert a cached value to not trigger too many DB queries. + -- for now, this will lock the cache for the expiraiton duration. + return {null = true} + end + end) + + if plugin ~= nil and plugin.enabled then + return plugin.config or {} + end +end + +local function load_plugins_for_req(loaded_plugins) + if ngx.ctx.plugins_for_request == nil then + local t = {} + -- Build an array of plugins that must be executed for this particular request. + -- A plugin is considered to be executed if there is a row in the DB which contains: + -- 1. the API id (contained in ngx.ctx.api.id, retrived by the core resolver) + -- 2. a Consumer id, in which case it overrides any previous plugin found in 1. + -- this use case will be treated once the authentication plugins have run (access phase). + -- Such a row will contain a `config` value, which is a table. + if ngx.ctx.api ~= nil then + for _, plugin in ipairs(loaded_plugins) do + local plugin_configuration = load_plugin_configuration(ngx.ctx.api.id, nil, plugin.name) + if plugin_configuration ~= nil then + table_insert(t, {plugin, plugin_configuration}) + end + end + end + + ngx.ctx.plugins_for_request = t + end +end + +--- Plugins for request iterator. +-- Iterate over the plugin loaded for a request, stored in `ngx.ctx.plugins_for_request`. +-- @param[type=string] context_name Name of the current nginx context. We don't use `ngx.get_phase()` simply because we can avoid it. +-- @treturn function iterator +local function iter_plugins_for_req(loaded_plugins, context_name) + -- In case previous contexts did not run, we need to handle + -- the case when plugins have not been fetched for a given request. + -- This will simply make it so the look gets skipped if no API is set in the context + load_plugins_for_req(loaded_plugins) + + local i = 0 + + -- Iterate on plugins to execute for this request until + -- a plugin with a handler for the given context is found. + local function get_next() + i = i + 1 + local p = ngx.ctx.plugins_for_request[i] + if p == nil then + return + end + + local plugin, plugin_configuration = p[1], p[2] + if plugin.handler[context_name] == nil then + ngx.log(ngx.DEBUG, "No handler for "..context_name.." phase on "..plugin.name.." plugin") + return get_next() + end + + return plugin, plugin_configuration + end + + return function() + local plugin, plugin_configuration = get_next() + + -- Check if any Consumer was authenticated during the access phase. + -- If so, retrieve the configuration for this Consumer which overrides + -- the API-wide configuration. + if plugin ~= nil and context_name == "access" then + local consumer_id = ngx.ctx.authenticated_credential and ngx.ctx.authenticated_credential.consumer_id or nil + if consumer_id ~= nil then + local consumer_plugin_configuration = load_plugin_configuration(ngx.ctx.api.id, consumer_id, plugin.name) + if consumer_plugin_configuration ~= nil then + -- This Consumer has a special configuration when this plugin gets executed. + -- Override this plugin's configuration for this request. + plugin_configuration = consumer_plugin_configuration + ngx.ctx.plugins_for_request[i][2] = consumer_plugin_configuration + end + end + end + + return plugin, plugin_configuration + end +end + +return iter_plugins_for_req diff --git a/kong/reports/init_worker.lua b/kong/core/reports.lua similarity index 62% rename from kong/reports/init_worker.lua rename to kong/core/reports.lua index 66785e1fd635..ba9805f6f24f 100644 --- a/kong/reports/init_worker.lua +++ b/kong/core/reports.lua @@ -1,11 +1,8 @@ local syslog = require "kong.tools.syslog" -local lock = require "resty.lock" local cache = require "kong.tools.database_cache" local INTERVAL = 3600 -local _M = {} - local function create_timer(at, cb) local ok, err = ngx.timer.at(at, cb) if not ok then @@ -14,22 +11,26 @@ local function create_timer(at, cb) end local function send_ping(premature) - local lock = lock:new("locks", { + local resty_lock = require "resty.lock" + local lock = resty_lock:new("locks", { exptime = INTERVAL - 0.001 }) local elapsed = lock:lock("ping") if elapsed and elapsed == 0 then local reqs = cache.get(cache.requests_key()) if not reqs then reqs = 0 end - syslog.log({signal = "ping", requests=reqs}) + syslog.log({signal = "ping", requests=reqs, process_id=process_id}) cache.incr(cache.requests_key(), -reqs) -- Reset counter end create_timer(INTERVAL, send_ping) end -function _M.execute() - cache.rawset(cache.requests_key(), 0, 0) -- Initializing the counter - create_timer(INTERVAL, send_ping) -end - -return _M +return { + init_worker = function() + cache.rawset(cache.requests_key(), 0, 0) -- Initializing the counter + create_timer(INTERVAL, send_ping) + end, + log = function() + cache.incr(cache.requests_key(), 1) + end +} diff --git a/kong/resolver/access.lua b/kong/core/resolver.lua similarity index 65% rename from kong/resolver/access.lua rename to kong/core/resolver.lua index 3001f9172f1a..f76bb4c7e526 100644 --- a/kong/resolver/access.lua +++ b/kong/core/resolver.lua @@ -4,39 +4,50 @@ local stringy = require "stringy" local constants = require "kong.constants" local responses = require "kong.tools.responses" +local table_insert = table.insert +local string_match = string.match +local string_find = string.find +local string_format = string.format +local string_sub = string.sub +local string_gsub = string.gsub +local string_len = string.len +local ipairs = ipairs +local unpack = unpack +local type = type + local _M = {} -- Take a request_host and make it a pattern for wildcard matching. -- Only do so if the request_host actually has a wildcard. local function create_wildcard_pattern(request_host) - if string.find(request_host, "*", 1, true) then - local pattern = string.gsub(request_host, "%.", "%%.") - pattern = string.gsub(pattern, "*", ".+") - pattern = string.format("^%s$", pattern) + if string_find(request_host, "*", 1, true) then + local pattern = string_gsub(request_host, "%.", "%%.") + pattern = string_gsub(pattern, "*", ".+") + pattern = string_format("^%s$", pattern) return pattern end end -- Handles pattern-specific characters if any. local function create_strip_request_path_pattern(request_path) - return string.gsub(request_path, "[%(%)%.%%%+%-%*%?%[%]%^%$]", function(c) return "%"..c end) + return string_gsub(request_path, "[%(%)%.%%%+%-%*%?%[%]%^%$]", function(c) return "%"..c end) end -local function get_backend_url(api) +local function get_upstream_url(api) local result = api.upstream_url -- Checking if the target url ends with a final slash - local len = string.len(result) - if string.sub(result, len, len) == "/" then + local len = string_len(result) + if string_sub(result, len, len) == "/" then -- Remove one slash to avoid having a double slash - -- Because ngx.var.uri always starts with a slash - result = string.sub(result, 0, len - 1) + -- Because ngx.var.request_uri always starts with a slash + result = string_sub(result, 0, len - 1) end return result end -local function get_host_from_url(val) +local function get_host_from_upstream_url(val) local parsed_url = url.parse(val) local port @@ -66,14 +77,14 @@ function _M.load_apis_in_memory() if pattern then -- If the request_host is a wildcard, we have a pattern and we can -- store it in an array for later lookup. - table.insert(dns_wildcard_arr, {pattern = pattern, api = api}) + table_insert(dns_wildcard_arr, {pattern = pattern, api = api}) else -- Keep non-wildcard request_host in a dictionary for faster lookup. dns_dic[api.request_host] = api end end if api.request_path then - table.insert(request_path_arr, { + table_insert(request_path_arr, { api = api, request_path = api.request_path, strip_request_path_pattern = create_strip_request_path_pattern(api.request_path) @@ -89,7 +100,7 @@ function _M.load_apis_in_memory() end function _M.find_api_by_request_host(req_headers, apis_dics) - local all_hosts = {} + local hosts_list = {} for _, header_name in ipairs({"Host", constants.HEADERS.HOST_OVERRIDE}) do local hosts = req_headers[header_name] if hosts then @@ -99,14 +110,14 @@ function _M.find_api_by_request_host(req_headers, apis_dics) -- for all values of this header, try to find an API using the apis_by_dns dictionnary for _, host in ipairs(hosts) do host = unpack(stringy.split(host, ":")) - table.insert(all_hosts, host) + table_insert(hosts_list, host) if apis_dics.by_dns[host] then - return apis_dics.by_dns[host] + return apis_dics.by_dns[host], host else -- If the API was not found in the dictionary, maybe it is a wildcard request_host. -- In that case, we need to loop over all of them. for _, wildcard_dns in ipairs(apis_dics.wildcard_dns_arr) do - if string.match(host, wildcard_dns.pattern) then + if string_match(host, wildcard_dns.pattern) then return wildcard_dns.api end end @@ -115,7 +126,7 @@ function _M.find_api_by_request_host(req_headers, apis_dics) end end - return nil, all_hosts + return nil, nil, hosts_list end -- To do so, we have to compare entire URI segments (delimited by "/"). @@ -149,9 +160,14 @@ end -- Replace `/request_path` with `request_path`, and then prefix with a `/` -- or replace `/request_path/foo` with `/foo`, and then do not prefix with `/`. -function _M.strip_request_path(uri, strip_request_path_pattern) - local uri = string.gsub(uri, strip_request_path_pattern, "", 1) - if string.sub(uri, 0, 1) ~= "/" then +function _M.strip_request_path(uri, strip_request_path_pattern, upstream_url_has_path) + local uri = string_gsub(uri, strip_request_path_pattern, "", 1) + + -- Sometimes uri can be an empty string, and adding a slash "/"..uri will lead to a trailing slash + -- We don't want to add a trailing slash in one specific scenario, when the upstream_url already has + -- a path (so it's not root, like http://hello.com/, but http://hello.com/path) in order to avoid + -- having an unnecessary trailing slash not wanted by the user. Hence the "upstream_url_has_path" check. + if string_sub(uri, 0, 1) ~= "/" and not upstream_url_has_path then uri = "/"..uri end return uri @@ -165,13 +181,14 @@ end -- We keep APIs in the database cache for a longer time than usual. -- @see https://github.com/Mashape/kong/issues/15 for an improvement on this. -- --- @param `uri` The URI for this request. --- @return `err` Any error encountered during the retrieval. --- @return `api` The retrieved API, if any. --- @return `hosts` The list of headers values found in Host and X-Host-Override. +-- @param `uri` The URI for this request. +-- @return `err` Any error encountered during the retrieval. +-- @return `api` The retrieved API, if any. +-- @return `matched_host` The host that was matched for this API, if matched. +-- @return `hosts` The list of headers values found in Host and X-Host-Override. -- @return `strip_request_path_pattern` If the API was retrieved by request_path, contain the pattern to strip it from the URI. -local function find_api(uri) - local api, all_hosts, strip_request_path_pattern +local function find_api(uri, headers) + local api, matched_host, hosts_list, strip_request_path_pattern -- Retrieve all APIs local apis_dics, err = cache.get_or_set("ALL_APIS_BY_DIC", _M.load_apis_in_memory, 60) -- 60 seconds cache, longer than usual @@ -180,47 +197,55 @@ local function find_api(uri) end -- Find by Host header - api, all_hosts = _M.find_api_by_request_host(ngx.req.get_headers(), apis_dics) - + api, matched_host, hosts_list = _M.find_api_by_request_host(headers, apis_dics) -- If it was found by Host, return if api then - return nil, api, all_hosts + return nil, api, matched_host, hosts_list end -- Otherwise, we look for it by request_path. We have to loop over all APIs and compare the requested URI. api, strip_request_path_pattern = _M.find_api_by_request_path(uri, apis_dics.request_path_arr) - return nil, api, all_hosts, strip_request_path_pattern + return nil, api, nil, hosts_list, strip_request_path_pattern end -function _M.execute(conf) - local uri = stringy.split(ngx.var.request_uri, "?")[1] +local function url_has_path(url) + local _, count_slashes = string_gsub(url, "/", "") + return count_slashes > 2 +end - local err, api, hosts, strip_request_path_pattern = find_api(uri) +function _M.execute(request_uri, request_headers) + local uri = unpack(stringy.split(request_uri, "?")) + local err, api, matched_host, hosts_list, strip_request_path_pattern = find_api(uri, request_headers) if err then return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) elseif not api then return responses.send_HTTP_NOT_FOUND { message = "API not found with these values", - request_host = hosts, + request_host = hosts_list, request_path = uri } end + local upstream_host + local upstream_url = get_upstream_url(api) + -- If API was retrieved by request_path and the request_path needs to be stripped if strip_request_path_pattern and api.strip_request_path then - uri = _M.strip_request_path(uri, strip_request_path_pattern) + uri = _M.strip_request_path(uri, strip_request_path_pattern, url_has_path(upstream_url)) end - -- Setting the backend URL for the proxy_pass directive - ngx.var.backend_url = get_backend_url(api)..uri + upstream_url = upstream_url..uri + if api.preserve_host then - ngx.var.backend_host = ngx.req.get_headers()["host"] - else - ngx.var.backend_host = get_host_from_url(ngx.var.backend_url) + upstream_host = matched_host + end + + if upstream_host == nil then + upstream_host = get_host_from_upstream_url(upstream_url) end - ngx.ctx.api = api + return api, upstream_url, upstream_host end return _M diff --git a/kong/dao/cassandra/apis.lua b/kong/dao/cassandra/apis.lua index 3933f5f2a33c..7e5d5d5d02b1 100644 --- a/kong/dao/cassandra/apis.lua +++ b/kong/dao/cassandra/apis.lua @@ -2,6 +2,9 @@ local BaseDao = require "kong.dao.cassandra.base_dao" local apis_schema = require "kong.dao.schemas.apis" local query_builder = require "kong.dao.cassandra.query_builder" +local ipairs = ipairs +local table_insert = table.insert + local Apis = BaseDao:extend() function Apis:new(properties) @@ -13,13 +16,14 @@ end function Apis:find_all() local apis = {} local select_q = query_builder.select(self._table) - for rows, err in Apis.super.execute(self, select_q, nil, nil, {auto_paging=true}) do + + for rows, err in self:execute(select_q, nil, {auto_paging = true}) do if err then return nil, err - end - - for _, row in ipairs(rows) do - table.insert(apis, row) + elseif rows ~= nil then + for _, row in ipairs(rows) do + table_insert(apis, row) + end end end diff --git a/kong/dao/cassandra/base_dao.lua b/kong/dao/cassandra/base_dao.lua index ce91d3e6b999..09f7e6a34c3f 100644 --- a/kong/dao/cassandra/base_dao.lua +++ b/kong/dao/cassandra/base_dao.lua @@ -1,5 +1,9 @@ --- Kong's Cassandra base DAO entity. Provides basic functionalities on top of --- lua-resty-cassandra (https://github.com/jbochi/lua-resty-cassandra) +--- +-- Kong's Cassandra base DAO module. Provides functionalities on top of +-- lua-cassandra (https://github.com/thibaultCha/lua-cassandra) for schema validations, +-- CRUD operations, preparation and caching of executed statements, etc... +-- +-- @see http://thibaultcha.github.io/lua-cassandra/manual/README.md.html local query_builder = require "kong.dao.cassandra.query_builder" local validations = require "kong.dao.schemas_validation" @@ -10,217 +14,67 @@ local DaoError = require "kong.dao.error" local stringy = require "stringy" local Object = require "classic" local utils = require "kong.tools.utils" -local uuid = require "uuid" +local uuid = require "lua_uuid" -local cassandra_constants = cassandra.constants +local table_remove = table.remove local error_types = constants.DATABASE_ERROR_TYPES -local BaseDao = Object:extend() - --- This is important to seed the UUID generator -uuid.seed() +--- Base DAO +-- @section base_dao -local function session_uniq_addr(session) - return session.host..":"..session.port -end +local BaseDao = Object:extend() -function BaseDao:new(properties) - if self._schema then - self._primary_key = self._schema.primary_key - self._clustering_key = self._schema.clustering_key - local indexes = {} - for field_k, field_v in pairs(self._schema.fields) do - if field_v.queryable then - indexes[field_k] = true +--- Public interface. +-- Public methods developers can use in Kong core or in any plugin. +-- @section public + +local function page_iterator(self, session, query, args, query_options) + local iter = session:execute(query, args, query_options) + return function(query, previous_rows) + local rows, err, page = iter(query, previous_rows) + if rows == nil or err ~= nil then + session:set_keep_alive() + else + for i, row in ipairs(rows) do + rows[i] = self:_unmarshall(row) end end - - self._column_family_details = { - primary_key = self._primary_key, - clustering_key = self._clustering_key, - indexes = indexes - } - end - - self._properties = properties - self._statements_cache = {} - self._cascade_delete_hooks = {} + return rows, err, page + end, query end --- Marshall an entity. Does nothing by default, --- must be overriden for entities where marshalling applies. -function BaseDao:_marshall(t) - return t -end - --- Unmarshall an entity. Does nothing by default, --- must be overriden for entities where marshalling applies. -function BaseDao:_unmarshall(t) - return t -end - --- Open a session on the configured keyspace. --- @param `keyspace` (Optional) Override the keyspace for this session if specified. --- @return `session` Opened session --- @return `error` Error if any -function BaseDao:_open_session(keyspace) - local ok, err - - -- Start cassandra session - local session = cassandra:new() - session:set_timeout(self._properties.timeout) - +--- Execute a query. +-- This method should be called with the proper **args** formatting (as an array). +-- See `execute()` for building this parameter. +-- @see execute +-- @param query Plain string CQL query. +-- @param[type=table] args (Optional) Arguments to the query, as an array. Simply passed to lua-cassandra `execute()`. +-- @param[type=table] query_options (Optional) Options to give to lua-cassandra `execute()` query_options. +-- @param[type=string] keyspace (Optional) Override the keyspace for this query if specified. +-- @treturn table If the result consists of ROWS, a table with an array of unmarshalled rows and a `next_page` property if the results has a `paging_state`. If the result is of type "VOID", a boolean representing the success of the query. Otherwise, the raw result as given by lua-cassandra. +-- @treturn table An error if any during the execution. +function BaseDao:execute(query, args, query_options, keyspace) local options = self._factory:get_session_options() - - ok, err = session:connect(self._properties.hosts or self._properties.contact_points, nil, options) - if not ok then - return nil, DaoError(err, error_types.DATABASE) - end - - local times, err = session:get_reused_times() - if err and err.message ~= "luasocket does not support reusable sockets" then - return nil, DaoError(err, error_types.DATABASE) - end - - if times == 0 or not times then - ok, err = session:set_keyspace(keyspace ~= nil and keyspace or self._properties.keyspace) - if not ok then - return nil, DaoError(err, error_types.DATABASE) - end + if keyspace then + options.keyspace = keyspace end - return session -end - --- Close the given opened session. --- Will try to put the session in the socket pool if supported. --- @param `session` Cassandra session to close --- @return `error` Error if any -function BaseDao:_close_session(session) - -- Back to the pool or close if using luasocket - local ok, err = session:set_keepalive(self._properties.keepalive) - if not ok and err.message == "luasocket does not support reusable sockets" then - ok, err = session:close() + local session, err = cassandra.spawn_session(options) + if not session then + return nil, DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) end - if not ok then - return DaoError(err, error_types.DATABASE) + if query_options and query_options.auto_paging then + return page_iterator(self, session, query, args, query_options) end -end --- Build the array of arguments to pass to lua-resty-cassandra :execute method. --- Note: --- Since this method only accepts an ordered list, we build this list from --- the entity `t` and an (ordered) array of parameters for a query, taking --- into account special cassandra values (uuid, timestamps, NULL). --- @param `schema` A schema with type properties to encode specific values --- @param `t` Values to bind to a statement --- @param `parameters` An ordered list of parameters --- @return `args` An ordered list of values to be binded to lua-resty-cassandra :execute --- @return `error` Error Cassandra type validation errors -local function encode_cassandra_args(schema, t, args_keys) - local args_to_bind = {} - local errors - - for _, column in ipairs(args_keys) do - local schema_field = schema.fields[column] - local arg = t[column] - - if schema_field.type == "id" and arg then - if validations.is_valid_uuid(arg) then - arg = cassandra.uuid(arg) - else - errors = utils.add_error(errors, column, arg.." is an invalid uuid") - end - elseif schema_field.type == "timestamp" and arg then - arg = cassandra.timestamp(arg) - elseif arg == nil then - arg = cassandra.null - end - - table.insert(args_to_bind, arg) - end - - return args_to_bind, errors -end - --- Get a statement from the cache or prepare it (and thus insert it in the cache). --- The cache key will be the plain string query representation. --- @param `query` The query to prepare --- @return `statement` The prepared cassandra statement --- @return `cache_key` The cache key used to store it into the cache --- @return `error` Error if any during the query preparation -function BaseDao:get_or_prepare_stmt(session, query) - if type(query) ~= "string" then - -- Cannot be prepared (probably a BatchStatement) - return query - end - - local statement, err - local session_addr = session_uniq_addr(session) - -- Retrieve the prepared statement from cache or prepare and cache - if self._statements_cache[session_addr] and self._statements_cache[session_addr][query] then - statement = self._statements_cache[session_addr][query] - else - statement, err = self:prepare_stmt(session, query) - if err then - return nil, query, err - end - end - - return statement, query -end - --- Execute a query, trying to prepare them on a per-host basis. --- Opens a socket, execute the statement, puts the socket back into the --- socket pool and returns a parsed result. --- @param `query` Plain string query or BatchStatement. --- @param `args` (Optional) Arguments to the query, simply passed to lua-resty-cassandra's :execute() --- @param `options` (Optional) Options to give to lua-resty-cassandra's :execute() --- @param `keyspace` (Optional) Override the keyspace for this query if specified. --- @return `results` If results set are ROWS, a table with an array of unmarshalled rows and a `next_page` property if the results have a paging_state. --- @return `error` An error if any during the whole execution (sockets/query execution) -function BaseDao:_execute(query, args, options, keyspace) - local session, err = self:_open_session(keyspace) - if err then - return nil, err - end - - -- Prepare query and cache the prepared statement for later call - local statement, cache_key, err = self:get_or_prepare_stmt(session, query) + local results, err = session:execute(query, args, query_options) if err then - if options and options.auto_paging then - -- Allow the iteration to run once and thus catch the error - return function() return {}, err end - end - return nil, err + err = DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) end - if options and options.auto_paging then - local _, rows, err, page = session:execute(statement, args, options) - for i, row in ipairs(rows) do - rows[i] = self:_unmarshall(row) - end - return _, rows, err, page - end - - local results, err = session:execute(statement, args, options) - - -- First, close the socket - local socket_err = self:_close_session(session) - if socket_err then - return nil, socket_err - end - - -- Handle unprepared queries - if err and err.cassandra_err_code == cassandra_constants.error_codes.UNPREPARED then - ngx.log(ngx.NOTICE, "Cassandra did not recognize prepared statement \""..cache_key.."\". Re-preparing it and re-trying the query. (Error: "..err..")") - -- If the statement was declared unprepared, clear it from the cache, and try again. - self._statements_cache[session_uniq_addr(session)][cache_key] = nil - return self:_execute(query, args, options) - elseif err then - err = DaoError(err, error_types.DATABASE) - end + -- First, close the session (and underlying sockets) + session:set_keep_alive() -- Parse result if results and results.type == "ROWS" then @@ -246,111 +100,19 @@ function BaseDao:_execute(query, args, options, keyspace) end end --- Bind a table of arguments to a query depending on the entity's schema, --- and then execute the query. --- @param `query` The query to execute --- @param `args_to_bind` Key/value table of arguments to bind --- @param `options` Options to pass to lua-resty-cassandra :execute() --- @return :_execute() -function BaseDao:execute(query, columns, args_to_bind, options) - -- Build args array if operation has some - local args - if columns and args_to_bind then - local errors - args, errors = encode_cassandra_args(self._schema, args_to_bind, columns) - if errors then - return nil, DaoError(errors, error_types.INVALID_TYPE) - end - end - - -- Execute statement - return self:_execute(query, args, options) -end - --- Check all fields marked with a `unique` in the schema do not already exist. -function BaseDao:check_unique_fields(t, is_update) - local errors - - for k, field in pairs(self._schema.fields) do - if field.unique and t[k] ~= nil then - local res, err = self:find_by_keys {[k] = t[k]} - if err then - return false, nil, "Error during UNIQUE check: "..err.message - elseif res and #res > 0 then - local is_self = true - if is_update then - -- If update, check if the retrieved entity is not the entity itself - res = res[1] - for _, key in ipairs(self._primary_key) do - if t[key] ~= res[key] then - is_self = false - break - end - end - else - is_self = false - end - - if not is_self then - errors = utils.add_error(errors, k, k.." already exists with value '"..t[k].."'") - end - end - end - end - - return errors == nil, errors -end - --- Check all fields marked as `foreign` in the schema exist on other column families. -function BaseDao:check_foreign_fields(t) - local errors, foreign_type, foreign_field, res, err - - for k, field in pairs(self._schema.fields) do - if field.foreign ~= nil and type(field.foreign) == "string" then - foreign_type, foreign_field = unpack(stringy.split(field.foreign, ":")) - if foreign_type and foreign_field and self._factory[foreign_type] and t[k] ~= nil and t[k] ~= constants.DATABASE_NULL_ID then - res, err = self._factory[foreign_type]:find_by_keys {[foreign_field] = t[k]} - if err then - return false, nil, "Error during FOREIGN check: "..err.message - elseif not res or #res == 0 then - errors = utils.add_error(errors, k, k.." "..t[k].." does not exist") - end - end - end - end - - return errors == nil, errors -end - --- Prepare a query and insert it into the statement cache. --- @param `query` The query to prepare --- @return `statement` The prepared statement, ready to be used by lua-resty-cassandra. --- @return `error` Error if any during the preparation of the statement -function BaseDao:prepare_stmt(session, query) - assert(type(query) == "string", "Query to prepare must be a string") - query = stringy.strip(query) - - local prepared_stmt, prepare_err = session:prepare(query) - if prepare_err then - return nil, DaoError("Failed to prepare statement: \""..query.."\". "..prepare_err, error_types.DATABASE) - else - local session_addr = session_uniq_addr(session) - -- cache of prepared statements must be specific to each node - if not self._statements_cache[session_addr] then - self._statements_cache[session_addr] = {} - end - - -- cache key is the non-striped/non-formatted query from _queries - self._statements_cache[session_addr][query] = prepared_stmt - return prepared_stmt - end -end - --- Insert a row in the DAO's table. --- Perform schema validation, UNIQUE checks, FOREIGN checks. --- @param `t` A table representing the entity to insert --- @return `result` Inserted entity or nil --- @return `error` Error if any during the execution +--- Children DAOs interface. +-- Those methds are to be used in any child DAO and will perform the named operations +-- the entity they represent. +-- @section inherited + +--- +-- Insert a row in the defined column family (defined by the **_table** attribute). +-- Perform schema validation, 'UNIQUE' checks, 'FOREIGN' checks. +-- @see check_unique_fields +-- @see check_foreign_fields +-- @param[table=table] t A table representing the entity to insert. +-- @treturn table Inserted entity or nil. +-- @treturn table Error if any during the execution. function BaseDao:insert(t) assert(t ~= nil, "Cannot insert a nil element") assert(type(t) == "table", "Entity to insert must be a table") @@ -389,7 +151,7 @@ function BaseDao:insert(t) end local insert_q, columns = query_builder.insert(self._table, t) - local _, stmt_err = self:execute(insert_q, columns, self:_marshall(t)) + local _, stmt_err = self:build_args_and_execute(insert_q, columns, self:_marshall(t)) if stmt_err then return nil, stmt_err else @@ -413,6 +175,7 @@ local function extract_primary_key(t, primary_key, clustering_key) return t_primary_key, t_no_primary_key end +--- -- When updating a row that has a json-as-text column (ex: plugin.config), -- we want to avoid overriding it with a partial value. -- Ex: config.key_name + config.hide_credential, if we update only one field, @@ -431,13 +194,15 @@ local function fix_tables(t, old_t, schema) end end --- Update a row: find the row with the given PRIMARY KEY and update the other values --- If `full`, sets to NULL values that are not included in the schema. --- Performs schema validation, UNIQUE and FOREIGN checks. --- @param `t` A table representing the entity to insert --- @param `full` If `true`, set to NULL any column not in the `t` parameter --- @return `result` Updated entity or nil --- @return `error` Error if any during the execution +--- +-- Update an entity: find the row with the given PRIMARY KEY and update the other values +-- Performs schema validation, 'UNIQUE' and 'FOREIGN' checks. +-- @see check_unique_fields +-- @see check_foreign_fields +-- @param[type=table] t A table representing the entity to update. It **must** contain the entity's PRIMARY KEY (can be composite). +-- @param[type=boolean] full If **true**, set to NULL any column not in the `t` parameter, such as a PUT query would do for example. +-- @treturn table Updated entity or nil. +-- @treturn table Error if any during the execution. function BaseDao:update(t, full) assert(t ~= nil, "Cannot update a nil element") assert(type(t) == "table", "Entity to update must be a table") @@ -489,14 +254,14 @@ function BaseDao:update(t, full) if full then for k, v in pairs(self._schema.fields) do if not t[k] and not v.immutable then - t_no_primary_key[k] = cassandra.null + t_no_primary_key[k] = cassandra.unset end end end local update_q, columns = query_builder.update(self._table, t_no_primary_key, t_primary_key) - local _, stmt_err = self:execute(update_q, columns, self:_marshall(t)) + local _, stmt_err = self:build_args_and_execute(update_q, columns, self:_marshall(t)) if stmt_err then return nil, stmt_err else @@ -504,10 +269,11 @@ function BaseDao:update(t, full) end end +--- -- Retrieve a row at given PRIMARY KEY. --- @param `where_t` A table containing the PRIMARY KEY (columns/values) of the row to retrieve. --- @return `row` The first row of the result. --- @return `error` +-- @param[type=table] where_t A table containing the PRIMARY KEY (it can be composite, hence be multiple columns as keys and their values) of the row to retrieve. +-- @treturn table The first row of the result. +-- @treturn table Error if any during the execution function BaseDao:find_by_primary_key(where_t) assert(self._primary_key ~= nil and type(self._primary_key) == "table" , "Entity does not have a primary_key") assert(where_t ~= nil and type(where_t) == "table", "where_t must be a table") @@ -519,7 +285,7 @@ function BaseDao:find_by_primary_key(where_t) end local select_q, where_columns = query_builder.select(self._table, t_primary_key, self._column_family_details, nil, true) - local data, err = self:execute(select_q, where_columns, t_primary_key) + local data, err = self:build_args_and_execute(select_q, where_columns, t_primary_key) -- Return the 1st and only element of the result set if data and utils.table_size(data) > 0 then @@ -531,16 +297,18 @@ function BaseDao:find_by_primary_key(where_t) return data, err end --- Retrieve a set of rows from the given columns/value table. --- @param `where_t` (Optional) columns/values table by which to find an entity. --- @param `page_size` Size of the page to retrieve (number of rows). --- @param `paging_state` Start page from given offset. See lua-resty-cassandra's :execute() option. --- @return `res` --- @return `err` --- @return `filtering` A boolean indicating if ALLOW FILTERING was needed by the query +--- +-- Retrieve a set of rows from the given columns/value table with a given +-- 'WHERE' clause. +-- @param[type=table] where_t (Optional) columns/values table by which to find an entity. +-- @param[type=number] page_size Size of the page to retrieve (number of rows). +-- @param[type=string] paging_state Start page from given offset. See lua-cassandra's related `execute()` option. +-- @treturn table An array (of possible length 0) of entities as the result of the query +-- @treturn table An error if any +-- @treturn boolean A boolean indicating if the 'ALLOW FILTERING' clause was needed by the query function BaseDao:find_by_keys(where_t, page_size, paging_state) local select_q, where_columns, filtering = query_builder.select(self._table, where_t, self._column_family_details) - local res, err = self:execute(select_q, where_columns, where_t, { + local res, err = self:build_args_and_execute(select_q, where_columns, where_t, { page_size = page_size, paging_state = paging_state }) @@ -548,58 +316,40 @@ function BaseDao:find_by_keys(where_t, page_size, paging_state) return res, err, filtering end --- Retrieve a page of the table attached to the DAO. --- @param `page_size` Size of the page to retrieve (number of rows). --- @param `paging_state` Start page from given offset. See lua-resty-cassandra's :execute() option. --- @return `find_by_keys()` -function BaseDao:find(page_size, paging_state) - return self:find_by_keys(nil, page_size, paging_state) -end - --- Add a delete hook on a parent DAO of a foreign row. --- The delete hook will basically "cascade delete" all foreign rows of a parent row. --- @see cassandra/factory.lua ':load_daos()' --- @param foreign_dao_name Name (string) of the parent DAO --- @param foreign_column Name (string) of the foreign column --- @param parent_column Name (string) of the parent column identifying the parent row -function BaseDao:add_delete_hook(foreign_dao_name, foreign_column, parent_column) - - -- The actual delete hook - -- @param deleted_primary_key The value of the deleted row's primary key - -- @return boolean True if success, false otherwise - -- @return table A DAOError in case of error - local delete_hook = function(deleted_primary_key) - local foreign_dao = self._factory[foreign_dao_name] - local select_args = { - [foreign_column] = deleted_primary_key[parent_column] - } - - -- Iterate over all rows with the foreign key and delete them. - -- Rows need to be deleted by PRIMARY KEY, and we only have the value of the foreign key, hence we need - -- to retrieve all rows with the foreign key, and then delete them, identifier by their own primary key. - local select_q, columns = query_builder.select(foreign_dao._table, select_args, foreign_dao._column_family_details ) - for rows, err in foreign_dao:execute(select_q, columns, select_args, {auto_paging = true}) do - if err then - return false, err - end - for _, row in ipairs(rows) do - local ok_del_foreign_row, err = foreign_dao:delete(row) - if not ok_del_foreign_row then - return false, err - end - end - end - - return true +--- +-- Retrieve the number of rows in the related column family matching a possible 'WHERE' clause. +-- @param[type=table] where_t (Optional) columns/values table by which to count entities. +-- @param[type=string] paging_state Start page from given offset. It'll be passed along to lua-cassandra `execute()` query_options. +-- @treturn number The number of rows matching the specified criteria. +-- @treturn table An error if any. +-- @treturn boolean A boolean indicating if the 'ALLOW FILTERING' clause was needed by the query. +function BaseDao:count_by_keys(where_t, paging_state) + local count_q, where_columns, filtering = query_builder.count(self._table, where_t, self._column_family_details) + local res, err = self:build_args_and_execute(count_q, where_columns, where_t, { + paging_state = paging_state + }) + if err then + return nil, err end - table.insert(self._cascade_delete_hooks, delete_hook) + return (#res >= 1 and table_remove(res, 1).count or 0), nil, filtering end --- Delete the row at a given PRIMARY KEY. --- @param `where_t` A table containing the PRIMARY KEY (columns/values) of the row to delete --- @return `success` True if deleted, false if otherwise or not found --- @return `error` Error if any during the query execution or the cascade delete hook +--- +-- Retrieve a page of rows from the related column family. +-- @param[type=number] page_size Size of the page to retrieve (number of rows). The default is the default value from lua-cassandra. +-- @param[type=string] paging_state Start page from given offset. It'll be passed along to lua-cassandra `execute()` query_options. +-- @return return values of find_by_keys() +-- @see find_by_keys +function BaseDao:find(page_size, paging_state) + return self:find_by_keys(nil, page_size, paging_state) +end + +--- +-- Delete the row with PRIMARY KEY from the configured table (**_table** attribute). +-- @param[table=table] where_t A table containing the PRIMARY KEY (columns/values) of the row to delete +-- @treturn boolean True if deleted, false if otherwise or not found. +-- @treturn table Error if any during the query execution or the cascade delete hook. function BaseDao:delete(where_t) assert(self._primary_key ~= nil and type(self._primary_key) == "table" , "Entity does not have a primary_key") assert(where_t ~= nil and type(where_t) == "table", "where_t must be a table") @@ -614,14 +364,14 @@ function BaseDao:delete(where_t) local t_primary_key = extract_primary_key(where_t, self._primary_key, self._clustering_key) local delete_q, where_columns = query_builder.delete(self._table, t_primary_key) - local results, err = self:execute(delete_q, where_columns, where_t) + local results, err = self:build_args_and_execute(delete_q, where_columns, where_t) if err then return false, err end -- Delete successful, trigger cascade delete hooks if any. local foreign_err - for _, hook in ipairs(self._cascade_delete_hooks) do + for _, hook in ipairs(self.cascade_delete_hooks) do foreign_err = select(2, hook(t_primary_key)) if foreign_err then return false, foreign_err @@ -631,11 +381,239 @@ function BaseDao:delete(where_t) return results end --- Truncate the table of this DAO --- @return `:execute()` +--- +-- Truncate the table related to this DAO (the **_table** attribute). +-- Only executes a 'TRUNCATE' query using the @{execute} method. +-- @return Return values of execute(). +-- @see execute function BaseDao:drop() local truncate_q = query_builder.truncate(self._table) return self:execute(truncate_q) end +--- Optional overrides. +-- Can be optionally overriden by a child DAO. +-- @section optional + +--- Constructor. +-- Instanciate a new Cassandra DAO. This method is to be overriden from the +-- child class and called once the child class has a schema set. +-- @param properties Cassandra properties from the configuration file. +-- @treturn table Instanciated DAO. +function BaseDao:new(properties) + if self._schema then + self._primary_key = self._schema.primary_key + self._clustering_key = self._schema.clustering_key + local indexes = {} + for field_k, field_v in pairs(self._schema.fields) do + if field_v.queryable then + indexes[field_k] = true + end + end + + self._column_family_details = { + primary_key = self._primary_key, + clustering_key = self._clustering_key, + indexes = indexes + } + end + + self.properties = properties + self.cascade_delete_hooks = {} +end + +--- +-- Marshall an entity. +-- Executed on each entity insertion to serialize +-- eventual properties for Cassandra storage. +-- Does nothing by default, must be overriden for entities where marshalling applies. +-- @see _unmarshall +-- @param[type=table] t Entity to marshall. +-- @treturn table Serialized entity. +function BaseDao:_marshall(t) + return t +end + +--- +-- Unmarshall an entity. +-- Executed each time an entity is being retrieved from Cassandra +-- to deserialize properties serialized by `:_mashall()`, +-- Does nothing by default, must be overriden for entities where marshalling applies. +-- @see _marshall +-- @param[type=table] t Entity to unmarshall. +-- @treturn table Deserialized entity. +function BaseDao:_unmarshall(t) + return t +end + +--- Private methods. +-- For internal use in the base_dao itself or advanced usage in a child DAO. +-- @section private + +--- +-- @local +-- Build the array of arguments to pass to lua-cassandra's `execute()` method. +-- Note: +-- Since this method only accepts an ordered list, we build this list from +-- the entity `t` and an (ordered) array of parameters for a query, taking +-- into account special cassandra values (uuid, timestamps, NULL). +-- @param[type=table] schema A schema with type properties to encode specific values. +-- @param[type=table] t Values to bind to a statement. +-- @param[type=table] parameters An ordered list of parameters. +-- @treturn table An ordered list of values to pass to lua-cassandra `execute()` args. +-- @treturn table Error Cassandra type validation errors +local function encode_cassandra_args(schema, t, args_keys) + local args_to_bind = {} + local errors + + for _, column in ipairs(args_keys) do + local schema_field = schema.fields[column] + local arg = t[column] + + if schema_field.type == "id" and arg then + if validations.is_valid_uuid(arg) then + arg = cassandra.uuid(arg) + else + errors = utils.add_error(errors, column, arg.." is an invalid uuid") + end + elseif schema_field.type == "timestamp" and arg then + arg = cassandra.timestamp(arg) + elseif arg == nil then + arg = cassandra.unset + end + + table.insert(args_to_bind, arg) + end + + return args_to_bind, errors +end + +--- +-- Bind a table of arguments to a query depending on the entity's schema, +-- and then execute the query via `execute()`. +-- @param[type=string] query The query to execute. +-- @param[type=table] columns A list of column names where each value indicates the column of the value at the same index in `args_to_bind`. +-- @param[type=table] args_to_bind Key/value table of arguments to bind. +-- @param[type=table] query_options Options to pass to lua-cassandra `execute()` query_options. +-- @return return values of `execute()`. +-- @see _execute +function BaseDao:build_args_and_execute(query, columns, args_to_bind, query_options) + -- Build args array if operation has some + local args + if columns and args_to_bind then + local errors + args, errors = encode_cassandra_args(self._schema, args_to_bind, columns) + if errors then + return nil, DaoError(errors, error_types.INVALID_TYPE) + end + end + + return self:execute(query, args, query_options) +end + +--- Perform "unique" check on a column. +-- Check that all fields marked with `unique` in the schema do not already exist +-- with the same value. +-- @param[type=table] t Key/value representation of the entity +-- @param[type=boolean] is_update If true, ignore an identical value if the row containing it is the one we are trying to update. +-- @treturn boolean True if all unique fields are not already present, false if any already exists with the same value. +-- @treturn table A key/value table of all columns (as keys) having values already in the database. +function BaseDao:check_unique_fields(t, is_update) + local errors + + for k, field in pairs(self._schema.fields) do + if field.unique and t[k] ~= nil then + local res, err = self:find_by_keys {[k] = t[k]} + if err then + return false, nil, "Error during UNIQUE check: "..err.message + elseif res and #res > 0 then + local is_self = true + if is_update then + -- If update, check if the retrieved entity is not the entity itself + res = res[1] + for _, key in ipairs(self._primary_key) do + if t[key] ~= res[key] then + is_self = false + break + end + end + else + is_self = false + end + + if not is_self then + errors = utils.add_error(errors, k, k.." already exists with value '"..t[k].."'") + end + end + end + end + + return errors == nil, errors +end + +--- Perform "foreign" check on a column. +-- Check all fields marked with `foreign` in the schema have an existing parent row. +-- @param[type=table] t Key/value representation of the entity. +-- @treturn boolean True if all fields marked as foreign have a parent row. +-- @treturn table A key/value table of all columns (as keys) not having a parent row. +function BaseDao:check_foreign_fields(t) + local errors, foreign_type, foreign_field, res, err + + for k, field in pairs(self._schema.fields) do + if field.foreign ~= nil and type(field.foreign) == "string" then + foreign_type, foreign_field = unpack(stringy.split(field.foreign, ":")) + if foreign_type and foreign_field and self._factory[foreign_type] and t[k] ~= nil and t[k] ~= constants.DATABASE_NULL_ID then + res, err = self._factory[foreign_type]:find_by_keys {[foreign_field] = t[k]} + if err then + return false, nil, "Error during FOREIGN check: "..err.message + elseif not res or #res == 0 then + errors = utils.add_error(errors, k, k.." "..t[k].." does not exist") + end + end + end + end + + return errors == nil, errors +end + +-- Add a delete hook on a parent DAO of a foreign row. +-- The delete hook will basically "cascade delete" all foreign rows of a parent row. +-- @see cassandra/factory.lua `load_daos()`. +-- @param[type=string] foreign_dao_name Name of the parent DAO. +-- @param[type=string] foreign_column Name of the foreign column. +-- @param[type=string] parent_column Name of the parent column identifying the parent row. +function BaseDao:add_delete_hook(foreign_dao_name, foreign_column, parent_column) + + -- The actual delete hook. + -- @param[type=table] deleted_primary_key The value of the deleted row's primary key. + -- @treturn boolean True if success, false otherwise. + -- @treturn table A DAOError in case of error. + local delete_hook = function(deleted_primary_key) + local foreign_dao = self._factory[foreign_dao_name] + local select_args = { + [foreign_column] = deleted_primary_key[parent_column] + } + + -- Iterate over all rows with the foreign key and delete them. + -- Rows need to be deleted by PRIMARY KEY, and we only have the value of the foreign key, hence we need + -- to retrieve all rows with the foreign key, and then delete them, identifier by their own primary key. + local select_q, columns = query_builder.select(foreign_dao._table, select_args, foreign_dao._column_family_details ) + for rows, err in foreign_dao:build_args_and_execute(select_q, columns, select_args, {auto_paging = true}) do + if err then + return false, err + end + for _, row in ipairs(rows) do + local ok_del_foreign_row, err = foreign_dao:delete(row) + if not ok_del_foreign_row then + return false, err + end + end + end + + return true + end + + table.insert(self.cascade_delete_hooks, delete_hook) +end + return BaseDao diff --git a/kong/dao/cassandra/factory.lua b/kong/dao/cassandra/factory.lua index 6def78139db0..65ea48e5328c 100644 --- a/kong/dao/cassandra/factory.lua +++ b/kong/dao/cassandra/factory.lua @@ -11,6 +11,12 @@ local stringy = require "stringy" local Object = require "classic" local utils = require "kong.tools.utils" +if ngx ~= nil and type(ngx.get_phase) == "function" and ngx.get_phase() == "init" and not ngx.stub then + cassandra.set_log_level("INFO") +else + cassandra.set_log_level("QUIET") +end + local CassandraFactory = Object:extend() -- Shorthand for accessing one of the underlying DAOs @@ -24,11 +30,18 @@ end -- Instantiate a Cassandra Factory and all its DAOs for various entities -- @param `properties` Cassandra properties -function CassandraFactory:new(properties, plugins) - self._properties = properties +function CassandraFactory:new(properties, plugins, spawn_cluster) + self.properties = properties self.type = "cassandra" self.daos = {} + if spawn_cluster then + local ok, err = cassandra.spawn_cluster(self:get_session_options()) + if not ok then + error(err) + end + end + -- Load core entities DAOs for _, entity in ipairs({"apis", "consumers", "plugins"}) do self:load_daos(require("kong.dao.cassandra."..entity)) @@ -63,7 +76,7 @@ end function CassandraFactory:load_daos(plugin_daos) local dao for name, plugin_dao in pairs(plugin_daos) do - dao = plugin_dao(self._properties) + dao = plugin_dao(self.properties) dao._factory = self self.daos[name] = dao if dao._schema then @@ -97,18 +110,22 @@ function CassandraFactory:drop() end function CassandraFactory:get_session_options() - local options = { - ssl = self._properties.ssl, - ssl_verify = self._properties.ssl_verify, - ca_file = self._properties.ssl_certificate -- in case of using luasocket + return { + shm = "cassandra", + prepared_shm = "cassandra_prepared", + contact_points = self.properties.contact_points, + keyspace = self.properties.keyspace, + query_options = { + prepare = true + }, + username = self.properties.username, + password = self.properties.password, + ssl_options = { + enabled = self.properties.ssl.enabled, + verify = self.properties.ssl.verify, + ca = self.properties.ssl.certificate_authority + } } - - if self._properties.user and self._properties.password then - local PasswordAuthenticator = require "cassandra.authenticators.PasswordAuthenticator" - options.authenticator = PasswordAuthenticator(self._properties.user, self._properties.password) - end - - return options end -- Execute a string of queries separated by ; @@ -117,22 +134,16 @@ end -- @param {boolean} no_keyspace Won't set the keyspace if true -- @return {table} error if any function CassandraFactory:execute_queries(queries, no_keyspace) - local ok, err - local session = cassandra:new() - session:set_timeout(self._properties.timeout) - local options = self:get_session_options() + options.query_options.same_coordinator = true - ok, err = session:connect(self._properties.hosts or self._properties.contact_points, nil, options) - if not ok then - return DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) + if no_keyspace then + options.keyspace = nil end - if no_keyspace == nil then - ok, err = session:set_keyspace(self._properties.keyspace) - if not ok then - return DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) - end + local session, err = cassandra.spawn_session(options) + if not session then + return DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) end -- Cassandra only supports BATCH on DML statements. @@ -140,14 +151,18 @@ function CassandraFactory:execute_queries(queries, no_keyspace) queries = stringy.split(queries, ";") for _, query in ipairs(queries) do if stringy.strip(query) ~= "" then - local _, stmt_err = session:execute(query, nil, {consistency_level = cassandra.constants.consistency.ALL}) - if stmt_err then - return DaoError(stmt_err, constants.DATABASE_ERROR_TYPES.DATABASE) + err = select(2, session:execute(query)) + if err then + break end end end - session:close() + session:shutdown() + + if err then + return DaoError(err, constants.DATABASE_ERROR_TYPES.DATABASE) + end end return CassandraFactory diff --git a/kong/dao/cassandra/migrations.lua b/kong/dao/cassandra/migrations.lua index 49c50b78e787..d11467802b46 100644 --- a/kong/dao/cassandra/migrations.lua +++ b/kong/dao/cassandra/migrations.lua @@ -28,11 +28,7 @@ function Migrations:new(properties) end function Migrations:keyspace_exists(keyspace) - local rows, err = Migrations.super._execute(self, - self.queries.get_keyspace, - {self._properties.keyspace}, - nil, "system" - ) + local rows, err = Migrations.super.execute(self, self.queries.get_keyspace, {self.properties.keyspace}, nil, "system") if err then return nil, err else @@ -45,11 +41,7 @@ end -- @return query result -- @return error if any function Migrations:add_migration(migration_name, identifier) - return Migrations.super._execute(self, - self.queries.add_migration, - {cassandra.list({migration_name}), identifier}, - {consistency_level = cassandra.constants.consistency.ALL} - ) + return Migrations.super.execute(self, self.queries.add_migration, {cassandra.list({migration_name}), identifier}) end -- Return all logged migrations with a filter by identifier optionally. Check if keyspace exists before to avoid error during the first migration. @@ -67,17 +59,9 @@ function Migrations:get_migrations(identifier) local rows, err if identifier ~= nil then - rows, err = Migrations.super._execute(self, - self.queries.get_migrations, - {identifier}, - {consistency_level = cassandra.constants.consistency.ALL} - ) + rows, err = Migrations.super.execute(self, self.queries.get_migrations, {identifier}) else - rows, err = Migrations.super._execute(self, - self.queries.get_all_migrations, - nil, - {consistency_level = cassandra.constants.consistency.ALL} - ) + rows, err = Migrations.super.execute(self, self.queries.get_all_migrations) end if err and stringy.find(err.message, "unconfigured columnfamily schema_migrations") ~= nil then @@ -93,22 +77,18 @@ end -- @return query result -- @return error if any function Migrations:delete_migration(migration_name, identifier) - return Migrations.super._execute(self, - self.queries.delete_migration, - {cassandra.list({migration_name}), identifier}, - {consistency_level = cassandra.constants.consistency.ALL} - ) + return Migrations.super.execute(self, self.queries.delete_migration, {cassandra.list({migration_name}), identifier}) end -- Drop the entire keyspace -- @param `keyspace` Name of the keyspace to drop -- @return query result function Migrations:drop_keyspace(keyspace) - return Migrations.super._execute(self, string.format("DROP keyspace \"%s\"", keyspace)) + return Migrations.super.execute(self, string.format("DROP keyspace \"%s\"", keyspace)) end function Migrations:drop() -- never drop this end -return { migrations = Migrations } +return {migrations = Migrations} diff --git a/kong/dao/cassandra/plugins.lua b/kong/dao/cassandra/plugins.lua index bf9c5acfea04..fa61085877ce 100644 --- a/kong/dao/cassandra/plugins.lua +++ b/kong/dao/cassandra/plugins.lua @@ -4,6 +4,10 @@ local constants = require "kong.constants" local BaseDao = require "kong.dao.cassandra.base_dao" local cjson = require "cjson" +local pairs = pairs +local ipairs = ipairs +local table_insert = table.insert + local Plugins = BaseDao:extend() function Plugins:new(properties) @@ -45,37 +49,24 @@ function Plugins:update(t, full) end function Plugins:find_distinct() - -- Open session - local session, err = Plugins.super._open_session(self) - if err then - return nil, err - end - - local select_q = query_builder.select(self._table) - - -- Execute query local distinct_names = {} - for rows, err in Plugins.super.execute(self, select_q, nil, nil, {auto_paging=true}) do + local select_q = query_builder.select(self._table) + for rows, err in self:execute(select_q, nil, {auto_paging = true}) do if err then return nil, err - end - for _, v in ipairs(rows) do - -- Rows also contains other properties, so making sure it's a plugin - if v.name then - distinct_names[v.name] = true + elseif rows ~= nil then + for _, v in ipairs(rows) do + -- Rows also contains other properties, so making sure it's a plugin + if v.name then + distinct_names[v.name] = true + end end end end - -- Close session - local socket_err = Plugins.super._close_session(self, session) - if socket_err then - return nil, socket_err - end - local result = {} for k, _ in pairs(distinct_names) do - table.insert(result, k) + table_insert(result, k) end return result, nil diff --git a/kong/dao/cassandra/query_builder.lua b/kong/dao/cassandra/query_builder.lua index ffbb9af25797..ce8ebc809f47 100644 --- a/kong/dao/cassandra/query_builder.lua +++ b/kong/dao/cassandra/query_builder.lua @@ -15,6 +15,10 @@ local function select_fragment(column_family, select_columns) return string.format("SELECT %s FROM %s", select_columns, column_family) end +local function count_fragment(column_family) + return string.format("SELECT COUNT(*) FROM %s", column_family) +end + local function insert_fragment(column_family, insert_values) local values_placeholders, columns = {}, {} for column, value in pairs(insert_values) do @@ -51,7 +55,7 @@ local function where_fragment(where_t, column_family_details, no_filtering_check assert(type(where_t) == "table", "where_t must be a table") if next(where_t) == nil then if not no_filtering_check then - return "" + return "", nil, false else error("where_t must contain keys") end @@ -103,8 +107,6 @@ local function where_fragment(where_t, column_family_details, no_filtering_check if needs_filtering then filtering = " ALLOW FILTERING" - else - needs_filtering = false end where_parts = table.concat(where_parts, " AND ") @@ -129,6 +131,21 @@ function _M.select(column_family, where_t, column_family_details, select_columns return trim(string.format("%s %s", select_str, where_str)), columns, needed_filtering end +-- Generate a COUNT query with an optional WHERE instruction. +-- If building a WHERE instruction, we need some additional informations about the column family. +-- @param `column_family` Name of the column family +-- @param `column_family_details` Additional infos about the column family (partition key, clustering key, indexes) +-- @return `query` The SELECT query +-- @return `columns` An list of columns to bind for the query, in the order of the placeholder markers (?) +-- @return `needs_filtering` A boolean indicating if ALLOW FILTERING was added to this query or not +function _M.count(column_family, where_t, column_family_details) + assert(type(column_family) == "string", "column_family must be a string") + + local count_str = count_fragment(column_family) + local where_str, columns, needed_filtering = where_fragment(where_t, column_family_details) + return trim(string.format("%s %s", count_str, where_str)), columns, needed_filtering +end + -- Generate an INSERT query. -- @param `column_family` Name of the column family -- @param `insert_values` A columns/values table of values to insert diff --git a/kong/dao/cassandra/schema/migrations.lua b/kong/dao/cassandra/schema/migrations.lua index 0433f2752bc2..a2b009ed4fad 100644 --- a/kong/dao/cassandra/schema/migrations.lua +++ b/kong/dao/cassandra/schema/migrations.lua @@ -1,20 +1,44 @@ local Migrations = { - -- skeleton { init = true, name = "2015-01-12-175310_skeleton", up = function(options, dao_factory) - return dao_factory:execute_queries([[ - CREATE KEYSPACE IF NOT EXISTS "]]..options.keyspace..[[" - WITH REPLICATION = {'class' : 'SimpleStrategy', 'replication_factor' : 1}; + local keyspace_name = options.keyspace + local strategy, strategy_properties = options.replication_strategy, "" - USE "]]..options.keyspace..[["; + -- Format strategy options + if strategy == "SimpleStrategy" then + strategy_properties = string.format(", 'replication_factor': %s", options.replication_factor) + elseif strategy == "NetworkTopologyStrategy" then + local dcs = {} + for dc_name, dc_repl in pairs(options.data_centers) do + table.insert(dcs, string.format("'%s': %s", dc_name, dc_repl)) + end + if #dcs > 0 then + strategy_properties = string.format(", %s", table.concat(dcs, ", ")) + end + else + -- Strategy unknown + return "invalid replication_strategy class" + end + -- Format final keyspace creation query + local keyspace_str = string.format([[ + CREATE KEYSPACE IF NOT EXISTS "%s" + WITH REPLICATION = {'class': '%s'%s}; + ]], keyspace_name, strategy, strategy_properties) + + local err = dao_factory:execute_queries(keyspace_str, true) + if err then + return err + end + + return dao_factory:execute_queries [[ CREATE TABLE IF NOT EXISTS schema_migrations( id text PRIMARY KEY, migrations list ); - ]], true) + ]] end, down = function(options, dao_factory) return dao_factory:execute_queries [[ diff --git a/kong/dao/schemas/plugins.lua b/kong/dao/schemas/plugins.lua index b292955d4113..1149f9e224dc 100644 --- a/kong/dao/schemas/plugins.lua +++ b/kong/dao/schemas/plugins.lua @@ -19,33 +19,40 @@ return { clustering_key = {"name"}, fields = { id = { - type = "id", - dao_insert_value = true }, + type = "id", + dao_insert_value = true + }, created_at = { - type = "timestamp", - dao_insert_value = true }, + type = "timestamp", + dao_insert_value = true + }, api_id = { - type = "id", - required = true, - foreign = "apis:id", - queryable = true }, + type = "id", + required = true, + foreign = "apis:id", + queryable = true + }, consumer_id = { - type = "id", - foreign = "consumers:id", - queryable = true, - default = constants.DATABASE_NULL_ID }, + type = "id", + foreign = "consumers:id", + queryable = true, + default = constants.DATABASE_NULL_ID + }, name = { - type = "string", - required = true, - immutable = true, - queryable = true }, + type = "string", + required = true, + immutable = true, + queryable = true + }, config = { - type = "table", - schema = load_config_schema, - default = {} }, + type = "table", + schema = load_config_schema, + default = {} + }, enabled = { - type = "boolean", - default = true } + type = "boolean", + default = true + } }, self_check = function(self, plugin_t, dao, is_update) -- Load the config schema diff --git a/kong/kong.lua b/kong/kong.lua index 6aa745ad28e6..7bcb0c4af7d9 100644 --- a/kong/kong.lua +++ b/kong/kong.lua @@ -24,54 +24,28 @@ -- |[[ ]]| -- ========== -local IO = require "kong.tools.io" +local core = require "kong.core.handler" local utils = require "kong.tools.utils" -local cache = require "kong.tools.database_cache" -local stringy = require "stringy" -local constants = require "kong.constants" -local responses = require "kong.tools.responses" - --- Define the plugins to load here, in the appropriate order -local plugins = {} - -local _M = {} - -local function get_now() - return ngx.now() * 1000 -end - -local function load_plugin(api_id, consumer_id, plugin_name) - local cache_key = cache.plugin_key(plugin_name, api_id, consumer_id) - - local plugin = cache.get_or_set(cache_key, function() - local rows, err = dao.plugins:find_by_keys { - api_id = api_id, - consumer_id = consumer_id ~= nil and consumer_id or constants.DATABASE_NULL_ID, - name = plugin_name - } - if err then - return responses.send_HTTP_INTERNAL_SERVER_ERROR(err) - end - - if #rows > 0 then - return table.remove(rows, 1) - else - return { null = true } - end - end) - - if plugin and not plugin.null and plugin.enabled then - return plugin - else - return nil - end -end - -local function init_plugins() - -- TODO: this should be handled with other default configs - configuration.plugins_available = configuration.plugins_available or {} - - print("Discovering used plugins") +local dao_loader = require "kong.tools.dao_loader" +local config_loader = require "kong.tools.config_loader" +local plugins_iterator = require "kong.core.plugins_iterator" + +local ipairs = ipairs +local table_insert = table.insert +local table_sort = table.sort + +local loaded_plugins = {} +-- @TODO make those locals too +-- local configuration +-- local dao_factory + +--- Load enabled plugins on the node. +-- Get plugins in the DB (distinct by `name`), compare them with plugins +-- in kong.yml's `plugins_available`. If both lists match, return a list +-- of plugins sorted by execution priority for lua-nginx-module's context handlers. +-- @treturn table Array of plugins to execute in context handlers. +local function load_node_plugins(configuration) + ngx.log(ngx.DEBUG, "Discovering used plugins") local db_plugins, err = dao.plugins:find_distinct() if err then error(err) @@ -84,170 +58,111 @@ local function init_plugins() end end - local loaded_plugins = {} + local sorted_plugins = {} for _, v in ipairs(configuration.plugins_available) do local loaded, plugin_handler_mod = utils.load_module_if_exists("kong.plugins."..v..".handler") if not loaded then error("The following plugin has been enabled in the configuration but it is not installed on the system: "..v) else - print("Loading plugin: "..v) - table.insert(loaded_plugins, { + ngx.log(ngx.DEBUG, "Loading plugin: "..v) + table_insert(sorted_plugins, { name = v, handler = plugin_handler_mod() }) end end - table.sort(loaded_plugins, function(a, b) + table_sort(sorted_plugins, function(a, b) local priority_a = a.handler.PRIORITY or 0 local priority_b = b.handler.PRIORITY or 0 return priority_a > priority_b end) - -- resolver is always the first plugin as it is the one retrieving any needed information - table.insert(loaded_plugins, 1, { - resolver = true, - name = "resolver", - handler = require("kong.resolver.handler")() - }) - if configuration.send_anonymous_reports then - table.insert(loaded_plugins, 1, { - reports = true, + table_insert(sorted_plugins, 1, { name = "reports", - handler = require("kong.reports.handler")() + handler = require("kong.core.reports") }) end - return loaded_plugins + return sorted_plugins end --- To be called by nginx's init_by_lua directive. +--- Kong public context handlers. +-- @section kong_handlers + +local Kong = {} + +--- Init Kong's environment in the Nginx master process. +-- To be called by the lua-nginx-module `init_by_lua` directive. -- Execution: -- - load the configuration from the path computed by the CLI -- - instanciate the DAO Factory -- - load the used plugins -- - load all plugins if used and installed -- - sort the plugins by priority --- - load the resolver -- --- If any error during the initialization of the DAO or plugins, --- it will be thrown and needs to be catched in init_by_lua. -function _M.init() - -- Loading configuration - configuration, dao = IO.load_configuration_and_dao(os.getenv("KONG_CONF")) - - -- Initializing plugins - plugins = init_plugins() - +-- If any error happens during the initialization of the DAO or plugins, +-- it will be thrown and needs to be catched in `init_by_lua`. +function Kong.init() + configuration = config_loader.load(os.getenv("KONG_CONF")) + dao = dao_loader.load(configuration, true) + loaded_plugins = load_node_plugins(configuration) + process_id = utils.random_string() ngx.update_time() end --- Calls `init_worker()` on eveyr loaded plugin -function _M.exec_plugins_init_worker() - for _, plugin_t in ipairs(plugins) do - plugin_t.handler:init_worker() +function Kong.exec_plugins_init_worker() + core.init_worker() + + for _, plugin in ipairs(loaded_plugins) do + plugin.handler:init_worker() end end -function _M.exec_plugins_certificate() - ngx.ctx.plugin = {} - - for _, plugin_t in ipairs(plugins) do - if ngx.ctx.api then - ngx.ctx.plugin[plugin_t.name] = load_plugin(ngx.ctx.api.id, nil, plugin_t.name) - end +function Kong.exec_plugins_certificate() + core.certificate() - local plugin = ngx.ctx.plugin[plugin_t.name] - if not ngx.ctx.stop_phases and (plugin_t.resolver or plugin) then - plugin_t.handler:certificate(plugin and plugin.config or {}) - end + for plugin, plugin_conf in plugins_iterator(loaded_plugins, "certificate") do + plugin.handler:certificate(plugin_conf) end - - return end --- Calls `access()` on every loaded plugin -function _M.exec_plugins_access() - local start = get_now() - ngx.ctx.plugin = {} - - -- Iterate over all the plugins - for _, plugin_t in ipairs(plugins) do - if ngx.ctx.api then - ngx.ctx.plugin[plugin_t.name] = load_plugin(ngx.ctx.api.id, nil, plugin_t.name) - local consumer_id = ngx.ctx.authenticated_credential and ngx.ctx.authenticated_credential.consumer_id or nil - if consumer_id then - local app_plugin = load_plugin(ngx.ctx.api.id, consumer_id, plugin_t.name) - if app_plugin then - ngx.ctx.plugin[plugin_t.name] = app_plugin - end - end - end +function Kong.exec_plugins_access() + core.access.before() - local plugin = ngx.ctx.plugin[plugin_t.name] - if not ngx.ctx.stop_phases and (plugin_t.resolver or plugin) then - plugin_t.handler:access(plugin and plugin.config or {}) - end - end - -- Append any modified querystring parameters - local parts = stringy.split(ngx.var.backend_url, "?") - local final_url = parts[1] - if utils.table_size(ngx.req.get_uri_args()) > 0 then - final_url = final_url.."?"..ngx.encode_args(ngx.req.get_uri_args()) + for plugin, plugin_conf in plugins_iterator(loaded_plugins, "access") do + plugin.handler:access(plugin_conf) end - ngx.var.backend_url = final_url - local t_end = get_now() - ngx.ctx.kong_processing_access = t_end - start - -- Setting a property that will be available for every plugin - ngx.ctx.proxy_started_at = t_end + core.access.after() end --- Calls `header_filter()` on every loaded plugin -function _M.exec_plugins_header_filter() - local start = get_now() - -- Setting a property that will be available for every plugin - ngx.ctx.proxy_ended_at = start +function Kong.exec_plugins_header_filter() + core.header_filter.before() - if not ngx.ctx.stop_phases then - ngx.header["Via"] = constants.NAME.."/"..constants.VERSION - - for _, plugin_t in ipairs(plugins) do - local plugin = ngx.ctx.plugin[plugin_t.name] - if plugin then - plugin_t.handler:header_filter(plugin and plugin.config or {}) - end - end + for plugin, plugin_conf in plugins_iterator(loaded_plugins, "header_filter") do + plugin.handler:header_filter(plugin_conf) end - ngx.ctx.kong_processing_header_filter = get_now() - start + + core.header_filter.after() end --- Calls `body_filter()` on every loaded plugin -function _M.exec_plugins_body_filter() - local start = get_now() - if not ngx.ctx.stop_phases then - for _, plugin_t in ipairs(plugins) do - local plugin = ngx.ctx.plugin[plugin_t.name] - if plugin then - plugin_t.handler:body_filter(plugin and plugin.config or {}) - end - end +function Kong.exec_plugins_body_filter() + for plugin, plugin_conf in plugins_iterator(loaded_plugins, "body_filter") do + plugin.handler:body_filter(plugin_conf) end - ngx.ctx.kong_processing_body_filter = (ngx.ctx.kong_processing_body_filter or 0) + (get_now() - start) + + core.body_filter.after() end --- Calls `log()` on every loaded plugin -function _M.exec_plugins_log() - if not ngx.ctx.stop_phases then - for _, plugin_t in ipairs(plugins) do - local plugin = ngx.ctx.plugin[plugin_t.name] - if plugin or plugin_t.reports then - plugin_t.handler:log(plugin and plugin.config or {}) - end - end +function Kong.exec_plugins_log() + for plugin, plugin_conf in plugins_iterator(loaded_plugins, "log") do + plugin.handler:log(plugin_conf) end + + core.log() end -return _M +return Kong diff --git a/kong/plugins/acl/access.lua b/kong/plugins/acl/access.lua index a21922be4305..6950e8139696 100644 --- a/kong/plugins/acl/access.lua +++ b/kong/plugins/acl/access.lua @@ -50,7 +50,6 @@ function _M.execute(conf) end if block then - ngx.ctx.stop_phases = true -- interrupt other phases of this request return responses.send_HTTP_FORBIDDEN("You cannot consume this service") end diff --git a/kong/plugins/basic-auth/access.lua b/kong/plugins/basic-auth/access.lua index e9f6c189f644..4d272547f178 100644 --- a/kong/plugins/basic-auth/access.lua +++ b/kong/plugins/basic-auth/access.lua @@ -84,7 +84,6 @@ end function _M.execute(conf) -- If both headers are missing, return 401 if not (ngx.req.get_headers()[AUTHORIZATION] or ngx.req.get_headers()[PROXY_AUTHORIZATION]) then - ngx.ctx.stop_phases = true ngx.header["WWW-Authenticate"] = "Basic realm=\""..constants.NAME.."\"" return responses.send_HTTP_UNAUTHORIZED() end @@ -102,7 +101,6 @@ function _M.execute(conf) end if not credential or not validate_credentials(credential, given_password) then - ngx.ctx.stop_phases = true -- interrupt other phases of this request return responses.send_HTTP_FORBIDDEN("Invalid authentication credentials") end diff --git a/kong/plugins/basic-auth/crypto.lua b/kong/plugins/basic-auth/crypto.lua index 99f665bdf410..d1aaf466c9d3 100644 --- a/kong/plugins/basic-auth/crypto.lua +++ b/kong/plugins/basic-auth/crypto.lua @@ -1,4 +1,3 @@ ---- -- Module to encrypt the basic-auth credentials password field local crypto = require "crypto" diff --git a/kong/plugins/file-log/log.lua b/kong/plugins/file-log/log.lua index 1e1f546d30bd..1249f6d124fe 100644 --- a/kong/plugins/file-log/log.lua +++ b/kong/plugins/file-log/log.lua @@ -1,25 +1,17 @@ -- Copyright (C) Mashape, Inc. local ffi = require "ffi" -local bit = require "bit" local cjson = require "cjson" local fd_util = require "kong.plugins.file-log.fd_util" +local system_constants = require "lua_system_constants" local basic_serializer = require "kong.plugins.log-serializers.basic" ffi.cdef[[ int open(char * filename, int flags, int mode); int write(int fd, void * ptr, int numbytes); -]] - -local octal = function(n) return tonumber(n, 8) end -local O_CREAT = octal('0100') -local O_APPEND = octal('02000') -local O_WRONLY = octal('0001') - -local S_IWUSR = octal('00200') -local S_IRUSR = octal('00400') -local S_IXUSR = octal('00100') +char *strerror(int errnum); +]] local function string_to_char(str) return ffi.cast("uint8_t*", str) @@ -34,8 +26,15 @@ local function log(premature, conf, message) local fd = fd_util.get_fd(conf.path) if not fd then - fd = ffi.C.open(string_to_char(conf.path), bit.bor(O_CREAT, O_APPEND, O_WRONLY), bit.bor(S_IWUSR, S_IRUSR, S_IXUSR)) - fd_util.set_fd(conf.path, fd) + fd = ffi.C.open(string_to_char(conf.path), + bit.bor(system_constants.O_WRONLY(), system_constants.O_CREAT(), system_constants.O_APPEND()), + bit.bor(system_constants.S_IWUSR(), system_constants.S_IRUSR(), system_constants.S_IXUSR())) + if fd < 0 then + local errno = ffi.errno() + ngx.log(ngx.ERR, "[file-log] failed to open the file: ", ffi.string(ffi.C.strerror(errno))) + else + fd_util.set_fd(conf.path, fd) + end end ffi.C.write(fd, string_to_char(message), string.len(message)) diff --git a/kong/plugins/hmac-auth/access.lua b/kong/plugins/hmac-auth/access.lua index 22e1696dd644..5bac6ec459dc 100644 --- a/kong/plugins/hmac-auth/access.lua +++ b/kong/plugins/hmac-auth/access.lua @@ -145,7 +145,6 @@ function _M.execute(conf) local headers = ngx_set_headers(); -- If both headers are missing, return 401 if not (headers[AUTHORIZATION] or headers[PROXY_AUTHORIZATION]) then - ngx.ctx.stop_phases = true return responses.send_HTTP_UNAUTHORIZED() end @@ -171,7 +170,6 @@ function _M.execute(conf) end hmac_params.secret = credential.secret if not validate_signature(ngx.req, hmac_params, headers) then - ngx.ctx.stop_phases = true -- interrupt other phases of this request return responses.send_HTTP_FORBIDDEN("HMAC signature does not match") end diff --git a/kong/plugins/ip-restriction/access.lua b/kong/plugins/ip-restriction/access.lua index 6fd0990c2535..ad43db1da382 100644 --- a/kong/plugins/ip-restriction/access.lua +++ b/kong/plugins/ip-restriction/access.lua @@ -1,4 +1,4 @@ -local iputils = require "resty.iputils" +local iputils = require "resty.iputils" local responses = require "kong.tools.responses" local utils = require "kong.tools.utils" @@ -22,7 +22,6 @@ function _M.execute(conf) end if block then - ngx.ctx.stop_phases = true -- interrupt other phases of this request return responses.send_HTTP_FORBIDDEN("Your IP address is not allowed") end end diff --git a/kong/plugins/jwt/access.lua b/kong/plugins/jwt/access.lua index 23500b706f5e..9921a9018880 100644 --- a/kong/plugins/jwt/access.lua +++ b/kong/plugins/jwt/access.lua @@ -53,7 +53,6 @@ function _M.execute(conf) end if not token then - ngx.ctx.stop_phases = true return responses.send_HTTP_UNAUTHORIZED() end @@ -67,7 +66,6 @@ function _M.execute(conf) local jwt_secret_key = claims.iss if not jwt_secret_key then - ngx.ctx.stop_phases = true return responses.send_HTTP_UNAUTHORIZED("No mandatory 'iss' in claims") end @@ -82,20 +80,17 @@ function _M.execute(conf) end) if not jwt_secret then - ngx.ctx.stop_phases = true return responses.send_HTTP_FORBIDDEN("No credentials found for given 'iss'") end -- Now verify the JWT signature if not jwt:verify_signature(jwt_secret.secret) then - ngx.ctx.stop_phases = true return responses.send_HTTP_FORBIDDEN("Invalid signature") end -- Verify the JWT registered claims local ok_claims, errors = jwt:verify_registered_claims(conf.claims_to_verify) if not ok_claims then - ngx.ctx.stop_phases = true return responses.send_HTTP_FORBIDDEN(errors) end @@ -110,7 +105,6 @@ function _M.execute(conf) -- However this should not happen if not consumer then - ngx.ctx.stop_phases = true return responses.send_HTTP_FORBIDDEN(string_format("Could not find consumer for '%s=%s'", "iss", jwt_secret_key)) end diff --git a/kong/plugins/jwt/api.lua b/kong/plugins/jwt/api.lua index e0bb3ea6075b..df5dccc5f4ea 100644 --- a/kong/plugins/jwt/api.lua +++ b/kong/plugins/jwt/api.lua @@ -40,6 +40,10 @@ return { return helpers.responses.send_HTTP_OK(self.credential) end, + PATCH = function(self, dao_factory) + crud.patch(self.params, self.credential, dao_factory.jwt_secrets) + end, + DELETE = function(self, dao_factory) crud.delete(self.credential, dao_factory.jwt_secrets) end diff --git a/kong/plugins/jwt/jwt_parser.lua b/kong/plugins/jwt/jwt_parser.lua index b1aab47b3857..19e2a12f3d08 100644 --- a/kong/plugins/jwt/jwt_parser.lua +++ b/kong/plugins/jwt/jwt_parser.lua @@ -1,6 +1,5 @@ ---- JWT verification module --- --- Adapted version of x25/luajwt for Kong. It provide various improvements and +-- JWT verification module +-- Adapted version of x25/luajwt for Kong. It provides various improvements and -- an OOP architecture allowing the JWT to be parsed and verified separatly, -- avoiding multiple parsings. -- @@ -23,7 +22,7 @@ local setmetatable = setmetatable local alg_sign = { ["HS256"] = function(data, key) return crypto.hmac.digest("sha256", data, key, true) end --["HS384"] = function(data, key) return crypto.hmac.digest("sha384", data, key, true) end, - --["HS512"] = function(data, key) return crypto.hmac.digest("sha512", data, key, true) end, + --["HS512"] = function(data, key) return crypto.hmac.digest("sha512", data, key, true) end } --- Supported algorithms for verifying tokens. @@ -31,7 +30,7 @@ local alg_sign = { local alg_verify = { ["HS256"] = function(data, signature, key) return signature == alg_sign["HS256"](data, key) end --["HS384"] = function(data, signature, key) return signature == alg_sign["HS384"](data, key) end, - --["HS512"] = function(data, signature, key) return signature == alg_sign["HS512"](data, key) end, + --["HS512"] = function(data, signature, key) return signature == alg_sign["HS512"](data, key) end } --- base 64 encoding diff --git a/kong/plugins/key-auth/access.lua b/kong/plugins/key-auth/access.lua index 3c8b3367821b..9c634162e037 100644 --- a/kong/plugins/key-auth/access.lua +++ b/kong/plugins/key-auth/access.lua @@ -71,14 +71,12 @@ function _M.execute(conf) -- No key found in the request's headers or parameters if not key_found then - ngx.ctx.stop_phases = true ngx.header["WWW-Authenticate"] = "Key realm=\""..constants.NAME.."\"" return responses.send_HTTP_UNAUTHORIZED("No API Key found in headers, body or querystring") end -- No key found in the DB, this credential is invalid if not credential then - ngx.ctx.stop_phases = true -- interrupt other phases of this request return responses.send_HTTP_FORBIDDEN("Invalid authentication credentials") end diff --git a/kong/plugins/log-serializers/alf.lua b/kong/plugins/log-serializers/alf.lua index 92c36e44e8aa..88f326839797 100644 --- a/kong/plugins/log-serializers/alf.lua +++ b/kong/plugins/log-serializers/alf.lua @@ -1,7 +1,7 @@ -- ALF serializer module. -- ALF is the format supported by Mashape Analytics (http://apianalytics.com) -- --- This module represents _one_ ALF, zhich has _one_ ALF entry. +-- This module represents _one_ ALF, which has _one_ ALF entry. -- It used to be a representation of one ALF with several entries, but ALF -- had its `clientIPAddress` moved to the root level of ALF, hence breaking -- this implementation. @@ -15,8 +15,6 @@ -- - Nginx lua module documentation: http://wiki.nginx.org/HttpLuaModule -- - ngx_http_core_module: http://wiki.nginx.org/HttpCoreModule#.24http_HEADER -local stringy = require "stringy" - local type = type local pairs = pairs local ipairs = ipairs @@ -60,6 +58,26 @@ local function dic_to_array(hash, fn) end end +--- Get a header from nginx's headers +-- Make sure that is multiple headers of a same name are present, +-- we only want the last one. Also include a default value if +-- no header is present. +-- @param `headers` ngx's request or response headers table. +-- @param `name` Name of the desired header to retrieve. +-- @param `default` String returned in case no header is found. +-- @return `header` The header value (a string) or the default, or nil. +local function get_header(headers, name, default) + local val = headers[name] + if val ~= nil then + if type(val) == "table" then + val = val[#val] + end + return val + end + + return default +end + local _M = {} -- Serialize `ngx` into one ALF entry. @@ -81,29 +99,10 @@ function _M.serialize_entry(ngx) local alf_base64_res_body = ngx_encode_base64(alf_res_body) -- timers - local proxy_started_at, proxy_ended_at = ngx.ctx.proxy_started_at, ngx.ctx.proxy_ended_at - - local alf_started_at = ngx.req.start_time() - - -- First byte sent to upstream - first byte received from client - local alf_send_time = proxy_started_at - alf_started_at * 1000 - - -- Time waiting for the upstream response - local upstream_response_time = 0 - local upstream_response_times = ngx.var.upstream_response_time - if not upstream_response_times or upstream_response_times == "-" then - -- client aborted the request - return - end - - upstream_response_times = stringy.split(upstream_response_times, ", ") - for _, val in ipairs(upstream_response_times) do - upstream_response_time = upstream_response_time + val - end - local alf_wait_time = upstream_response_time * 1000 - - -- upstream response fully received - upstream response 1 byte received - local alf_receive_time = analytics_data.response_received and analytics_data.response_received - proxy_ended_at or -1 + -- @see core.handler for their definition + local alf_send_time = ngx.ctx.KONG_PROXY_LATENCY or -1 + local alf_wait_time = ngx.ctx.KONG_WAITING_TIME or -1 + local alf_receive_time = ngx.ctx.KONG_RECEIVE_TIME or -1 -- Compute the total time. If some properties were unavailable -- (because the proxying was aborted), then don't add the value. @@ -125,11 +124,11 @@ function _M.serialize_entry(ngx) local alf_res_headers_size = string_len(res_headers_str) -- mimeType, defaulting to "application/octet-stream" - local alf_req_mimeType = req_headers["Content-Type"] and req_headers["Content-Type"] or "application/octet-stream" - local alf_res_mimeType = res_headers["Content-Type"] and res_headers["Content-Type"] or "application/octet-stream" + local alf_req_mimeType = get_header(req_headers, "Content-Type", "application/octet-stream") + local alf_res_mimeType = get_header(res_headers, "Content-Type", "application/octet-stream") return { - startedDateTime = os_date("!%Y-%m-%dT%TZ", alf_started_at), + startedDateTime = os_date("!%Y-%m-%dT%TZ", ngx.req.start_time()), time = alf_time, request = { method = ngx.req.get_method(), diff --git a/kong/plugins/log-serializers/basic.lua b/kong/plugins/log-serializers/basic.lua index 6571c9dd9eac..d3e611b4a7b3 100644 --- a/kong/plugins/log-serializers/basic.lua +++ b/kong/plugins/log-serializers/basic.lua @@ -24,10 +24,9 @@ function _M.serialize(ngx) size = ngx.var.bytes_sent }, latencies = { - kong = (ngx.ctx.kong_processing_access or 0) + - (ngx.ctx.kong_processing_header_filter or 0) + - (ngx.ctx.kong_processing_body_filter or 0), - proxy = ngx.var.upstream_response_time * 1000, + kong = (ngx.ctx.KONG_ACCESS_TIME or 0) + + (ngx.ctx.KONG_RECEIVE_TIME or 0), + proxy = ngx.ctx.KONG_WAITING_TIME or -1, request = ngx.var.request_time * 1000 }, authenticated_entity = authenticated_entity, diff --git a/kong/plugins/loggly/handler.lua b/kong/plugins/loggly/handler.lua new file mode 100644 index 000000000000..3827061ee7bc --- /dev/null +++ b/kong/plugins/loggly/handler.lua @@ -0,0 +1,20 @@ +local log = require "kong.plugins.loggly.log" +local BasePlugin = require "kong.plugins.base_plugin" +local basic_serializer = require "kong.plugins.log-serializers.basic" + +local LogglyLogHandler = BasePlugin:extend() + +function LogglyLogHandler:new() + LogglyLogHandler.super.new(self, "loggly") +end + +function LogglyLogHandler:log(conf) + LogglyLogHandler.super.log(self) + + local message = basic_serializer.serialize(ngx) + log.execute(conf, message) +end + +LogglyLogHandler.PRIORITY = 1 + +return LogglyLogHandler diff --git a/kong/plugins/loggly/log.lua b/kong/plugins/loggly/log.lua new file mode 100644 index 000000000000..3f7ec83252a7 --- /dev/null +++ b/kong/plugins/loggly/log.lua @@ -0,0 +1,104 @@ +local cjson = require "cjson" + +local os_date = os.date +local tostring = tostring +local ngx_log = ngx.log +local ngx_timer_at = ngx.timer.at +local ngx_socket_udp = ngx.socket.udp +local table_concat = table.concat +local table_insert = table.insert + +local _M = {} + +local function getHostname() + local f = io.popen ("/bin/hostname") + local hostname = f:read("*a") or "" + f:close() + hostname = string.gsub(hostname, "\n$", "") + return hostname +end + +local HOSTNAME = getHostname() +local SENDER_NAME = "kong" + +local LOG_LEVELS = { + debug = 7, + info = 6, + notice = 5, + warning = 4, + err = 3, + crit = 2, + alert = 1, + emerg = 0 +} + +local function merge(conf, message, pri) + local tags_list = conf.tags + local tags = {} + for i = 1, #tags_list do + table_insert(tags, "tag=".."\""..tags_list[i].."\"") + end + + local udp_message = { + "<"..pri..">1", + os_date("!%Y-%m-%dT%XZ"), + HOSTNAME, + SENDER_NAME, + "-", + "-", + "["..conf.key.."@41058", table_concat(tags, " ").."]", + cjson.encode(message) + } + return table_concat(udp_message, " ") +end + +local function send_to_loggly(conf, message, pri) + local host = conf.host + local port = conf.port + local timeout = conf.timeout + local udp_message = merge(conf, message, pri) + local sock = ngx_socket_udp() + sock:settimeout(timeout) + + local ok, err = sock:setpeername(host, port) + if not ok then + ngx_log(ngx.ERR, "failed to connect to "..host..":"..tostring(port)..": ", err) + return + end + local ok, err = sock:send(udp_message) + if not ok then + ngx_log(ngx.ERR, "failed to send data to ".. host..":"..tostring(port)..": ", err) + end + + local ok, err = sock:close() + if not ok then + ngx_log(ngx.ERR, "failed to close connection from "..host..":"..tostring(port)..": ", err) + return + end +end + +local function decide_severity(conf, severity, message) + if LOG_LEVELS[severity] <= LOG_LEVELS[conf.log_level] then + local pri = 8 + LOG_LEVELS[severity] + return send_to_loggly(conf, message, pri) + end +end + +local function log(premature, conf, message) + if message.response.status >= 500 then + return decide_severity(conf.log_level, conf.server_errors_severity, message) + elseif message.response.status >= 400 then + return decide_severity(conf.log_level, conf.client_errors_severity, message) + else + return decide_severity(conf, conf.successful_severity, message) + end +end + +function _M.execute(conf, message) + local ok, err = ngx_timer_at(0, log, conf, message) + if not ok then + ngx_log(ngx.ERR, "failed to create timer: ", err) + end +end + +return _M diff --git a/kong/plugins/loggly/schema.lua b/kong/plugins/loggly/schema.lua new file mode 100644 index 000000000000..fc069d58b45f --- /dev/null +++ b/kong/plugins/loggly/schema.lua @@ -0,0 +1,15 @@ +local ALLOWED_LEVELS = { "debug", "info", "notice", "warning", "err", "crit", "alert", "emerg" } + +return { + fields = { + host = { type = "string", default = "logs-01.loggly.com" }, + port = { type = "number", default = 514 }, + key = { required = true, type = "string"}, + tags = {type = "array", default = { "kong" }}, + log_level = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + successful_severity = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + client_errors_severity = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + server_errors_severity = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + timeout = { type = "number", default = 10000 } + } +} diff --git a/kong/plugins/mashape-analytics/handler.lua b/kong/plugins/mashape-analytics/handler.lua index b51f462fb4bc..e64063a628b9 100644 --- a/kong/plugins/mashape-analytics/handler.lua +++ b/kong/plugins/mashape-analytics/handler.lua @@ -18,7 +18,6 @@ local ALFBuffer = require "kong.plugins.mashape-analytics.buffer" local BasePlugin = require "kong.plugins.base_plugin" local ALFSerializer = require "kong.plugins.log-serializers.alf" -local ngx_now = ngx.now local ngx_log = ngx.log local ngx_log_ERR = ngx.ERR local string_find = string.find @@ -70,15 +69,11 @@ end function AnalyticsHandler:body_filter(conf) AnalyticsHandler.super.body_filter(self) - local chunk, eof = ngx.arg[1], ngx.arg[2] + local chunk = ngx.arg[1] -- concatenate response chunks for ALF's `response.content.text` if conf.log_body then ngx.ctx.analytics.res_body = ngx.ctx.analytics.res_body..chunk end - - if eof then -- latest chunk - ngx.ctx.analytics.response_received = ngx_now() * 1000 - end end function AnalyticsHandler:log(conf) diff --git a/kong/plugins/oauth2/access.lua b/kong/plugins/oauth2/access.lua index 336c861fb73b..3f6ad610ccdf 100644 --- a/kong/plugins/oauth2/access.lua +++ b/kong/plugins/oauth2/access.lua @@ -26,8 +26,9 @@ local GRANT_PASSWORD = "password" local ERROR = "error" local AUTHENTICATED_USERID = "authenticated_userid" -local AUTHORIZE_URL = "^%s/oauth2/authorize/?$" -local TOKEN_URL = "^%s/oauth2/token/?$" + +local AUTHORIZE_URL = "^%s/oauth2/authorize(/?(\\?[^\\s]*)?)$" +local TOKEN_URL = "^%s/oauth2/token(/?(\\?[^\\s]*)?)$" -- TODO: Expire token (using TTL ?) local function generate_token(conf, credential, authenticated_userid, scope, state, expiration, disable_refresh) @@ -75,10 +76,15 @@ local function get_redirect_uri(client_id) return client and client.redirect_uri or nil, client end -local function is_https() - local forwarded_proto_header = ngx.req.get_headers()["x-forwarded-proto"] +local HTTPS = "https" - return ngx.var.scheme:lower() == "https" or (forwarded_proto_header and forwarded_proto_header:lower() == "https") +local function is_https(conf) + local result = ngx.var.scheme:lower() == HTTPS + if not result and conf.accept_http_if_already_terminated then + local forwarded_proto_header = ngx.req.get_headers()["x-forwarded-proto"] + result = forwarded_proto_header and forwarded_proto_header:lower() == HTTPS + end + return result end local function retrieve_parameters() @@ -91,7 +97,7 @@ local function retrieve_scopes(parameters, conf) local scope = parameters[SCOPE] local scopes = {} if conf.scopes and scope then - for v in scope:gmatch("%w+") do + for v in scope:gmatch("%S+") do if not utils.table_contains(conf.scopes, v) then return false, {[ERROR] = "invalid_scope", error_description = "\""..v.."\" is an invalid "..SCOPE} else @@ -112,7 +118,7 @@ local function authorize(conf) local state = parameters[STATE] local redirect_uri, client - if not is_https() then + if not is_https(conf) then response_params = {[ERROR] = "access_denied", error_description = "You must use HTTPS"} else if conf.provision_key ~= parameters.provision_key then @@ -166,11 +172,8 @@ local function authorize(conf) -- Adding the state if it exists. If the state == nil then it won't be added response_params.state = state - -- Stopping other phases - ngx.ctx.stop_phases = true - -- Sending response in JSON format - responses.send(response_params[ERROR] and 400 or 200, redirect_uri and { + return responses.send(response_params[ERROR] and 400 or 200, redirect_uri and { redirect_uri = redirect_uri.."?"..ngx.encode_args(response_params) } or response_params, false, { ["cache-control"] = "no-store", @@ -216,7 +219,7 @@ local function issue_token(conf) local parameters = retrieve_parameters() --TODO: Also from authorization header local state = parameters[STATE] - if not is_https() then + if not is_https(conf) then response_params = {[ERROR] = "access_denied", error_description = "You must use HTTPS"} else local grant_type = parameters[GRANT_TYPE] @@ -294,11 +297,8 @@ local function issue_token(conf) -- Adding the state if it exists. If the state == nil then it won't be added response_params.state = state - -- Stopping other phases - ngx.ctx.stop_phases = true - -- Sending response in JSON format - responses.send(response_params[ERROR] and 400 or 200, response_params, false, { + return responses.send(response_params[ERROR] and 400 or 200, response_params, false, { ["cache-control"] = "no-store", ["pragma"] = "no-cache" }) @@ -328,7 +328,7 @@ local function parse_access_token(conf) local authorization = ngx.req.get_headers()["authorization"] if authorization then local parts = {} - for v in authorization:gmatch("%w+") do -- Split by space + for v in authorization:gmatch("%S+") do -- Split by space table.insert(parts, v) end if #parts == 2 and (parts[1]:lower() == "token" or parts[1]:lower() == "bearer") then @@ -376,18 +376,21 @@ function _M.execute(conf) end end - local token = retrieve_token(parse_access_token(conf)) + local accessToken = parse_access_token(conf); + if not accessToken then + return responses.send_HTTP_UNAUTHORIZED({}, false, {["WWW-Authenticate"] = 'Bearer realm="service"'}) + end + + local token = retrieve_token(accessToken) if not token then - ngx.ctx.stop_phases = true -- interrupt other phases of this request - return responses.send_HTTP_FORBIDDEN("Invalid authentication credentials") + return responses.send_HTTP_UNAUTHORIZED({[ERROR] = "invalid_token", error_description = "The access token is invalid"}, false, {["WWW-Authenticate"] = 'Bearer realm="service" error="invalid_token" error_description="The access token is invalid"'}) end -- Check expiration date if token.expires_in > 0 then -- zero means the token never expires local now = timestamp.get_utc() if now - token.created_at > (token.expires_in * 1000) then - ngx.ctx.stop_phases = true -- interrupt other phases of this request - return responses.send_HTTP_BAD_REQUEST({[ERROR] = "invalid_request", error_description = "access_token expired"}) + return responses.send_HTTP_UNAUTHORIZED({[ERROR] = "invalid_token", error_description = "The access token expired"}, false, {["WWW-Authenticate"] = 'Bearer realm="service" error="invalid_token" error_description="The access token expired"'}) end end diff --git a/kong/plugins/oauth2/schema.lua b/kong/plugins/oauth2/schema.lua index 8115c4a79fb2..32cd9a13af31 100644 --- a/kong/plugins/oauth2/schema.lua +++ b/kong/plugins/oauth2/schema.lua @@ -26,6 +26,7 @@ return { enable_implicit_grant = { required = true, type = "boolean", default = false }, enable_client_credentials = { required = true, type = "boolean", default = false }, enable_password_grant = { required = true, type = "boolean", default = false }, - hide_credentials = { type = "boolean", default = false } + hide_credentials = { type = "boolean", default = false }, + accept_http_if_already_terminated = { required = false, type = "boolean", default = false } } } diff --git a/kong/plugins/rate-limiting/access.lua b/kong/plugins/rate-limiting/access.lua index 5f947c3b60ef..11e99fa5cdb9 100644 --- a/kong/plugins/rate-limiting/access.lua +++ b/kong/plugins/rate-limiting/access.lua @@ -72,7 +72,6 @@ function _M.execute(conf) -- If limit is exceeded, terminate the request if stop then - ngx.ctx.stop_phases = true -- interrupt other phases of this request return responses.send(429, "API rate limit exceeded") end diff --git a/kong/plugins/rate-limiting/daos.lua b/kong/plugins/rate-limiting/daos.lua index dc3127ac7c1d..37ac937851d9 100644 --- a/kong/plugins/rate-limiting/daos.lua +++ b/kong/plugins/rate-limiting/daos.lua @@ -1,7 +1,11 @@ -local cassandra = require "cassandra" local BaseDao = require "kong.dao.cassandra.base_dao" +local cassandra = require "cassandra" local timestamp = require "kong.tools.timestamp" +local ngx_log = ngx and ngx.log or print +local ngx_err = ngx and ngx.ERR +local tostring = tostring + local RateLimitingMetrics = BaseDao:extend() function RateLimitingMetrics:new(properties) @@ -26,25 +30,37 @@ end function RateLimitingMetrics:increment(api_id, identifier, current_timestamp, value) local periods = timestamp.get_timestamps(current_timestamp) - local batch = cassandra:BatchStatement(cassandra.batch_types.COUNTER) + local options = self._factory:get_session_options() + local session, err = cassandra.spawn_session(options) + if err then + ngx_log(ngx_err, "[rate-limiting] could not spawn session to Cassandra: "..tostring(err)) + return + end + local ok = true for period, period_date in pairs(periods) do - batch:add(self.queries.increment_counter, { + local res, err = session:execute(self.queries.increment_counter, { cassandra.counter(value), cassandra.uuid(api_id), identifier, cassandra.timestamp(period_date), period }) + if not res then + ok = false + ngx_log(ngx_err, "[rate-limiting] could not increment counter for period '"..period.."': ", tostring(err)) + end end - return RateLimitingMetrics.super._execute(self, batch) + session:set_keep_alive() + + return ok end function RateLimitingMetrics:find_one(api_id, identifier, current_timestamp, period) local periods = timestamp.get_timestamps(current_timestamp) - local metric, err = RateLimitingMetrics.super._execute(self, self.queries.select_one, { + local metric, err = RateLimitingMetrics.super.execute(self, self.queries.select_one, { cassandra.uuid(api_id), identifier, cassandra.timestamp(periods[period]), @@ -86,4 +102,4 @@ function RateLimitingMetrics:find_by_keys() error("ratelimiting_metrics:find_by_keys() not supported", 2) end -return { ratelimiting_metrics = RateLimitingMetrics } +return {ratelimiting_metrics = RateLimitingMetrics} diff --git a/kong/plugins/request-transformer/access.lua b/kong/plugins/request-transformer/access.lua index d6076d753524..92382e230b85 100644 --- a/kong/plugins/request-transformer/access.lua +++ b/kong/plugins/request-transformer/access.lua @@ -1,6 +1,21 @@ -local utils = require "kong.tools.utils" local stringy = require "stringy" -local Multipart = require "multipart" +local multipart = require "multipart" + +local table_insert = table.insert +local req_set_uri_args = ngx.req.set_uri_args +local req_get_uri_args = ngx.req.get_uri_args +local req_set_header = ngx.req.set_header +local req_get_headers = ngx.req.get_headers +local req_read_body = ngx.req.read_body +local req_set_body_data = ngx.req.set_body_data +local req_get_body_data = ngx.req.get_body_data +local req_clear_header = ngx.req.clear_header +local req_get_post_args = ngx.req.get_post_args +local encode_args = ngx.encode_args +local type = type +local string_len = string.len + +local unpack = unpack local _M = {} @@ -10,120 +25,218 @@ local MULTIPART_DATA = "multipart/form-data" local CONTENT_TYPE = "content-type" local HOST = "host" -local function iterate_and_exec(val, cb) - if utils.table_size(val) > 0 then - for _, entry in ipairs(val) do - local parts = stringy.split(entry, ":") - cb(parts[1], utils.table_size(parts) == 2 and parts[2] or nil) + +local function iter(config_array) + return function(config_array, i, previous_name, previous_value) + i = i + 1 + local current_pair = config_array[i] + if current_pair == nil then -- n + 1 + return nil end - end + local current_name, current_value = unpack(stringy.split(current_pair, ":")) + return i, current_name, current_value + end, config_array, 0 end local function get_content_type() - local header_value = ngx.req.get_headers()[CONTENT_TYPE] + local header_value = req_get_headers()[CONTENT_TYPE] if header_value then return stringy.strip(header_value):lower() end end -function _M.execute(conf) - if conf.add then - - -- Add headers - if conf.add.headers then - iterate_and_exec(conf.add.headers, function(name, value) - ngx.req.set_header(name, value) - if name:lower() == HOST then -- Host header has a special treatment - ngx.var.backend_host = value - end - end) +local function append_value(current_value, value) + local current_value_type = type(current_value) + + if current_value_type == "string" then + return { current_value, value } + elseif current_value_type == "table" then + table_insert(current_value, value) + return current_value + else + return { value } + end +end + +local function transform_headers(conf) + -- Remove header(s) + for _, name, value in iter(conf.remove.headers) do + req_clear_header(name) + end + + -- Replace header(s) + for _, name, value in iter(conf.replace.headers) do + if req_get_headers()[name] then + req_set_header(name, value) + if name:lower() == HOST then -- Host header has a special treatment + ngx.var.backend_host = value + end end + end - -- Add Querystring - if conf.add.querystring then - local querystring = ngx.req.get_uri_args() - iterate_and_exec(conf.add.querystring, function(name, value) - querystring[name] = value - end) - ngx.req.set_uri_args(querystring) + -- Add header(s) + for _, name, value in iter(conf.add.headers) do + if not req_get_headers()[name] then + req_set_header(name, value) + if name:lower() == HOST then -- Host header has a special treatment + ngx.var.backend_host = value + end end + end - if conf.add.form then - local content_type = get_content_type() - if content_type and stringy.startswith(content_type, FORM_URLENCODED) then - -- Call ngx.req.read_body to read the request body first - ngx.req.read_body() + -- Append header(s) + for _, name, value in iter(conf.append.headers) do + req_set_header(name, append_value(req_get_headers()[name], value)) + if name:lower() == HOST then -- Host header has a special treatment + ngx.var.backend_host = value + end + end +end - local parameters = ngx.req.get_post_args() - iterate_and_exec(conf.add.form, function(name, value) - parameters[name] = value - end) - local encoded_args = ngx.encode_args(parameters) - ngx.req.set_header(CONTENT_LENGTH, string.len(encoded_args)) - ngx.req.set_body_data(encoded_args) - elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then - -- Call ngx.req.read_body to read the request body first - ngx.req.read_body() - - local body = ngx.req.get_body_data() - local parameters = Multipart(body and body or "", content_type) - iterate_and_exec(conf.add.form, function(name, value) - parameters:set_simple(name, value) - end) - local new_data = parameters:tostring() - ngx.req.set_header(CONTENT_LENGTH, string.len(new_data)) - ngx.req.set_body_data(new_data) - end +local function transform_querystrings(conf) + -- Remove querystring(s) + if conf.remove.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.remove.querystring) do + querystring[name] = nil end + req_set_uri_args(querystring) + end + -- Replace querystring(s) + if conf.replace.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.replace.querystring) do + if querystring[name] then + querystring[name] = value + end + end + req_set_uri_args(querystring) end - if conf.remove then + -- Add querystring(s) + if conf.add.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.add.querystring) do + if not querystring[name] then + querystring[name] = value + end + end + req_set_uri_args(querystring) + end - -- Remove headers - if conf.remove.headers then - iterate_and_exec(conf.remove.headers, function(name, value) - ngx.req.clear_header(name) - end) + -- Append querystring(s) + if conf.append.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.append.querystring) do + querystring[name] = append_value(querystring[name], value) end + req_set_uri_args(querystring) + end +end + +local function transform_form_params(conf) + -- Remove form parameter(s) + if conf.remove.form then + local content_type = get_content_type() + if content_type and stringy.startswith(content_type, FORM_URLENCODED) then + req_read_body() + local parameters = req_get_post_args() + + for _, name, value in iter(conf.remove.form) do + parameters[name] = nil + end - if conf.remove.querystring then - local querystring = ngx.req.get_uri_args() - iterate_and_exec(conf.remove.querystring, function(name) - querystring[name] = nil - end) - ngx.req.set_uri_args(querystring) + local encoded_args = encode_args(parameters) + req_set_header(CONTENT_LENGTH, string_len(encoded_args)) + req_set_body_data(encoded_args) + elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then + -- Call req_read_body to read the request body first + req_read_body() + + local body = req_get_body_data() + local parameters = multipart(body and body or "", content_type) + for _, name, value in iter(conf.remove.form) do + parameters:delete(name) + end + local new_data = parameters:tostring() + req_set_header(CONTENT_LENGTH, string_len(new_data)) + req_set_body_data(new_data) end + end + + -- Replace form parameter(s) + if conf.replace.form then + local content_type = get_content_type() + if content_type and stringy.startswith(content_type, FORM_URLENCODED) then + -- Call req_read_body to read the request body first + req_read_body() - if conf.remove.form then - local content_type = get_content_type() - if content_type and stringy.startswith(content_type, FORM_URLENCODED) then - local parameters = ngx.req.get_post_args() - - iterate_and_exec(conf.remove.form, function(name) - parameters[name] = nil - end) - - local encoded_args = ngx.encode_args(parameters) - ngx.req.set_header(CONTENT_LENGTH, string.len(encoded_args)) - ngx.req.set_body_data(encoded_args) - elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then - -- Call ngx.req.read_body to read the request body first - ngx.req.read_body() - - local body = ngx.req.get_body_data() - local parameters = Multipart(body and body or "", content_type) - iterate_and_exec(conf.remove.form, function(name) + local parameters = req_get_post_args() + for _, name, value in iter(conf.replace.form) do + if parameters[name] then + parameters[name] = value + end + end + local encoded_args = encode_args(parameters) + req_set_header(CONTENT_LENGTH, string_len(encoded_args)) + req_set_body_data(encoded_args) + elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then + -- Call req_read_body to read the request body first + req_read_body() + + local body = req_get_body_data() + local parameters = multipart(body and body or "", content_type) + for _, name, value in iter(conf.replace.form) do + if parameters:get(name) then parameters:delete(name) - end) - local new_data = parameters:tostring() - ngx.req.set_header(CONTENT_LENGTH, string.len(new_data)) - ngx.req.set_body_data(new_data) + parameters:set_simple(name, value) + end end + local new_data = parameters:tostring() + req_set_header(CONTENT_LENGTH, string_len(new_data)) + req_set_body_data(new_data) end + end + + -- Add form parameter(s) + if conf.add.form then + local content_type = get_content_type() + if content_type and stringy.startswith(content_type, FORM_URLENCODED) then + -- Call req_read_body to read the request body first + req_read_body() + local parameters = req_get_post_args() + for _, name, value in iter(conf.add.form) do + if not parameters[name] then + parameters[name] = value + end + end + local encoded_args = encode_args(parameters) + req_set_header(CONTENT_LENGTH, string_len(encoded_args)) + req_set_body_data(encoded_args) + elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then + -- Call req_read_body to read the request body first + req_read_body() + + local body = req_get_body_data() + local parameters = multipart(body and body or "", content_type) + for _, name, value in iter(conf.add.form) do + if not parameters:get(name) then + parameters:set_simple(name, value) + end + end + local new_data = parameters:tostring() + req_set_header(CONTENT_LENGTH, string_len(new_data)) + req_set_body_data(new_data) + end end +end +function _M.execute(conf) + transform_form_params(conf) + transform_headers(conf) + transform_querystrings(conf) end return _M diff --git a/kong/plugins/request-transformer/schema.lua b/kong/plugins/request-transformer/schema.lua index 00a02e109393..87bc2439d4e4 100644 --- a/kong/plugins/request-transformer/schema.lua +++ b/kong/plugins/request-transformer/schema.lua @@ -1,20 +1,41 @@ return { fields = { - add = { type = "table", - schema = { - fields = { - form = { type = "array" }, - headers = { type = "array" }, - querystring = { type = "array" } + remove = { + type = "table", + schema = { + fields = { + form = {type = "array", default = {}}, + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} + } + } + }, + replace = { + type = "table", + schema = { + fields = { + form = {type = "array", default = {}}, + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} + } + } + }, + add = { + type = "table", + schema = { + fields = { + form = {type = "array", default = {}}, + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} + } } - } }, - remove = { type = "table", + append = { + type = "table", schema = { fields = { - form = { type = "array" }, - headers = { type = "array" }, - querystring = { type = "array" } + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} } } } diff --git a/kong/plugins/response-ratelimiting/daos.lua b/kong/plugins/response-ratelimiting/daos.lua index 16fd17f48732..e8cdeb03f507 100644 --- a/kong/plugins/response-ratelimiting/daos.lua +++ b/kong/plugins/response-ratelimiting/daos.lua @@ -1,7 +1,11 @@ -local cassandra = require "cassandra" local BaseDao = require "kong.dao.cassandra.base_dao" +local cassandra = require "cassandra" local timestamp = require "kong.tools.timestamp" +local ngx_log = ngx and ngx.log or print +local ngx_err = ngx and ngx.ERR +local tostring = tostring + local ResponseRateLimitingMetrics = BaseDao:extend() function ResponseRateLimitingMetrics:new(properties) @@ -26,25 +30,37 @@ end function ResponseRateLimitingMetrics:increment(api_id, identifier, current_timestamp, value, name) local periods = timestamp.get_timestamps(current_timestamp) - local batch = cassandra:BatchStatement(cassandra.batch_types.COUNTER) + local options = self._factory:get_session_options() + local session, err = cassandra.spawn_session(options) + if err then + ngx_log(ngx_err, "[response-rate-limiting] could not spawn session to Cassandra: "..tostring(err)) + return + end + local ok = true for period, period_date in pairs(periods) do - batch:add(self.queries.increment_counter, { + local res, err = session:execute(self.queries.increment_counter, { cassandra.counter(value), cassandra.uuid(api_id), identifier, cassandra.timestamp(period_date), name.."_"..period }) + if not res then + ok = false + ngx_log(ngx_err, "[response-rate-limiting] could not increment counter for period '"..period.."': ", tostring(err)) + end end - return ResponseRateLimitingMetrics.super._execute(self, batch) + session:set_keep_alive() + + return ok end function ResponseRateLimitingMetrics:find_one(api_id, identifier, current_timestamp, period, name) local periods = timestamp.get_timestamps(current_timestamp) - local metric, err = ResponseRateLimitingMetrics.super._execute(self, self.queries.select_one, { + local metric, err = ResponseRateLimitingMetrics.super.execute(self, self.queries.select_one, { cassandra.uuid(api_id), identifier, cassandra.timestamp(periods[period]), @@ -86,4 +102,4 @@ function ResponseRateLimitingMetrics:find_by_keys() error("ratelimiting_metrics:find_by_keys() not supported", 2) end -return { response_ratelimiting_metrics = ResponseRateLimitingMetrics } +return {response_ratelimiting_metrics = ResponseRateLimitingMetrics} diff --git a/kong/plugins/response-ratelimiting/log.lua b/kong/plugins/response-ratelimiting/log.lua index cc9c04643929..638c2d95a2b4 100644 --- a/kong/plugins/response-ratelimiting/log.lua +++ b/kong/plugins/response-ratelimiting/log.lua @@ -4,7 +4,7 @@ local function increment(api_id, identifier, current_timestamp, value, name) -- Increment metrics for all periods if the request goes through local _, stmt_err = dao.response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, value, name) if stmt_err then - ngx.log(ngx.ERR, stmt_err) + ngx.log(ngx.ERR, tostring(stmt_err)) end end diff --git a/kong/plugins/ssl/access.lua b/kong/plugins/ssl/access.lua index 5f40ea5abee4..5da073be49ed 100644 --- a/kong/plugins/ssl/access.lua +++ b/kong/plugins/ssl/access.lua @@ -2,8 +2,19 @@ local responses = require "kong.tools.responses" local _M = {} +local HTTPS = "https" + +local function is_https(conf) + local result = ngx.var.scheme:lower() == HTTPS + if not result and conf.accept_http_if_already_terminated then + local forwarded_proto_header = ngx.req.get_headers()["x-forwarded-proto"] + result = forwarded_proto_header and forwarded_proto_header:lower() == HTTPS + end + return result +end + function _M.execute(conf) - if conf.only_https and ngx.var.scheme:lower() ~= "https" then + if conf.only_https and not is_https(conf) then ngx.header["connection"] = { "Upgrade" } ngx.header["upgrade"] = "TLS/1.0, HTTP/1.1" return responses.send(426, {message="Please use HTTPS protocol"}) diff --git a/kong/plugins/ssl/schema.lua b/kong/plugins/ssl/schema.lua index 11c897eacd77..7bcb0c2a3276 100644 --- a/kong/plugins/ssl/schema.lua +++ b/kong/plugins/ssl/schema.lua @@ -23,6 +23,7 @@ return { cert = { required = true, type = "string", func = validate_cert }, key = { required = true, type = "string", func = validate_key }, only_https = { required = false, type = "boolean", default = false }, + accept_http_if_already_terminated = { required = false, type = "boolean", default = false }, -- Internal use _cert_der_cache = { type = "string", immutable = true }, diff --git a/kong/plugins/syslog/handler.lua b/kong/plugins/syslog/handler.lua new file mode 100644 index 000000000000..a58b480b9584 --- /dev/null +++ b/kong/plugins/syslog/handler.lua @@ -0,0 +1,20 @@ +local log = require "kong.plugins.syslog.log" +local BasePlugin = require "kong.plugins.base_plugin" +local basic_serializer = require "kong.plugins.log-serializers.basic" + +local SysLogHandler = BasePlugin:extend() + +function SysLogHandler:new() + SysLogHandler.super.new(self, "syslog") +end + +function SysLogHandler:log(conf) + SysLogHandler.super.log(self) + + local message = basic_serializer.serialize(ngx) + log.execute(conf, message) +end + +SysLogHandler.PRIORITY = 1 + +return SysLogHandler diff --git a/kong/plugins/syslog/log.lua b/kong/plugins/syslog/log.lua new file mode 100644 index 000000000000..37ad4d7291c5 --- /dev/null +++ b/kong/plugins/syslog/log.lua @@ -0,0 +1,49 @@ +local lsyslog = require "lsyslog" +local cjson = require "cjson" + +local ngx_log = ngx.log +local ngx_timer_at = ngx.timer.at +local l_open = lsyslog.open +local l_log = lsyslog.log +local string_upper = string.upper + +local _M = {} + +local SENDER_NAME = "kong" + +local LOG_LEVELS = { + debug = 7, + info = 6, + notice = 5, + warning = 4, + err = 3, + crit = 2, + alert = 1, + emerg = 0 +} + +local function send_to_syslog(log_level, severity, message) + if LOG_LEVELS[severity] <= LOG_LEVELS[log_level] then + l_open(SENDER_NAME, lsyslog.FACILITY_USER) + l_log(lsyslog["LOG_"..string_upper(severity)], cjson.encode(message)) + end +end + +local function log(premature, conf, message) + if message.response.status >= 500 then + send_to_syslog(conf.log_level, conf.server_errors_severity, message) + elseif message.response.status >= 400 then + send_to_syslog(conf.log_level, conf.client_errors_severity, message) + else + send_to_syslog(conf.log_level, conf.successful_severity, message) + end +end + +function _M.execute(conf, message) + local ok, err = ngx_timer_at(0, log, conf, message) + if not ok then + ngx_log(ngx.ERR, "failed to create timer: ", err) + end +end + +return _M diff --git a/kong/plugins/syslog/schema.lua b/kong/plugins/syslog/schema.lua new file mode 100644 index 000000000000..92dc60ce4ac9 --- /dev/null +++ b/kong/plugins/syslog/schema.lua @@ -0,0 +1,10 @@ +local ALLOWED_LEVELS = { "debug", "info", "notice", "warning", "err", "crit", "alert", "emerg" } + +return { + fields = { + log_level = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + successful_severity = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + client_errors_severity = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + server_errors_severity = { type = "string", enum = ALLOWED_LEVELS, default = "info" }, + } +} diff --git a/kong/reports/handler.lua b/kong/reports/handler.lua deleted file mode 100644 index e931ca30daf9..000000000000 --- a/kong/reports/handler.lua +++ /dev/null @@ -1,21 +0,0 @@ -local BasePlugin = require "kong.plugins.base_plugin" -local init_worker = require "kong.reports.init_worker" -local log = require "kong.reports.log" - -local ReportsHandler = BasePlugin:extend() - -function ReportsHandler:new() - ReportsHandler.super.new(self, "reports") -end - -function ReportsHandler:init_worker() - ReportsHandler.super.init_worker(self) - init_worker.execute() -end - -function ReportsHandler:log() - ReportsHandler.super.log(self) - log.execute() -end - -return ReportsHandler diff --git a/kong/reports/log.lua b/kong/reports/log.lua deleted file mode 100644 index 284525bdc26e..000000000000 --- a/kong/reports/log.lua +++ /dev/null @@ -1,9 +0,0 @@ -local cache = require "kong.tools.database_cache" - -local _M = {} - -function _M.execute(conf) - cache.incr(cache.requests_key(), 1) -end - -return _M \ No newline at end of file diff --git a/kong/resolver/handler.lua b/kong/resolver/handler.lua deleted file mode 100644 index 85ffd046e694..000000000000 --- a/kong/resolver/handler.lua +++ /dev/null @@ -1,35 +0,0 @@ --- Kong resolver core-plugin --- --- This core-plugin is executed before any other, and allows to map a Host header --- to an API added to Kong. If the API was found, it will set the $backend_url variable --- allowing nginx to proxy the request as defined in the nginx configuration. --- --- Executions: 'access', 'header_filter' - -local access = require "kong.resolver.access" -local BasePlugin = require "kong.plugins.base_plugin" -local certificate = require "kong.resolver.certificate" -local header_filter = require "kong.resolver.header_filter" - -local ResolverHandler = BasePlugin:extend() - -function ResolverHandler:new() - ResolverHandler.super.new(self, "resolver") -end - -function ResolverHandler:access(conf) - ResolverHandler.super.access(self) - access.execute(conf) -end - -function ResolverHandler:certificate(conf) - ResolverHandler.super.certificate(self) - certificate.execute(conf) -end - -function ResolverHandler:header_filter(conf) - ResolverHandler.super.header_filter(self) - header_filter.execute(conf) -end - -return ResolverHandler diff --git a/kong/resolver/header_filter.lua b/kong/resolver/header_filter.lua deleted file mode 100644 index d5c1b40440b3..000000000000 --- a/kong/resolver/header_filter.lua +++ /dev/null @@ -1,11 +0,0 @@ -local constants = require "kong.constants" - -local _M = {} - -function _M.execute(conf) - local api_time = ngx.ctx.proxy_ended_at - ngx.ctx.proxy_started_at - ngx.header[constants.HEADERS.PROXY_TIME] = ngx.now() * 1000 - ngx.ctx.started_at - api_time - ngx.header[constants.HEADERS.API_TIME] = api_time -end - -return _M diff --git a/kong/tools/config_defaults.lua b/kong/tools/config_defaults.lua new file mode 100644 index 000000000000..7c53249b3e50 --- /dev/null +++ b/kong/tools/config_defaults.lua @@ -0,0 +1,62 @@ +return { + ["plugins_available"] = {type = "array", + default = {"ssl", "jwt", "acl", "cors", "oauth2", "tcp-log", "udp-log", "file-log", + "http-log", "key-auth", "hmac-auth", "basic-auth", "ip-restriction", + "mashape-analytics", "request-transformer", "response-transformer", + "request-size-limiting", "rate-limiting", "response-ratelimiting", "syslog", "loggly"} + }, + ["nginx_working_dir"] = {type = "string", default = "/usr/local/kong"}, + ["proxy_port"] = {type = "number", default = 8000}, + ["proxy_ssl_port"] = {type = "number", default = 8443}, + ["admin_api_port"] = {type = "number", default = 8001}, + ["dns_resolver"] = {type = "string", default = "dnsmasq", enum = {"server", "dnsmasq"}}, + ["dns_resolvers_available"] = { + type = "table", + content = { + ["server"] = { + type = "table", + content = { + ["address"] = {type = "string", default = "8.8.8.8:53"} + } + }, + ["dnsmasq"] = { + type = "table", + content = { + ["port"] = {type = "number", default = 8053} + } + } + } + }, + ["database"] = {type = "string", default = "cassandra"}, + ["databases_available"] = { + type = "table", + content = { + ["cassandra"] = { + type = "table", + content = { + ["contact_points"] = {type = "array", default = {"localhost:9042"}}, + ["keyspace"] = {type = "string", default = "kong"}, + ["replication_strategy"] = {type = "string", default = "SimpleStrategy", enum = {"SimpleStrategy", "NetworkTopologyStrategy"}}, + ["replication_factor"] = {type = "number", default = 1}, + ["data_centers"] = {type = "table", default = {}}, + ["username"] = {type = "string", nullable = true}, + ["password"] = {type = "string", nullable = true}, + ["ssl"] = { + type = "table", + content = { + ["enabled"] = {type = "boolean", default = false}, + ["verify"] = {type = "boolean", default = false}, + ["certificate_authority"] = {type = "string", nullable = true} + } + } + } + } + } + }, + ["database_cache_expiration"] = {type = "number", default = 5}, + ["ssl_cert_path"] = {type = "string", nullable = true}, + ["ssl_key_path"] = {type = "string", nullable = true}, + ["send_anonymous_reports"] = {type = "boolean", default = false}, + ["memory_cache_size"] = {type = "number", default = 128, min = 32}, + ["nginx"] = {type = "string", nullable = true} +} diff --git a/kong/tools/config_loader.lua b/kong/tools/config_loader.lua new file mode 100644 index 000000000000..c48c73e7c428 --- /dev/null +++ b/kong/tools/config_loader.lua @@ -0,0 +1,135 @@ +local yaml = require "yaml" +local IO = require "kong.tools.io" +local utils = require "kong.tools.utils" +local cutils = require "kong.cli.utils" +local stringy = require "stringy" +local constants = require "kong.constants" +local config_defaults = require "kong.tools.config_defaults" + +local function get_type(value, val_type) + if val_type == "array" and utils.is_array(value) then + return "array" + else + return type(value) + end +end + +local checks = { + type = function(value, key_infos, value_type) + if value_type ~= key_infos.type then + return "must be a "..key_infos.type + end + end, + minimum = function(value, key_infos, value_type) + if value_type == "number" and key_infos.min ~= nil and value < key_infos.min then + return "must be greater than "..key_infos.min + end + end, + enum = function(value, key_infos, value_type) + if key_infos.enum ~= nil and not utils.table_contains(key_infos.enum, value) then + return string.format("must be one of: '%s'", table.concat(key_infos.enum, ", ")) + end + end +} + +local function validate_config_schema(config, config_schema) + if not config_schema then config_schema = config_defaults end + local errors, property + + for config_key, key_infos in pairs(config_schema) do + -- Default value + property = config[config_key] or key_infos.default + + -- Recursion on table values + if key_infos.type == "table" and key_infos.content ~= nil then + if property == nil then + property = {} + end + + local ok, s_errors = validate_config_schema(property, key_infos.content) + if not ok then + for s_k, s_v in pairs(s_errors) do + errors = utils.add_error(errors, config_key.."."..s_k, s_v) + end + end + end + + -- Nullable checking + if property ~= nil and not key_infos.nullable then + local property_type = get_type(property, key_infos.type) + local err + -- Individual checks + for _, check_fun in pairs(checks) do + err = check_fun(property, key_infos, property_type) + if err then + errors = utils.add_error(errors, config_key, err) + end + end + end + + config[config_key] = property + end + + return errors == nil, errors +end + +local _M = {} + +function _M.validate(config) + local ok, errors = validate_config_schema(config) + if not ok then + return false, errors + end + + -- Check selected database + if config.databases_available[config.database] == nil then + return false, {database = config.database.." is not listed in databases_available"} + end + + return true +end + +function _M.load(config_path) + local config_contents = IO.read_file(config_path) + if not config_contents then + cutils.logger:error_exit("No configuration file at: "..config_path) + end + + local config = yaml.load(config_contents) + + local ok, errors = _M.validate(config) + if not ok then + for config_key, config_error in pairs(errors) do + if type(config_error) == "table" then + config_error = table.concat(config_error, ", ") + end + cutils.logger:warn(string.format("%s: %s", config_key, config_error)) + end + cutils.logger:error_exit("Invalid properties in given configuration file") + end + + -- Adding computed properties + config.pid_file = IO.path:join(config.nginx_working_dir, constants.CLI.NGINX_PID) + config.dao_config = config.databases_available[config.database] + if config.dns_resolver == "dnsmasq" then + config.dns_resolver = { + address = "127.0.0.1:"..config.dns_resolvers_available.dnsmasq.port, + port = config.dns_resolvers_available.dnsmasq.port, + dnsmasq = true + } + else + config.dns_resolver = {address = config.dns_resolver.server.address} + end + + + -- Load absolute path for the nginx working directory + if not stringy.startswith(config.nginx_working_dir, "/") then + -- It's a relative path, convert it to absolute + local fs = require "luarocks.fs" + config.nginx_working_dir = fs.current_dir().."/"..config.nginx_working_dir + end + + return config +end + +return _M diff --git a/kong/tools/dao_loader.lua b/kong/tools/dao_loader.lua new file mode 100644 index 000000000000..ea576ac8a926 --- /dev/null +++ b/kong/tools/dao_loader.lua @@ -0,0 +1,8 @@ +local _M = {} + +function _M.load(config, spawn_cluster) + local DaoFactory = require("kong.dao."..config.database..".factory") + return DaoFactory(config.dao_config, config.plugins_available, spawn_cluster) +end + +return _M diff --git a/kong/tools/http_client.lua b/kong/tools/http_client.lua index 14a25957ae1d..5b8e9ccfa7bc 100644 --- a/kong/tools/http_client.lua +++ b/kong/tools/http_client.lua @@ -55,7 +55,7 @@ local function with_body(method) else headers["content-type"] = "application/x-www-form-urlencoded" if type(body) == "table" then - body = ngx.encode_args(body) + body = ngx.encode_args(body, true) end end @@ -75,7 +75,7 @@ local function without_body(method) if not headers then headers = {} end if querystring then - url = string.format("%s?%s", url, ngx.encode_args(querystring)) + url = string.format("%s?%s", url, ngx.encode_args(querystring, true)) end return http_call { diff --git a/kong/tools/io.lua b/kong/tools/io.lua index 615e5ef709c4..2c8c521d6d7f 100644 --- a/kong/tools/io.lua +++ b/kong/tools/io.lua @@ -1,11 +1,7 @@ ---- -- IO related utility functions --- -local yaml = require "yaml" local path = require("path").new("/") local stringy = require "stringy" -local constants = require "kong.constants" local _M = {} @@ -102,44 +98,4 @@ function _M.file_size(path) return size end ---- Load a yaml configuration file. --- The return config will get 2 extra fields; `pid_file` of the nginx process --- and `dao_config` as a shortcut to the dao configuration --- @param configuration_path path to configuration file to load --- @return config Loaded configuration table --- @return dao_factory the accompanying DAO factory -function _M.load_configuration_and_dao(configuration_path) - local configuration_file = _M.read_file(configuration_path) - if not configuration_file then - error("No configuration file at: "..configuration_path) - end - - -- Configuration should already be validated by the CLI at this point - local configuration = yaml.load(configuration_file) - - local dao_config = configuration.databases_available[configuration.database] - if dao_config == nil then - error('No "'..configuration.database..'" dao defined') - end - - -- Adding computed properties to the configuration - configuration.pid_file = path:join(configuration.nginx_working_dir, constants.CLI.NGINX_PID) - - -- Alias the DAO configuration we are using for this instance for easy access - configuration.dao_config = dao_config - - -- Load absolute path for the nginx working directory - if not stringy.startswith(configuration.nginx_working_dir, "/") then - -- It's a relative path, convert it to absolute - local fs = require "luarocks.fs" - configuration.nginx_working_dir = fs.current_dir().."/"..configuration.nginx_working_dir - end - - -- Instantiate the DAO Factory along with the configuration - local DaoFactory = require("kong.dao."..configuration.database..".factory") - local dao_factory = DaoFactory(dao_config.properties, configuration.plugins_available) - - return configuration, dao_factory -end - return _M diff --git a/kong/tools/migrations.lua b/kong/tools/migrations.lua index fb4f26cd1031..29738a63fdbd 100644 --- a/kong/tools/migrations.lua +++ b/kong/tools/migrations.lua @@ -14,7 +14,7 @@ function Migrations:new(dao, kong_config, core_migrations_module, plugins_namesp dao:load_daos(require("kong.dao."..dao.type..".migrations")) self.dao = dao - self.options = dao._properties + self.dao_properties = dao.properties self.migrations = { [_CORE_MIGRATIONS_IDENTIFIER] = require(core_migrations_module) } @@ -91,7 +91,7 @@ function Migrations:run_migrations(identifier, before, on_each_success) -- Execute all new migrations, in order for _, migration in ipairs(diff_migrations) do - local err = migration.up(self.options, self.dao) + local err = migration.up(self.dao_properties, self.dao) if err then return fmt('Error executing migration for "%s": %s', identifier, err) end @@ -146,7 +146,7 @@ function Migrations:run_rollback(identifier, before, on_success) before(identifier) end - local err = migration_to_rollback.down(self.options, self.dao) + local err = migration_to_rollback.down(self.dao_properties, self.dao) if err then return fmt('Error rollbacking migration for "%s": %s', identifier, err) end diff --git a/kong/tools/ngx_stub.lua b/kong/tools/ngx_stub.lua index 2d17539293d2..3e2141e05ba8 100644 --- a/kong/tools/ngx_stub.lua +++ b/kong/tools/ngx_stub.lua @@ -1,10 +1,11 @@ ---- Stub _G.ngx for unit testing. +-- Stub _G.ngx for unit testing. -- Creates a stub for `ngx` for use by Kong's modules such as the DAO. It allows to use them -- outside of the nginx context such as when using the CLI, or unit testing. -- -- Monkeypatches the global `ngx` table. local reg = require "rex_pcre" +local utils = require "kong.tools.utils" -- DICT Proxy -- https://github.com/bsm/fakengx/blob/master/fakengx.lua @@ -102,9 +103,11 @@ local shared_mt = { } _G.ngx = { + stub = true, req = {}, ctx = {}, header = {}, + get_phase = function() return "init" end, exit = function() end, say = function() end, log = function() end, @@ -125,37 +128,5 @@ _G.ngx = { encode_base64 = function(str) return string.format("base64_%s", str) end, - -- Builds a querystring from a table, separated by `&` - -- @param `tab` The key/value parameters - -- @param `key` The parent key if the value is multi-dimensional (optional) - -- @return `querystring` A string representing the built querystring - encode_args = function(tab, key) - local query = {} - local keys = {} - - for k in pairs(tab) do - keys[#keys+1] = k - end - - table.sort(keys) - - for _, name in ipairs(keys) do - local value = tab[name] - if key then - name = string.format("%s[%s]", tostring(key), tostring(name)) - end - if type(value) == "table" then - query[#query+1] = ngx.encode_args(value, name) - else - value = tostring(value) - if value ~= "" then - query[#query+1] = string.format("%s=%s", name, value) - else - query[#query+1] = name - end - end - end - - return table.concat(query, "&") - end + encode_args = utils.encode_args } diff --git a/kong/tools/printable.lua b/kong/tools/printable.lua index 6228870f2cba..59809fd1b514 100644 --- a/kong/tools/printable.lua +++ b/kong/tools/printable.lua @@ -1,17 +1,23 @@ -- A metatable for pretty printing a table with key=value properties -- -- Example: --- { hello = "world", foo = "bar", baz = {"hello", "world"} } +-- {hello = "world", foo = "bar", baz = {"hello", "world"}} -- Output: -- "hello=world foo=bar, baz=hello,world" +local utils = require "kong.tools.utils" + local printable_mt = {} function printable_mt:__tostring() local t = {} for k, v in pairs(self) do if type(v) == "table" then - v = table.concat(v, ",") + if utils.is_array(v) then + v = table.concat(v, ",") + else + setmetatable(v, printable_mt) + end end table.insert(t, k.."="..tostring(v)) diff --git a/kong/tools/responses.lua b/kong/tools/responses.lua index dd3372b91623..86036c36e2a0 100644 --- a/kong/tools/responses.lua +++ b/kong/tools/responses.lua @@ -1,10 +1,39 @@ --- Kong helper methods to send HTTP responses to clients. --- Can be used in the proxy, plugins or admin API. --- Most used status codes and responses are implemented as helper methods. +-- Can be used in the proxy (core/resolver), plugins or Admin API. +-- Most used HTTP status codes and responses are implemented as helper methods. -- --- @author thibaultcha +-- local responses = require "kong.tools.responses" +-- +-- -- In an Admin API endpoint handler, or in one of the plugins' phases. +-- -- the `return` keyword is optional since the execution will be stopped +-- -- anyways. It simply improves code readability. +-- return responses.send_HTTP_OK() +-- +-- -- Or: +-- return responses.send_HTTP_NOT_FOUND("No entity for given id") +-- +-- -- Raw send() helper: +-- return responses.send(418, "This is a teapot") --- Define the most used HTTP status codes through Kong +--- Define the most common HTTP status codes for sugar methods. +-- Each of those status will generate a helper method (sugar) +-- attached to this exported module prefixed with `send_`. +-- Final signature of those methods will be `send_(message, raw, headers)`. See @{send} for more details on those parameters. +-- @field HTTP_OK 200 OK +-- @field HTTP_CREATED 201 Created +-- @field HTTP_NO_CONTENT 204 No Content +-- @field HTTP_BAD_REQUEST 400 Bad Request +-- @field HTTP_UNAUTHORIZED 401 Unauthorized +-- @field HTTP_FORBIDDEN 403 Forbidden +-- @field HTTP_NOT_FOUND 404 Not Found +-- @field HTTP_METHOD_NOT_ALLOWED 405 Method Not Allowed +-- @field HTTP_CONFLICT 409 Conflict +-- @field HTTP_UNSUPPORTED_MEDIA_TYPE 415 Unsupported Media Type +-- @field HTTP_INTERNAL_SERVER_ERROR Internal Server Error +-- @usage return responses.send_HTTP_OK() +-- @usage return responses.HTTP_CREATED("Entity created") +-- @usage return responses.HTTP_INTERNAL_SERVER_ERROR() +-- @table status_codes local _M = { status_codes = { HTTP_OK = 200, @@ -21,8 +50,15 @@ local _M = { } } --- Define some rules that will ALWAYS be applied to some status codes. --- Ex: 204 must not have content, but if 404 has no content then "Not found" will be set. +--- Define some default response bodies for some status codes. +-- Some other status codes will have response bodies that cannot be overriden. +-- Example: 204 MUST NOT have content, but if 404 has no content then "Not found" will be set. +-- @field status_codes.HTTP_UNAUTHORIZED Default: Unauthorized +-- @field status_codes.HTTP_NO_CONTENT Always empty. +-- @field status_codes.HTTP_NOT_FOUND Default: Not Found +-- @field status_codes.HTTP_UNAUTHORIZED Default: Unauthorized +-- @field status_codes.HTTP_INTERNAL_SERVER_ERROR Always "Internal Server Error" +-- @field status_codes.HTTP_METHOD_NOT_ALLOWED Always "Method not allowed" local response_default_content = { [_M.status_codes.HTTP_UNAUTHORIZED] = function(content) return content or "Unauthorized" @@ -42,25 +78,24 @@ local response_default_content = { } -- Return a closure which will be usable to respond with a certain status code. --- @param `status_code` The status for which to define a function --- --- Send a JSON response for the closure's status code with the given content. --- If the content happens to be an error (>500), it will be logged by ngx.log as an ERR. --- @see http://wiki.nginx.org/HttpLuaModule --- @param `content` (Optional) The content to send as a response. --- @param `raw` (Optional) A boolean defining if the `content` should not be serialized to JSON --- This useed to send text as JSON in some edge-cases of cjson. --- @return `ngx.exit()` +-- @local +-- @param[type=number] status_code The status for which to define a function local function send_response(status_code) local constants = require "kong.constants" local cjson = require "cjson" + -- Send a JSON response for the closure's status code with the given content. + -- If the content happens to be an error (>500), it will be logged by ngx.log as an ERR. + -- @see https://github.com/openresty/lua-nginx-module + -- @param content (Optional) The content to send as a response. + -- @param raw (Optional) A boolean defining if the `content` should not be serialized to JSON + -- This useed to send text as JSON in some edge-cases of cjson. + -- @return ngx.exit (Exit current context) return function(content, raw, headers) if status_code >= _M.status_codes.HTTP_INTERNAL_SERVER_ERROR then if content then ngx.log(ngx.ERR, tostring(content)) end - ngx.ctx.stop_phases = true -- interrupt other phases of this request end ngx.status = status_code @@ -98,16 +133,26 @@ for status_code_name, status_code in pairs(_M.status_codes) do end local closure_cache = {} --- Sends any status code as a response. This is useful for plugins which want to --- send a response when the status code is not defined in `_M.status_codes` and thus --- has no sugar method on `_M`. -function _M.send(status_code, content, raw, headers) + +--- Send a response with any status code or body, +-- Not all status codes are available as sugar methods, this function can be +-- used to send any response. +-- If the `status_code` parameter is in the 5xx range, it is expectde that the `content` parameter be the error encountered. It will be logged and the response body will be empty. The user will just receive a 500 status code. +-- Will call `ngx.say` and `ngx.exit`, terminating the current context. +-- @see ngx.say +-- @see ngx.exit +-- @param[type=number] status_code HTTP status code to send +-- @param body A string or table which will be the body of the sent response. If table, the response will be encoded as a JSON object. If string, the response will be a JSON object and the string will be contained in the `message` property. Except if the `raw` parameter is set to `true`. +-- @param[type=boolean] raw If true, send the `body` as it is. +-- @param[type=table] headers Response headers to send. +-- @return ngx.exit (Exit current context) +function _M.send(status_code, body, raw, headers) local res = closure_cache[status_code] if not res then res = send_response(status_code) closure_cache[status_code] = res end - return res(content, raw, headers) + return res(body, raw, headers) end return _M diff --git a/kong/tools/timestamp.lua b/kong/tools/timestamp.lua index 1d2999a205c3..c3e4230c0b47 100644 --- a/kong/tools/timestamp.lua +++ b/kong/tools/timestamp.lua @@ -1,4 +1,4 @@ ---- +-- -- Module for timestamp support. -- Based on the LuaTZ module. local luatz = require "luatz" diff --git a/kong/tools/utils.lua b/kong/tools/utils.lua index a74257902fe3..8bccb3b205c0 100644 --- a/kong/tools/utils.lua +++ b/kong/tools/utils.lua @@ -1,10 +1,22 @@ --- --- Module containing some general utility functions - -local uuid = require "uuid" - --- This is important to seed the UUID generator -uuid.seed() +-- Module containing some general utility functions used in many places in Kong. +-- +-- NOTE: Before implementing a function here, consider if it will be used in many places +-- across Kong. If not, a local function in the appropriate module is prefered. +-- + +local url = require "socket.url" +local uuid = require "lua_uuid" + +local type = type +local pairs = pairs +local ipairs = ipairs +local tostring = tostring +local table_sort = table.sort +local table_concat = table.concat +local table_insert = table.insert +local string_find = string.find +local string_format = string.format local _M = {} @@ -14,6 +26,65 @@ function _M.random_string() return uuid():gsub("-", "") end +--- URL escape and format key and value +-- An obligatory url.unescape pass must be done to prevent double-encoding +-- already encoded values (which contain a '%' character that `url.escape` escapes) +local function encode_args_value(key, value, raw) + if not raw then + key = url.unescape(key) + key = url.escape(key) + end + if value ~= nil then + if not raw then + value = url.unescape(value) + value = url.escape(value) + end + return string_format("%s=%s", key, value) + else + return key + end +end + +--- Encode a Lua table to a querystring +-- Tries to mimic ngx_lua's `ngx.encode_args`, but also percent-encode querystring values. +-- Supports multi-value query args, boolean values. +-- It also supports encoding for bodies (only because it is used in http_client for specs. +-- @TODO drop and use `ngx.encode_args` once it implements percent-encoding. +-- @see https://github.com/Mashape/kong/issues/749 +-- @param[type=table] args A key/value table containing the query args to encode. +-- @param[type=boolean] raw If true, will not percent-encode any key/value and will ignore special boolean rules. +-- @treturn string A valid querystring (without the prefixing '?') +function _M.encode_args(args, raw) + local query = {} + local keys = {} + + for k in pairs(args) do + keys[#keys+1] = k + end + + table_sort(keys) + + for _, key in ipairs(keys) do + local value = args[key] + if type(value) == "table" then + for _, sub_value in ipairs(value) do + query[#query+1] = encode_args_value(key, sub_value, raw) + end + elseif value == true then + query[#query+1] = encode_args_value(key, raw and true or nil, raw) + elseif value ~= false and value ~= nil or raw then + value = tostring(value) + if value ~= "" then + query[#query+1] = encode_args_value(key, value, raw) + elseif raw then + query[#query+1] = key + end + end + end + + return table_concat(query, "&") +end + --- Calculates a table size. -- All entries both in array and hash part. -- @param t The table to use @@ -60,6 +131,7 @@ end -- @param t The table to check -- @return Returns `true` if the table is an array, `false` otherwise function _M.is_array(t) + if type(t) ~= "table" then return false end local i = 0 for _ in pairs(t) do i = i + 1 @@ -102,7 +174,7 @@ function _M.add_error(errors, k, v) errors[k] = setmetatable({errors[k]}, err_list_mt) end - table.insert(errors[k], v) + table_insert(errors[k], v) else errors[k] = v end @@ -121,7 +193,7 @@ function _M.load_module_if_exists(module_name) if status then return true, res -- Here we match any character because if a module has a dash '-' in its name, we would need to escape it. - elseif type(res) == "string" and string.find(res, "module '"..module_name.."' not found", nil, true) then + elseif type(res) == "string" and string_find(res, "module '"..module_name.."' not found", nil, true) then return false else error(res) diff --git a/kong/vendor/classic.lua b/kong/vendor/classic.lua index c62602dea73d..d9402f9a8588 100644 --- a/kong/vendor/classic.lua +++ b/kong/vendor/classic.lua @@ -1,4 +1,4 @@ ---- +-- -- classic, object model. -- -- Copyright (c) 2014, rxi diff --git a/kong/vendor/resty_http.lua b/kong/vendor/resty_http.lua index ee002ade830e..4164db2c6f72 100644 --- a/kong/vendor/resty_http.lua +++ b/kong/vendor/resty_http.lua @@ -17,17 +17,12 @@ module(...) _VERSION = "0.1.0" --------------------------------------- --- LOCAL CONSTANTS -- --------------------------------------- +-- LOCAL CONSTANTS local HTTP_1_1 = " HTTP/1.1\r\n" local CHUNK_SIZE = 1048576 local USER_AGENT = "Resty/HTTP " .. _VERSION .. " (Lua)" --------------------------------------- --- LOCAL HELPERS -- --------------------------------------- - +-- LOCAL HELPERS local function _req_header(conf, opts) opts = opts or {} @@ -201,7 +196,7 @@ local function _receive(self, sock) return nil, err end body = str - end +end if lower(headers["connection"]) == "close" then self:close() @@ -213,10 +208,7 @@ local function _receive(self, sock) end --------------------------------------- --- PUBLIC API -- --------------------------------------- - +-- PUBLIC API function new(self) local sock, err = tcp() if not sock then diff --git a/scripts/migration.py b/scripts/migration.py deleted file mode 100755 index d55f35628b0a..000000000000 --- a/scripts/migration.py +++ /dev/null @@ -1,283 +0,0 @@ -#!/usr/bin/env python - -'''Kong 0.5.0 Migration Script - -Usage: python migration.py --config=/path/to/kong/config [--purge] - -Run this script first to migrate Kong to the 0.5.0 schema. Once successful, reload Kong -and run this script again with the --purge option. - -Arguments: - -c, --config path to your Kong configuration file -Flags: - --purge if already migrated, purge the old values - -h print help -''' - -import getopt, sys, os.path, logging, json, hashlib - -log = logging.getLogger() -log.setLevel("INFO") -handler = logging.StreamHandler() -handler.setFormatter(logging.Formatter("[%(levelname)s]: %(message)s")) -log.addHandler(handler) - -try: - import yaml - from cassandra.cluster import Cluster - from cassandra import ConsistencyLevel, InvalidRequest - from cassandra.query import SimpleStatement - from cassandra import InvalidRequest -except ImportError as err: - log.error(err) - log.info("""This script requires cassandra-driver and PyYAML: - $ pip install cassandra-driver pyyaml""") - sys.exit(1) - -session = None - -class ArgumentException(Exception): - pass - -def usage(): - """ - Print usage informations about this script. - """ - print sys.exit(__doc__) - -def shutdown_exit(exit_code): - """ - Shutdown the Cassandra session and exit the script. - """ - session.shutdown() - sys.exit(exit_code) - -def load_cassandra_config(kong_config): - """ - Return a host and port from the first contact point in the Kong configuration. - - :param kong_config: parsed Kong configuration - :return: host and port tuple - """ - cass_properties = kong_config["databases_available"]["cassandra"]["properties"] - - host, port = cass_properties["contact_points"][0].split(":") - keyspace = cass_properties["keyspace"] - - return (host, port, keyspace) - -def migrate_schema_migrations_table(session): - """ - Migrate the schema_migrations table whose values changed between < 0.5.0 and 0.5.0 - - :param session: opened cassandra session - """ - log.info("Migrating schema_migrations table...") - query = SimpleStatement("INSERT INTO schema_migrations(id, migrations) VALUES(%s, %s)", consistency_level=ConsistencyLevel.ALL) - session.execute(query, ["core", ['2015-01-12-175310_skeleton', '2015-01-12-175310_init_schema']]) - session.execute(query, ["basic-auth", ['2015-08-03-132400_init_basicauth']]) - session.execute(query, ["key-auth", ['2015-07-31-172400_init_keyauth']]) - session.execute(query, ["rate-limiting", ['2015-08-03-132400_init_ratelimiting']]) - session.execute(query, ["oauth2", ['2015-08-03-132400_init_oauth2', '2015-08-24-215800_cascade_delete_index']]) - log.info("schema_migrations table migrated") - -def migrate_plugins_configurations(session): - """ - Migrate all rows in the `plugins_configurations` table to `plugins`, applying: - - renaming of plugins if name changed - - conversion of old rate-limiting schema if old schema detected - - :param session: opened cassandra session - """ - log.info("Migrating plugins...") - - new_names = { - "keyauth": "key-auth", - "basicauth": "basic-auth", - "ratelimiting": "rate-limiting", - "tcplog": "tcp-log", - "udplog": "udp-log", - "filelog": "file-log", - "httplog": "http-log", - "request_transformer": "request-transformer", - "response_transfomer": "response-transfomer", - "requestsizelimiting": "request-size-limiting", - "ip_restriction": "ip-restriction" - } - - session.execute(""" - create table if not exists plugins( - id uuid, - api_id uuid, - consumer_id uuid, - name text, - config text, - enabled boolean, - created_at timestamp, - primary key (id, name))""") - session.execute("create index if not exists on plugins(name)") - session.execute("create index if not exists on plugins(api_id)") - session.execute("create index if not exists on plugins(consumer_id)") - - select_query = SimpleStatement("SELECT * FROM plugins_configurations", consistency_level=ConsistencyLevel.ALL) - for plugin in session.execute(select_query): - # New plugins names - plugin_name = plugin.name - if plugin.name in new_names: - plugin_name = new_names[plugin.name] - - # rate-limiting config - plugin_conf = plugin.value - if plugin_name == "rate-limiting": - conf = json.loads(plugin.value) - if "limit" in conf: - plugin_conf = {} - plugin_conf[conf["period"]] = conf["limit"] - plugin_conf = json.dumps(plugin_conf) - - insert_query = SimpleStatement(""" - INSERT INTO plugins(id, api_id, consumer_id, name, config, enabled, created_at) - VALUES(%s, %s, %s, %s, %s, %s, %s)""", consistency_level=ConsistencyLevel.ALL) - session.execute(insert_query, [plugin.id, plugin.api_id, plugin.consumer_id, plugin_name, plugin_conf, plugin.enabled, plugin.created_at]) - - log.info("Plugins migrated") - -def migrate_rename_apis_properties(sessions): - """ - Create new columns for the `apis` column family and insert the equivalent values in it - - :param session: opened cassandra session - """ - log.info("Renaming some properties for APIs...") - - session.execute("ALTER TABLE apis ADD request_host text") - session.execute("ALTER TABLE apis ADD request_path text") - session.execute("ALTER TABLE apis ADD strip_request_path boolean") - session.execute("ALTER TABLE apis ADD upstream_url text") - session.execute("CREATE INDEX IF NOT EXISTS ON apis(request_host)") - session.execute("CREATE INDEX IF NOT EXISTS ON apis(request_path)") - - select_query = SimpleStatement("SELECT * FROM apis", consistency_level=ConsistencyLevel.ALL) - for api in session.execute(select_query): - session.execute("UPDATE apis SET request_host = %s, request_path = %s, strip_request_path = %s, upstream_url = %s WHERE id = %s", [api.public_dns, api.path, api.strip_path, api.target_url, api.id]) - - log.info("APIs properties renamed") - -def migrate_hash_passwords(session): - """ - Hash all passwords in basicauth_credentials using sha1 and the consumer_id as the salt. - Also stores the plain passwords in a temporary column in case this script is run multiple times by the user. - Temporare column will be dropped on --purge. - - :param session: opened cassandra session - """ - log.info("Hashing basic-auth passwords...") - - first_run = True - - try: - session.execute("ALTER TABLE basicauth_credentials ADD plain_password text") - except InvalidRequest as err: - first_run = False - - select_query = SimpleStatement("SELECT * FROM basicauth_credentials", consistency_level=ConsistencyLevel.ALL) - for credential in session.execute(select_query): - plain_password = credential.password if first_run else credential.plain_password - m = hashlib.sha1() - m.update(plain_password) - m.update(str(credential.consumer_id)) - digest = m.hexdigest() - session.execute("UPDATE basicauth_credentials SET password = %s, plain_password = %s WHERE id = %s", [digest, plain_password, credential.id]) - -def purge(session): - session.execute("ALTER TABLE apis DROP public_dns") - session.execute("ALTER TABLE apis DROP target_url") - session.execute("ALTER TABLE apis DROP path") - session.execute("ALTER TABLE apis DROP strip_path") - session.execute("ALTER TABLE basicauth_credentials DROP plain_password") - session.execute("DROP TABLE plugins_configurations") - session.execute(SimpleStatement("DELETE FROM schema_migrations WHERE id = 'migrations'", consistency_level=ConsistencyLevel.ALL)) - -def migrate(session): - migrate_schema_migrations_table(session) - migrate_plugins_configurations(session) - migrate_rename_apis_properties(session) - migrate_hash_passwords(session) - -def parse_arguments(argv): - """ - Parse the scripts arguments. - - :param argv: scripts arguments - :return: parsed kong configuration - """ - config_path = "" - purge = False - - opts, args = getopt.getopt(argv, "hc:", ["config=", "purge"]) - for opt, arg in opts: - if opt == "-h": - usage() - elif opt in ("-c", "--config"): - config_path = arg - elif opt in ("--purge"): - purge = True - - if config_path == "": - raise ArgumentException("No Kong configuration given") - elif not os.path.isfile(config_path): - raise ArgumentException("No configuration file at path %s" % str(arg)) - - log.info("Using Kong configuration file at: %s" % os.path.abspath(config_path)) - - with open(config_path, "r") as stream: - config = yaml.load(stream) - - return (config, purge) - -def main(argv): - try: - kong_config, purge_cmd = parse_arguments(argv) - host, port, keyspace = load_cassandra_config(kong_config) - cluster = Cluster([host], protocol_version=2, port=port) - global session - session = cluster.connect(keyspace) - - # Find out where the schema is at - rows = session.execute("SELECT * FROM schema_migrations") - is_migrated = len(rows) > 1 and any(mig.id == "core" for mig in rows) - is_0_4_2 = len(rows) == 1 and rows[0].migrations[-1] == "2015-08-10-813213_0.4.2" - is_purged = len(session.execute("SELECT * FROM system.schema_columnfamilies WHERE keyspace_name = %s AND columnfamily_name = 'plugins_configurations'", [keyspace])) == 0 - - if not is_0_4_2 and not is_migrated: - log.error("Please migrate your cluster to Kong 0.4.2 before running this script.") - shutdown_exit(1) - - if purge_cmd: - if not is_purged and is_migrated: - purge(session) - log.info("Cassandra purged from <0.5.0 data.") - elif not is_purged and not is_migrated: - log.info("Cassandra not previously migrated. Run this script in migration mode before.") - shutdown_exit(1) - else: - log.info("Cassandra already purged and migrated.") - elif not is_migrated: - migrate(session) - log.info("Cassandra migrated to Kong 0.5.0. Restart Kong and run this script with '--purge'.") - else: - log.info("Cassandra already migrated to Kong 0.5.0. Restart Kong and run this script with '--purge'.") - - shutdown_exit(0) - except getopt.GetoptError as err: - log.error(err) - usage() - except ArgumentException as err: - log.error("Bad argument: %s " % err) - usage() - except yaml.YAMLError as err: - log.error("Cannot parse given configuration file: %s" % err) - sys.exit(1) - -if __name__ == "__main__": - main(sys.argv[1:]) diff --git a/spec/integration/admin_api/apis_routes_spec.lua b/spec/integration/admin_api/apis_routes_spec.lua index eb3f531bc797..6806b9b94196 100644 --- a/spec/integration/admin_api/apis_routes_spec.lua +++ b/spec/integration/admin_api/apis_routes_spec.lua @@ -18,77 +18,69 @@ describe("Admin API", function() local BASE_URL = spec_helper.API_URL.."/apis/" describe("POST", function() - it("[SUCCESS] should create an API", function() send_content_types(BASE_URL, "POST", { - name="api-POST-tests", - request_host="api.mockbin.com", - upstream_url="http://mockbin.com" - }, 201, nil, {drop_db=true}) + name = "api-POST-tests", + request_host = "api.mockbin.com", + upstream_url = "http://mockbin.com" + }, 201, nil, {drop_db = true}) end) - it("[FAILURE] should notify of malformed body", function() local response, status = http_client.post(BASE_URL, '{"hello":"world"', {["content-type"] = "application/json"}) assert.are.equal(400, status) assert.are.equal('{"message":"Cannot parse JSON body"}\n', response) end) - it("[FAILURE] should return proper errors", function() send_content_types(BASE_URL, "POST", {}, 400, '{"upstream_url":"upstream_url is required","request_path":"At least a \'request_host\' or a \'request_path\' must be specified","request_host":"At least a \'request_host\' or a \'request_path\' must be specified"}') - send_content_types(BASE_URL, "POST", {request_host="api.mockbin.com"}, + send_content_types(BASE_URL, "POST", {request_host = "api.mockbin.com"}, 400, '{"upstream_url":"upstream_url is required"}') send_content_types(BASE_URL, "POST", { - request_host="api.mockbin.com", - upstream_url="http://mockbin.com" + request_host = "api.mockbin.com", + upstream_url = "http://mockbin.com" }, 409, '{"request_host":"request_host already exists with value \'api.mockbin.com\'"}') end) - end) describe("PUT", function() - setup(function() spec_helper.drop_db() end) it("[SUCCESS] should create and update", function() local api = send_content_types(BASE_URL, "PUT", { - name="api-PUT-tests", - request_host="api.mockbin.com", - upstream_url="http://mockbin.com" - }, 201, nil, {drop_db=true}) + name = "api-PUT-tests", + request_host = "api.mockbin.com", + upstream_url = "http://mockbin.com" + }, 201, nil, {drop_db = true}) api = send_content_types(BASE_URL, "PUT", { - id=api.id, - name="api-PUT-tests-updated", - request_host="updated-api.mockbin.com", - upstream_url="http://mockbin.com" + id = api.id, + name = "api-PUT-tests-updated", + request_host = "updated-api.mockbin.com", + upstream_url = "http://mockbin.com" }, 200) assert.equal("api-PUT-tests-updated", api.name) end) - it("[FAILURE] should return proper errors", function() send_content_types(BASE_URL, "PUT", {}, 400, '{"upstream_url":"upstream_url is required","request_path":"At least a \'request_host\' or a \'request_path\' must be specified","request_host":"At least a \'request_host\' or a \'request_path\' must be specified"}') - send_content_types(BASE_URL, "PUT", {request_host="api.mockbin.com"}, + send_content_types(BASE_URL, "PUT", {request_host = "api.mockbin.com"}, 400, '{"upstream_url":"upstream_url is required"}') send_content_types(BASE_URL, "PUT", { - request_host="updated-api.mockbin.com", - upstream_url="http://mockbin.com" + request_host = "updated-api.mockbin.com", + upstream_url = "http://mockbin.com" }, 409, '{"request_host":"request_host already exists with value \'updated-api.mockbin.com\'"}') end) - end) describe("GET", function() - setup(function() spec_helper.drop_db() spec_helper.seed_db(10) @@ -100,34 +92,35 @@ describe("Admin API", function() local body = json.decode(response) assert.truthy(body.data) assert.equal(10, table.getn(body.data)) + assert.equal(10, body.total) end) - it("should retrieve a paginated set", function() - local response, status = http_client.get(BASE_URL, {size=3}) + local response, status = http_client.get(BASE_URL, {size = 3}) assert.equal(200, status) local body_page_1 = json.decode(response) assert.truthy(body_page_1.data) assert.equal(3, table.getn(body_page_1.data)) assert.truthy(body_page_1.next) + assert.equal(10, body_page_1.total) - response, status = http_client.get(BASE_URL, {size=3,offset=body_page_1.next}) + response, status = http_client.get(BASE_URL, {size = 3, offset = body_page_1.next}) assert.equal(200, status) local body_page_2 = json.decode(response) assert.truthy(body_page_2.data) assert.equal(3, table.getn(body_page_2.data)) assert.truthy(body_page_2.next) assert.not_same(body_page_1, body_page_2) + assert.equal(10, body_page_2.total) - response, status = http_client.get(BASE_URL, {size=4,offset=body_page_2.next}) + response, status = http_client.get(BASE_URL, {size = 4, offset = body_page_2.next}) assert.equal(200, status) local body_page_3 = json.decode(response) assert.truthy(body_page_3.data) assert.equal(4, table.getn(body_page_3.data)) - -- TODO: fixme - --assert.falsy(body_page_3.next) + assert.equal(10, body_page_3.total) + assert.falsy(body_page_3.next) assert.not_same(body_page_2, body_page_3) end) - end) end) @@ -138,71 +131,64 @@ describe("Admin API", function() setup(function() spec_helper.drop_db() local fixtures = spec_helper.insert_fixtures { - api = {{ request_host="mockbin.com", upstream_url="http://mockbin.com" }} + api = { + {request_host = "mockbin.com", upstream_url = "http://mockbin.com"} + } } api = fixtures.api[1] end) describe("GET", function() - it("should retrieve by id", function() local response, status = http_client.get(BASE_URL..api.id) assert.equal(200, status) local body = json.decode(response) assert.same(api, body) end) - it("should retrieve by name", function() local response, status = http_client.get(BASE_URL..api.name) assert.equal(200, status) local body = json.decode(response) assert.same(api, body) end) - end) describe("PATCH", function() - it("[SUCCESS] should update an API", function() - local response, status = http_client.patch(BASE_URL..api.id, {name="patch-updated"}) + local response, status = http_client.patch(BASE_URL..api.id, {name = "patch-updated"}) assert.equal(200, status) local body = json.decode(response) assert.same("patch-updated", body.name) api = body - response, status = http_client.patch(BASE_URL..api.name, {name="patch-updated-json"}, {["content-type"]="application/json"}) + response, status = http_client.patch(BASE_URL..api.name, {name = "patch-updated-json"}, {["content-type"] = "application/json"}) assert.equal(200, status) body = json.decode(response) assert.same("patch-updated-json", body.name) api = body end) - it("[FAILURE] should return proper errors", function() - local _, status = http_client.patch(BASE_URL.."hello", {name="patch-updated"}) + local _, status = http_client.patch(BASE_URL.."hello", {name = "patch-updated"}) assert.equal(404, status) - local response, status = http_client.patch(BASE_URL..api.id, {upstream_url=""}) + local response, status = http_client.patch(BASE_URL..api.id, {upstream_url = ""}) assert.equal(400, status) assert.equal('{"upstream_url":"upstream_url is not a url"}\n', response) end) - end) describe("DELETE", function() - it("[FAILURE] should return proper errors", function() local _, status = http_client.delete(BASE_URL.."hello") assert.equal(404, status) end) - it("[SUCCESS] should delete an API", function() local response, status = http_client.delete(BASE_URL..api.id) assert.equal(204, status) assert.falsy(response) end) - end) describe("/apis/:api/plugins/", function() @@ -211,19 +197,19 @@ describe("Admin API", function() setup(function() spec_helper.drop_db() local fixtures = spec_helper.insert_fixtures { - api = {{ request_host="mockbin.com", upstream_url="http://mockbin.com" }} + api = { + {request_host = "mockbin.com", upstream_url = "http://mockbin.com"} + } } api = fixtures.api[1] BASE_URL = BASE_URL..api.id.."/plugins/" end) describe("POST", function() - it("[FAILURE] should return proper errors", function() send_content_types(BASE_URL, "POST", {}, 400, '{"name":"name is required"}') end) - it("[SUCCESS] should create a plugin configuration", function() local response, status = http_client.post(BASE_URL, { name = "key-auth", @@ -237,15 +223,14 @@ describe("Admin API", function() response, status = http_client.post(BASE_URL, { name = "key-auth", - config = {key_names={"apikey"}} - }, {["content-type"]="application/json"}) + config = {key_names = {"apikey"}} + }, {["content-type"] = "application/json"}) assert.equal(201, status) body = json.decode(response) _, err = dao_plugins:delete({id = body.id, name = body.name}) assert.falsy(err) end) - end) describe("PUT", function() @@ -255,7 +240,6 @@ describe("Admin API", function() send_content_types(BASE_URL, "PUT", {}, 400, '{"name":"name is required"}') end) - it("[SUCCESS] should create and update", function() local response, status = http_client.put(BASE_URL, { name = "key-auth", @@ -285,7 +269,6 @@ describe("Admin API", function() body = json.decode(response) assert.equal("updated_apikey", body.config.key_names[1]) end) - it("should override a plugin's `config` if partial", function() local response, status = http_client.put(BASE_URL, { id = plugin_id, @@ -310,7 +293,6 @@ describe("Admin API", function() end) describe("GET", function() - it("should retrieve all", function() local response, status = http_client.get(BASE_URL) assert.equal(200, status) @@ -318,7 +300,6 @@ describe("Admin API", function() assert.truthy(body.data) assert.equal(1, table.getn(body.data)) end) - end) describe("/apis/:api/plugins/:plugin", function() @@ -328,8 +309,12 @@ describe("Admin API", function() setup(function() spec_helper.drop_db() local fixtures = spec_helper.insert_fixtures { - api = {{ request_host="mockbin.com", upstream_url="http://mockbin.com" }}, - plugin = {{ name = "key-auth", config = { key_names = { "apikey" }}, __api = 1 }} + api = { + {request_host="mockbin.com", upstream_url="http://mockbin.com"} + }, + plugin = { + {name = "key-auth", config = {key_names = {"apikey"}}, __api = 1} + } } api = fixtures.api[1] plugin = fixtures.plugin[1] @@ -337,35 +322,30 @@ describe("Admin API", function() end) describe("GET", function() - it("should retrieve by id", function() local response, status = http_client.get(BASE_URL..plugin.id) assert.equal(200, status) local body = json.decode(response) assert.same(plugin, body) end) - end) describe("PATCH", function() - it("[SUCCESS] should update a plugin", function() - local response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"]={"key_updated"}}) + local response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"] = {"key_updated"}}) assert.equal(200, status) local body = json.decode(response) assert.same("key_updated", body.config.key_names[1]) - response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"]={"key_updated-json"}}, {["content-type"]="application/json"}) + response, status = http_client.patch(BASE_URL..plugin.id, {["config.key_names"] = {"key_updated-json"}}, {["content-type"] = "application/json"}) assert.equal(200, status) body = json.decode(response) assert.same("key_updated-json", body.config.key_names[1]) end) - it("[FAILURE] should return proper errors", function() local _, status = http_client.patch(BASE_URL.."b6cca0aa-4537-11e5-af97-23a06d98af51", {}) assert.equal(404, status) end) - it("should not override a plugin's `config` if partial", function() -- This is delicate since a plugin's `config` is a text field in a DB like Cassandra local _, status = http_client.patch(BASE_URL..plugin.id, { @@ -382,22 +362,34 @@ describe("Admin API", function() assert.same({"key_set_null_test_updated"}, body.config.key_names) assert.equal(true, body.config.hide_credentials) end) - + it("should be possible to disable it", function() + local response, status = http_client.patch(BASE_URL..plugin.id, { + enabled = false + }) + assert.equal(200, status) + local body = json.decode(response) + assert.False(body.enabled) + end) + it("should be possible to enabled it", function() + local response, status = http_client.patch(BASE_URL..plugin.id, { + enabled = true + }) + assert.equal(200, status) + local body = json.decode(response) + assert.True(body.enabled) + end) end) describe("DELETE", function() - it("[FAILURE] should return proper errors", function() local _, status = http_client.delete(BASE_URL.."b6cca0aa-4537-11e5-af97-23a06d98af51") assert.equal(404, status) end) - it("[SUCCESS] should delete a plugin configuration", function() local response, status = http_client.delete(BASE_URL..plugin.id) assert.equal(204, status) assert.falsy(response) end) - end) end) end) diff --git a/spec/integration/admin_api/consumers_routes_spec.lua b/spec/integration/admin_api/consumers_routes_spec.lua index e6c8efcb8d6f..63716912d929 100644 --- a/spec/integration/admin_api/consumers_routes_spec.lua +++ b/spec/integration/admin_api/consumers_routes_spec.lua @@ -39,8 +39,10 @@ describe("Admin API", function() describe("PUT", function() + local consumer + it("[SUCCESS] should create and update", function() - local consumer = send_content_types(BASE_URL, "PUT", { + consumer = send_content_types(BASE_URL, "PUT", { username = "consumer PUT tests" }, 201, nil, {drop_db=true}) @@ -61,6 +63,44 @@ describe("Admin API", function() }, 409, '{"username":"username already exists with value \'consumer PUT tests updated\'"}') end) + it("[SUCCESS] should update a Consumer", function() + local response, status = http_client.get(BASE_URL..consumer.id) + assert.equal(200, status) + + local body = json.decode(response) + assert.falsy(body.custom_id) + + body.custom_id = "custom123" + local response, status = http_client.put(BASE_URL, body) + assert.equal(200, status) + assert.truthy(response) + + local response, status = http_client.get(BASE_URL..consumer.id) + assert.equal(200, status) + + local body = json.decode(response) + assert.equal("custom123", body.custom_id) + end) + + it("[SUCCESS] should update a Consumer and remove a field", function() + local response, status = http_client.get(BASE_URL..consumer.id) + assert.equal(200, status) + + local body = json.decode(response) + assert.equal("custom123", body.custom_id) + + body.custom_id = nil + local response, status = http_client.put(BASE_URL, body) + assert.equal(200, status) + assert.truthy(response) + + local response, status = http_client.get(BASE_URL..consumer.id) + assert.equal(200, status) + + local body = json.decode(response) + assert.falsy(body.custom_id) + end) + end) describe("GET", function() @@ -76,6 +116,7 @@ describe("Admin API", function() local body = json.decode(response) assert.truthy(body.data) assert.equal(10, table.getn(body.data)) + assert.equal(10, body.total) end) it("should retrieve a paginated set", function() @@ -85,6 +126,7 @@ describe("Admin API", function() assert.truthy(body_page_1.data) assert.equal(3, table.getn(body_page_1.data)) assert.truthy(body_page_1.next) + assert.equal(10, body_page_1.total) response, status = http_client.get(BASE_URL, {size=3,offset=body_page_1.next}) assert.equal(200, status) @@ -93,14 +135,15 @@ describe("Admin API", function() assert.equal(3, table.getn(body_page_2.data)) assert.truthy(body_page_2.next) assert.not_same(body_page_1, body_page_2) + assert.equal(10, body_page_2.total) response, status = http_client.get(BASE_URL, {size=4,offset=body_page_2.next}) assert.equal(200, status) local body_page_3 = json.decode(response) assert.truthy(body_page_3.data) assert.equal(4, table.getn(body_page_3.data)) - -- TODO: fixme - --assert.falsy(body_page_3.next) + assert.equal(10, body_page_3.total) + assert.falsy(body_page_3.next) assert.not_same(body_page_2, body_page_3) end) diff --git a/spec/integration/admin_api/kong_routes_spec.lua b/spec/integration/admin_api/kong_routes_spec.lua index e8a5bd098e2b..7d4a1af0f152 100644 --- a/spec/integration/admin_api/kong_routes_spec.lua +++ b/spec/integration/admin_api/kong_routes_spec.lua @@ -2,6 +2,8 @@ local json = require "cjson" local http_client = require "kong.tools.http_client" local spec_helper = require "spec.spec_helpers" local utils = require "kong.tools.utils" +local env = spec_helper.get_env() -- test environment +local dao_factory = env.dao_factory describe("Admin API", function() @@ -13,7 +15,7 @@ describe("Admin API", function() teardown(function() spec_helper.stop_kong() end) - + describe("Kong routes", function() describe("/", function() local constants = require "kong.constants" @@ -60,15 +62,58 @@ describe("Admin API", function() assert.are.equal(200, status) local body = json.decode(response) assert.truthy(body) + assert.are.equal(2, utils.table_size(body)) + + -- Database stats + -- Removing migrations DAO + dao_factory.daos.migrations = nil + assert.are.equal(utils.table_size(dao_factory.daos), utils.table_size(body.database)) + for k, _ in pairs(dao_factory.daos) do + assert.truthy(body.database[k]) + end - assert.are.equal(7, utils.table_size(body)) - assert.truthy(body.connections_accepted) - assert.truthy(body.connections_active) - assert.truthy(body.connections_handled) - assert.truthy(body.connections_reading) - assert.truthy(body.connections_writing) - assert.truthy(body.connections_waiting) - assert.truthy(body.total_requests) + -- Server stats + assert.are.equal(7, utils.table_size(body.server)) + assert.truthy(body.server.connections_accepted) + assert.truthy(body.server.connections_active) + assert.truthy(body.server.connections_handled) + assert.truthy(body.server.connections_reading) + assert.truthy(body.server.connections_writing) + assert.truthy(body.server.connections_waiting) + assert.truthy(body.server.total_requests) end) end) + + describe("Request size", function() + it("should properly hanlde big POST bodies < 10MB", function() + local response, status = http_client.post(spec_helper.API_URL.."/apis", { request_path = "hello.com", upstream_url = "http://mockbin.org" }) + assert.equal(201, status) + local api_id = json.decode(response).id + assert.truthy(api_id) + + + local big_value = string.rep("204.48.16.0,", 1000) + big_value = string.sub(big_value, 1, string.len(big_value) - 1) + assert.truthy(string.len(big_value) > 10000) -- More than 10kb + + local _, status = http_client.post(spec_helper.API_URL.."/apis/"..api_id.."/plugins/", { name = "ip-restriction", ["config.blacklist"] = big_value}) + assert.equal(201, status) + end) + + it("should fail with requests > 10MB", function() + local response, status = http_client.post(spec_helper.API_URL.."/apis", { request_path = "hello2.com", upstream_url = "http://mockbin.org" }) + assert.equal(201, status) + local api_id = json.decode(response).id + assert.truthy(api_id) + + -- It should fail with more than 10MB + local big_value = string.rep("204.48.16.0,", 1024000) + big_value = string.sub(big_value, 1, string.len(big_value) - 1) + assert.truthy(string.len(big_value) > 10000000) -- More than 10kb + + local _, status = http_client.post(spec_helper.API_URL.."/apis/"..api_id.."/plugins/", { name = "ip-restriction", ["config.blacklist"] = big_value}) + assert.equal(413, status) + end) + end) + end) diff --git a/spec/integration/admin_api/plugins_routes_spec.lua b/spec/integration/admin_api/plugins_routes_spec.lua index 29e162c94b56..beb387dc004d 100644 --- a/spec/integration/admin_api/plugins_routes_spec.lua +++ b/spec/integration/admin_api/plugins_routes_spec.lua @@ -5,6 +5,11 @@ local spec_helper = require "spec.spec_helpers" describe("Admin API", function() setup(function() spec_helper.prepare_db() + spec_helper.insert_fixtures { + api = { + {request_host = "test.com", upstream_url = "http://mockbin.com"} + } + } spec_helper.start_kong() end) diff --git a/spec/integration/cli/restart_spec.lua b/spec/integration/cli/restart_spec.lua index f89ff1944d27..4003bbbd07a9 100644 --- a/spec/integration/cli/restart_spec.lua +++ b/spec/integration/cli/restart_spec.lua @@ -27,6 +27,13 @@ describe("CLI", function() it("should restart kong when it's crashed", function() local kong_pid = IO.read_file(spec_helper.get_env().configuration.pid_file) + if not kong_pid then + -- we might be to quick, so wait and retry + os.execute("sleep 1") + kong_pid = IO.read_file(spec_helper.get_env().configuration.pid_file) + if not kong_pid then error("Could not read Kong pid") end + end + os.execute("pkill -9 nginx") repeat diff --git a/spec/integration/cli/start_spec.lua b/spec/integration/cli/start_spec.lua index 1a63e3d35ef7..cc0459868a62 100644 --- a/spec/integration/cli/start_spec.lua +++ b/spec/integration/cli/start_spec.lua @@ -14,9 +14,7 @@ end describe("CLI", function() - describe("Startup plugins check", function() - - setup(function() + setup(function() os.execute("cp "..TEST_CONF.." "..SERVER_CONF) spec_helper.add_env(SERVER_CONF) spec_helper.prepare_db(SERVER_CONF) @@ -31,6 +29,8 @@ describe("CLI", function() pcall(spec_helper.stop_kong, SERVER_CONF) end) + describe("Startup plugins check", function() + it("should start with the default configuration", function() assert.has_no.errors(function() spec_helper.start_kong(TEST_CONF, true) diff --git a/spec/integration/dao/cassandra/base_dao_spec.lua b/spec/integration/dao/cassandra/base_dao_spec.lua index ef172b3bce03..f9720873d741 100644 --- a/spec/integration/dao/cassandra/base_dao_spec.lua +++ b/spec/integration/dao/cassandra/base_dao_spec.lua @@ -1,19 +1,15 @@ -local spec_helper = require "spec.spec_helpers" local cassandra = require "cassandra" +local spec_helper = require "spec.spec_helpers" local constants = require "kong.constants" local DaoError = require "kong.dao.error" local utils = require "kong.tools.utils" -local cjson = require "cjson" -local uuid = require "uuid" +local uuid = require "lua_uuid" --- Raw session for double-check purposes -local session -- Load everything we need from the spec_helper local env = spec_helper.get_env() -- test environment local faker = env.faker local dao_factory = env.dao_factory local configuration = env.configuration -configuration.cassandra = configuration.databases_available[configuration.database].properties -- An utility function to apply tests on core collections. local function describe_core_collections(tests_cb) @@ -38,34 +34,36 @@ say:set("assertion.daoError.positive", "Expected %s\nto be a DaoError") say:set("assertion.daoError.negative", "Expected %s\nto not be a DaoError") assert:register("assertion", "daoError", daoError, "assertion.daoError.positive", "assertion.daoError.negative") --- Let's go describe("Cassandra", function() - + -- Create a parallel session to verify the dao's behaviour + local session setup(function() spec_helper.prepare_db() - -- Create a parallel session to verify the dao's behaviour - session = cassandra:new() - session:set_timeout(configuration.cassandra.timeout) - - local _, err = session:connect(configuration.cassandra.contact_points) - assert.falsy(err) - - local _, err = session:set_keyspace("kong_tests") + local err + session, err = cassandra.spawn_session { + shm = "factory_specs", + keyspace = configuration.dao_config.keyspace, + contact_points = configuration.dao_config.contact_points + } assert.falsy(err) end) teardown(function() - if session then - local _, err = session:close() - assert.falsy(err) + if session ~= nil then + session:shutdown() end end) describe("Base DAO", function() - describe(":insert()", function() - - it("should error if called with invalid parameters", function() + describe("insert()", function() + setup(function() + spec_helper.drop_db() + end) + teardown(function() + spec_helper.drop_db() + end) + it("should throw an error if called with invalid parameters", function() assert.has_error(function() dao_factory.apis:insert() end, "Cannot insert a nil element") @@ -74,7 +72,6 @@ describe("Cassandra", function() dao_factory.apis:insert("") end, "Entity to insert must be a table") end) - it("should insert in DB and let the schema validation add generated values", function() -- API local api_t = faker:fake_entity("api") @@ -82,10 +79,10 @@ describe("Cassandra", function() assert.falsy(err) assert.truthy(api.id) assert.truthy(api.created_at) - local apis, err = session:execute("SELECT * FROM apis") + local rows, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api.id)}) assert.falsy(err) - assert.True(#apis > 0) - assert.equal(api.id, apis[1].id) + assert.True(#rows == 1) + assert.equal(api.id, rows[1].id) -- API api, err = dao_factory.apis:insert { @@ -102,10 +99,10 @@ describe("Cassandra", function() assert.falsy(err) assert.truthy(consumer.id) assert.truthy(consumer.created_at) - local consumers, err = session:execute("SELECT * FROM consumers") + rows, err = session:execute("SELECT * FROM consumers WHERE id = ?", {cassandra.uuid(consumer.id)}) assert.falsy(err) - assert.True(#consumers > 0) - assert.equal(consumer.id, consumers[1].id) + assert.True(#rows == 1) + assert.equal(consumer.id, rows[1].id) -- Plugin configuration local plugin_t = {name = "key-auth", api_id = api.id} @@ -117,7 +114,6 @@ describe("Cassandra", function() assert.True(#plugins > 0) assert.equal(plugin.id, plugins[1].id) end) - it("should let the schema validation return errors and not insert", function() -- Without an api_id, it's a schema error local plugin_t = faker:fake_entity("plugin") @@ -126,9 +122,8 @@ describe("Cassandra", function() assert.truthy(err) assert.is_daoError(err) assert.True(err.schema) - assert.are.same("api_id is required", err.message.api_id) + assert.equal("api_id is required", err.message.api_id) end) - it("should ensure fields with `unique` are unique", function() local api_t = faker:fake_entity("api") @@ -141,10 +136,9 @@ describe("Cassandra", function() assert.truthy(err) assert.is_daoError(err) assert.True(err.unique) - assert.are.same("name already exists with value '"..api_t.name.."'", err.message.name) + assert.equal("name already exists with value '"..api_t.name.."'", err.message.name) assert.falsy(api) end) - it("should ensure fields with `foreign` are existing", function() -- Plugin configuration local plugin_t = faker:fake_entity("plugin") @@ -155,9 +149,8 @@ describe("Cassandra", function() assert.truthy(err) assert.is_daoError(err) assert.True(err.foreign) - assert.are.same("api_id "..plugin_t.api_id.." does not exist", err.message.api_id) + assert.equal("api_id "..plugin_t.api_id.." does not exist", err.message.api_id) end) - it("should do insert checks for entities with `self_check`", function() local api, err = dao_factory.apis:insert(faker:fake_entity("api")) assert.falsy(err) @@ -177,13 +170,82 @@ describe("Cassandra", function() assert.truthy(err) assert.is_daoError(err) assert.True(err.unique) - assert.are.same("Plugin configuration already exists", err.message) + assert.equal("Plugin configuration already exists", err.message) end) + end) -- describe insert() - end) -- describe :insert() + describe("count_by_keys()", function() + setup(function() + spec_helper.drop_db() - describe(":update()", function() + local err = select(2, session:execute("INSERT INTO apis(id, name) VALUES(uuid(), 'foo')")) + assert.falsy(err) + for i = 1, 99 do + err = select(2, session:execute("INSERT INTO apis(id, name) VALUES(uuid(), 'bar')")) + assert.falsy(err) + end + end) + teardown(function() + spec_helper.drop_db() + end) + it("should return the count of rows in a table", function() + local count, err = dao_factory.apis:count_by_keys() + assert.falsy(err) + assert.equal(100, count) + end) + it("should return the count of rows in a table with filter columns", function() + local count, err = dao_factory.apis:count_by_keys({name = "bar"}) + assert.falsy(err) + assert.equal(99, count) + + count, err = dao_factory.apis:count_by_keys({name = "test"}) + assert.falsy(err) + assert.equal(0, count) + + count, err = dao_factory.apis:count_by_keys({name = ""}) + assert.falsy(err) + assert.equal(0, count) + end) + it("should return the count of rows in a table from a given paging_state", function() + local rows, err = session:execute("SELECT * FROM apis", nil, {page_size = 50}) + assert.falsy(err) + + local paging_state = rows.meta.paging_state + assert.truthy(paging_state) + + local count, err = dao_factory.apis:count_by_keys(nil, paging_state) + assert.falsy(err) + assert.equal(50, count) + end) + it("should return a filtered value to know if the query was filtered", function() + local _, err, filtered = dao_factory.apis:count_by_keys() + assert.falsy(err) + assert.False(filtered) + + _, err, filtered = dao_factory.apis:count_by_keys({name = "bar"}) + assert.falsy(err) + assert.False(filtered) + + _, err, filtered = dao_factory.apis:count_by_keys({name = "bar", request_host = ""}) + assert.falsy(err) + assert.True(filtered) + end) + it("should return errors when query is refused by Cassandra", function() + local count, err = dao_factory.apis:count_by_keys({upstream_url = ""}) + assert.truthy(err) + assert.falsy(count) + assert.is_daoError(err) + end) + end) + + describe("update()", function() + setup(function() + spec_helper.drop_db() + end) + teardown(function() + spec_helper.drop_db() + end) it("should error if called with invalid parameters", function() assert.has_error(function() dao_factory.apis:update() @@ -193,7 +255,6 @@ describe("Cassandra", function() dao_factory.apis:update("") end, "Entity to update must be a table") end) - it("should return nil and no error if no entity was found to update in DB", function() local api_t = faker:fake_entity("api") api_t.id = uuid() @@ -203,186 +264,233 @@ describe("Cassandra", function() assert.falsy(entity) assert.falsy(err) end) - it("should consider no entity to be found if an empty table is given to it", function() local api, err = dao_factory.apis:update({}) assert.falsy(err) assert.falsy(api) end) + it("should update an entity's non primary fields", function() + local UUID = uuid() - it("should update specified, non-primary fields in DB", function() -- API - local apis, err = session:execute("SELECT * FROM apis") + local api_t = { + id = UUID, + name = "mockbin" + } + local _, err = session:execute("INSERT INTO apis(id, name) VALUES(?, ?)", {cassandra.uuid(api_t.id), api_t.name}) assert.falsy(err) - assert.True(#apis > 0) - local api_t = apis[1] api_t.name = api_t.name.."-updated" local api, err = dao_factory.apis:update(api_t) assert.falsy(err) assert.truthy(api) - apis, err = session:execute("SELECT * FROM apis WHERE name = ?", {api_t.name}) + local rows, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api.id)}) assert.falsy(err) - assert.equal(1, #apis) - assert.equal(api_t.id, apis[1].id) - assert.equal(api_t.name, apis[1].name) - assert.equal(api_t.request_host, apis[1].request_host) - assert.equal(api_t.upstream_url, apis[1].upstream_url) + assert.equal(1, #rows) + assert.equal(api_t.id, rows[1].id) + assert.equal(api_t.name, rows[1].name) + assert.equal(api_t.request_host, rows[1].request_host) + assert.equal(api_t.upstream_url, rows[1].upstream_url) -- Consumer - local consumers, err = session:execute("SELECT * FROM consumers") + local consumer_t = { + id = UUID, + username = "john" + } + _, err = session:execute("INSERT INTO consumers(id, username) VALUES(?, ?)", { + cassandra.uuid(consumer_t.id), + consumer_t.username + }) assert.falsy(err) - assert.True(#consumers > 0) - local consumer_t = consumers[1] - consumer_t.custom_id = consumer_t.custom_id.."updated" + consumer_t.username = consumer_t.username.."-updated" local consumer, err = dao_factory.consumers:update(consumer_t) assert.falsy(err) assert.truthy(consumer) - consumers, err = session:execute("SELECT * FROM consumers WHERE custom_id = ?", {consumer_t.custom_id}) + rows, err = session:execute("SELECT * FROM consumers WHERE id = ?", {cassandra.uuid(consumer_t.id)}) assert.falsy(err) - assert.equal(1, #consumers) - assert.equal(consumer_t.name, consumers[1].name) + assert.equal(1, #rows) + assert.equal(consumer_t.name, rows[1].name) -- Plugin Configuration - local plugins, err = session:execute("SELECT * FROM plugins") + local plugin_t = { + id = UUID, + name = "key-auth", + api_id = UUID, + enabled = true + } + _, err = session:execute("INSERT INTO plugins(id, name, api_id) VALUES(?, ?, ?)", { + cassandra.uuid(plugin_t.id), + plugin_t.name, + cassandra.uuid(plugin_t.api_id) + }) assert.falsy(err) - assert.True(#plugins > 0) - local plugin_t = plugins[1] - plugin_t.config = cjson.decode(plugin_t.config) plugin_t.enabled = false + local plugin, err = dao_factory.plugins:update(plugin_t) assert.falsy(err) assert.truthy(plugin) - plugins, err = session:execute("SELECT * FROM plugins WHERE id = ?", {cassandra.uuid(plugin_t.id)}) + rows, err = session:execute("SELECT * FROM plugins WHERE id = ?", {cassandra.uuid(plugin_t.id)}) assert.falsy(err) - assert.equal(1, #plugins) + assert.equal(1, #rows) + assert.False(rows[1].enabled) end) - it("should ensure fields with `unique` are unique", function() - local apis, err = session:execute("SELECT * FROM apis") + local UUID_1 = uuid() + local UUID_2 = uuid() + + local _, err = session:execute("INSERT INTO apis(id, request_host) VALUES(?, ?)", { + cassandra.uuid(UUID_1), + "host1.com" + }) assert.falsy(err) - assert.True(#apis > 0) - local api_t = apis[1] - -- Should not work because we're reusing a request_host - api_t.request_host = apis[2].request_host + _, err = session:execute("INSERT INTO apis(id, request_host) VALUES(?, ?)", { + cassandra.uuid(UUID_2), + "host2.com" + }) + assert.falsy(err) - local api, err = dao_factory.apis:update(api_t) + local api, err = dao_factory.apis:update { + id = UUID_2, + request_host = "host1.com" + } assert.truthy(err) assert.falsy(api) assert.is_daoError(err) assert.True(err.unique) - assert.equal("request_host already exists with value '"..api_t.request_host.."'", err.message.request_host) + assert.equal("request_host already exists with value 'host1.com'", err.message.request_host) end) - describe("full", function() - - it("should set to NULL if a field is not specified", function() - local api_t = faker:fake_entity("api") - api_t.request_path = "/request_path" - - local api, err = dao_factory.apis:insert(api_t) + describe("full update", function() + it("should set a column to NULL if a field is not specified", function() + local api_t = { + id = uuid(), + upstream_url = "http://mockbin.com", + request_path = "/request_path", + request_host = "host.com" + } + local api, err = session:execute("INSERT INTO apis(id, upstream_url, request_path, request_host) VALUES(?, ?, ?, ?)", { + cassandra.uuid(api_t.id), + api_t.upstream_url, + api_t.request_path, + api_t.request_host + }) assert.falsy(err) - assert.truthy(api_t.request_path) + assert.truthy(api.request_path) -- Update - api.request_path = nil - api, err = dao_factory.apis:update(api, true) + api_t.request_path = nil + api, err = dao_factory.apis:update(api_t, true) assert.falsy(err) assert.truthy(api) - assert.falsy(api.request_path) + + local rows, err = dao_factory.apis:find_by_keys({id = api.id}) + assert.falsy(err) + assert.truthy(rows) + assert.falsy(rows[1].request_path) -- Check update - api, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api.id)}) + local _, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api_t.id)}) assert.falsy(err) assert.falsy(api.request_path) end) - it("should still check the validity of the schema", function() - local api_t = faker:fake_entity("api") - - local api, err = dao_factory.apis:insert(api_t) + local api_t = { + id = uuid(), + upstream_url = "http://mockbin.com", + request_path = "/request_path" + } + local _, err = session:execute("INSERT INTO apis(id, upstream_url, request_path) VALUES(?, ?, ?)", { + cassandra.uuid(api_t.id), + api_t.upstream_url, + api_t.request_path + }) assert.falsy(err) - assert.truthy(api_t) -- Update - api.request_host = nil + api_t.upstream_url = nil - local nil_api, err = dao_factory.apis:update(api, true) + local api, err = dao_factory.apis:update(api_t, true) assert.truthy(err) - assert.falsy(nil_api) + assert.falsy(api) + assert.equal("upstream_url is required", err.message.upstream_url) -- Check update failed - api, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api.id)}) + local rows, err = session:execute("SELECT * FROM apis WHERE id = ?", {cassandra.uuid(api_t.id)}) assert.falsy(err) - assert.truthy(api[1].name) - assert.truthy(api[1].request_host) + assert.truthy(rows[1].upstream_url) end) - end) - end) -- describe :update() + end) -- describe update() - describe(":find_by_keys()", function() - describe_core_collections(function(type, collection) - - it("should error if called with invalid parameters", function() - assert.has_error(function() - dao_factory[collection]:find_by_keys("") - end, "where_t must be a table") - end) + describe("find_by_keys()", function() + setup(function() + spec_helper.drop_db() - it("should handle empty search fields", function() - local results, err = dao_factory[collection]:find_by_keys({}) - assert.falsy(err) - assert.truthy(results) - assert.True(#results > 0) - end) + local err = select(2, session:execute("INSERT INTO apis(id, request_host, upstream_url) VALUES(uuid(), 'foo.com', 'http://foo.com')")) + assert.falsy(err) - it("should handle nil search fields", function() - local results, err = dao_factory[collection]:find_by_keys(nil) + for i = 1, 99 do + err = select(2, session:execute("INSERT INTO apis(id, request_host, upstream_url) VALUES(uuid(), 'foo.com', 'http://bar.com')")) assert.falsy(err) - assert.truthy(results) - assert.True(#results > 0) - end) + end end) - - it("should query an entity from the given fields and return if filtering was needed", function() - -- Filtering needed - local apis, err = session:execute("SELECT * FROM apis") + teardown(function() + spec_helper.drop_db() + end) + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory.apis:find_by_keys("") + end, "where_t must be a table") + end) + it("should handle empty search fields", function() + local apis, err = dao_factory.apis:find_by_keys({}) assert.falsy(err) + assert.truthy(apis) assert.True(#apis > 0) - - local api_t = apis[1] - local apis, err, needs_filtering = dao_factory.apis:find_by_keys(api_t) + end) + it("should handle nil search fields", function() + local apis, err = dao_factory.apis:find_by_keys(nil) assert.falsy(err) - assert.same(api_t, apis[1]) - assert.True(needs_filtering) - - -- No Filtering needed - apis, err, needs_filtering = dao_factory.apis:find_by_keys {request_host = api_t.request_host} + assert.truthy(apis) + assert.True(#apis > 0) + end) + it("should query an entity from the given fields and return if filtering was needed", function() + -- No filtering needed + local apis, err, needs_filtering = dao_factory.apis:find_by_keys { + request_host = 'foo.com', + } assert.falsy(err) - assert.same(api_t, apis[1]) + assert.equal(100, #apis) assert.False(needs_filtering) - end) - end) -- describe :find_by_keys() - - describe(":find()", function() + -- Filtering needed + apis, err, needs_filtering = dao_factory.apis:find_by_keys { + request_host = 'foo.com', + upstream_url = 'http://foo.com' + } + assert.falsy(err) + assert.equal(1, #apis) + assert.True(needs_filtering) + end) + end) -- describe find_by_keys() + describe("find()", function() setup(function() spec_helper.drop_db() spec_helper.seed_db(10) end) - + teardown(function() + spec_helper.drop_db() + end) describe_core_collections(function(type, collection) - it("should find entities", function() local entities, err = session:execute("SELECT * FROM "..collection) assert.falsy(err) @@ -409,64 +517,65 @@ describe("Cassandra", function() assert.truthy(rows_2) assert.same(2, #rows_2) end) - end) - end) -- describe :find() - - describe(":find_by_primary_key()", function() - describe_core_collections(function(type, collection) - - it("should error if called with invalid parameters", function() - assert.has_error(function() - dao_factory[collection]:find_by_primary_key("") - end, "where_t must be a table") - end) - - it("should return nil (not found) if where_t is empty", function() - local res, err = dao_factory[collection]:find_by_primary_key({}) - assert.falsy(err) - assert.falsy(res) - end) + end) -- describe find() + describe("find_by_primary_key()", function() + local api_t = { + id = uuid(), + name = "mockbin" + } + setup(function() + spec_helper.drop_db() + local err = select(2, session:execute("INSERT INTO apis(id, name) VALUES(?, ?)", { + cassandra.uuid(api_t.id), + api_t.name + })) + assert.falsy(err) end) - - it("should find one entity by its primary key", function() - local apis, err = session:execute("SELECT * FROM apis") + teardown(function() + spec_helper.drop_db() + end) + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory.apis:find_by_primary_key("") + end, "where_t must be a table") + end) + it("should return nil (not found) if where_t is empty", function() + local res, err = dao_factory.apis:find_by_primary_key {} assert.falsy(err) - assert.True(#apis > 0) - - local api, err = dao_factory.apis:find_by_primary_key { id = apis[1].id } + assert.falsy(res) + end) + it("should find one entity by its primary key", function() + local api, err = dao_factory.apis:find_by_primary_key {id = api_t.id} assert.falsy(err) - assert.truthy(apis) - assert.same(apis[1], api) + assert.truthy(api) + assert.same(api_t, api) end) - it("should handle an invalid uuid value", function() - local apis, err = dao_factory.apis:find_by_primary_key { id = "abcd" } - assert.falsy(apis) + local api, err = dao_factory.apis:find_by_primary_key {id = "abcd"} + assert.falsy(api) assert.True(err.invalid_type) assert.equal("abcd is an invalid uuid", err.message.id) end) describe("plugins", function() - + local plugin_t = { + id = uuid(), + name = "key-auth", + api_id = api_t.id, + config = [[{"key_names": ["test_key"]}]] + } setup(function() - local fixtures = spec_helper.seed_db(1) - faker:insert_from_table { - plugin = { - { name = "key-auth", config = {key_names = {"apikey"}}, api_id = fixtures.api[1].id } - } - } + local err = select(2, session:execute("INSERT INTO plugins(id, api_id, name, config) VALUES(?, ?, ?, ?)", { + cassandra.uuid(plugin_t.id), + cassandra.uuid(plugin_t.api_id), + plugin_t.name, + plugin_t.config + })) + assert.falsy(err) end) - it("should unmarshall the `config` field", function() - local plugins, err = session:execute("SELECT * FROM plugins") - assert.falsy(err) - assert.truthy(plugins) - assert.True(#plugins> 0) - - local plugin_t = plugins[1] - local plugin, err = dao_factory.plugins:find_by_primary_key { id = plugin_t.id, name = plugin_t.name @@ -475,46 +584,44 @@ describe("Cassandra", function() assert.truthy(plugin) assert.equal("table", type(plugin.config)) end) - end) - end) -- describe :find_by_primary_key() - - describe(":delete()", function() + end) -- describe find_by_primary_key() + describe("delete()", function() + setup(function() + spec_helper.drop_db() + for i = 1, 100 do + local err = select(2, session:execute("INSERT INTO plugins(id, name) VALUES(uuid(), 'some-plugin')")) + assert.falsy(err) + end + end) teardown(function() spec_helper.drop_db() end) + it("should error if called with invalid parameters", function() + assert.has_error(function() + dao_factory.plugins:delete("") + end, "where_t must be a table") + end) + it("should return false if entity to delete wasn't found", function() + local ok, err = dao_factory.plugins:delete({id = uuid()}) + assert.falsy(err) + assert.False(ok) + end) + it("should delete an entity based on its primary key", function() + local entities, err = session:execute("SELECT * FROM plugins") + assert.falsy(err) + assert.truthy(entities) + assert.True(#entities > 0) - describe_core_collections(function(type, collection) - - it("should error if called with invalid parameters", function() - assert.has_error(function() - dao_factory[collection]:delete("") - end, "where_t must be a table") - end) - - it("should return false if entity to delete wasn't found", function() - local ok, err = dao_factory[collection]:delete({id = uuid()}) - assert.falsy(err) - assert.False(ok) - end) - - it("should delete an entity based on its primary key", function() - local entities, err = session:execute("SELECT * FROM "..collection) - assert.falsy(err) - assert.truthy(entities) - assert.True(#entities > 0) - - local ok, err = dao_factory[collection]:delete(entities[1]) - assert.falsy(err) - assert.True(ok) - - local entities, err = session:execute("SELECT * FROM "..collection.." WHERE id = ?", {cassandra.uuid(entities[1].id)}) - assert.falsy(err) - assert.truthy(entities) - assert.are.same(0, #entities) - end) + local ok, err = dao_factory.plugins:delete(entities[1]) + assert.falsy(err) + assert.True(ok) + local entities, err = session:execute("SELECT * FROM plugins WHERE id = ?", {cassandra.uuid(entities[1].id)}) + assert.falsy(err) + assert.truthy(entities) + assert.equal(0, #entities) end) end) @@ -523,16 +630,24 @@ describe("Cassandra", function() -- describe("APIs", function() - setup(function() - spec_helper.seed_db(100) + spec_helper.drop_db() + for i = 1, 100 do + local err = select(2, session:execute("INSERT INTO apis(id, name) VALUES(uuid(), 'mockbin')")) + assert.falsy(err) + end + end) + teardown(function() + spec_helper.drop_db() end) describe(":find_all()", function() - local apis, err = dao_factory.apis:find_all() - assert.falsy(err) - assert.truthy(apis) - assert.equal(100, #apis) + it("should retrieve all APIs", function() + local apis, err = dao_factory.apis:find_all() + assert.falsy(err) + assert.truthy(apis) + assert.equal(100, #apis) + end) end) end) @@ -541,34 +656,38 @@ describe("Cassandra", function() -- describe("plugins", function() - describe(":find_distinct()", function() - it("should find distinct plugins configurations", function() - faker:insert_from_table { - api = { - { name = "tests-distinct-1", request_host = "foo.com", upstream_url = "http://mockbin.com" }, - { name = "tests-distinct-2", request_host = "bar.com", upstream_url = "http://mockbin.com" } - }, - plugin = { - { name = "key-auth", config = {key_names = {"apikey"}, hide_credentials = true}, __api = 1 }, - { name = "rate-limiting", config = { minute = 6}, __api = 1 }, - { name = "rate-limiting", config = { minute = 6}, __api = 2 }, - { name = "file-log", config = { path = "/tmp/spec.log" }, __api = 1 } - } + setup(function() + spec_helper.drop_db() + faker:insert_from_table { + api = { + {name = "tests-distinct-1", request_host = "foo.com", upstream_url = "http://mockbin.com"}, + {name = "tests-distinct-2", request_host = "bar.com", upstream_url = "http://mockbin.com"} + }, + plugin = { + {name = "key-auth", config = {key_names = {"apikey"}, hide_credentials = true}, __api = 1}, + {name = "rate-limiting", config = { minute = 6}, __api = 1}, + {name = "rate-limiting", config = { minute = 6}, __api = 2}, + {name = "file-log", config = { path = "/tmp/spec.log" }, __api = 1} } - + } + end) + teardown(function() + spec_helper.drop_db() + end) + describe("find_distinct()", function() + it("should find distinct plugins configurations", function() local res, err = dao_factory.plugins:find_distinct() - assert.falsy(err) assert.truthy(res) - assert.are.same(3, #res) + assert.equal(3, #res) assert.truthy(utils.table_contains(res, "key-auth")) assert.truthy(utils.table_contains(res, "rate-limiting")) assert.truthy(utils.table_contains(res, "file-log")) end) end) - describe(":insert()", function() + describe("insert()", function() local api_id local inserted_plugin it("should insert a plugin and set the consumer_id to a 'null' uuid if none is specified", function() @@ -593,25 +712,23 @@ describe("Cassandra", function() inserted_plugin = plugin inserted_plugin.consumer_id = nil end) - it("should insert a plugin with an empty config if none is specified", function() local api_t = faker:fake_entity("api") local api, err = dao_factory.apis:insert(api_t) assert.falsy(err) assert.truthy(api) - local plugin, err = dao_factory.plugins:insert({ + local plugin, err = dao_factory.plugins:insert { name = "request-transformer", api_id = api.id - }) + } assert.falsy(err) assert.truthy(plugin) assert.falsy(plugin.consumer_id) - assert.same("request-transformer", plugin.name) + assert.equal("request-transformer", plugin.name) assert.same({}, plugin.config) end) - it("should select a plugin configuration by 'null' uuid consumer_id and remove the column", function() -- Now we should be able to select this plugin local rows, err = dao_factory.plugins:find_by_keys { @@ -624,7 +741,6 @@ describe("Cassandra", function() assert.falsy(rows[1].consumer_id) end) end) - end) -- describe plugins configurations end) -- describe Base DAO end) -- describe Cassandra diff --git a/spec/integration/dao/cassandra/cascade_spec.lua b/spec/integration/dao/cassandra/cascade_spec.lua index afbd506f37be..c28b0d145661 100644 --- a/spec/integration/dao/cassandra/cascade_spec.lua +++ b/spec/integration/dao/cassandra/cascade_spec.lua @@ -6,11 +6,9 @@ local dao_factory = env.dao_factory dao_factory:load_plugins({"keyauth", "basicauth", "oauth2"}) describe("Cassandra cascade delete", function() - setup(function() spec_helper.prepare_db() end) - describe("API -> plugins", function() local api, untouched_api @@ -34,11 +32,9 @@ describe("Cassandra cascade delete", function() api = fixtures.api[1] untouched_api = fixtures.api[2] end) - teardown(function() spec_helper.drop_db() end) - it("should delete foreign plugins when deleting an API", function() local ok, err = dao_factory.apis:delete(api) assert.falsy(err) @@ -84,11 +80,9 @@ describe("Cassandra cascade delete", function() consumer = fixtures.consumer[1] untouched_consumer = fixtures.consumer[2] end) - teardown(function() spec_helper.drop_db() end) - it("should delete foreign plugins when deleting a Consumer", function() local ok, err = dao_factory.consumers:delete(consumer) assert.falsy(err) @@ -126,11 +120,9 @@ describe("Cassandra cascade delete", function() consumer = fixtures.consumer[1] untouched_consumer = fixtures.consumer[2] end) - teardown(function() spec_helper.drop_db() end) - it("should delete foreign keyauth_credentials when deleting a Consumer", function() local ok, err = dao_factory.consumers:delete(consumer) assert.falsy(err) @@ -167,11 +159,9 @@ describe("Cassandra cascade delete", function() consumer = fixtures.consumer[1] untouched_consumer = fixtures.consumer[2] end) - teardown(function() spec_helper.drop_db() end) - it("should delete foreign basicauth_credentials when deleting a Consumer", function() local ok, err = dao_factory.consumers:delete(consumer) assert.falsy(err) @@ -225,11 +215,9 @@ describe("Cassandra cascade delete", function() } assert.falsy(err) end) - teardown(function() spec_helper.drop_db() end) - it("should delete foreign oauth2_credentials and tokens when deleting a Consumer", function() local ok, err = dao_factory.consumers:delete(consumer) assert.falsy(err) diff --git a/spec/integration/dao/cassandra/fixtures/core_migrations.lua b/spec/integration/dao/cassandra/fixtures/core_migrations.lua index 8db0c8166cce..9e6b58c6286a 100644 --- a/spec/integration/dao/cassandra/fixtures/core_migrations.lua +++ b/spec/integration/dao/cassandra/fixtures/core_migrations.lua @@ -3,17 +3,23 @@ local CORE_MIGRATIONS_FIXTURES = { name = "stub_skeleton", init = true, up = function(options, dao_factory) - return dao_factory:execute_queries([[ - CREATE KEYSPACE IF NOT EXISTS "]]..options.keyspace..[[" + -- Format final keyspace creation query + local keyspace_str = string.format([[ + CREATE KEYSPACE IF NOT EXISTS "%s" WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}; + ]], options.keyspace) - USE "]]..options.keyspace..[["; + local err = dao_factory:execute_queries(keyspace_str, true) + if err then + return err + end + return dao_factory:execute_queries [[ CREATE TABLE IF NOT EXISTS schema_migrations( id text PRIMARY KEY, migrations list ); - ]], true) + ]] end, down = function(options, dao_factory) return dao_factory:execute_queries [[ diff --git a/spec/integration/dao/cassandra/migrations_spec.lua b/spec/integration/dao/cassandra/migrations_spec.lua index 5cec6e4eac72..166d5f98e4ac 100644 --- a/spec/integration/dao/cassandra/migrations_spec.lua +++ b/spec/integration/dao/cassandra/migrations_spec.lua @@ -1,5 +1,5 @@ -local DAO = require "kong.dao.cassandra.factory" local cassandra = require "cassandra" +local DAO = require "kong.dao.cassandra.factory" local Migrations = require "kong.tools.migrations" local spec_helper = require "spec.spec_helpers" @@ -18,11 +18,17 @@ local FIXTURES = { local test_env = spec_helper.get_env() -- test environment local test_configuration = test_env.configuration -local test_cassandra_properties = test_configuration.databases_available[test_configuration.database].properties +local test_cassandra_properties = test_configuration.dao_config test_cassandra_properties.keyspace = FIXTURES.keyspace local test_dao = DAO(test_cassandra_properties) -local session = cassandra:new() +local session, err = cassandra.spawn_session { + shm = "factory_specs", + contact_points = test_configuration.dao_config.contact_points +} +if err then + error(err) +end local function has_table(state, arguments) local rows, err = session:execute("SELECT columnfamily_name FROM system.schema_columnfamilies WHERE keyspace_name = ?", {FIXTURES.keyspace}) @@ -45,6 +51,37 @@ say:set("assertion.has_table.positive", "Expected keyspace to have table %s") say:set("assertion.has_table.negative", "Expected keyspace not to have table %s") assert:register("assertion", "has_table", has_table, "assertion.has_table.positive", "assertion.has_table.negative") +local function has_keyspace(state, arguments) + local rows, err = session:execute("SELECT * FROM system.schema_keyspaces WHERE keyspace_name = ?", {arguments[1]}) + if err then + error(err) + end + + return #rows > 0 +end + +say:set("assertion.has_keyspace.positive", "Expected keyspace %s to exist") +say:set("assertion.has_keyspace.negative", "Expected keyspace %s to not exist") +assert:register("assertion", "has_keyspace", has_keyspace, "assertion.has_keyspace.positive", "assertion.has_keyspace.negative") + +local function has_replication_options(state, arguments) + local rows, err = session:execute("SELECT * FROM system.schema_keyspaces WHERE keyspace_name = ?", {arguments[1]}) + if err then + error(err) + end + + if #rows > 0 then + local keyspace = rows[1] + assert.equal("org.apache.cassandra.locator."..arguments[2], keyspace.strategy_class) + assert.equal(arguments[3], keyspace.strategy_options) + return true + end +end + +say:set("assertion.has_replication_options.positive", "Expected keyspace %s to have given replication options") +say:set("assertion.has_replication_options.negative", "Expected keyspace %s to not have given replication options") +assert:register("assertion", "has_replication_options", has_replication_options, "assertion.has_replication_options.positive", "assertion.has_replication_options.negative") + local function has_migration(state, arguments) local identifier = arguments[1] local migration = arguments[2] @@ -67,7 +104,6 @@ local function has_migration(state, arguments) return false end -local say = require "say" say:set("assertion.has_migration.positive", "Expected keyspace to have migration %s record") say:set("assertion.has_migration.negative", "Expected keyspace not to have migration %s recorded") assert:register("assertion", "has_migration", has_migration, "assertion.has_migration.positive", "assertion.has_migration.negative") @@ -79,13 +115,6 @@ assert:register("assertion", "has_migration", has_migration, "assertion.has_migr describe("Migrations", function() local migrations - setup(function() - local ok, err = session:connect(test_cassandra_properties.contact_points, test_cassandra_properties.port) - if not ok then - error(err) - end - end) - teardown(function() session:execute("DROP KEYSPACE "..FIXTURES.keyspace) end) @@ -315,4 +344,27 @@ describe("Migrations", function() assert.equal(0, #rows) end) end) + describe("keyspace replication strategy", function() + local KEYSPACE_NAME = "kong_replication_strategy_tests" + + setup(function() + migrations = Migrations(test_dao, FIXTURES.kong_config) + migrations.dao_properties.keyspace = KEYSPACE_NAME + end) + after_each(function() + session:execute("DROP KEYSPACE "..KEYSPACE_NAME) + end) + it("should create a keyspace with SimpleStrategy by default", function() + local err = migrations:run_migrations("core") + assert.falsy(err) + assert.has_keyspace(KEYSPACE_NAME) + assert.has_replication_options(KEYSPACE_NAME, "SimpleStrategy", "{\"replication_factor\":\"1\"}") + end) + it("should catch an invalid replication strategy", function() + migrations.dao_properties.replication_strategy = "foo" + local err = migrations:run_migrations("core") + assert.truthy(err) + assert.equal('Error executing migration for "core": invalid replication_strategy class', err) + end) + end) end) diff --git a/spec/integration/proxy/database_cache_spec.lua b/spec/integration/proxy/database_cache_spec.lua index 6850785f6a39..55cf9d81e719 100644 --- a/spec/integration/proxy/database_cache_spec.lua +++ b/spec/integration/proxy/database_cache_spec.lua @@ -10,7 +10,7 @@ describe("Database cache", function() spec_helper.prepare_db() fixtures = spec_helper.insert_fixtures { api = { - { name = "tests-database-cache", request_host = "cache.test", upstream_url = "http://httpbin.org" } + {name = "tests-database-cache", request_host = "cache.test", upstream_url = "http://httpbin.org"} } } @@ -22,7 +22,8 @@ describe("Database cache", function() end) it("should expire cache after five seconds", function() - local _ = http_client.get(spec_helper.PROXY_URL.."/get", {}, {host = "cache.test"}) + -- trigger a db fetch for this API's plugins + http_client.get(spec_helper.PROXY_URL.."/get", {}, {host = "cache.test"}) -- Let's add the authentication plugin configuration local _, err = env.dao_factory.plugins:insert { @@ -45,7 +46,7 @@ describe("Database cache", function() assert.are.equal(401, status) -- Create a consumer and a key will make it work again - local consumer, err = env.dao_factory.consumers:insert { username = "john" } + local consumer, err = env.dao_factory.consumers:insert {username = "john"} assert.falsy(err) local _, err = env.dao_factory.keyauth_credentials:insert { diff --git a/spec/integration/proxy/realip_spec.lua b/spec/integration/proxy/realip_spec.lua index ccfd944da98d..9856ccdbbd4a 100644 --- a/spec/integration/proxy/realip_spec.lua +++ b/spec/integration/proxy/realip_spec.lua @@ -31,17 +31,20 @@ describe("Real IP", function() local uuid = utils.random_string() -- Making the request - local _ = http_client.get(spec_helper.STUB_GET_URL, nil, + http_client.get(spec_helper.STUB_GET_URL, nil, { host = "realip.com", ["X-Forwarded-For"] = "4.4.4.4, 1.1.1.1, 5.5.5.5", file_log_uuid = uuid } ) - --assert.are.equal(200, status) + local timeout = 10 while not (IO.file_exists(FILE_LOG_PATH) and IO.file_size(FILE_LOG_PATH) > 0) do -- Wait for the file to be created, and for the log to be appended + os.execute("sleep 1") + timeout = timeout -1 + if timeout == 0 then error("Retrieving the ip address timed out") end end local file_log = IO.read_file(FILE_LOG_PATH) diff --git a/spec/integration/proxy/api_resolver_spec.lua b/spec/integration/proxy/resolver_spec.lua similarity index 61% rename from spec/integration/proxy/api_resolver_spec.lua rename to spec/integration/proxy/resolver_spec.lua index 6a9d608a3b5d..597b0386ef4d 100644 --- a/spec/integration/proxy/api_resolver_spec.lua +++ b/spec/integration/proxy/resolver_spec.lua @@ -21,7 +21,6 @@ local function parse_cert(cert) end describe("Resolver", function() - setup(function() spec_helper.prepare_db() spec_helper.insert_fixtures { @@ -37,7 +36,11 @@ describe("Resolver", function() {name = "tests-wildcard-subdomain-2", upstream_url = "http://mockbin.com/status/201", request_host = "wildcard.*"}, {name = "tests-preserve-host", request_host = "httpbin-nopreserve.com", upstream_url = "http://httpbin.org"}, {name = "tests-preserve-host-2", request_host = "httpbin-preserve.com", upstream_url = "http://httpbin.org", preserve_host = true}, - {name = "tests-uri", request_host = "mockbin-uri.com", upstream_url = "http://mockbin.org"} + {name = "tests-uri", request_host = "mockbin-uri.com", upstream_url = "http://mockbin.org"}, + {name = "tests-trailing-slash-path", request_path = "/test-trailing-slash", strip_request_path = true, upstream_url = "http://www.mockbin.org/request"}, + {name = "tests-trailing-slash-path2", request_path = "/test-trailing-slash2", strip_request_path = false, upstream_url = "http://www.mockbin.org/request"}, + {name = "tests-trailing-slash-path3", request_path = "/test-trailing-slash3", strip_request_path = true, upstream_url = "http://www.mockbin.org"}, + {name = "tests-trailing-slash-path4", request_path = "/test-trailing-slash4", strip_request_path = true, upstream_url = "http://www.mockbin.org/"} }, plugin = { {name = "key-auth", config = {key_names = {"apikey"} }, __api = 2} @@ -51,34 +54,17 @@ describe("Resolver", function() spec_helper.stop_kong() end) - describe("Test URI", function() - - it("should URL decode the URI with querystring", function() - local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", { hello = "world"}, {host = "mockbin-uri.com"}) - assert.equal(200, status) - assert.equal("http://mockbin.org/request/hello%2f?hello=world", cjson.decode(response).url) - end) - - it("should URL decode the URI without querystring", function() - local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2F", nil, {host = "mockbin-uri.com"}) - assert.equal(200, status) - assert.equal("http://mockbin.org/request/hello%2f", cjson.decode(response).url) - end) - - end) - describe("Inexistent API", function() - it("should return Not Found when the API is not in Kong", function() - local response, status = http_client.get(spec_helper.STUB_GET_URL, nil, {host = "foo.com"}) + local response, status, headers = http_client.get(spec_helper.STUB_GET_URL, nil, {host = "foo.com"}) assert.equal(404, status) assert.equal('{"request_path":"\\/request","message":"API not found with these values","request_host":["foo.com"]}\n', response) + assert.falsy(headers[constants.HEADERS.PROXY_LATENCY]) + assert.falsy(headers[constants.HEADERS.UPSTREAM_LATENCY]) end) - end) describe("SSL", function() - it("should work when calling SSL port", function() local response, status = http_client.get(STUB_GET_SSL_URL, nil, {host = "mockbin.com"}) assert.equal(200, status) @@ -86,7 +72,6 @@ describe("Resolver", function() local parsed_response = cjson.decode(response) assert.same("GET", parsed_response.method) end) - it("should work when manually triggering the handshake on default route", function() local parsed_url = url.parse(STUB_GET_SSL_URL) @@ -120,7 +105,6 @@ describe("Resolver", function() conn:close() end) - end) describe("Existing API", function() @@ -129,24 +113,19 @@ describe("Resolver", function() local _, status = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com"}) assert.equal(200, status) end) - it("should proxy when the Host header is not trimmed", function() local _, status = http_client.get(STUB_GET_URL, nil, {host = " mockbin.com "}) assert.equal(200, status) end) - it("should proxy when the request has no Host header but the X-Host-Override header", function() local _, status = http_client.get(STUB_GET_URL, nil, {["X-Host-Override"] = "mockbin.com"}) assert.equal(200, status) end) - it("should proxy when the Host header contains a port", function() local _, status = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com:80"}) assert.equal(200, status) end) - describe("with wildcard subdomain", function() - it("should proxy when the request_host is a wildcard subdomain", function() local _, status = http_client.get(STUB_GET_URL, nil, {host = "subdomain.wildcard.com"}) assert.equal(200, status) @@ -157,7 +136,7 @@ describe("Resolver", function() end) end) - describe("By request_Path", function() + describe("By request_path", function() it("should proxy when no Host is present but the request_uri matches the API's request_path", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/status/200") assert.equal(200, status) @@ -175,18 +154,6 @@ describe("Resolver", function() assert.equal("/somerequest_path/status/200", body.request_path) assert.equal(404, status) end) - it("should proxy and strip the request_path if `strip_request_path` is true", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request") - assert.equal(200, status) - local body = cjson.decode(response) - assert.equal("http://mockbin.com/request", body.url) - end) - it("should proxy and strip the request_path if `strip_request_path` is true if request_path has pattern characters", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request") - assert.equal(200, status) - local body = cjson.decode(response) - assert.equal("http://mockbin.com/request", body.url) - end) it("should proxy when the request_path has a deep level", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/deep/request_path/status/200") assert.equal(200, status) @@ -195,43 +162,111 @@ describe("Resolver", function() local _, status = http_client.get(spec_helper.PROXY_URL.."/mockbin?foo=bar") assert.equal(200, status) end) - it("should not strip if the `request_path` pattern is repeated in the request_uri", function() - local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request") + it("should not add a trailing slash when strip_path is disabled", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash2", {hello = "world"}) assert.equal(200, status) - local body = cjson.decode(response) - local upstream_url = body.log.entries[1].request.url - assert.equal("http://mockbin.com/har/of/request", upstream_url) + assert.equal("http://www.mockbin.org/request/test-trailing-slash2?hello=world", cjson.decode(response).url) end) end) it("should return the correct Server and Via headers when the request was proxied", function() - local _, status, headers = http_client.get(STUB_GET_URL, nil, { host = "mockbin.com"}) + local _, status, headers = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com"}) assert.equal(200, status) assert.equal("cloudflare-nginx", headers.server) assert.equal(constants.NAME.."/"..constants.VERSION, headers.via) end) it("should return the correct Server and no Via header when the request was NOT proxied", function() - local _, status, headers = http_client.get(STUB_GET_URL, nil, { host = "mockbin-auth.com"}) + local _, status, headers = http_client.get(STUB_GET_URL, nil, {host = "mockbin-auth.com"}) assert.equal(401, status) assert.equal(constants.NAME.."/"..constants.VERSION, headers.server) assert.falsy(headers.via) end) + it("should return correct timing headers when the request was proxied", function() + local _, status, headers = http_client.get(STUB_GET_URL, nil, {host = "mockbin.com"}) + assert.equal(200, status) + assert.truthy(headers[constants.HEADERS.PROXY_LATENCY:lower()]) + assert.truthy(headers[constants.HEADERS.UPSTREAM_LATENCY:lower()]) + end) end) - describe("Preseve Host", function() + describe("preserve_host", function() it("should not preserve the host (default behavior)", function() - local response, status = http_client.get(PROXY_URL.."/get", nil, { host = "httpbin-nopreserve.com"}) + local response, status = http_client.get(PROXY_URL.."/get", nil, {host = "httpbin-nopreserve.com"}) assert.equal(200, status) local parsed_response = cjson.decode(response) assert.equal("httpbin.org", parsed_response.headers["Host"]) end) it("should preserve the host (default behavior)", function() - local response, status = http_client.get(PROXY_URL.."/get", nil, { host = "httpbin-preserve.com"}) + local response, status = http_client.get(PROXY_URL.."/get", nil, {host = "httpbin-preserve.com"}) assert.equal(200, status) local parsed_response = cjson.decode(response) assert.equal("httpbin-preserve.com", parsed_response.headers["Host"]) end) end) - + + describe("strip_path", function() + it("should strip the request_path if `strip_request_path` is true", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin/request") + assert.equal(200, status) + local body = cjson.decode(response) + assert.equal("http://mockbin.com/request", body.url) + end) + it("should strip the request_path if `strip_request_path` is true if `request_path` has pattern characters", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/mockbin-with-pattern/request") + assert.equal(200, status) + local body = cjson.decode(response) + assert.equal("http://mockbin.com/request", body.url) + end) + it("should not strip if the `request_path` pattern is repeated in the request_uri", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/har/har/of/request") + assert.equal(200, status) + local body = cjson.decode(response) + local upstream_url = body.log.entries[1].request.url + assert.equal("http://mockbin.com/har/of/request", upstream_url) + end) + it("should not add a trailing slash when strip_path is enabled", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + it("should not add a trailing slash when strip_path is enabled and upstream_url has no path", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash3/request", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + it("should not add a trailing slash when strip_path is enabled and upstream_url has single path", function() + local response, status = http_client.get(spec_helper.PROXY_URL.."/test-trailing-slash4/request", {hello = "world"}) + assert.equal(200, status) + assert.equal("http://www.mockbin.org/request?hello=world", cjson.decode(response).url) + end) + end) + + describe("Percent-encoding", function() + it("should leave percent-encoded values in URI untouched", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL.."/hello%2Fworld", {}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request/hello%2fworld", cjson.decode(response).url) + end) + it("should leave untouched percent-encoded values in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {foo = "abc%7Cdef%2c%20world"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?foo=abc%7cdef%2c%20world", cjson.decode(response).url) + end) + it("should leave untouched percent-encoded keys in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {["hello%20world"] = "foo"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?hello%20world=foo", cjson.decode(response).url) + end) + it("should percent-encoded keys in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {["hello world"] = "foo"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?hello%20world=foo", cjson.decode(response).url) + end) + it("should percent-encoded keys in querystring", function() + local response, status = http_client.get(spec_helper.STUB_GET_URL, {foo = "abc|def, world"}, {host = "mockbin-uri.com"}) + assert.equal(200, status) + assert.equal("http://mockbin.org/request?foo=abc%7cdef%2c%20world", cjson.decode(response).url) + end) + end) end) diff --git a/spec/plugins/acl/api_spec.lua b/spec/plugins/acl/api_spec.lua index 02d31b3b6ff6..13a4cc2a2ce3 100644 --- a/spec/plugins/acl/api_spec.lua +++ b/spec/plugins/acl/api_spec.lua @@ -41,7 +41,7 @@ describe("ACLs API", function() end) end) - + describe("PUT", function() it("[SUCCESS] should create and update", function() @@ -64,9 +64,9 @@ describe("ACLs API", function() end) end) - + end) - + describe("/consumers/:consumer/acl/:id", function() describe("GET", function() @@ -79,7 +79,7 @@ describe("ACLs API", function() end) end) - + describe("PATCH", function() it("[SUCCESS] should update an ACL association", function() @@ -96,7 +96,7 @@ describe("ACLs API", function() end) end) - + describe("DELETE", function() it("[FAILURE] should return proper errors", function() @@ -113,7 +113,7 @@ describe("ACLs API", function() end) end) - + end) - + end) diff --git a/spec/plugins/jwt/api_spec.lua b/spec/plugins/jwt/api_spec.lua index f8802bf72be6..7115612256e7 100644 --- a/spec/plugins/jwt/api_spec.lua +++ b/spec/plugins/jwt/api_spec.lua @@ -94,9 +94,11 @@ describe("JWT API", function() describe("PATCH", function() - it("[SUCCESS] should not be supported", function() - local _, status = http_client.patch(BASE_URL..jwt_secret.id, {key = "alice"}) - assert.equal(405, status) + it("[SUCCESS] should update a credential", function() + local response, status = http_client.patch(BASE_URL..jwt_secret.id, {key = "alice",secret = "newsecret"}) + assert.equal(200, status) + jwt_secret = json.decode(response) + assert.equal("newsecret", jwt_secret.secret) end) end) diff --git a/spec/plugins/key-auth/daos_spec.lua b/spec/plugins/key-auth/daos_spec.lua index 4c35139b6a26..26bef046c5a7 100644 --- a/spec/plugins/key-auth/daos_spec.lua +++ b/spec/plugins/key-auth/daos_spec.lua @@ -1,5 +1,5 @@ local spec_helper = require "spec.spec_helpers" -local uuid = require "uuid" +local uuid = require "lua_uuid" local env = spec_helper.get_env() local dao_factory = env.dao_factory diff --git a/spec/plugins/logging_spec.lua b/spec/plugins/logging_spec.lua index 07e970fa3e05..4f25eaa5794b 100644 --- a/spec/plugins/logging_spec.lua +++ b/spec/plugins/logging_spec.lua @@ -159,12 +159,20 @@ describe("Logging Plugins", function() ) assert.are.equal(200, status) + local timeout = 10 while not (IO.file_exists(FILE_LOG_PATH)) do -- Wait for the file to be created + os.execute("sleep 1") + timeout = timeout -1 + if timeout == 0 then error("Creating the logfile timed out") end end + local timeout = 10 while not (IO.file_size(FILE_LOG_PATH) > 0) do -- Wait for the log to be appended + os.execute("sleep 1") + timeout = timeout -1 + if timeout == 0 then error("Appending to the logfile timed out") end end local file_log = IO.read_file(FILE_LOG_PATH) @@ -174,5 +182,4 @@ describe("Logging Plugins", function() os.remove(FILE_LOG_PATH) end) - end) diff --git a/spec/plugins/loggly/log_spec.lua b/spec/plugins/loggly/log_spec.lua new file mode 100644 index 000000000000..945a8ef48cfc --- /dev/null +++ b/spec/plugins/loggly/log_spec.lua @@ -0,0 +1,121 @@ +local cjson = require "cjson" +local spec_helper = require "spec.spec_helpers" +local http_client = require "kong.tools.http_client" + +local STUB_GET_URL = spec_helper.STUB_GET_URL + +local UDP_PORT = spec_helper.find_port() + +describe("Logging Plugins", function() + + setup(function() + spec_helper.prepare_db() + spec_helper.insert_fixtures { + api = { + { request_host = "logging.com", upstream_url = "http://mockbin.com" }, + { request_host = "logging1.com", upstream_url = "http://mockbin.com" }, + { request_host = "logging2.com", upstream_url = "http://mockbin.com" }, + { request_host = "logging3.com", upstream_url = "http://mockbin.com" } + }, + plugin = { + { name = "loggly", config = { host = "127.0.0.1", port = UDP_PORT, key = "123456789", log_level = "info", + successful_severity = "warning" }, __api = 1 }, + { name = "loggly", config = { host = "127.0.0.1", port = UDP_PORT, key = "123456789", log_level = "debug", + successful_severity = "info", timeout = 2000 }, __api = 2 }, + { name = "loggly", config = { host = "127.0.0.1", port = UDP_PORT, key = "123456789", log_level = "crit", + successful_severity = "crit", client_errors_severity = "warning" }, __api = 3 }, + { name = "loggly", config = { host = "127.0.0.1", port = UDP_PORT, key = "123456789" }, __api = 4 }, + } + } + + spec_helper.start_kong() + end) + + teardown(function() + spec_helper.stop_kong() + end) + + it("should log to UDP when severity is warning and log level info", function() + local thread = spec_helper.start_udp_server(UDP_PORT) -- Starting the mock TCP server + + local _, status = http_client.get(STUB_GET_URL, nil, { host = "logging.com" }) + assert.are.equal(200, status) + + local ok, res = thread:join() + assert.truthy(ok) + assert.truthy(res) + + local pri = string.sub(res,2,3) + assert.are.equal("12", pri) + + local message = {} + for w in string.gmatch(res,"{.*}") do + table.insert(message, w) + end + local log_message = cjson.decode(message[1]) + assert.are.same("127.0.0.1", log_message.client_ip) + end) + + it("should log to UDP when severity is info and log level debug", function() + local thread = spec_helper.start_udp_server(UDP_PORT) -- Starting the mock TCP server + + local _, status = http_client.get(STUB_GET_URL, nil, { host = "logging1.com" }) + assert.are.equal(200, status) + + local ok, res = thread:join() + assert.truthy(ok) + assert.truthy(res) + + local pri = string.sub(res,2,3) + assert.are.equal("14", pri) + + local message = {} + for w in string.gmatch(res,"{.*}") do + table.insert(message, w) + end + local log_message = cjson.decode(message[1]) + assert.are.same("127.0.0.1", log_message.client_ip) + end) + + it("should log to UDP when severity is critical and log level critical", function() + local thread = spec_helper.start_udp_server(UDP_PORT) -- Starting the mock TCP server + + local _, status = http_client.get(STUB_GET_URL, nil, { host = "logging2.com" }) + assert.are.equal(200, status) + + local ok, res = thread:join() + assert.truthy(ok) + assert.truthy(res) + + local pri = string.sub(res,2,3) + assert.are.equal("10", pri) + + local message = {} + for w in string.gmatch(res,"{.*}") do + table.insert(message, w) + end + local log_message = cjson.decode(message[1]) + assert.are.same("127.0.0.1", log_message.client_ip) + end) + + it("should log to UDP when severity and log level are default values", function() + local thread = spec_helper.start_udp_server(UDP_PORT) -- Starting the mock TCP server + + local _, status = http_client.get(STUB_GET_URL, nil, { host = "logging3.com" }) + assert.are.equal(200, status) + + local ok, res = thread:join() + assert.truthy(ok) + assert.truthy(res) + + local pri = string.sub(res,2,3) + assert.are.equal("14", pri) + + local message = {} + for w in string.gmatch(res,"{.*}") do + table.insert(message, w) + end + local log_message = cjson.decode(message[1]) + assert.are.same("127.0.0.1", log_message.client_ip) + end) +end) diff --git a/spec/plugins/mashape-analytics/alf_serializer_spec.lua b/spec/plugins/mashape-analytics/alf_serializer_spec.lua index e3244b4a43e6..b29339d6e277 100644 --- a/spec/plugins/mashape-analytics/alf_serializer_spec.lua +++ b/spec/plugins/mashape-analytics/alf_serializer_spec.lua @@ -65,7 +65,12 @@ describe("ALF serializer", function() it("should handle timing calculation if multiple upstreams were called", function() local entry = ALFSerializer.serialize_entry(fixtures.MULTIPLE_UPSTREAMS.NGX_STUB) assert.are.sameEntry(fixtures.MULTIPLE_UPSTREAMS.ENTRY, entry) - assert.equal(60468, entry.timings.wait) + assert.equal(236, entry.timings.wait) + end) + + it("should return the last header if two are present for mimeType", function() + local entry = ALFSerializer.serialize_entry(fixtures.MULTIPLE_HEADERS.NGX_STUB) + assert.are.sameEntry(fixtures.MULTIPLE_HEADERS.ENTRY, entry) end) end) diff --git a/spec/plugins/mashape-analytics/fixtures/requests.lua b/spec/plugins/mashape-analytics/fixtures/requests.lua index b861bb394fd9..0a138269d6de 100644 --- a/spec/plugins/mashape-analytics/fixtures/requests.lua +++ b/spec/plugins/mashape-analytics/fixtures/requests.lua @@ -20,17 +20,16 @@ return { request_uri = "/request", request_length = 123, body_bytes_sent = 934, - remote_addr = "127.0.0.1", - upstream_response_time = 0.391 + remote_addr = "127.0.0.1" }, ctx = { - proxy_started_at = 1432844571719, - proxy_ended_at = 143284457211, + KONG_PROXY_LATENCY = 22, + KONG_WAITING_TIME = 236, + KONG_RECEIVE_TIME = 177, analytics = { req_body = "hello=world&hello=earth", res_body = "{\"message\":\"response body\"}", - req_post_args = {["hello"] = {"world", "earth"}}, - response_received = 143284457211 + req_post_args = {["hello"] = {"world", "earth"}} } } }, @@ -81,15 +80,15 @@ return { statusText = "" }, startedDateTime = "2015-05-28T20:22:51Z", - time = 487, + time = 435, timings = { blocked = -1, connect = -1, dns = -1, - receive = 0, - send = 96, + receive = 177, + send = 22, ssl = -1, - wait = 391 + wait = 236 } } }, @@ -112,17 +111,16 @@ return { request_uri = "/request", request_length = 123, body_bytes_sent = 934, - remote_addr = "127.0.0.1", - upstream_response_time = "60.345, 0.123" + remote_addr = "127.0.0.1" }, ctx = { - proxy_started_at = 1432844571719, - proxy_ended_at = 143284457211, + KONG_PROXY_LATENCY = 10, + KONG_WAITING_TIME = 236, + KONG_RECEIVE_TIME = 1, analytics = { req_body = "hello=world&hello=earth", res_body = "{\"message\":\"response body\"}", - req_post_args = {["hello"] = {"world", "earth"}}, - response_received = 143284457211 + req_post_args = {["hello"] = {"world", "earth"}} } } }, @@ -172,15 +170,108 @@ return { statusText = "" }, startedDateTime = "2015-05-28T20:22:51Z", - time = 60564, + time = 247, timings = { blocked = -1, connect = -1, dns = -1, - receive = 0, - send = 96, + receive = 1, + send = 10, ssl = -1, - wait = 60468 + wait = 236 + } + } + }, + ["MULTIPLE_HEADERS"] = { + ["NGX_STUB"] = { + req = { + start_time = function() return 1432844571.623 end, + get_method = function() return "GET" end, + http_version = function() return 1.1 end, + get_headers = function() return {["Accept"] = "/*/", ["Host"] = "mockbin.com", ["Content-Type"] = {"application/json", "application/www-form-urlencoded"}} end, + get_uri_args = function() return {["hello"] = "world", ["foo"] = "bar"} end + }, + resp = { + get_headers = function() return {["Connection"] = "close", ["Content-Type"] = {"application/json", "application/www-form-urlencoded"}, ["Content-Length"] = "934"} end + }, + status = 200, + var = { + scheme = "http", + host = "mockbin.com", + request_uri = "/request", + request_length = 123, + body_bytes_sent = 934, + remote_addr = "127.0.0.1" + }, + ctx = { + KONG_PROXY_LATENCY = 10, + KONG_WAITING_TIME = 236, + KONG_RECEIVE_TIME = 1, + analytics = { + req_body = "hello=world&hello=earth", + res_body = "{\"message\":\"response body\"}", + req_post_args = {["hello"] = {"world", "earth"}} + } + } + }, + ["ENTRY"] = { + cache = {}, + request = { + bodySize = 23, + cookies = {EMPTY_ARRAY_PLACEHOLDER}, + headers = { + {name = "Accept", value = "/*/"}, + {name = "Host", value = "mockbin.com"}, + {name = "Content-Type", value = "application/json"}, + {name = "Content-Type", value = "application/www-form-urlencoded"} + }, + headersSize = 95, + httpVersion = "HTTP/1.1", + method = "GET", + postData = { + mimeType = "application/www-form-urlencoded", + params = { + {name = "hello", value = "world"}, + {name = "hello", value = "earth"} + }, + text = "base64_hello=world&hello=earth" + }, + queryString = { + {name = "foo", value = "bar"}, + {name = "hello", value = "world"} + }, + url = "http://mockbin.com/request" + }, + response = { + bodySize = 934, + content = { + mimeType = "application/www-form-urlencoded", + size = 934, + text = "base64_{\"message\":\"response body\"}" + }, + cookies = {EMPTY_ARRAY_PLACEHOLDER}, + headers = { + {name = "Content-Length", value = "934"}, + {name = "Content-Type", value = "application/json"}, + {name = "Content-Type", value = "application/www-form-urlencoded"}, + {name = "Connection", value = "close"} + }, + headersSize = 103, + httpVersion = "", + redirectURL = "", + status = 200, + statusText = "" + }, + startedDateTime = "2015-05-28T20:22:51Z", + time = 247, + timings = { + blocked = -1, + connect = -1, + dns = -1, + receive = 1, + send = 10, + ssl = -1, + wait = 236 } } } diff --git a/spec/plugins/oauth2/access_spec.lua b/spec/plugins/oauth2/access_spec.lua index 052de2c3bfb9..640d6d5aa3de 100644 --- a/spec/plugins/oauth2/access_spec.lua +++ b/spec/plugins/oauth2/access_spec.lua @@ -44,17 +44,19 @@ describe("Authentication Plugin", function() { name = "tests-oauth2-with-path", request_host = "mockbin-path.com", upstream_url = "http://mockbin.com", request_path = "/somepath/" }, { name = "tests-oauth2-with-hide-credentials", request_host = "oauth2_3.com", upstream_url = "http://mockbin.com" }, { name = "tests-oauth2-client-credentials", request_host = "oauth2_4.com", upstream_url = "http://mockbin.com" }, - { name = "tests-oauth2-password-grant", request_host = "oauth2_5.com", upstream_url = "http://mockbin.com" } + { name = "tests-oauth2-password-grant", request_host = "oauth2_5.com", upstream_url = "http://mockbin.com" }, + { name = "tests-oauth2-accept_http_if_already_terminated", request_host = "oauth2_6.com", upstream_url = "http://mockbin.com" }, }, consumer = { { username = "auth_tests_consumer" } }, plugin = { - { name = "oauth2", config = { scopes = { "email", "profile" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_implicit_grant = true }, __api = 1 }, + { name = "oauth2", config = { scopes = { "email", "profile", "user.email" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_implicit_grant = true }, __api = 1 }, { name = "oauth2", config = { scopes = { "email", "profile" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_implicit_grant = true }, __api = 2 }, { name = "oauth2", config = { scopes = { "email", "profile" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_implicit_grant = true, hide_credentials = true }, __api = 3 }, { name = "oauth2", config = { scopes = { "email", "profile" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_client_credentials = true, enable_authorization_code = false }, __api = 4 }, - { name = "oauth2", config = { scopes = { "email", "profile" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_password_grant = true, enable_authorization_code = false }, __api = 5 } + { name = "oauth2", config = { scopes = { "email", "profile" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_password_grant = true, enable_authorization_code = false }, __api = 5 }, + { name = "oauth2", config = { scopes = { "email", "profile", "user.email" }, mandatory_scope = true, provision_key = "provision123", token_expiration = 5, enable_implicit_grant = true, accept_http_if_already_terminated = true }, __api = 6 }, }, oauth2_credential = { { client_id = "clientid123", client_secret = "secret123", redirect_uri = "http://google.com/kong", name="testapp", __consumer = 1 } @@ -163,14 +165,23 @@ describe("Authentication Plugin", function() assert.are.equal("You must use HTTPS", body.error_description) end) - it("should return success when under HTTP and X-Forwarded-Proto header is set to HTTPS", function() - local response, status = http_client.post(PROXY_URL.."/oauth2/authorize", { provision_key = "provision123", authenticated_userid = "id123", client_id = "clientid123", scope = "email", response_type = "code" }, {host = "oauth2.com", ["X-Forwarded-Proto"] = "https"}) + it("should work when not under HTTPS but accept_http_if_already_terminated is true", function() + local response, status = http_client.post(PROXY_URL.."/oauth2/authorize", { provision_key = "provision123", authenticated_userid = "id123", client_id = "clientid123", scope = "email", response_type = "code" }, {host = "oauth2_6.com", ["X-Forwarded-Proto"] = "https"}) local body = cjson.decode(response) assert.are.equal(200, status) assert.are.equal(1, utils.table_size(body)) assert.truthy(rex.match(body.redirect_uri, "^http://google\\.com/kong\\?code=[\\w]{32,32}$")) end) + it("should fail when not under HTTPS and accept_http_if_already_terminated is false", function() + local response, status = http_client.post(PROXY_URL.."/oauth2/authorize", { provision_key = "provision123", authenticated_userid = "id123", client_id = "clientid123", scope = "email", response_type = "code" }, {host = "oauth2.com", ["X-Forwarded-Proto"] = "https"}) + local body = cjson.decode(response) + assert.are.equal(400, status) + assert.are.equal(2, utils.table_size(body)) + assert.are.equal("access_denied", body.error) + assert.are.equal("You must use HTTPS", body.error_description) + end) + it("should return success", function() local response, status = http_client.post(PROXY_SSL_URL.."/oauth2/authorize", { provision_key = "provision123", authenticated_userid = "id123", client_id = "clientid123", scope = "email", response_type = "code" }, {host = "oauth2.com"}) local body = cjson.decode(response) @@ -236,9 +247,30 @@ describe("Authentication Plugin", function() assert.are.equal("email", data[1].scope) end) + it("should return success with a dotted scope and store authenticated user properties", function() + local response, status = http_client.post(PROXY_SSL_URL.."/oauth2/authorize", { provision_key = "provision123", authenticated_userid = "id123", client_id = "clientid123", scope = "user.email", response_type = "code", state = "hello", authenticated_userid = "userid123" }, {host = "oauth2.com"}) + local body = cjson.decode(response) + assert.are.equal(200, status) + assert.are.equal(1, utils.table_size(body)) + assert.truthy(rex.match(body.redirect_uri, "^http://google\\.com/kong\\?code=[\\w]{32,32}&state=hello$")) + + local matches = rex.gmatch(body.redirect_uri, "^http://google\\.com/kong\\?code=([\\w]{32,32})&state=hello$") + local code + for line in matches do + code = line + end + local data = dao_factory.oauth2_authorization_codes:find_by_keys({code = code}) + assert.are.equal(1, #data) + assert.are.equal(code, data[1].code) + + assert.are.equal("userid123", data[1].authenticated_userid) + assert.are.equal("user.email", data[1].scope) + end) + end) describe("Implicit Grant", function() + it("should return success", function() local response, status, headers = http_client.post(PROXY_SSL_URL.."/oauth2/authorize", { provision_key = "provision123", authenticated_userid = "id123", client_id = "clientid123", scope = "email", response_type = "token" }, {host = "oauth2.com"}) local body = cjson.decode(response) @@ -282,6 +314,26 @@ describe("Authentication Plugin", function() assert.falsy(data[1].refresh_token) end) + it("should return set the right upstream headers", function() + local response = http_client.post(PROXY_SSL_URL.."/oauth2/authorize", { provision_key = "provision123", authenticated_userid = "id123", client_id = "clientid123", scope = "email profile", response_type = "token", authenticated_userid = "userid123" }, {host = "oauth2.com"}) + local body = cjson.decode(response) + + local matches = rex.gmatch(body.redirect_uri, "^http://google\\.com/kong\\?token_type=bearer&access_token=([\\w]{32,32})$") + local access_token + for line in matches do + access_token = line + end + + local response, status = http_client.get(PROXY_SSL_URL.."/request", { access_token = access_token }, {host = "oauth2.com"}) + assert.are.equal(200, status) + + local body = cjson.decode(response) + assert.truthy(body.headers["x-consumer-id"]) + assert.are.equal("auth_tests_consumer", body.headers["x-consumer-username"]) + assert.are.equal("email profile", body.headers["x-authenticated-scope"]) + assert.are.equal("userid123", body.headers["x-authenticated-userid"]) + end) + end) describe("Client Credentials", function() @@ -373,6 +425,20 @@ describe("Authentication Plugin", function() assert.are.equal("Invalid client_secret", body.error_description) end) + it("should return set the right upstream headers", function() + local response, status = http_client.post(PROXY_SSL_URL.."/oauth2/token", { client_id = "clientid123", client_secret="secret123", scope = "email", grant_type = "client_credentials", authenticated_userid = "hello", provision_key = "provision123" }, {host = "oauth2_4.com"}) + assert.are.equal(200, status) + + local response, status = http_client.get(PROXY_SSL_URL.."/request", { access_token = cjson.decode(response).access_token }, {host = "oauth2_4.com"}) + assert.are.equal(200, status) + + local body = cjson.decode(response) + assert.truthy(body.headers["x-consumer-id"]) + assert.are.equal("auth_tests_consumer", body.headers["x-consumer-username"]) + assert.are.equal("email", body.headers["x-authenticated-scope"]) + assert.are.equal("hello", body.headers["x-authenticated-userid"]) + end) + end) describe("Password Grant", function() @@ -453,6 +519,20 @@ describe("Authentication Plugin", function() assert.are.equal("Invalid client_secret", body.error_description) end) + it("should return set the right upstream headers", function() + local response, status = http_client.post(PROXY_SSL_URL.."/oauth2/token", { provision_key = "provision123", authenticated_userid = "id123", scope = "email", grant_type = "password" }, {host = "oauth2_5.com", authorization = "Basic Y2xpZW50aWQxMjM6c2VjcmV0MTIz"}) + assert.are.equal(200, status) + + local response, status = http_client.get(PROXY_SSL_URL.."/request", { access_token = cjson.decode(response).access_token }, {host = "oauth2_5.com"}) + assert.are.equal(200, status) + + local body = cjson.decode(response) + assert.truthy(body.headers["x-consumer-id"]) + assert.are.equal("auth_tests_consumer", body.headers["x-consumer-username"]) + assert.are.equal("email", body.headers["x-authenticated-scope"]) + assert.are.equal("id123", body.headers["x-authenticated-userid"]) + end) + end) end) @@ -550,24 +630,24 @@ describe("Authentication Plugin", function() assert.are.equal(5, body.expires_in) assert.are.equal("wot", body.state) end) - end) - describe("Making a request", function() + it("should return set the right upstream headers", function() + local code = provision_code() + local response, status = http_client.post(PROXY_SSL_URL.."/oauth2/token", { code = code, client_id = "clientid123", client_secret = "secret123", grant_type = "authorization_code" }, {host = "oauth2.com"}) + assert.are.equal(200, status) - it("should return an error when nothing is being sent", function() - local response, status = http_client.post(STUB_GET_URL, { }, {host = "oauth2.com"}) - local body = cjson.decode(response) - assert.are.equal(403, status) - assert.are.equal("Invalid authentication credentials", body.message) - end) + local response, status = http_client.get(PROXY_SSL_URL.."/request", { access_token = cjson.decode(response).access_token }, {host = "oauth2.com"}) + assert.are.equal(200, status) - it("should return an error when a wrong access token is being sent", function() - local response, status = http_client.get(STUB_GET_URL, { access_token = "hello" }, {host = "oauth2.com"}) local body = cjson.decode(response) - assert.are.equal(403, status) - assert.are.equal("Invalid authentication credentials", body.message) + assert.truthy(body.headers["x-consumer-id"]) + assert.are.equal("auth_tests_consumer", body.headers["x-consumer-username"]) + assert.are.equal("email", body.headers["x-authenticated-scope"]) + assert.are.equal("userid123", body.headers["x-authenticated-userid"]) end) + end) + describe("Making a request", function() it("should work when a correct access_token is being sent in the querystring", function() local token = provision_token() local _, status = http_client.post(STUB_GET_URL, { access_token = token.access_token }, {host = "oauth2.com"}) @@ -599,15 +679,49 @@ describe("Authentication Plugin", function() assert.are.equal("userid123", body.headers["x-authenticated-userid"]) assert.are.equal("email", body.headers["x-authenticated-scope"]) end) + end) - it("should not work when a correct access_token is being sent in an authorization header (bearer)", function() - local token = provision_token() - local response, status = http_client.post(STUB_POST_URL, { }, {host = "oauth2.com", authorization = "bearer "..token.access_token.."hello"}) + describe("Authentication challenge", function() + it("should return 401 Unauthorized without error if it lacks any authentication information", function() + local response, status, headers = http_client.post(STUB_GET_URL, { }, {host = "oauth2.com"}) + local body = cjson.decode(response) + assert.are.equal(401, status) + assert.are.equal('Bearer realm="service"', headers['www-authenticate']) + assert.are.equal(0, utils.table_size(body)) + end) + + it("should return 401 Unauthorized when an invalid access token is being sent via url parameter", function() + local response, status, headers = http_client.get(STUB_GET_URL, { access_token = "invalid" }, {host = "oauth2.com"}) + local body = cjson.decode(response) + assert.are.equal(401, status) + assert.are.equal('Bearer realm="service" error="invalid_token" error_description="The access token is invalid"', headers['www-authenticate']) + assert.are.equal("invalid_token", body.error) + assert.are.equal("The access token is invalid", body.error_description) + end) + + it("should return 401 Unauthorized when an invalid access token is being sent via the Authorization header", function() + local response, status, headers = http_client.post(STUB_POST_URL, { }, {host = "oauth2.com", authorization = "bearer invalid"}) local body = cjson.decode(response) - assert.are.equal(403, status) - assert.are.equal("Invalid authentication credentials", body.message) + assert.are.equal(401, status) + assert.are.equal('Bearer realm="service" error="invalid_token" error_description="The access token is invalid"', headers['www-authenticate']) + assert.are.equal("invalid_token", body.error) + assert.are.equal("The access token is invalid", body.error_description) end) + it("should return 401 Unauthorized when token has expired", function() + local token = provision_token() + + -- Token expires in (5 seconds) + os.execute("sleep "..tonumber(6)) + + local response, status, headers = http_client.post(STUB_POST_URL, { }, {host = "oauth2.com", authorization = "bearer "..token.access_token}) + local body = cjson.decode(response) + assert.are.equal(401, status) + assert.are.equal(2, utils.table_size(body)) + assert.are.equal('Bearer realm="service" error="invalid_token" error_description="The access token expired"', headers['www-authenticate']) + assert.are.equal("invalid_token", body.error) + assert.are.equal("The access token expired", body.error_description) + end) end) describe("Refresh Token", function() @@ -646,10 +760,8 @@ describe("Authentication Plugin", function() local response, status = http_client.post(STUB_POST_URL, { }, {host = "oauth2.com", authorization = "bearer "..token.access_token}) local body = cjson.decode(response) - assert.are.equal(400, status) - assert.are.equal(2, utils.table_size(body)) - assert.are.equal("invalid_request", body.error) - assert.are.equal("access_token expired", body.error_description) + assert.are.equal(401, status) + assert.are.equal("The access token expired", body.error_description) -- Refreshing the token local response, status = http_client.post(PROXY_SSL_URL.."/oauth2/token", { refresh_token = token.refresh_token, client_id = "clientid123", client_secret = "secret123", grant_type = "refresh_token" }, {host = "oauth2.com"}) @@ -718,6 +830,7 @@ describe("Authentication Plugin", function() assert.are.equal(200, status) assert.falsy(body.headers.authorization) end) + end) end) diff --git a/spec/plugins/oauth2/api_spec.lua b/spec/plugins/oauth2/api_spec.lua index 707cd9be4c25..b49f4e2e0a28 100644 --- a/spec/plugins/oauth2/api_spec.lua +++ b/spec/plugins/oauth2/api_spec.lua @@ -18,7 +18,9 @@ describe("OAuth 2 Credentials API", function() setup(function() local fixtures = spec_helper.insert_fixtures { - consumer = {{ username = "bob" }} + consumer = { + {username = "bob"} + } } consumer = fixtures.consumer[1] BASE_URL = spec_helper.API_URL.."/consumers/bob/oauth2/" @@ -27,7 +29,7 @@ describe("OAuth 2 Credentials API", function() describe("POST", function() it("[SUCCESS] should create a oauth2 credential", function() - local response, status = http_client.post(BASE_URL, { name = "Test APP", redirect_uri = "http://google.com/" }) + local response, status = http_client.post(BASE_URL, {name = "Test APP", redirect_uri = "http://google.com/"}) assert.equal(201, status) credential = json.decode(response) assert.equal(consumer.id, credential.consumer_id) @@ -43,11 +45,11 @@ describe("OAuth 2 Credentials API", function() describe("PUT", function() setup(function() - spec_helper.get_env().dao_factory.keyauth_credentials:delete({id=credential.id}) + spec_helper.get_env().dao_factory.keyauth_credentials:delete({id = credential.id}) end) it("[SUCCESS] should create and update", function() - local response, status = http_client.put(BASE_URL, { redirect_uri = "http://google.com/", name = "Test APP" }) + local response, status = http_client.put(BASE_URL, {redirect_uri = "http://google.com/", name = "Test APP"}) assert.equal(201, status) credential = json.decode(response) assert.equal(consumer.id, credential.consumer_id) @@ -89,14 +91,14 @@ describe("OAuth 2 Credentials API", function() describe("PATCH", function() it("[SUCCESS] should update a credential", function() - local response, status = http_client.patch(BASE_URL..credential.id, { redirect_uri = "http://getkong.org/" }) + local response, status = http_client.patch(BASE_URL..credential.id, {redirect_uri = "http://getkong.org/"}) assert.equal(200, status) credential = json.decode(response) assert.equal("http://getkong.org/", credential.redirect_uri) end) it("[FAILURE] should return proper errors", function() - local response, status = http_client.patch(BASE_URL..credential.id, { redirect_uri = "" }) + local response, status = http_client.patch(BASE_URL..credential.id, {redirect_uri = ""}) assert.equal(400, status) assert.equal('{"redirect_uri":"redirect_uri is not a url"}\n', response) end) diff --git a/spec/plugins/rate-limiting/daos_spec.lua b/spec/plugins/rate-limiting/daos_spec.lua index 239e45c2bfaf..8344ea613a1a 100644 --- a/spec/plugins/rate-limiting/daos_spec.lua +++ b/spec/plugins/rate-limiting/daos_spec.lua @@ -1,6 +1,6 @@ local spec_helper = require "spec.spec_helpers" local timestamp = require "kong.tools.timestamp" -local uuid = require "uuid" +local uuid = require "lua_uuid" local env = spec_helper.get_env() local dao_factory = env.dao_factory @@ -21,7 +21,7 @@ describe("Rate Limiting Metrics", function() for period, period_date in pairs(periods) do local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) assert.falsy(err) - assert.are.same(nil, metric) + assert.same(nil, metric) end end) @@ -30,15 +30,14 @@ describe("Rate Limiting Metrics", function() local periods = timestamp.get_timestamps(current_timestamp) -- First increment - local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) - assert.falsy(err) + local ok = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) assert.True(ok) -- First select for period, period_date in pairs(periods) do local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) assert.falsy(err) - assert.are.same({ + assert.same({ api_id = api_id, identifier = identifier, period = period, @@ -48,15 +47,14 @@ describe("Rate Limiting Metrics", function() end -- Second increment - local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) - assert.falsy(err) + local ok = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) assert.True(ok) -- Second select for period, period_date in pairs(periods) do local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) assert.falsy(err) - assert.are.same({ + assert.same({ api_id = api_id, identifier = identifier, period = period, @@ -70,8 +68,7 @@ describe("Rate Limiting Metrics", function() periods = timestamp.get_timestamps(current_timestamp) -- Third increment - local ok, err = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) - assert.falsy(err) + local ok = ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1) assert.True(ok) -- Third select with 1 second delay @@ -85,7 +82,7 @@ describe("Rate Limiting Metrics", function() local metric, err = ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period) assert.falsy(err) - assert.are.same({ + assert.same({ api_id = api_id, identifier = identifier, period = period, diff --git a/spec/plugins/request-transformer/access_spec.lua b/spec/plugins/request-transformer/access_spec.lua index e0eee70368de..7cbb834bc255 100644 --- a/spec/plugins/request-transformer/access_spec.lua +++ b/spec/plugins/request-transformer/access_spec.lua @@ -6,27 +6,25 @@ local STUB_GET_URL = spec_helper.STUB_GET_URL local STUB_POST_URL = spec_helper.STUB_POST_URL describe("Request Transformer", function() - setup(function() spec_helper.prepare_db() spec_helper.insert_fixtures { api = { - { name = "tests-request-transformer-1", request_host = "test1.com", upstream_url = "http://mockbin.com" }, - { name = "tests-request-transformer-2", request_host = "test2.com", upstream_url = "http://httpbin.org" } + {name = "tests-request-transformer-1", request_host = "test1.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-2", request_host = "test2.com", upstream_url = "http://httpbin.org"}, + {name = "tests-request-transformer-3", request_host = "test3.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-4", request_host = "test4.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-5", request_host = "test5.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-6", request_host = "test6.com", upstream_url = "http://mockbin.com"}, }, plugin = { { name = "request-transformer", config = { add = { - headers = {"x-added:true", "x-added2:true" }, - querystring = {"newparam:value"}, - form = {"newformparam:newvalue"} - }, - remove = { - headers = { "x-to-remove" }, - querystring = { "toremovequery" }, - form = { "toremoveform" } + headers = {"h1:v1", "h2:v2"}, + querystring = {"q1:v1"}, + form = {"p1:v1"} } }, __api = 1 @@ -35,110 +33,335 @@ describe("Request Transformer", function() name = "request-transformer", config = { add = { - headers = { "host:mark" } + headers = {"host:mark"} } }, __api = 2 + }, + { + name = "request-transformer", + config = { + add = { + headers = {"x-added:a1", "x-added2:b1", "x-added3:c2"}, + querystring = {"query-added:newvalue", "p1:a1"}, + form = {"newformparam:newvalue"} + }, + remove = { + headers = {"x-to-remove"}, + querystring = {"toremovequery"} + }, + append = { + headers = {"x-added:a2", "x-added:a3"}, + querystring = {"p1:a2", "p2:b1"} + }, + replace = { + headers = {"x-to-replace:false"}, + querystring = {"toreplacequery:no"} + } + }, + __api = 3 + }, + { + name = "request-transformer", + config = { + remove = { + headers = {"x-to-remove"}, + querystring = {"q1"}, + form = {"toremoveform"} + } + }, + __api = 4 + }, + { + name = "request-transformer", + config = { + replace = { + headers = {"h1:v1"}, + querystring = {"q1:v1"}, + form = {"p1:v1"} + } + }, + __api = 5 + }, + { + name = "request-transformer", + config = { + append = { + headers = {"h1:v1", "h1:v2", "h2:v1",}, + querystring = {"q1:v1", "q1:v2", "q2:v1"} + } + }, + __api = 6 } }, } - spec_helper.start_kong() end) teardown(function() spec_helper.stop_kong() end) - - describe("Test adding parameters", function() - + + describe("Test remove", function() + it("should remove specified header", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test4.com", ["x-to-remove"] = "true", ["x-another-header"] = "true"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.headers["x-to-remove"]) + assert.equal("true", body.headers["x-another-header"]) + end) + it("should remove parameters on POST", function() + local response, status = http_client.post(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test4.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["toremoveform"]) + assert.equal("yes", body.postData.params["nottoremove"]) + end) + it("should remove parameters on multipart POST", function() + local response, status = http_client.post_multipart(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test4.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["toremoveform"]) + assert.equal("yes", body.postData.params["nottoremove"]) + end) + it("should remove queryString on GET if it exist", function() + local response, status = http_client.post(STUB_POST_URL.."/?q1=v1&q2=v2", { hello = "world"}, {host = "test4.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.queryString["q1"]) + assert.equal("v2", body.queryString["q2"]) + end) + end) + + describe("Test replace", function() + it("should replace specified header if it exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test5.com", ["h1"] = "V", ["h2"] = "v2"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) + end) + it("should not add as new header if header does not exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test5.com", ["h2"] = "v2"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) + end) + it("should replace specified parameters on POST", function() + local response, status = http_client.post(STUB_POST_URL, {["p1"] = "v", ["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should not add as new parameter if parameter does not exist on POST", function() + local response, status = http_client.post(STUB_POST_URL, {["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should replace specified parameters on multipart POST", function() + local response, status = http_client.post_multipart(STUB_POST_URL, {["p1"] = "v", ["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should not add as new parameter if parameter does not exist on multipart POST", function() + local response, status = http_client.post_multipart(STUB_POST_URL, {["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should replace queryString on POST if it exist", function() + local response, status = http_client.post(STUB_POST_URL.."/?q1=v&q2=v2", { hello = "world"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.queryString["q1"]) + assert.equal("v2", body.queryString["q2"]) + end) + it("should not add new queryString on POST if it does not exist", function() + local response, status = http_client.post(STUB_POST_URL.."/?q2=v2", { hello = "world"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.queryString["q1"]) + assert.equal("v2", body.queryString["q2"]) + end) + end) + + describe("Test add", function() it("should add new headers", function() local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("true", body.headers["x-added"]) - assert.are.equal("true", body.headers["x-added2"]) + assert.equal(200, status) + assert.equal("v1", body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) end) - - it("should add new parameters on POST", function() - local response, status = http_client.post(STUB_POST_URL, {}, {host = "test1.com"}) + it("should not change or append value if header already exists", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com", h1 = "v3"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("v3", body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) end) - - it("should add new parameters on POST when existing params exist", function() - local response, status = http_client.post(STUB_POST_URL, { hello = "world" }, {host = "test1.com"}) + it("should add new parameter on POST", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("world", body.postData.params["hello"]) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("world", body.postData.params["hello"]) + assert.equal("v1", body.postData.params["p1"]) end) - - it("should add new parameters on multipart POST", function() + it("should not change or append value to parameter on POST when parameter exists", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test1.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("world", body.postData.params["hello"]) + assert.equal("v1", body.postData.params["p1"]) + end) + it("should add new parameter on multipart POST", function() local response, status = http_client.post_multipart(STUB_POST_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("v1", body.postData.params["p1"]) end) - - it("should add new parameters on multipart POST when existing params exist", function() - local response, status = http_client.post_multipart(STUB_POST_URL, { hello = "world" }, {host = "test1.com"}) + it("should not change or append value to parameter on multipart POST when parameter exists", function() + local response, status = http_client.post_multipart(STUB_POST_URL, { hello = "world"}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("world", body.postData.params["hello"]) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("world", body.postData.params["hello"]) + assert.equal("v1", body.postData.params["p1"]) end) - - it("should add new parameters on GET", function() + it("should add new querystring on GET", function() local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("value", body.queryString["newparam"]) + assert.equal(200, status) + assert.equal("v1", body.queryString["q1"]) + end) + it("should not change or append value to querystring on GET if querystring exists", function() + local response, status = http_client.get(STUB_GET_URL, {q1 = "v2"}, {host = "test1.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v2", body.queryString["q1"]) end) - - it("should change the host header", function() + it("should not change the host header", function() local response, status = http_client.get(spec_helper.PROXY_URL.."/get", {}, {host = "test2.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("mark", body.headers["Host"]) + assert.equal(200, status) + assert.equal("httpbin.org", body.headers["Host"]) + end) + end) + + describe("Test append ", function() + it("should add a new header if header does not exists", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.headers["h2"]) + end) + it("should append values to existing headers", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1, v2", body.headers["h1"]) + end) + it("should add new querystring if querystring does not exists", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.queryString["q2"]) + end) + it("should append values to existing querystring", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.same({"v1", "v2"}, body.queryString["q1"]) end) - end) - describe("Test removing parameters", function() - + describe("Test for remove, replace, add and append ", function() it("should remove a header", function() - local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com", ["x-to-remove"] = "true"}) + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com", ["x-to-remove"] = "true"}) local body = cjson.decode(response) - assert.are.equal(200, status) + assert.equal(200, status) assert.falsy(body.headers["x-to-remove"]) end) - - it("should remove parameters on POST", function() - local response, status = http_client.post(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) + it("should replace value of header, if header exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com", ["x-to-replace"] = "true"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.falsy(body.postData.params["toremoveform"]) - assert.are.same("yes", body.postData.params["nottoremove"]) + assert.equal(200, status) + assert.equal("false", body.headers["x-to-replace"]) end) - - it("should remove parameters on multipart POST", function() - local response, status = http_client.post_multipart(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) + it("should not add new header if to be replaced header does not exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.falsy(body.postData.params["toremoveform"]) - assert.are.same("yes", body.postData.params["nottoremove"]) + assert.equal(200, status) + assert.falsy(body.headers["x-to-replace"]) + end) + it("should add new header if missing", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("b1", body.headers["x-added2"]) + end) + it("should not add new header if it already exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com", ["x-added3"] = "c1"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("c1", body.headers["x-added3"]) + end) + it("should append values to existing headers", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("a1, a2, a3", body.headers["x-added"]) + end) + it("should add new parameters on POST when query string key missing", function() + local response, status = http_client.post(STUB_POST_URL, {hello = "world"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("b1", body.queryString["p2"]) end) - it("should remove parameters on GET", function() - local response, status = http_client.get(STUB_GET_URL, {["toremovequery"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) + local response, status = http_client.get(STUB_GET_URL, {["toremovequery"] = "yes", ["nottoremove"] = "yes"}, {host = "test3.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) + assert.equal(200, status) assert.falsy(body.queryString["toremovequery"]) - assert.are.equal("yes", body.queryString["nottoremove"]) + assert.equal("yes", body.queryString["nottoremove"]) + end) + it("should replace parameters on GET", function() + local response, status = http_client.get(STUB_GET_URL, {["toreplacequery"] = "yes"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("no", body.queryString["toreplacequery"]) + end) + it("should not add new parameter if to be replaced parameters does not exist on GET", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.queryString["toreplacequery"]) + end) + it("should add parameters on GET if it does not exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("newvalue", body.queryString["query-added"]) + end) + it("should not add new parameter if to be added parameters already exist on GET", function() + local response, status = http_client.get(STUB_GET_URL, {["query-added"] = "oldvalue"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("oldvalue", body.queryString["query-added"]) + end) + it("should append parameters on GET", function() + local response, status = http_client.post(STUB_POST_URL.."/?q1=20", { hello = "world"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("a1", body.queryString["p1"][1]) + assert.equal("a2", body.queryString["p1"][2]) + assert.equal("20", body.queryString["q1"]) end) - end) - end) diff --git a/spec/plugins/response-ratelimiting/daos_spec.lua b/spec/plugins/response-ratelimiting/daos_spec.lua index 6c42b22ac18d..2ee33c4605d2 100644 --- a/spec/plugins/response-ratelimiting/daos_spec.lua +++ b/spec/plugins/response-ratelimiting/daos_spec.lua @@ -1,6 +1,6 @@ local spec_helper = require "spec.spec_helpers" local timestamp = require "kong.tools.timestamp" -local uuid = require "uuid" +local uuid = require "lua_uuid" local env = spec_helper.get_env() local dao_factory = env.dao_factory @@ -30,15 +30,14 @@ describe("Rate Limiting Metrics", function() local periods = timestamp.get_timestamps(current_timestamp) -- First increment - local ok, err = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") - assert.falsy(err) + local ok = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") assert.True(ok) -- First select for period, period_date in pairs(periods) do local metric, err = response_ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period, "video") assert.falsy(err) - assert.are.same({ + assert.same({ api_id = api_id, identifier = identifier, period = "video_"..period, @@ -48,15 +47,14 @@ describe("Rate Limiting Metrics", function() end -- Second increment - local ok, err = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") - assert.falsy(err) + local ok = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") assert.True(ok) -- Second select for period, period_date in pairs(periods) do local metric, err = response_ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period, "video") assert.falsy(err) - assert.are.same({ + assert.same({ api_id = api_id, identifier = identifier, period = "video_"..period, @@ -70,8 +68,7 @@ describe("Rate Limiting Metrics", function() periods = timestamp.get_timestamps(current_timestamp) -- Third increment - local ok, err = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") - assert.falsy(err) + local ok = response_ratelimiting_metrics:increment(api_id, identifier, current_timestamp, 1, "video") assert.True(ok) -- Third select with 1 second delay @@ -85,7 +82,7 @@ describe("Rate Limiting Metrics", function() local metric, err = response_ratelimiting_metrics:find_one(api_id, identifier, current_timestamp, period, "video") assert.falsy(err) - assert.are.same({ + assert.same({ api_id = api_id, identifier = identifier, period = "video_"..period, @@ -103,4 +100,4 @@ describe("Rate Limiting Metrics", function() assert.has_error(response_ratelimiting_metrics.find_by_keys, "ratelimiting_metrics:find_by_keys() not supported") end) -end) -- describe rate limiting metrics \ No newline at end of file +end) -- describe rate limiting metrics diff --git a/spec/plugins/ssl/access_spec.lua b/spec/plugins/ssl/access_spec.lua index 6c96e37055fd..afec839ff847 100644 --- a/spec/plugins/ssl/access_spec.lua +++ b/spec/plugins/ssl/access_spec.lua @@ -18,11 +18,13 @@ describe("SSL Plugin", function() api = { { name = "ssl-test", request_host = "ssl1.com", upstream_url = "http://mockbin.com" }, { name = "ssl-test2", request_host = "ssl2.com", upstream_url = "http://mockbin.com" }, - { name = "ssl-test3", request_host = "ssl3.com", upstream_url = "http://mockbin.com" } + { name = "ssl-test3", request_host = "ssl3.com", upstream_url = "http://mockbin.com" }, + { name = "ssl-test4", request_host = "ssl4.com", upstream_url = "http://mockbin.com" }, }, plugin = { { name = "ssl", config = { cert = ssl_fixtures.cert, key = ssl_fixtures.key }, __api = 1 }, - { name = "ssl", config = { cert = ssl_fixtures.cert, key = ssl_fixtures.key, only_https = true }, __api = 2 } + { name = "ssl", config = { cert = ssl_fixtures.cert, key = ssl_fixtures.key, only_https = true }, __api = 2 }, + { name = "ssl", config = { cert = ssl_fixtures.cert, key = ssl_fixtures.key, only_https = true, accept_http_if_already_terminated = true }, __api = 4 } } } @@ -86,6 +88,21 @@ describe("SSL Plugin", function() assert.are.equal(200, status) end) + it("should block request with https in x-forwarded-proto but no accept_if_already_terminated", function() + local _, status = http_client.get(STUB_GET_URL, nil, { host = "ssl2.com", ["x-forwarded-proto"] = "https" }) + assert.are.equal(426, status) + end) + + it("should not block request with https", function() + local _, status = http_client.get(STUB_GET_URL, nil, { host = "ssl4.com", ["x-forwarded-proto"] = "https" }) + assert.are.equal(200, status) + end) + + it("should not block request with https in x-forwarded-proto but accept_if_already_terminated", function() + local _, status = http_client.get(STUB_GET_URL, nil, { host = "ssl4.com", ["x-forwarded-proto"] = "https" }) + assert.are.equal(200, status) + end) + end) describe("should work with curl", function() diff --git a/spec/plugins/syslog/log_spec.lua b/spec/plugins/syslog/log_spec.lua new file mode 100644 index 000000000000..9a5db3f1f0d0 --- /dev/null +++ b/spec/plugins/syslog/log_spec.lua @@ -0,0 +1,85 @@ +local IO = require "kong.tools.io" +local utils = require "kong.tools.utils" +local cjson = require "cjson" +local spec_helper = require "spec.spec_helpers" +local http_client = require "kong.tools.http_client" + +local STUB_GET_URL = spec_helper.STUB_GET_URL + +describe("Syslog", function() + + setup(function() + spec_helper.prepare_db() + spec_helper.insert_fixtures { + api = { + { request_host = "logging.com", upstream_url = "http://mockbin.com" }, + { request_host = "logging2.com", upstream_url = "http://mockbin.com" }, + { request_host = "logging3.com", upstream_url = "http://mockbin.com" } + }, + plugin = { + { name = "syslog", config = { log_level = "info", successful_severity = "warning", + client_errors_severity = "warning", + server_errors_severity = "warning" }, __api = 1 }, + { name = "syslog", config = { log_level = "err", successful_severity = "warning", + client_errors_severity = "warning", + server_errors_severity = "warning" }, __api = 2 }, + { name = "syslog", config = { log_level = "warning", successful_severity = "warning", + client_errors_severity = "warning", + server_errors_severity = "warning" }, __api = 3 } + } + } + + spec_helper.start_kong() + end) + + teardown(function() + spec_helper.stop_kong() + end) + + local function do_test(host, expecting_same) + local uuid = utils.random_string() + + -- Making the request + local _, status = http_client.get(STUB_GET_URL, nil, + { host = host, sys_log_uuid = uuid } + ) + assert.are.equal(200, status) + local platform, code = IO.os_execute("/bin/uname") + if code ~= 0 then + platform, code = IO.os_execute("/usr/bin/uname") + end + if code == 0 and platform == "Darwin" then + local output, code = IO.os_execute("syslog -k Sender kong | tail -1") + assert.are.equal(0, code) + local message = {} + for w in string.gmatch(output,"{.*}") do + table.insert(message, w) + end + local log_message = cjson.decode(message[1]) + if expecting_same then + assert.are.same(uuid, log_message.request.headers.sys_log_uuid) + else + assert.are_not.same(uuid, log_message.request.headers.sys_log_uuid) + end + else + if expecting_same then + local output, code = IO.os_execute("find /var/log -type f -mmin -5 2>/dev/null | xargs grep -l "..uuid) + assert.are.equal(0, code) + assert.truthy(#output > 0) + end + end + end + + it("should log to syslog if log_level is lower", function() + do_test("logging.com", true) + end) + + it("should not log to syslog if the log_level is higher", function() + do_test("logging2.com", false) + end) + + it("should log to syslog if log_level is the same", function() + do_test("logging3.com", true) + end) + +end) diff --git a/spec/spec_helpers.lua b/spec/spec_helpers.lua index e535d3fab9a0..fbcd4a71f1c4 100644 --- a/spec/spec_helpers.lua +++ b/spec/spec_helpers.lua @@ -3,9 +3,11 @@ -- It supports other environments by passing a configuration file. local IO = require "kong.tools.io" +local dao = require "kong.tools.dao_loader" local Faker = require "kong.tools.faker" -local Migrations = require "kong.tools.migrations" +local config = require "kong.tools.config_loader" local Threads = require "llthreads2.ex" +local Migrations = require "kong.tools.migrations" require "kong.tools.ngx_stub" @@ -30,7 +32,8 @@ _M.envs = {} -- When dealing with another configuration file for a few tests, this allows to add -- a factory/migrations/faker that are environment-specific to this new config. function _M.add_env(conf_file) - local env_configuration, env_factory = IO.load_configuration_and_dao(conf_file) + local env_configuration = config.load(conf_file) + local env_factory = dao.load(env_configuration) _M.envs[conf_file] = { configuration = env_configuration, dao_factory = env_factory, @@ -95,8 +98,8 @@ function _M.find_port(exclude) end -- Finding an available port - local handle = io.popen([[(netstat -atn | awk '{printf "%s\n%s\n", $4, $4}' | grep -oE '[0-9]*$'; seq 32768 61000) | sort -n | uniq -u | head -n 1]]) - local result = handle:read("*a") + local handle = io.popen([[(netstat -atn | awk '{printf "%s\n%s\n", $4, $4}' | grep -oE '[0-9]*$'; seq 32768 61000) | sort -n | uniq -u]]) + local result = (handle:read("*a") .. "\n"):match("^(.-)\n") handle:close() -- Closing the opened servers @@ -107,18 +110,22 @@ function _M.find_port(exclude) return tonumber(result) end --- Starts a TCP server +-- Starts a TCP server, accepting a single connection and then closes -- @param `port` The port where the server will be listening to -- @return `thread` A thread object function _M.start_tcp_server(port, ...) local thread = Threads.new({ function(port) local socket = require "socket" - local server = assert(socket.bind("*", port)) + local server = assert(socket.tcp()) + assert(server:setoption('reuseaddr', true)) + assert(server:bind("*", port)) + assert(server:listen()) local client = server:accept() local line, err = client:receive() if not err then client:send(line .. "\n") end client:close() + server:close() return line end; }, port) @@ -127,14 +134,17 @@ function _M.start_tcp_server(port, ...) end --- Starts a HTTP server +-- Starts a HTTP server, accepting a single connection and then closes -- @param `port` The port where the server will be listening to -- @return `thread` A thread object function _M.start_http_server(port, ...) local thread = Threads.new({ function(port) local socket = require "socket" - local server = assert(socket.bind("*", port)) + local server = assert(socket.tcp()) + assert(server:setoption('reuseaddr', true)) + assert(server:bind("*", port)) + assert(server:listen()) local client = server:accept() local lines = {} @@ -153,11 +163,13 @@ function _M.start_http_server(port, ...) end if err then + server:close() error(err) end client:send("HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n") client:close() + server:close() return lines end; }, port) @@ -165,7 +177,7 @@ function _M.start_http_server(port, ...) return thread:start(...) end --- Starts a UDP server +-- Starts a UDP server, accepting a single connection and then closes -- @param `port` The port where the server will be listening to -- @return `thread` A thread object function _M.start_udp_server(port, ...) @@ -173,8 +185,10 @@ function _M.start_udp_server(port, ...) function(port) local socket = require("socket") local udp = socket.udp() + udp:setoption('reuseaddr', true) udp:setsockname("*", port) local data = udp:receivefrom() + udp:close() return data end; }, port) diff --git a/spec/unit/cli/utils_spec.lua b/spec/unit/cli/utils_spec.lua index 69e2a7a68caa..d08c15654605 100644 --- a/spec/unit/cli/utils_spec.lua +++ b/spec/unit/cli/utils_spec.lua @@ -1,12 +1,30 @@ local cutils = require "kong.cli.utils" -local spec_helper = require "spec.spec_helpers" +local socket = require "socket" describe("CLI Utils", function() - it("should check if a port is open", function() + pending("should check if a port is open", function() local PORT = 30000 - assert.falsy(cutils.is_port_open(PORT)) - spec_helper.start_tcp_server(PORT, true, true) - os.execute("sleep 0") -- Wait for the server to start - assert.truthy(cutils.is_port_open(PORT)) + local server, success, err + + -- Check a currently closed port + assert.truthy(cutils.is_port_bindable(PORT)) + + -- Check an open port, with SO_REUSEADDR set + server = socket.tcp() + assert(server:setoption('reuseaddr', true)) + assert(server:bind("*", PORT)) + assert(server:listen()) + success, err = cutils.is_port_bindable(PORT) + server:close() + assert.truthy(success, err) + + -- Check an open port, without SO_REUSEADDR set + server = socket.tcp() + assert(server:bind("*", PORT)) + assert(server:listen()) + success, err = cutils.is_port_bindable(PORT) + server:close() + assert.falsy(success, err) + end) end) diff --git a/spec/unit/core/resolver_spec.lua b/spec/unit/core/resolver_spec.lua new file mode 100644 index 000000000000..7d326537cdb3 --- /dev/null +++ b/spec/unit/core/resolver_spec.lua @@ -0,0 +1,278 @@ +local resolver = require "kong.core.resolver" + +-- Stubs +require "kong.tools.ngx_stub" + +local APIS_FIXTURES = { + -- request_host + {name = "mockbin", request_host = "mockbin.com", upstream_url = "http://mockbin.com"}, + {name = "mockbin", request_host = "mockbin-auth.com", upstream_url = "http://mockbin.com"}, + {name = "mockbin", request_host = "*.wildcard.com", upstream_url = "http://mockbin.com"}, + {name = "mockbin", request_host = "wildcard.*", upstream_url = "http://mockbin.com"}, + -- request_path + {name = "mockbin", request_path = "/mockbin", upstream_url = "http://mockbin.com"}, + {name = "mockbin", request_path = "/mockbin-with-dashes", upstream_url = "http://mockbin.com/some/path"}, + {name = "mockbin", request_path = "/some/deep/url", upstream_url = "http://mockbin.com"}, + -- + {name = "mockbin", request_path = "/strip", upstream_url = "http://mockbin.com/some/path/", strip_request_path = true}, + {name = "mockbin", request_path = "/strip-me", upstream_url = "http://mockbin.com/", strip_request_path = true}, + {name = "preserve-host", request_path = "/preserve-host", request_host = "preserve-host.com", upstream_url = "http://mockbin.com", preserve_host = true} +} + +_G.dao = { + apis = { + find_all = function() + return APIS_FIXTURES + end + } +} + +local apis_dics + +describe("Resolver", function() + describe("load_apis_in_memory()", function() + it("should retrieve all APIs in datastore and return them organized", function() + apis_dics = resolver.load_apis_in_memory() + assert.equal("table", type(apis_dics)) + assert.truthy(apis_dics.by_dns) + assert.truthy(apis_dics.request_path_arr) + assert.truthy(apis_dics.wildcard_dns_arr) + end) + it("should return a dictionary of APIs by request_host", function() + assert.equal("table", type(apis_dics.by_dns["mockbin.com"])) + assert.equal("table", type(apis_dics.by_dns["mockbin-auth.com"])) + end) + it("should return an array of APIs by request_path", function() + assert.equal("table", type(apis_dics.request_path_arr)) + assert.equal(6, #apis_dics.request_path_arr) + for _, item in ipairs(apis_dics.request_path_arr) do + assert.truthy(item.strip_request_path_pattern) + assert.truthy(item.request_path) + assert.truthy(item.api) + end + assert.equal("/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern) + assert.equal("/mockbin%-with%-dashes", apis_dics.request_path_arr[2].strip_request_path_pattern) + end) + it("should return an array of APIs with wildcard request_host", function() + assert.equal("table", type(apis_dics.wildcard_dns_arr)) + assert.equal(2, #apis_dics.wildcard_dns_arr) + for _, item in ipairs(apis_dics.wildcard_dns_arr) do + assert.truthy(item.api) + assert.truthy(item.pattern) + end + assert.equal("^.+%.wildcard%.com$", apis_dics.wildcard_dns_arr[1].pattern) + assert.equal("^wildcard%..+$", apis_dics.wildcard_dns_arr[2].pattern) + end) + end) + describe("strip_request_path()", function() + it("should strip the api's request_path from the requested URI", function() + assert.equal("/status/200", resolver.strip_request_path("/mockbin/status/200", apis_dics.request_path_arr[1].strip_request_path_pattern)) + assert.equal("/status/200", resolver.strip_request_path("/mockbin-with-dashes/status/200", apis_dics.request_path_arr[2].strip_request_path_pattern)) + assert.equal("/", resolver.strip_request_path("/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) + assert.equal("/", resolver.strip_request_path("/mockbin/", apis_dics.request_path_arr[1].strip_request_path_pattern)) + end) + it("should only strip the first pattern", function() + assert.equal("/mockbin/status/200/mockbin", resolver.strip_request_path("/mockbin/mockbin/status/200/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) + end) + it("should not add final slash", function() + assert.equal("hello", resolver.strip_request_path("hello", apis_dics.request_path_arr[3].strip_request_path_pattern, true)) + assert.equal("/hello", resolver.strip_request_path("hello", apis_dics.request_path_arr[3].strip_request_path_pattern, false)) + end) + end) + + -- Note: ngx.var.request_uri always adds a trailing slash even with a request without any + -- `curl kong:8000` will result in ngx.var.request_uri being '/' + describe("execute()", function() + local DEFAULT_REQUEST_URI = "/" + + it("should find an API by the request's simple Host header", function() + local api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "mockbin.com"}) + assert.same(APIS_FIXTURES[1], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("mockbin.com", upstream_host) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "mockbin-auth.com"}) + assert.same(APIS_FIXTURES[2], api) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = {"example.com", "mockbin.com"}}) + assert.same(APIS_FIXTURES[1], api) + end) + it("should find an API by the request's wildcard Host header", function() + local api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "foobar.wildcard.com"}) + assert.same(APIS_FIXTURES[3], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("mockbin.com", upstream_host) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "something.wildcard.com"}) + assert.same(APIS_FIXTURES[3], api) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "wildcard.com"}) + assert.same(APIS_FIXTURES[4], api) + + api = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "wildcard.fr"}) + assert.same(APIS_FIXTURES[4], api) + end) + it("should find an API by the request's URI (path component)", function() + local api, upstream_url, upstream_host = resolver.execute("/mockbin", {}) + assert.same(APIS_FIXTURES[5], api) + assert.equal("http://mockbin.com/mockbin", upstream_url) + assert.equal("mockbin.com", upstream_host) + + api = resolver.execute("/mockbin-with-dashes", {}) + assert.same(APIS_FIXTURES[6], api) + + api = resolver.execute("/some/deep/url", {}) + assert.same(APIS_FIXTURES[7], api) + + api = resolver.execute("/mockbin-with-dashes/and/some/uri", {}) + assert.same(APIS_FIXTURES[6], api) + end) + it("should return a 404 HTTP response if no API was found", function() + local responses = require "kong.tools.responses" + spy.on(responses, "send_HTTP_NOT_FOUND") + finally(function() + responses.send_HTTP_NOT_FOUND:revert() + end) + + -- non existant request_path + local api, upstream_url, upstream_host = resolver.execute("/inexistant-mockbin", {}) + assert.falsy(api) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(1) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {}, + request_path = "/inexistant-mockbin" + }) + assert.equal(404, ngx.status) + ngx.status = nil + + -- non-existant Host + api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "inexistant.com"}) + assert.falsy(api) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(2) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {"inexistant.com"}, + request_path = "/" + }) + assert.equal(404, ngx.status) + ngx.status = nil + + -- non-existant request_path with many Host headers + api, upstream_url, upstream_host = resolver.execute("/some-path", { + ["Host"] = {"nowhere.com", "inexistant.com"}, + ["X-Host-Override"] = "nowhere.fr" + }) + assert.falsy(api) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(3) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {"nowhere.com", "inexistant.com", "nowhere.fr"}, + request_path = "/some-path" + }) + assert.equal(404, ngx.status) + ngx.status = nil + + -- when a later part of the URI has a valid request_path + api, upstream_url, upstream_host = resolver.execute("/invalid-part/some-path", {}) + assert.falsy(api) + assert.falsy(upstream_url) + assert.falsy(upstream_host) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(4) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called_with({ + message = "API not found with these values", + request_host = {}, + request_path = "/invalid-part/some-path" + }) + assert.equal(404, ngx.status) + ngx.status = nil + end) + it("should strip_request_path", function() + local api = resolver.execute("/strip", {}) + assert.same(APIS_FIXTURES[8], api) + + -- strip when contains pattern characters + local api, upstream_url, upstream_host = resolver.execute("/strip-me/hello/world", {}) + assert.same(APIS_FIXTURES[9], api) + assert.equal("http://mockbin.com/hello/world", upstream_url) + assert.equal("mockbin.com", upstream_host) + + -- only strip first match of request_uri + api, upstream_url = resolver.execute("/strip-me/strip-me/hello/world", {}) + assert.same(APIS_FIXTURES[9], api) + assert.equal("http://mockbin.com/strip-me/hello/world", upstream_url) + end) + it("should preserve_host", function() + local api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, {["Host"] = "preserve-host.com"}) + assert.same(APIS_FIXTURES[10], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("preserve-host.com", upstream_host) + + api, upstream_url, upstream_host = resolver.execute(DEFAULT_REQUEST_URI, { + ["Host"] = {"inexistant.com", "preserve-host.com"}, + ["X-Host-Override"] = "hello.com" + }) + assert.same(APIS_FIXTURES[10], api) + assert.equal("http://mockbin.com/", upstream_url) + assert.equal("preserve-host.com", upstream_host) + + -- No host given to this request, we extract if from the configured upstream_url + api, upstream_url, upstream_host = resolver.execute("/preserve-host", {}) + assert.same(APIS_FIXTURES[10], api) + assert.equal("http://mockbin.com/preserve-host", upstream_url) + assert.equal("mockbin.com", upstream_host) + end) + it("should not decode percent-encoded values in URI", function() + -- they should be forwarded as-is + local api, upstream_url = resolver.execute("/mockbin/path%2Fwith%2Fencoded/values", {}) + assert.same(APIS_FIXTURES[5], api) + assert.equal("http://mockbin.com/mockbin/path%2Fwith%2Fencoded/values", upstream_url) + + api, upstream_url = resolver.execute("/strip-me/path%2Fwith%2Fencoded/values", {}) + assert.same(APIS_FIXTURES[9], api) + assert.equal("http://mockbin.com/path%2Fwith%2Fencoded/values", upstream_url) + end) + it("should not recognized request_path if percent-encoded", function() + local responses = require "kong.tools.responses" + spy.on(responses, "send_HTTP_NOT_FOUND") + finally(function() + responses.send_HTTP_NOT_FOUND:revert() + end) + + local api = resolver.execute("/some/deep%2Furl", {}) + assert.falsy(api) + assert.spy(responses.send_HTTP_NOT_FOUND).was_called(1) + assert.equal(404, ngx.status) + ngx.status = nil + end) + it("should have or not have a trailing slash depending on the request URI", function() + local api, upstream_url = resolver.execute("/strip/", {}) + assert.same(APIS_FIXTURES[8], api) + assert.equal("http://mockbin.com/some/path/", upstream_url) + + api, upstream_url = resolver.execute("/strip", {}) + assert.same(APIS_FIXTURES[8], api) + assert.equal("http://mockbin.com/some/path", upstream_url) + + api, upstream_url = resolver.execute("/mockbin-with-dashes", {}) + assert.same(APIS_FIXTURES[6], api) + assert.equal("http://mockbin.com/some/path/mockbin-with-dashes", upstream_url) + + api, upstream_url = resolver.execute("/mockbin-with-dashes/", {}) + assert.same(APIS_FIXTURES[6], api) + assert.equal("http://mockbin.com/some/path/mockbin-with-dashes/", upstream_url) + end) + it("should strip the querystring out of the URI", function() + -- it will be re-inserted by core.handler just before proxying, once all plugins have been run and eventually modified it + local api, upstream_url = resolver.execute("/?hello=world&foo=bar", {["Host"] = "mockbin.com"}) + assert.same(APIS_FIXTURES[1], api) + assert.equal("http://mockbin.com/", upstream_url) + end) + end) +end) diff --git a/spec/unit/dao/cassandra/factory_spec.lua b/spec/unit/dao/cassandra/factory_spec.lua new file mode 100644 index 000000000000..bd8fe8cb4f8e --- /dev/null +++ b/spec/unit/dao/cassandra/factory_spec.lua @@ -0,0 +1,98 @@ +local Factory = require "kong.dao.cassandra.factory" +local spec_helpers = require "spec.spec_helpers" +local env = spec_helpers.get_env() +local default_dao_properties = env.configuration.databases_available.cassandra + +describe("Cassadra factory", function() + describe("get_session_options()", function() + local dao_properties + before_each(function() + dao_properties = default_dao_properties + end) + it("should reflect the default config", function() + local factory = Factory(dao_properties) + assert.truthy(factory) + local options = factory:get_session_options() + assert.truthy(options) + assert.same({ + shm = "cassandra", + prepared_shm = "cassandra_prepared", + contact_points = dao_properties.contact_points, + keyspace = dao_properties.keyspace, + query_options = { + prepare = true + }, + ssl_options = { + enabled = false, + verify = false + } + }, options) + end) + it("should accept some overriden properties", function() + dao_properties.contact_points = {"127.0.0.1:9042"} + dao_properties.keyspace = "my_keyspace" + + local factory = Factory(dao_properties) + assert.truthy(factory) + local options = factory:get_session_options() + assert.truthy(options) + assert.same({ + shm = "cassandra", + prepared_shm = "cassandra_prepared", + contact_points = {"127.0.0.1:9042"}, + keyspace = "my_keyspace", + query_options = { + prepare = true + }, + ssl_options = { + enabled = false, + verify = false + } + }, options) + end) + it("should accept SSL properties", function() + dao_properties.contact_points = {"127.0.0.1:9042"} + dao_properties.keyspace = "my_keyspace" + dao_properties.ssl.enabled = false + dao_properties.ssl.verify = true + + local factory = Factory(dao_properties) + assert.truthy(factory) + local options = factory:get_session_options() + assert.truthy(options) + assert.same({ + shm = "cassandra", + prepared_shm = "cassandra_prepared", + contact_points = {"127.0.0.1:9042"}, + keyspace = "my_keyspace", + query_options = { + prepare = true + }, + ssl_options = { + enabled = false, + verify = true + } + }, options) + + -- TEST 2 + dao_properties.ssl.enabled = true + factory = Factory(dao_properties) + assert.truthy(factory) + options = factory:get_session_options() + assert.truthy(options) + assert.same({ + shm = "cassandra", + prepared_shm = "cassandra_prepared", + contact_points = {"127.0.0.1:9042"}, + keyspace = "my_keyspace", + query_options = { + prepare = true + }, + ssl_options = { + enabled = true, + verify = true + } + }, options) + end) + end) +end) diff --git a/spec/unit/dao/cassandra/migrations_spec.lua b/spec/unit/dao/cassandra/migrations_spec.lua new file mode 100644 index 000000000000..5c0b0e660f85 --- /dev/null +++ b/spec/unit/dao/cassandra/migrations_spec.lua @@ -0,0 +1,57 @@ +local stringy = require "stringy" +local spec_helper = require "spec.spec_helpers" +local migrations = require "kong.dao.cassandra.schema.migrations" +local first_migration = migrations[1] + +local migrations_stub = { + execute_queries = function(self, queries) + return queries + end +} + +local function strip_query(str) + str = stringy.split(str, ";")[1] + str = str:gsub("\n", " "):gsub("%s+", " ") + return stringy.strip(str) +end + +local test_config = spec_helper.get_env().configuration +local dao_config = test_config.dao_config +dao_config.keyspace = "kong" + +describe("Cassandra migrations", function() + describe("Keyspace options", function() + it("should default to SimpleStrategy class with replication_factor of 1", function() + local queries = first_migration.up(dao_config, migrations_stub) + local keyspace_query = strip_query(queries) + assert.equal("CREATE KEYSPACE IF NOT EXISTS \"kong\" WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 1}", keyspace_query) + end) + it("should be possible to set a custom replication_factor", function() + dao_config.replication_factor = 2 + local queries = first_migration.up(dao_config, migrations_stub) + local keyspace_query = strip_query(queries) + assert.equal("CREATE KEYSPACE IF NOT EXISTS \"kong\" WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor': 2}", keyspace_query) + end) + it("should accept NetworkTopologyStrategy", function() + dao_config.replication_strategy = "NetworkTopologyStrategy" + local queries = first_migration.up(dao_config, migrations_stub) + local keyspace_query = strip_query(queries) + assert.equal("CREATE KEYSPACE IF NOT EXISTS \"kong\" WITH REPLICATION = {'class': 'NetworkTopologyStrategy'}", keyspace_query) + end) + it("should be possible to set data centers for NetworkTopologyStrategy", function() + dao_config.data_centers = { + dc1 = 2, + dc2 = 3 + } + local queries = first_migration.up(dao_config, migrations_stub) + local keyspace_query = strip_query(queries) + assert.equal("CREATE KEYSPACE IF NOT EXISTS \"kong\" WITH REPLICATION = {'class': 'NetworkTopologyStrategy', 'dc1': 2, 'dc2': 3}", keyspace_query) + end) + it("should return an error if an invalid replication_strategy is given", function() + dao_config.replication_strategy = "foo" + local err = first_migration.up(dao_config, migrations_stub) + assert.truthy(err) + assert.equal("invalid replication_strategy class", err) + end) + end) +end) diff --git a/spec/unit/dao/cassandra/query_builder_spec.lua b/spec/unit/dao/cassandra/query_builder_spec.lua index 63faba9fe0b4..d7f14f1a720b 100644 --- a/spec/unit/dao/cassandra/query_builder_spec.lua +++ b/spec/unit/dao/cassandra/query_builder_spec.lua @@ -88,6 +88,20 @@ describe("Query Builder", function() end) + describe("COUNT", function() + + it("should build a COUNT query", function() + local q = builder.count("apis") + assert.equal("SELECT COUNT(*) FROM apis", q) + end) + + it("should build a COUNT query with WHERE keys", function() + local q = builder.count("apis", {id="123", name="mockbin"}) + assert.equal("SELECT COUNT(*) FROM apis WHERE name = ? AND id = ? ALLOW FILTERING", q) + end) + + end) + describe("INSERT", function() it("should build an INSERT query", function() diff --git a/spec/unit/resolver/access_spec.lua b/spec/unit/resolver/access_spec.lua deleted file mode 100644 index 6ee00dbabc1b..000000000000 --- a/spec/unit/resolver/access_spec.lua +++ /dev/null @@ -1,120 +0,0 @@ -local resolver_access = require "kong.resolver.access" - --- Stubs -require "kong.tools.ngx_stub" -local APIS_FIXTURES = { - {name = "mockbin", request_host = "mockbin.com", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_host = "mockbin-auth.com", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_host = "*.wildcard.com", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_host = "wildcard.*", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_path = "/mockbin", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_path = "/mockbin-with-dashes", upstream_url = "http://mockbin.com"}, - {name = "mockbin", request_path = "/some/deep/url", upstream_url = "http://mockbin.com"} -} -_G.dao = { - apis = { - find_all = function() - return APIS_FIXTURES - end - } -} - -local apis_dics - -describe("Resolver Access", function() - describe("load_apis_in_memory()", function() - it("should retrieve all APIs in datastore and return them organized", function() - apis_dics = resolver_access.load_apis_in_memory() - assert.equal("table", type(apis_dics)) - assert.truthy(apis_dics.by_dns) - assert.truthy(apis_dics.request_path_arr) - assert.truthy(apis_dics.wildcard_dns_arr) - end) - it("should return a dictionary of APIs by request_host", function() - assert.equal("table", type(apis_dics.by_dns["mockbin.com"])) - assert.equal("table", type(apis_dics.by_dns["mockbin-auth.com"])) - end) - it("should return an array of APIs by request_path", function() - assert.equal("table", type(apis_dics.request_path_arr)) - assert.equal(3, #apis_dics.request_path_arr) - for _, item in ipairs(apis_dics.request_path_arr) do - assert.truthy(item.strip_request_path_pattern) - assert.truthy(item.request_path) - assert.truthy(item.api) - end - assert.equal("/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern) - assert.equal("/mockbin%-with%-dashes", apis_dics.request_path_arr[2].strip_request_path_pattern) - end) - it("should return an array of APIs with wildcard request_host", function() - assert.equal("table", type(apis_dics.wildcard_dns_arr)) - assert.equal(2, #apis_dics.wildcard_dns_arr) - for _, item in ipairs(apis_dics.wildcard_dns_arr) do - assert.truthy(item.api) - assert.truthy(item.pattern) - end - assert.equal("^.+%.wildcard%.com$", apis_dics.wildcard_dns_arr[1].pattern) - assert.equal("^wildcard%..+$", apis_dics.wildcard_dns_arr[2].pattern) - end) - end) - describe("find_api_by_request_path()", function() - it("should return nil when no matching API for that URI", function() - local api = resolver_access.find_api_by_request_path("/", apis_dics.request_path_arr) - assert.falsy(api) - end) - it("should return the API for a matching URI", function() - local api = resolver_access.find_api_by_request_path("/mockbin", apis_dics.request_path_arr) - assert.same(APIS_FIXTURES[5], api) - - api = resolver_access.find_api_by_request_path("/mockbin-with-dashes", apis_dics.request_path_arr) - assert.same(APIS_FIXTURES[6], api) - - api = resolver_access.find_api_by_request_path("/mockbin-with-dashes/and/some/uri", apis_dics.request_path_arr) - assert.same(APIS_FIXTURES[6], api) - - api = resolver_access.find_api_by_request_path("/dashes-mockbin", apis_dics.request_path_arr) - assert.falsy(api) - - api = resolver_access.find_api_by_request_path("/some/deep/url", apis_dics.request_path_arr) - assert.same(APIS_FIXTURES[7], api) - end) - end) - describe("find_api_by_request_host()", function() - it("should return nil and a list of all the Host headers in the request when no API was found", function() - local api, all_hosts = resolver_access.find_api_by_request_host({ - Host = "foo.com", - ["X-Host-Override"] = {"bar.com", "hello.com"} - }, apis_dics) - assert.falsy(api) - assert.same({"foo.com", "bar.com", "hello.com"}, all_hosts) - end) - it("should return an API when one of the Host headers matches", function() - local api = resolver_access.find_api_by_request_host({Host = "mockbin.com"}, apis_dics) - assert.same(APIS_FIXTURES[1], api) - - api = resolver_access.find_api_by_request_host({Host = "mockbin-auth.com"}, apis_dics) - assert.same(APIS_FIXTURES[2], api) - end) - it("should return an API when one of the Host headers matches a wildcard dns", function() - local api = resolver_access.find_api_by_request_host({Host = "wildcard.com"}, apis_dics) - assert.same(APIS_FIXTURES[4], api) - api = resolver_access.find_api_by_request_host({Host = "wildcard.fr"}, apis_dics) - assert.same(APIS_FIXTURES[4], api) - - api = resolver_access.find_api_by_request_host({Host = "foobar.wildcard.com"}, apis_dics) - assert.same(APIS_FIXTURES[3], api) - api = resolver_access.find_api_by_request_host({Host = "barfoo.wildcard.com"}, apis_dics) - assert.same(APIS_FIXTURES[3], api) - end) - end) - describe("strip_request_path()", function() - it("should strip the api's request_path from the requested URI", function() - assert.equal("/status/200", resolver_access.strip_request_path("/mockbin/status/200", apis_dics.request_path_arr[1].strip_request_path_pattern)) - assert.equal("/status/200", resolver_access.strip_request_path("/mockbin-with-dashes/status/200", apis_dics.request_path_arr[2].strip_request_path_pattern)) - assert.equal("/", resolver_access.strip_request_path("/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) - assert.equal("/", resolver_access.strip_request_path("/mockbin/", apis_dics.request_path_arr[1].strip_request_path_pattern)) - end) - it("should only strip the first pattern", function() - assert.equal("/mockbin/status/200/mockbin", resolver_access.strip_request_path("/mockbin/mockbin/status/200/mockbin", apis_dics.request_path_arr[1].strip_request_path_pattern)) - end) - end) -end) diff --git a/spec/unit/statics_spec.lua b/spec/unit/statics_spec.lua index bc773b8fc776..68b1689cce53 100644 --- a/spec/unit/statics_spec.lua +++ b/spec/unit/statics_spec.lua @@ -1,4 +1,3 @@ -local spec_helper = require "spec.spec_helpers" local constants = require "kong.constants" local stringy = require "stringy" local IO = require "kong.tools.io" @@ -33,236 +32,4 @@ describe("Static files", function() assert.has_error(function() local _ = constants.DATABASE_ERROR_TYPES.ThIs_TyPe_DoEs_NoT_ExIsT end) end) end) - - describe("Configuration", function() - - it("should equal to this template to make sure no errors are pushed in the default config", function() - local configuration = IO.read_file(spec_helper.DEFAULT_CONF_FILE) - - assert.are.same([[ -## Available plugins on this server -plugins_available: - - ssl - - jwt - - acl - - cors - - oauth2 - - tcp-log - - udp-log - - file-log - - http-log - - key-auth - - hmac-auth - - basic-auth - - ip-restriction - - mashape-analytics - - request-transformer - - response-transformer - - request-size-limiting - - rate-limiting - - response-ratelimiting - -## The Kong working directory -## (Make sure you have read and write permissions) -nginx_working_dir: /usr/local/kong/ - -## Port configuration -proxy_port: 8000 -proxy_ssl_port: 8443 -admin_api_port: 8001 - -## Secondary port configuration -dnsmasq_port: 8053 - -## Specify the DAO to use -database: cassandra - -## Databases configuration -databases_available: - cassandra: - properties: - contact_points: - - "localhost:9042" - timeout: 1000 - keyspace: kong - keepalive: 60000 # in milliseconds - # ssl: false - # ssl_verify: false - # ssl_certificate: "/path/to/cluster-ca-certificate.pem" - # user: cassandra - # password: cassandra - -## Cassandra cache configuration -database_cache_expiration: 5 # in seconds - -## SSL Settings -## (Uncomment the two properties below to set your own certificate) -# ssl_cert_path: /path/to/certificate.pem -# ssl_key_path: /path/to/certificate.key - -## Sends anonymous error reports -send_anonymous_reports: true - -## In-memory cache size (MB) -memory_cache_size: 128 - -## Nginx configuration -nginx: | - worker_processes auto; - error_log logs/error.log error; - daemon on; - - worker_rlimit_nofile {{auto_worker_rlimit_nofile}}; - - env KONG_CONF; - env PATH; - - events { - worker_connections {{auto_worker_connections}}; - multi_accept on; - } - - http { - resolver {{dns_resolver}} ipv6=off; - charset UTF-8; - - access_log logs/access.log; - access_log off; - - # Timeouts - keepalive_timeout 60s; - client_header_timeout 60s; - client_body_timeout 60s; - send_timeout 60s; - - # Proxy Settings - proxy_buffer_size 128k; - proxy_buffers 4 256k; - proxy_busy_buffers_size 256k; - proxy_ssl_server_name on; - - # IP Address - real_ip_header X-Forwarded-For; - set_real_ip_from 0.0.0.0/0; - real_ip_recursive on; - - # Other Settings - client_max_body_size 0; - underscores_in_headers on; - reset_timedout_connection on; - tcp_nopush on; - - ################################################ - # The following code is required to run Kong # - # Please be careful if you'd like to change it # - ################################################ - - # Lua Settings - lua_package_path ';;'; - lua_code_cache on; - lua_max_running_timers 4096; - lua_max_pending_timers 16384; - lua_shared_dict locks 100k; - lua_shared_dict cache {{memory_cache_size}}m; - lua_socket_log_errors off; - {{lua_ssl_trusted_certificate}} - - init_by_lua ' - kong = require "kong" - local status, err = pcall(kong.init) - if not status then - ngx.log(ngx.ERR, "Startup error: "..err) - os.exit(1) - end - '; - - init_worker_by_lua 'kong.exec_plugins_init_worker()'; - - server { - server_name _; - listen {{proxy_port}}; - listen {{proxy_ssl_port}} ssl; - - ssl_certificate_by_lua 'kong.exec_plugins_certificate()'; - - ssl_certificate {{ssl_cert}}; - ssl_certificate_key {{ssl_key}}; - ssl_protocols TLSv1 TLSv1.1 TLSv1.2;# omit SSLv3 because of POODLE (CVE-2014-3566) - - location / { - default_type 'text/plain'; - - # These properties will be used later by proxy_pass - set $backend_host nil; - set $backend_url nil; - - # Authenticate the user and load the API info - access_by_lua 'kong.exec_plugins_access()'; - - # Proxy the request - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_set_header Host $backend_host; - proxy_pass $backend_url; - proxy_pass_header Server; - - # Add additional response headers - header_filter_by_lua 'kong.exec_plugins_header_filter()'; - - # Change the response body - body_filter_by_lua 'kong.exec_plugins_body_filter()'; - - # Log the request - log_by_lua 'kong.exec_plugins_log()'; - } - - location /robots.txt { - return 200 'User-agent: *\nDisallow: /'; - } - - error_page 500 /500.html; - location = /500.html { - internal; - content_by_lua ' - local responses = require "kong.tools.responses" - responses.send_HTTP_INTERNAL_SERVER_ERROR("An unexpected error occurred") - '; - } - } - - server { - listen {{admin_api_port}}; - - location / { - default_type application/json; - content_by_lua ' - ngx.header["Access-Control-Allow-Origin"] = "*" - if ngx.req.get_method() == "OPTIONS" then - ngx.header["Access-Control-Allow-Methods"] = "GET,HEAD,PUT,PATCH,POST,DELETE" - ngx.header["Access-Control-Allow-Headers"] = "Content-Type" - ngx.exit(204) - end - local lapis = require "lapis" - lapis.serve("kong.api.app") - '; - } - - location /nginx_status { - internal; - stub_status; - } - - location /robots.txt { - return 200 'User-agent: *\nDisallow: /'; - } - - # Do not remove, additional configuration placeholder for some plugins - # {{additional_configuration}} - } - } -]], configuration) - end) - - end) end) diff --git a/spec/unit/tools/config_loader_spec.lua b/spec/unit/tools/config_loader_spec.lua new file mode 100644 index 000000000000..e2ec4d842cea --- /dev/null +++ b/spec/unit/tools/config_loader_spec.lua @@ -0,0 +1,85 @@ +local IO = require "kong.tools.io" +local yaml = require "yaml" +local spec_helper = require "spec.spec_helpers" +local config = require "kong.tools.config_loader" + +local TEST_CONF_PATH = spec_helper.get_env().conf_file + +describe("Configuration validation", function() + it("should validate the default configuration", function() + local test_config = yaml.load(IO.read_file(TEST_CONF_PATH)) + local ok, errors = config.validate(test_config) + assert.True(ok) + assert.falsy(errors) + end) + it("should populate defaults", function() + local conf = {} + local ok, errors = config.validate(conf) + assert.True(ok) + assert.falsy(errors) + + assert.truthy(conf.plugins_available) + assert.truthy(conf.admin_api_port) + assert.truthy(conf.proxy_port) + assert.truthy(conf.database) + assert.truthy(conf.databases_available) + assert.equal("table", type(conf.databases_available)) + + local function check_defaults(conf, conf_defaults) + for k, v in pairs(conf) do + if conf_defaults[k].type == "table" then + check_defaults(v, conf_defaults[k].content) + end + if conf_defaults[k].default ~= nil then + assert.equal(conf_defaults[k].default, v) + end + end + end + + check_defaults(conf, require("kong.tools.config_defaults")) + end) + it("should validate various types", function() + local ok, errors = config.validate({ + proxy_port = "string", + database = 666, + databases_available = { + cassandra = { + contact_points = "127.0.0.1", + ssl = { + enabled = "false" + } + } + } + }) + assert.False(ok) + assert.truthy(errors) + assert.equal("must be a number", errors.proxy_port) + assert.equal("must be a string", errors.database) + assert.equal("must be a array", errors["databases_available.cassandra.contact_points"]) + assert.equal("must be a boolean", errors["databases_available.cassandra.ssl.enabled"]) + assert.falsy(errors.ssl_cert_path) + assert.falsy(errors.ssl_key_path) + end) + it("should check for minimum allowed value if is a number", function() + local ok, errors = config.validate({memory_cache_size = 16}) + assert.False(ok) + assert.equal("must be greater than 32", errors.memory_cache_size) + end) + it("should check that the value is contained in `enum`", function() + local ok, errors = config.validate({ + databases_available = { + cassandra = { + replication_strategy = "foo" + } + } + }) + assert.False(ok) + assert.equal("must be one of: 'SimpleStrategy, NetworkTopologyStrategy'", errors["databases_available.cassandra.replication_strategy"]) + end) + it("should validate the selected database property", function() + local ok, errors = config.validate({database = "foo"}) + assert.False(ok) + assert.equal("foo is not listed in databases_available", errors.database) + end) +end) + diff --git a/spec/unit/tools/faker_spec.lua b/spec/unit/tools/faker_spec.lua index 9a420783085d..47db1bcf1683 100644 --- a/spec/unit/tools/faker_spec.lua +++ b/spec/unit/tools/faker_spec.lua @@ -1,4 +1,4 @@ -local uuid = require "uuid" +local uuid = require "lua_uuid" local Faker = require "kong.tools.faker" local DaoError = require "kong.dao.error" diff --git a/spec/unit/tools/responses_spec.lua b/spec/unit/tools/responses_spec.lua index 476fd990f26e..6a8af13a913d 100644 --- a/spec/unit/tools/responses_spec.lua +++ b/spec/unit/tools/responses_spec.lua @@ -65,14 +65,12 @@ describe("Responses", function() end end) - it("should call `ngx.log` and set `stop_phases` if and only if a 500 status code range was given", function() + it("should call `ngx.log` if and only if a 500 status code range was given", function() responses.send_HTTP_BAD_REQUEST() assert.stub(ngx.log).was_not_called() - assert.falsy(ngx.ctx.stop_phases) responses.send_HTTP_INTERNAL_SERVER_ERROR() assert.stub(ngx.log).was_not_called() - assert.True(ngx.ctx.stop_phases) responses.send_HTTP_INTERNAL_SERVER_ERROR("error") assert.stub(ngx.log).was_called() diff --git a/spec/unit/tools/utils_spec.lua b/spec/unit/tools/utils_spec.lua index 92e997ca1eb8..3c46f39cdd52 100644 --- a/spec/unit/tools/utils_spec.lua +++ b/spec/unit/tools/utils_spec.lua @@ -2,17 +2,111 @@ local utils = require "kong.tools.utils" describe("Utils", function() - describe("strings", function() - local first = utils.random_string() - assert.truthy(first) - assert.falsy(first:find("-")) - local second = utils.random_string() - assert.falsy(first == second) - end) + describe("string", function() + describe("random_string()", function() + it("should return a random string", function() + local first = utils.random_string() + assert.truthy(first) + assert.falsy(first:find("-")) + + local second = utils.random_string() + assert.not_equal(first, second) + end) + end) - describe("tables", function() - describe("#table_size()", function() + describe("encode_args()", function() + it("should encode a Lua table to a querystring", function() + local str = utils.encode_args { + foo = "bar", + hello = "world" + } + assert.equal("foo=bar&hello=world", str) + end) + it("should encode multi-value query args", function() + local str = utils.encode_args { + foo = {"bar", "zoo"}, + hello = "world" + } + assert.equal("foo=bar&foo=zoo&hello=world", str) + end) + it("should percent-encode given values", function() + local str = utils.encode_args { + encode = {"abc|def", ",$@|`"} + } + assert.equal("encode=abc%7cdef&encode=%2c%24%40%7c%60", str) + end) + it("should percent-encode given query args keys", function() + local str = utils.encode_args { + ["hello world"] = "foo" + } + assert.equal("hello%20world=foo", str) + end) + it("should support Lua numbers", function() + local str = utils.encode_args { + a = 1, + b = 2 + } + assert.equal("a=1&b=2", str) + end) + it("should support a boolean argument", function() + local str = utils.encode_args { + a = true, + b = 1 + } + assert.equal("a&b=1", str) + end) + it("should ignore nil and false values", function() + local str = utils.encode_args { + a = nil, + b = false + } + assert.equal("", str) + end) + it("should encode complex query args", function() + local str = utils.encode_args { + multiple = {"hello, world"}, + hello = "world", + ignore = false, + ["multiple values"] = true + } + assert.equal("hello=world&multiple=hello%2c%20world&multiple%20values", str) + end) + it("should not percent-encode if given a `raw` option", function() + -- this is useful for kong.tools.http_client + local str = utils.encode_args({ + ["hello world"] = "foo, bar" + }, true) + assert.equal("hello world=foo, bar", str) + end) + -- while this method's purpose is to mimic 100% the behavior of ngx.encode_args, + -- it is also used by Kong specs' http_client, to encode both querystrings and *bodies*. + -- Hence, a `raw` parameter allows encoding for bodies. + describe("raw", function() + it("should not percent-encode values", function() + local str = utils.encode_args({ + foo = "hello world" + }, true) + assert.equal("foo=hello world", str) + end) + it("should not percent-encode keys", function() + local str = utils.encode_args({ + ["hello world"] = "foo" + }, true) + assert.equal("hello world=foo", str) + end) + it("should plainly include true and false values", function() + local str = utils.encode_args({ + a = true, + b = false + }, true) + assert.equal("a=true&b=false", str) + end) + end) + end) + end) + describe("table", function() + describe("table_size()", function() it("should return the size of a table", function() assert.are.same(0, utils.table_size(nil)) assert.are.same(0, utils.table_size({})) @@ -20,44 +114,39 @@ describe("Utils", function() assert.are.same(2, utils.table_size({ foo = "bar", bar = "baz" })) assert.are.same(2, utils.table_size({ "foo", "bar" })) end) - end) - describe("#table_contains()", function() - + describe("table_contains()", function() it("should return false if a value is not contained in a nil table", function() assert.False(utils.table_contains(nil, "foo")) end) - it("should return true if a value is contained in a table", function() local t = { foo = "hello", bar = "world" } assert.True(utils.table_contains(t, "hello")) end) - it("should return false if a value is not contained in a table", function() local t = { foo = "hello", bar = "world" } assert.False(utils.table_contains(t, "foo")) end) - end) - describe("#is_array()", function() - + describe("is_array()", function() it("should know when an array ", function() assert.True(utils.is_array({ "a", "b", "c", "d" })) assert.True(utils.is_array({ ["1"] = "a", ["2"] = "b", ["3"] = "c", ["4"] = "d" })) assert.False(utils.is_array({ "a", "b", "c", foo = "d" })) + assert.False(utils.is_array()) + assert.False(utils.is_array(false)) + assert.False(utils.is_array(true)) end) - end) - describe("#add_error()", function() + describe("add_error()", function() local add_error = utils.add_error it("should create a table if given `errors` is nil", function() assert.same({hello = "world"}, add_error(nil, "hello", "world")) end) - it("should add a key/value when the key does not exists", function() local errors = {hello = "world"} assert.same({ @@ -65,10 +154,8 @@ describe("Utils", function() foo = "bar" }, add_error(errors, "foo", "bar")) end) - it("should transform previous values to a list if the same key is given again", function() - local e = nil - + local e = nil -- initialize for luacheck e = add_error(e, "key1", "value1") e = add_error(e, "key2", "value2") assert.same({key1 = "value1", key2 = "value2"}, e) @@ -82,10 +169,8 @@ describe("Utils", function() e = add_error(e, "key2", "value7") assert.same({key1 = {"value1", "value3", "value4", "value5", "value6"}, key2 = {"value2", "value7"}}, e) end) - it("should also list tables pushed as errors", function() - local e = nil - + local e = nil -- initialize for luacheck e = add_error(e, "key1", "value1") e = add_error(e, "key2", "value2") e = add_error(e, "key1", "value3") @@ -100,11 +185,9 @@ describe("Utils", function() keyO = {{message = "some error"}, {message = "another"}} }, e) end) - end) - describe("#load_module_if_exists()", function() - + describe("load_module_if_exists()", function() it("should return false if the module does not exist", function() local loaded, mod assert.has_no.errors(function() @@ -113,7 +196,6 @@ describe("Utils", function() assert.False(loaded) assert.falsy(mod) end) - it("should throw an error if the module is invalid", function() local loaded, mod assert.has.errors(function() @@ -122,7 +204,6 @@ describe("Utils", function() assert.falsy(loaded) assert.falsy(mod) end) - it("should load a module if it was found and valid", function() local loaded, mod assert.has_no.errors(function() @@ -132,7 +213,6 @@ describe("Utils", function() assert.truthy(mod) assert.are.same("All your base are belong to us.", mod.exposed) end) - end) end) end)