From 1bd42a7b8cb25d18946b62d581ed74b8c9d0997e Mon Sep 17 00:00:00 2001 From: thefosk Date: Mon, 29 Aug 2016 16:16:19 -0700 Subject: [PATCH] Extending the plugins APIs to support easier querying --- kong/plugins/basic-auth/api.lua | 10 ++- kong/plugins/hmac-auth/api.lua | 17 +++- kong/plugins/jwt/api.lua | 16 +++- kong/plugins/key-auth/api.lua | 10 ++- kong/plugins/oauth2/api.lua | 47 +++++++---- spec/03-plugins/01-basic-auth/02-api_spec.lua | 32 ++++++- spec/03-plugins/02-key-auth/01-api_spec.lua | 28 ++++++- spec/03-plugins/06-jwt/02-api_spec.lua | 27 +++++- spec/03-plugins/06-jwt/03-access_spec.lua | 7 +- spec/03-plugins/09-hmac-auth/02-api_spec.lua | 28 ++++++- spec/03-plugins/99-oauth2/02-api_spec.lua | 83 +++++++++++++++---- 11 files changed, 245 insertions(+), 60 deletions(-) diff --git a/kong/plugins/basic-auth/api.lua b/kong/plugins/basic-auth/api.lua index f7bb0d249f4e..423957a2ccc5 100644 --- a/kong/plugins/basic-auth/api.lua +++ b/kong/plugins/basic-auth/api.lua @@ -1,4 +1,5 @@ local crud = require "kong.api.crud_helpers" +local utils = require "kong.tools.utils" return { ["/consumers/:username_or_id/basic-auth/"] = { @@ -19,15 +20,18 @@ return { crud.post(self.params, dao_factory.basicauth_credentials) end }, - ["/consumers/:username_or_id/basic-auth/:id"] = { + ["/consumers/:username_or_id/basic-auth/:credential_username_or_id"] = { before = function(self, dao_factory, helpers) crud.find_consumer_by_username_or_id(self, dao_factory, helpers) self.params.consumer_id = self.consumer.id - local credentials, err = dao_factory.basicauth_credentials:find_all { + local filter_keys = { + [utils.is_valid_uuid(self.params.credential_username_or_id) and "id" or "username"] = self.params.credential_username_or_id, consumer_id = self.params.consumer_id, - id = self.params.id } + self.params.credential_username_or_id = nil + + local credentials, err = dao_factory.basicauth_credentials:find_all(filter_keys) if err then return helpers.yield_error(err) elseif next(credentials) == nil then diff --git a/kong/plugins/hmac-auth/api.lua b/kong/plugins/hmac-auth/api.lua index 7d495926312c..7bc12feec391 100644 --- a/kong/plugins/hmac-auth/api.lua +++ b/kong/plugins/hmac-auth/api.lua @@ -1,4 +1,5 @@ local crud = require "kong.api.crud_helpers" +local utils = require "kong.tools.utils" return{ ["/consumers/:username_or_id/hmac-auth/"] = { @@ -20,18 +21,26 @@ return{ end }, - ["/consumers/:username_or_id/hmac-auth/:id"] = { + ["/consumers/:username_or_id/hmac-auth/:credential_username_or_id"] = { before = function(self, dao_factory, helpers) crud.find_consumer_by_username_or_id(self, dao_factory, helpers) self.params.consumer_id = self.consumer.id - local err - self.hmacauth_credential, err = dao_factory.hmacauth_credentials:find(self.params) + + local filter_keys = { + [utils.is_valid_uuid(self.params.credential_username_or_id) and "id" or "username"] = self.params.credential_username_or_id, + consumer_id = self.params.consumer_id, + } + self.params.credential_username_or_id = nil + + local credentials, err = dao_factory.hmacauth_credentials:find_all(filter_keys) if err then return helpers.yield_error(err) - elseif self.hmacauth_credential == nil then + elseif next(credentials) == nil then return helpers.responses.send_HTTP_NOT_FOUND() end + + self.hmacauth_credential = credentials[1] end, GET = function(self, dao_factory, helpers) diff --git a/kong/plugins/jwt/api.lua b/kong/plugins/jwt/api.lua index ff7410133ed7..1c56d3cc1bfe 100644 --- a/kong/plugins/jwt/api.lua +++ b/kong/plugins/jwt/api.lua @@ -1,4 +1,5 @@ local crud = require "kong.api.crud_helpers" +local utils = require "kong.tools.utils" return { ["/consumers/:username_or_id/jwt/"] = { @@ -20,18 +21,25 @@ return { end }, - ["/consumers/:username_or_id/jwt/:id"] = { + ["/consumers/:username_or_id/jwt/:credential_key_or_id"] = { before = function(self, dao_factory, helpers) crud.find_consumer_by_username_or_id(self, dao_factory, helpers) self.params.consumer_id = self.consumer.id - local err - self.jwt_secret, err = dao_factory.jwt_secrets:find(self.params) + local filter_keys = { + [utils.is_valid_uuid(self.params.credential_key_or_id) and "id" or "key"] = self.params.credential_key_or_id, + consumer_id = self.params.consumer_id, + } + self.params.credential_key_or_id = nil + + local credentials, err = dao_factory.jwt_secrets:find_all(filter_keys) if err then return helpers.yield_error(err) - elseif self.jwt_secret == nil then + elseif next(credentials) == nil then return helpers.responses.send_HTTP_NOT_FOUND() end + + self.jwt_secret = credentials[1] end, GET = function(self, dao_factory, helpers) diff --git a/kong/plugins/key-auth/api.lua b/kong/plugins/key-auth/api.lua index 165a31cabfa8..ec49504ee304 100644 --- a/kong/plugins/key-auth/api.lua +++ b/kong/plugins/key-auth/api.lua @@ -1,4 +1,5 @@ local crud = require "kong.api.crud_helpers" +local utils = require "kong.tools.utils" return { ["/consumers/:username_or_id/key-auth/"] = { @@ -19,15 +20,18 @@ return { crud.post(self.params, dao_factory.keyauth_credentials) end }, - ["/consumers/:username_or_id/key-auth/:id"] = { + ["/consumers/:username_or_id/key-auth/:credential_key_or_id"] = { before = function(self, dao_factory, helpers) crud.find_consumer_by_username_or_id(self, dao_factory, helpers) self.params.consumer_id = self.consumer.id - local credentials, err = dao_factory.keyauth_credentials:find_all { + local filter_keys = { + [utils.is_valid_uuid(self.params.credential_key_or_id) and "id" or "key"] = self.params.credential_key_or_id, consumer_id = self.params.consumer_id, - id = self.params.id } + self.params.credential_key_or_id = nil + + local credentials, err = dao_factory.keyauth_credentials:find_all(filter_keys) if err then return helpers.yield_error(err) elseif next(credentials) == nil then diff --git a/kong/plugins/oauth2/api.lua b/kong/plugins/oauth2/api.lua index c197e317e632..6b6999bbd788 100644 --- a/kong/plugins/oauth2/api.lua +++ b/kong/plugins/oauth2/api.lua @@ -1,4 +1,5 @@ local crud = require "kong.api.crud_helpers" +local utils = require "kong.tools.utils" return { ["/oauth2_tokens/"] = { @@ -15,21 +16,34 @@ return { end }, - ["/oauth2_tokens/:id"] = { - GET = function(self, dao_factory) - crud.get(self.params, dao_factory.oauth2_tokens) + ["/oauth2_tokens/:token_or_id"] = { + before = function(self, dao_factory, helpers) + local filter_keys = { + [utils.is_valid_uuid(self.params.token_or_id) and "id" or "access_token"] = self.params.token_or_id, + consumer_id = self.params.consumer_id, + } + self.params.token_or_id = nil + + local credentials, err = dao_factory.oauth2_tokens:find_all(filter_keys) + if err then + return helpers.yield_error(err) + elseif next(credentials) == nil then + return helpers.responses.send_HTTP_NOT_FOUND() + end + + self.oauth2_token = credentials[1] end, - PATCH = function(self, dao_factory) - crud.patch(self.params, dao_factory.oauth2_tokens, self.params) + GET = function(self, dao_factory, helpers) + return helpers.responses.send_HTTP_OK(self.oauth2_token) end, - PUT = function(self, dao_factory) - crud.put(self.params, dao_factory.oauth2_tokens) + PATCH = function(self, dao_factory) + crud.patch(self.params, dao_factory.oauth2_tokens, self.oauth2_token) end, DELETE = function(self, dao_factory) - crud.delete(self.params, dao_factory.oauth2_tokens) + crud.delete(self.oauth2_token, dao_factory.oauth2_tokens) end }, @@ -58,15 +72,18 @@ return { end }, - ["/consumers/:username_or_id/oauth2/:id"] = { + ["/consumers/:username_or_id/oauth2/:clientid_or_id"] = { before = function(self, dao_factory, helpers) crud.find_consumer_by_username_or_id(self, dao_factory, helpers) self.params.consumer_id = self.consumer.id - local credentials, err = dao_factory.oauth2_credentials:find_all { + local filter_keys = { + [utils.is_valid_uuid(self.params.clientid_or_id) and "id" or "client_id"] = self.params.clientid_or_id, consumer_id = self.params.consumer_id, - id = self.params.id } + self.params.clientid_or_id = nil + + local credentials, err = dao_factory.oauth2_credentials:find_all(filter_keys) if err then return helpers.yield_error(err) elseif next(credentials) == nil then @@ -76,16 +93,16 @@ return { self.oauth2_credential = credentials[1] end, - GET = function(self, dao_factory) - crud.get(self.params, dao_factory.oauth2_credentials) + GET = function(self, dao_factory, helpers) + return helpers.responses.send_HTTP_OK(self.oauth2_credential) end, PATCH = function(self, dao_factory) - crud.patch(self.params, dao_factory.oauth2_credentials, self.params) + crud.patch(self.params, dao_factory.oauth2_credentials, self.oauth2_credential) end, DELETE = function(self, dao_factory) - crud.delete(self.params, dao_factory.oauth2_credentials) + crud.delete(self.oauth2_credential, dao_factory.oauth2_credentials) end } } diff --git a/spec/03-plugins/01-basic-auth/02-api_spec.lua b/spec/03-plugins/01-basic-auth/02-api_spec.lua index 862e2700e5d7..3184fc071daf 100644 --- a/spec/03-plugins/01-basic-auth/02-api_spec.lua +++ b/spec/03-plugins/01-basic-auth/02-api_spec.lua @@ -161,6 +161,15 @@ describe("Plugin: basic-auth (API)", function() local json = cjson.decode(body) assert.equal(credential.id, json.id) end) + it("retrieves basic-auth credential by username", function() + local res = assert(admin_client:send { + method = "GET", + path = "/consumers/bob/basic-auth/"..credential.username + }) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.equal(credential.id, json.id) + end) it("retrieves credential by id only if the credential belongs to the specified consumer", function() assert(helpers.dao.consumers:insert { username = "alice" @@ -181,7 +190,7 @@ describe("Plugin: basic-auth (API)", function() end) describe("PATCH", function() - it("updates a credential", function() + it("updates a credential by id", function() local previous_hash = credential.password local res = assert(admin_client:send { @@ -198,6 +207,23 @@ describe("Plugin: basic-auth (API)", function() local json = cjson.decode(body) assert.not_equal(previous_hash, json.password) end) + it("updates a credential by username", function() + local previous_hash = credential.password + + local res = assert(admin_client:send { + method = "PATCH", + path = "/consumers/bob/basic-auth/"..credential.username, + body = { + password = "upd4321" + }, + headers = { + ["Content-Type"] = "application/json" + } + }) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.not_equal(previous_hash, json.password) + end) describe("errors", function() it("handles invalid input", function() local res = assert(admin_client:send { @@ -225,12 +251,12 @@ describe("Plugin: basic-auth (API)", function() assert.res_status(204, res) end) describe("errors", function() - it("returns 400 on invalid input", function() + it("returns 404 on missing username", function() local res = assert(admin_client:send { method = "DELETE", path = "/consumers/bob/basic-auth/blah" }) - assert.res_status(400, res) + assert.res_status(404, res) end) it("returns 404 if not found", function() local res = assert(admin_client:send { diff --git a/spec/03-plugins/02-key-auth/01-api_spec.lua b/spec/03-plugins/02-key-auth/01-api_spec.lua index e63ff7b0546c..0e291795289f 100644 --- a/spec/03-plugins/02-key-auth/01-api_spec.lua +++ b/spec/03-plugins/02-key-auth/01-api_spec.lua @@ -151,6 +151,15 @@ describe("Plugin: key-auth (API)", function() local json = cjson.decode(body) assert.equal(credential.id, json.id) end) + it("retrieves key-auth credential by key", function() + local res = assert(admin_client:send { + method = "GET", + path = "/consumers/bob/key-auth/"..credential.key + }) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.equal(credential.id, json.id) + end) it("retrieves credential by id only if the credential belongs to the specified consumer", function() assert(helpers.dao.consumers:insert { username = "alice" @@ -171,7 +180,7 @@ describe("Plugin: key-auth (API)", function() end) describe("PATCH", function() - it("updates a credential", function() + it("updates a credential by id", function() local res = assert(admin_client:send { method = "PATCH", path = "/consumers/bob/key-auth/"..credential.id, @@ -186,6 +195,21 @@ describe("Plugin: key-auth (API)", function() local json = cjson.decode(body) assert.equal("4321", json.key) end) + it("updates a credential by key", function() + local res = assert(admin_client:send { + method = "PATCH", + path = "/consumers/bob/key-auth/"..credential.key, + body = { + key = "4321UPD" + }, + headers = { + ["Content-Type"] = "application/json" + } + }) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.equal("4321UPD", json.key) + end) describe("errors", function() it("handles invalid input", function() local res = assert(admin_client:send { @@ -218,7 +242,7 @@ describe("Plugin: key-auth (API)", function() method = "DELETE", path = "/consumers/bob/key-auth/blah" }) - assert.res_status(400, res) + assert.res_status(404, res) end) it("returns 404 if not found", function() local res = assert(admin_client:send { diff --git a/spec/03-plugins/06-jwt/02-api_spec.lua b/spec/03-plugins/06-jwt/02-api_spec.lua index 3eea07d4d7a1..01ead396f938 100644 --- a/spec/03-plugins/06-jwt/02-api_spec.lua +++ b/spec/03-plugins/06-jwt/02-api_spec.lua @@ -194,10 +194,17 @@ describe("Plugin: jwt (API)", function() }) assert.res_status(200, res) end) + it("retrieves by key", function() + local res = assert(admin_client:send { + method = "GET", + path = "/consumers/bob/jwt/"..jwt_secret.key, + }) + assert.res_status(200, res) + end) end) describe("PATCH", function() - it("updates a credential", function() + it("updates a credential by id", function() local res = assert(admin_client:send { method = "PATCH", path = "/consumers/bob/jwt/"..jwt_secret.id, @@ -213,6 +220,22 @@ describe("Plugin: jwt (API)", function() jwt_secret = cjson.decode(body) assert.equal("newsecret", jwt_secret.secret) end) + it("updates a credential by key", function() + local res = assert(admin_client:send { + method = "PATCH", + path = "/consumers/bob/jwt/"..jwt_secret.key, + body = { + key = "alice", + secret = "newsecret2" + }, + headers = { + ["Content-Type"] = "application/json" + } + }) + local body = assert.res_status(200, res) + jwt_secret = cjson.decode(body) + assert.equal("newsecret2", jwt_secret.secret) + end) end) describe("DELETE", function() @@ -236,7 +259,7 @@ describe("Plugin: jwt (API)", function() ["Content-Type"] = "application/json" } }) - assert.res_status(400, res) + assert.res_status(404, res) local res = assert(admin_client:send { method = "DELETE", diff --git a/spec/03-plugins/06-jwt/03-access_spec.lua b/spec/03-plugins/06-jwt/03-access_spec.lua index 4a59fb74168a..7abd69587fc8 100644 --- a/spec/03-plugins/06-jwt/03-access_spec.lua +++ b/spec/03-plugins/06-jwt/03-access_spec.lua @@ -165,9 +165,9 @@ describe("Plugin: jwt (access)", function() PAYLOAD.iss = base64_jwt_secret.key local original_secret = base64_jwt_secret.secret local base64_secret = ngx.encode_base64(base64_jwt_secret.secret) - assert(admin_client:send { + local res = assert(admin_client:send { method = "PATCH", - path = "/consumers/jwt_tests_consumer/jwt/"..base64_jwt_secret.id, + path = "/consumers/jwt_tests_base64_consumer/jwt/"..base64_jwt_secret.id, body = { key = base64_jwt_secret.key, secret = base64_secret}, @@ -175,6 +175,7 @@ describe("Plugin: jwt (access)", function() ["Content-Type"] = "application/json" } }) + assert.res_status(200, res) local jwt = jwt_encoder.encode(PAYLOAD, original_secret) local authorization = "Bearer "..jwt @@ -188,7 +189,7 @@ describe("Plugin: jwt (access)", function() }) local body = cjson.decode(assert.res_status(200, res)) assert.equal(authorization, body.headers.authorization) - assert.equal("jwt_tests_consumer", body.headers["x-consumer-username"]) + assert.equal("jwt_tests_base64_consumer", body.headers["x-consumer-username"]) end) it("finds the JWT if given in URL parameters", function() PAYLOAD.iss = jwt_secret.key diff --git a/spec/03-plugins/09-hmac-auth/02-api_spec.lua b/spec/03-plugins/09-hmac-auth/02-api_spec.lua index f7e9528c815e..fa7546eca744 100644 --- a/spec/03-plugins/09-hmac-auth/02-api_spec.lua +++ b/spec/03-plugins/09-hmac-auth/02-api_spec.lua @@ -106,10 +106,21 @@ describe("Plugin: hmac-auth (API)", function() local body = cjson.decode(body_json) assert.equals(credential.id, body.id) end) + it("should retrieve by username", function() + local res = assert(client:send { + method = "GET", + path = "/consumers/bob/hmac-auth/"..credential.username, + body = {}, + headers = {["Content-Type"] = "application/json"} + }) + local body_json = assert.res_status(200, res) + local body = cjson.decode(body_json) + assert.equals(credential.id, body.id) + end) end) describe("PATCH", function() - it("[SUCCESS] should update a credential", function() + it("[SUCCESS] should update a credential by id", function() local res = assert(client:send { method = "PATCH", path = "/consumers/bob/hmac-auth/"..credential.id, @@ -120,6 +131,17 @@ describe("Plugin: hmac-auth (API)", function() credential = cjson.decode(body_json) assert.equals("alice", credential.username) end) + it("[SUCCESS] should update a credential by username", function() + local res = assert(client:send { + method = "PATCH", + path = "/consumers/bob/hmac-auth/"..credential.username, + body = {username = "aliceUPD"}, + headers = {["Content-Type"] = "application/json"} + }) + local body_json = assert.res_status(200, res) + credential = cjson.decode(body_json) + assert.equals("aliceUPD", credential.username) + end) it("[FAILURE] should return proper errors", function() local res = assert(client:send { method = "PATCH", @@ -136,11 +158,11 @@ describe("Plugin: hmac-auth (API)", function() it("[FAILURE] should return proper errors", function() local res = assert(client:send { method = "DELETE", - path = "/consumers/bob/hmac-auth/alice", + path = "/consumers/bob/hmac-auth/aliceasd", body = {}, headers = {["Content-Type"] = "application/json"} }) - assert.res_status(400, res) + assert.res_status(404, res) local res = assert(client:send { method = "DELETE", diff --git a/spec/03-plugins/99-oauth2/02-api_spec.lua b/spec/03-plugins/99-oauth2/02-api_spec.lua index fd30e068a9ea..c963b0804be4 100644 --- a/spec/03-plugins/99-oauth2/02-api_spec.lua +++ b/spec/03-plugins/99-oauth2/02-api_spec.lua @@ -214,6 +214,15 @@ describe("Plugin: oauth (API)", function() local json = cjson.decode(body) assert.equal(credential.id, json.id) end) + it("retrieves oauth2 credential by client id", function() + local res = assert(admin_client:send { + method = "GET", + path = "/consumers/bob/oauth2/"..credential.client_id + }) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.equal(credential.id, json.id) + end) it("retrieves credential by id only if the credential belongs to the specified consumer", function() assert(helpers.dao.consumers:insert { username = "alice" @@ -231,10 +240,23 @@ describe("Plugin: oauth (API)", function() }) assert.res_status(404, res) end) + it("retrieves credential by clientid only if the credential belongs to the specified consumer", function() + local res = assert(admin_client:send { + method = "GET", + path = "/consumers/bob/oauth2/"..credential.client_id + }) + assert.res_status(200, res) + + res = assert(admin_client:send { + method = "GET", + path = "/consumers/alice/oauth2/"..credential.client_id + }) + assert.res_status(404, res) + end) end) describe("PATCH", function() - it("updates a credential", function() + it("updates a credential by id", function() local previous_name = credential.name local res = assert(admin_client:send { @@ -251,6 +273,23 @@ describe("Plugin: oauth (API)", function() local json = cjson.decode(body) assert.not_equal(previous_name, json.name) end) + it("updates a credential by client id", function() + local previous_name = credential.name + + local res = assert(admin_client:send { + method = "PATCH", + path = "/consumers/bob/oauth2/"..credential.client_id, + body = { + name = "4321UDP" + }, + headers = { + ["Content-Type"] = "application/json" + } + }) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.not_equal(previous_name, json.name) + end) describe("errors", function() it("handles invalid input", function() local res = assert(admin_client:send { @@ -283,7 +322,7 @@ describe("Plugin: oauth (API)", function() method = "DELETE", path = "/consumers/bob/oauth2/blah" }) - assert.res_status(400, res) + assert.res_status(404, res) end) it("returns 404 if not found", function() local res = assert(admin_client:send { @@ -426,35 +465,43 @@ describe("Plugin: oauth (API)", function() local json = cjson.decode(body) assert.equal(token.id, json.id) end) + it("retrieves oauth2 token by access_token", function() + local res = assert(admin_client:send { + method = "GET", + path = "/oauth2_tokens/"..token.access_token + }) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.equal(token.id, json.id) + end) end) - describe("PUT", function() - it("should update every field", function() - token.access_token = "helloworld" - token.refresh_token = nil + describe("PATCH", function() + it("updates a token by id", function() + local previous_expires_in = token.expires_in + local res = assert(admin_client:send { - method = "PUT", + method = "PATCH", path = "/oauth2_tokens/"..token.id, - body = token, + body = { + expires_in = 20 + }, headers = { ["Content-Type"] = "application/json" } }) - local body = cjson.decode(assert.res_status(200, res)) - assert.is_nil(body.refresh_token) - assert.equal("helloworld", body.access_token) + local body = assert.res_status(200, res) + local json = cjson.decode(body) + assert.not_equal(previous_expires_in, json.expires_in) end) - end) - - describe("PATCH", function() - it("updates a token", function() + it("updates a token by access_token", function() local previous_expires_in = token.expires_in local res = assert(admin_client:send { method = "PATCH", - path = "/oauth2_tokens/"..token.id, + path = "/oauth2_tokens/"..token.access_token, body = { - expires_in = 20 + expires_in = 400 }, headers = { ["Content-Type"] = "application/json" @@ -496,7 +543,7 @@ describe("Plugin: oauth (API)", function() method = "DELETE", path = "/oauth2_tokens/blah" }) - assert.res_status(400, res) + assert.res_status(404, res) end) it("returns 404 if not found", function() local res = assert(admin_client:send {