diff --git a/kong/plugins/request-transformer/access.lua b/kong/plugins/request-transformer/access.lua index f33766b5172c..92382e230b85 100644 --- a/kong/plugins/request-transformer/access.lua +++ b/kong/plugins/request-transformer/access.lua @@ -1,165 +1,242 @@ -local utils = require "kong.tools.utils" local stringy = require "stringy" -local Multipart = require "multipart" -local cjson = require "cjson" +local multipart = require "multipart" + +local table_insert = table.insert +local req_set_uri_args = ngx.req.set_uri_args +local req_get_uri_args = ngx.req.get_uri_args +local req_set_header = ngx.req.set_header +local req_get_headers = ngx.req.get_headers +local req_read_body = ngx.req.read_body +local req_set_body_data = ngx.req.set_body_data +local req_get_body_data = ngx.req.get_body_data +local req_clear_header = ngx.req.clear_header +local req_get_post_args = ngx.req.get_post_args +local encode_args = ngx.encode_args +local type = type +local string_len = string.len + +local unpack = unpack local _M = {} -local APPLICATION_JSON = "application/json" local CONTENT_LENGTH = "content-length" local FORM_URLENCODED = "application/x-www-form-urlencoded" local MULTIPART_DATA = "multipart/form-data" local CONTENT_TYPE = "content-type" local HOST = "host" -local function iterate_and_exec(val, cb) - if utils.table_size(val) > 0 then - for _, entry in ipairs(val) do - local parts = stringy.split(entry, ":") - cb(parts[1], utils.table_size(parts) == 2 and parts[2] or nil) + +local function iter(config_array) + return function(config_array, i, previous_name, previous_value) + i = i + 1 + local current_pair = config_array[i] + if current_pair == nil then -- n + 1 + return nil end - end + local current_name, current_value = unpack(stringy.split(current_pair, ":")) + return i, current_name, current_value + end, config_array, 0 end local function get_content_type() - local header_value = ngx.req.get_headers()[CONTENT_TYPE] + local header_value = req_get_headers()[CONTENT_TYPE] if header_value then return stringy.strip(header_value):lower() end end -function _M.execute(conf) - if conf.add then - - -- Add headers - if conf.add.headers then - iterate_and_exec(conf.add.headers, function(name, value) - ngx.req.set_header(name, value) - if name:lower() == HOST then -- Host header has a special treatment - ngx.var.upstream_host = value - end - end) - end +local function append_value(current_value, value) + local current_value_type = type(current_value) - -- Add Querystring - if conf.add.querystring then - local querystring = ngx.req.get_uri_args() - iterate_and_exec(conf.add.querystring, function(name, value) - querystring[name] = value - end) - ngx.req.set_uri_args(querystring) - end + if current_value_type == "string" then + return { current_value, value } + elseif current_value_type == "table" then + table_insert(current_value, value) + return current_value + else + return { value } + end +end - if conf.add.form then - local content_type = get_content_type() - if content_type and stringy.startswith(content_type, FORM_URLENCODED) then - -- Call ngx.req.read_body to read the request body first - ngx.req.read_body() +local function transform_headers(conf) + -- Remove header(s) + for _, name, value in iter(conf.remove.headers) do + req_clear_header(name) + end - local parameters = ngx.req.get_post_args() - iterate_and_exec(conf.add.form, function(name, value) - parameters[name] = value - end) - local encoded_args = ngx.encode_args(parameters) - ngx.req.set_header(CONTENT_LENGTH, string.len(encoded_args)) - ngx.req.set_body_data(encoded_args) - elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then - -- Call ngx.req.read_body to read the request body first - ngx.req.read_body() - - local body = ngx.req.get_body_data() - local parameters = Multipart(body and body or "", content_type) - iterate_and_exec(conf.add.form, function(name, value) - parameters:set_simple(name, value) - end) - local new_data = parameters:tostring() - ngx.req.set_header(CONTENT_LENGTH, string.len(new_data)) - ngx.req.set_body_data(new_data) + -- Replace header(s) + for _, name, value in iter(conf.replace.headers) do + if req_get_headers()[name] then + req_set_header(name, value) + if name:lower() == HOST then -- Host header has a special treatment + ngx.var.backend_host = value end end + end + -- Add header(s) + for _, name, value in iter(conf.add.headers) do + if not req_get_headers()[name] then + req_set_header(name, value) + if name:lower() == HOST then -- Host header has a special treatment + ngx.var.backend_host = value + end + end end - if conf.add.json then - local content_type = get_content_type() - if content_type and stringy.startswith(get_content_type(), APPLICATION_JSON) then - ngx.req.read_body() - local parameters = cjson.decode(ngx.req.get_body_data()) - - iterate_and_exec(conf.add.json, function(name, value) - local v = cjson.encode(value) - if stringy.startswith(v, "\"") and stringy.endswith(v, "\"") then - v = v:sub(2, v:len() - 1):gsub("\\\"", "\"") -- To prevent having double encoded quotes - end - parameters[name] = v - end) + -- Append header(s) + for _, name, value in iter(conf.append.headers) do + req_set_header(name, append_value(req_get_headers()[name], value)) + if name:lower() == HOST then -- Host header has a special treatment + ngx.var.backend_host = value + end + end +end - local new_data = cjson.encode(parameters) - ngx.req.set_header(CONTENT_LENGTH, string.len(new_data)) - ngx.req.set_body_data(new_data) +local function transform_querystrings(conf) + -- Remove querystring(s) + if conf.remove.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.remove.querystring) do + querystring[name] = nil end + req_set_uri_args(querystring) end - if conf.remove then + -- Replace querystring(s) + if conf.replace.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.replace.querystring) do + if querystring[name] then + querystring[name] = value + end + end + req_set_uri_args(querystring) + end - -- Remove headers - if conf.remove.headers then - iterate_and_exec(conf.remove.headers, function(name, value) - ngx.req.clear_header(name) - end) + -- Add querystring(s) + if conf.add.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.add.querystring) do + if not querystring[name] then + querystring[name] = value + end end + req_set_uri_args(querystring) + end - if conf.remove.querystring then - local querystring = ngx.req.get_uri_args() - iterate_and_exec(conf.remove.querystring, function(name) - querystring[name] = nil - end) - ngx.req.set_uri_args(querystring) + -- Append querystring(s) + if conf.append.querystring then + local querystring = req_get_uri_args() + for _, name, value in iter(conf.append.querystring) do + querystring[name] = append_value(querystring[name], value) end + req_set_uri_args(querystring) + end +end - if conf.remove.form then - local content_type = get_content_type() - if content_type and stringy.startswith(content_type, FORM_URLENCODED) then - local parameters = ngx.req.get_post_args() - - iterate_and_exec(conf.remove.form, function(name) - parameters[name] = nil - end) - - local encoded_args = ngx.encode_args(parameters) - ngx.req.set_header(CONTENT_LENGTH, string.len(encoded_args)) - ngx.req.set_body_data(encoded_args) - elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then - -- Call ngx.req.read_body to read the request body first - ngx.req.read_body() - - local body = ngx.req.get_body_data() - local parameters = Multipart(body and body or "", content_type) - iterate_and_exec(conf.remove.form, function(name) - parameters:delete(name) - end) - local new_data = parameters:tostring() - ngx.req.set_header(CONTENT_LENGTH, string.len(new_data)) - ngx.req.set_body_data(new_data) +local function transform_form_params(conf) + -- Remove form parameter(s) + if conf.remove.form then + local content_type = get_content_type() + if content_type and stringy.startswith(content_type, FORM_URLENCODED) then + req_read_body() + local parameters = req_get_post_args() + + for _, name, value in iter(conf.remove.form) do + parameters[name] = nil end + + local encoded_args = encode_args(parameters) + req_set_header(CONTENT_LENGTH, string_len(encoded_args)) + req_set_body_data(encoded_args) + elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then + -- Call req_read_body to read the request body first + req_read_body() + + local body = req_get_body_data() + local parameters = multipart(body and body or "", content_type) + for _, name, value in iter(conf.remove.form) do + parameters:delete(name) + end + local new_data = parameters:tostring() + req_set_header(CONTENT_LENGTH, string_len(new_data)) + req_set_body_data(new_data) end + end - if conf.remove.json then - local content_type = get_content_type() - if content_type and stringy.startswith(get_content_type(), APPLICATION_JSON) then - ngx.req.read_body() - local parameters = cjson.decode(ngx.req.get_body_data()) + -- Replace form parameter(s) + if conf.replace.form then + local content_type = get_content_type() + if content_type and stringy.startswith(content_type, FORM_URLENCODED) then + -- Call req_read_body to read the request body first + req_read_body() - iterate_and_exec(conf.remove.json, function(name) - parameters[name] = nil - end) + local parameters = req_get_post_args() + for _, name, value in iter(conf.replace.form) do + if parameters[name] then + parameters[name] = value + end + end + local encoded_args = encode_args(parameters) + req_set_header(CONTENT_LENGTH, string_len(encoded_args)) + req_set_body_data(encoded_args) + elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then + -- Call req_read_body to read the request body first + req_read_body() + + local body = req_get_body_data() + local parameters = multipart(body and body or "", content_type) + for _, name, value in iter(conf.replace.form) do + if parameters:get(name) then + parameters:delete(name) + parameters:set_simple(name, value) + end + end + local new_data = parameters:tostring() + req_set_header(CONTENT_LENGTH, string_len(new_data)) + req_set_body_data(new_data) + end + end - local new_data = cjson.encode(parameters) - ngx.req.set_header(CONTENT_LENGTH, string.len(new_data)) - ngx.req.set_body_data(new_data) + -- Add form parameter(s) + if conf.add.form then + local content_type = get_content_type() + if content_type and stringy.startswith(content_type, FORM_URLENCODED) then + -- Call req_read_body to read the request body first + req_read_body() + + local parameters = req_get_post_args() + for _, name, value in iter(conf.add.form) do + if not parameters[name] then + parameters[name] = value + end end + local encoded_args = encode_args(parameters) + req_set_header(CONTENT_LENGTH, string_len(encoded_args)) + req_set_body_data(encoded_args) + elseif content_type and stringy.startswith(content_type, MULTIPART_DATA) then + -- Call req_read_body to read the request body first + req_read_body() + + local body = req_get_body_data() + local parameters = multipart(body and body or "", content_type) + for _, name, value in iter(conf.add.form) do + if not parameters:get(name) then + parameters:set_simple(name, value) + end + end + local new_data = parameters:tostring() + req_set_header(CONTENT_LENGTH, string_len(new_data)) + req_set_body_data(new_data) end end end +function _M.execute(conf) + transform_form_params(conf) + transform_headers(conf) + transform_querystrings(conf) +end + return _M diff --git a/kong/plugins/request-transformer/schema.lua b/kong/plugins/request-transformer/schema.lua index e15f4b8275fc..87bc2439d4e4 100644 --- a/kong/plugins/request-transformer/schema.lua +++ b/kong/plugins/request-transformer/schema.lua @@ -1,22 +1,41 @@ return { fields = { - add = { type = "table", - schema = { - fields = { - form = { type = "array" }, - headers = { type = "array" }, - querystring = { type = "array" }, - json = { type = "array" } + remove = { + type = "table", + schema = { + fields = { + form = {type = "array", default = {}}, + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} + } + } + }, + replace = { + type = "table", + schema = { + fields = { + form = {type = "array", default = {}}, + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} + } + } + }, + add = { + type = "table", + schema = { + fields = { + form = {type = "array", default = {}}, + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} + } } - } }, - remove = { type = "table", + append = { + type = "table", schema = { fields = { - form = { type = "array" }, - headers = { type = "array" }, - querystring = { type = "array" }, - json = { type = "array" } + headers = {type = "array", default = {}}, + querystring = {type = "array", default = {}} } } } diff --git a/spec/plugins/request-transformer/access_spec.lua b/spec/plugins/request-transformer/access_spec.lua index a142dd8dc837..7cbb834bc255 100644 --- a/spec/plugins/request-transformer/access_spec.lua +++ b/spec/plugins/request-transformer/access_spec.lua @@ -6,29 +6,25 @@ local STUB_GET_URL = spec_helper.STUB_GET_URL local STUB_POST_URL = spec_helper.STUB_POST_URL describe("Request Transformer", function() - setup(function() spec_helper.prepare_db() spec_helper.insert_fixtures { api = { {name = "tests-request-transformer-1", request_host = "test1.com", upstream_url = "http://mockbin.com"}, - {name = "tests-request-transformer-2", request_host = "test2.com", upstream_url = "http://httpbin.org"} + {name = "tests-request-transformer-2", request_host = "test2.com", upstream_url = "http://httpbin.org"}, + {name = "tests-request-transformer-3", request_host = "test3.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-4", request_host = "test4.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-5", request_host = "test5.com", upstream_url = "http://mockbin.com"}, + {name = "tests-request-transformer-6", request_host = "test6.com", upstream_url = "http://mockbin.com"}, }, plugin = { { name = "request-transformer", config = { add = { - headers = {"x-added:true", "x-added2:true" }, - querystring = {"newparam:value"}, - form = {"newformparam:newvalue"}, - json = {"newjsonparam:newvalue"} - }, - remove = { - headers = {"x-to-remove"}, - querystring = {"toremovequery"}, - form = {"toremoveform"}, - json = {"toremovejson"} + headers = {"h1:v1", "h2:v2"}, + querystring = {"q1:v1"}, + form = {"p1:v1"} } }, __api = 1 @@ -41,115 +37,331 @@ describe("Request Transformer", function() } }, __api = 2 + }, + { + name = "request-transformer", + config = { + add = { + headers = {"x-added:a1", "x-added2:b1", "x-added3:c2"}, + querystring = {"query-added:newvalue", "p1:a1"}, + form = {"newformparam:newvalue"} + }, + remove = { + headers = {"x-to-remove"}, + querystring = {"toremovequery"} + }, + append = { + headers = {"x-added:a2", "x-added:a3"}, + querystring = {"p1:a2", "p2:b1"} + }, + replace = { + headers = {"x-to-replace:false"}, + querystring = {"toreplacequery:no"} + } + }, + __api = 3 + }, + { + name = "request-transformer", + config = { + remove = { + headers = {"x-to-remove"}, + querystring = {"q1"}, + form = {"toremoveform"} + } + }, + __api = 4 + }, + { + name = "request-transformer", + config = { + replace = { + headers = {"h1:v1"}, + querystring = {"q1:v1"}, + form = {"p1:v1"} + } + }, + __api = 5 + }, + { + name = "request-transformer", + config = { + append = { + headers = {"h1:v1", "h1:v2", "h2:v1",}, + querystring = {"q1:v1", "q1:v2", "q2:v1"} + } + }, + __api = 6 } - } + }, } - spec_helper.start_kong() end) teardown(function() spec_helper.stop_kong() end) - - describe("Test adding parameters", function() + + describe("Test remove", function() + it("should remove specified header", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test4.com", ["x-to-remove"] = "true", ["x-another-header"] = "true"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.headers["x-to-remove"]) + assert.equal("true", body.headers["x-another-header"]) + end) + it("should remove parameters on POST", function() + local response, status = http_client.post(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test4.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["toremoveform"]) + assert.equal("yes", body.postData.params["nottoremove"]) + end) + it("should remove parameters on multipart POST", function() + local response, status = http_client.post_multipart(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test4.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["toremoveform"]) + assert.equal("yes", body.postData.params["nottoremove"]) + end) + it("should remove queryString on GET if it exist", function() + local response, status = http_client.post(STUB_POST_URL.."/?q1=v1&q2=v2", { hello = "world"}, {host = "test4.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.queryString["q1"]) + assert.equal("v2", body.queryString["q2"]) + end) + end) + + describe("Test replace", function() + it("should replace specified header if it exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test5.com", ["h1"] = "V", ["h2"] = "v2"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) + end) + it("should not add as new header if header does not exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test5.com", ["h2"] = "v2"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) + end) + it("should replace specified parameters on POST", function() + local response, status = http_client.post(STUB_POST_URL, {["p1"] = "v", ["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should not add as new parameter if parameter does not exist on POST", function() + local response, status = http_client.post(STUB_POST_URL, {["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should replace specified parameters on multipart POST", function() + local response, status = http_client.post_multipart(STUB_POST_URL, {["p1"] = "v", ["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should not add as new parameter if parameter does not exist on multipart POST", function() + local response, status = http_client.post_multipart(STUB_POST_URL, {["p2"] = "v1"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.postData.params["p1"]) + assert.equal("v1", body.postData.params["p2"]) + end) + it("should replace queryString on POST if it exist", function() + local response, status = http_client.post(STUB_POST_URL.."/?q1=v&q2=v2", { hello = "world"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.queryString["q1"]) + assert.equal("v2", body.queryString["q2"]) + end) + it("should not add new queryString on POST if it does not exist", function() + local response, status = http_client.post(STUB_POST_URL.."/?q2=v2", { hello = "world"}, {host = "test5.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.queryString["q1"]) + assert.equal("v2", body.queryString["q2"]) + end) + end) + + describe("Test add", function() it("should add new headers", function() local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("true", body.headers["x-added"]) - assert.are.equal("true", body.headers["x-added2"]) + assert.equal(200, status) + assert.equal("v1", body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) end) - it("should add new parameters on POST", function() - local response, status = http_client.post(STUB_POST_URL, {}, {host = "test1.com"}) + it("should not change or append value if header already exists", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com", h1 = "v3"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("v3", body.headers["h1"]) + assert.equal("v2", body.headers["h2"]) end) - it("should add new parameters on POST when existing params exist", function() - local response, status = http_client.post(STUB_POST_URL, {hello = "world"}, {host = "test1.com"}) + it("should add new parameter on POST", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("world", body.postData.params["hello"]) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("world", body.postData.params["hello"]) + assert.equal("v1", body.postData.params["p1"]) end) - it("should add new parameters on multipart POST", function() - local response, status = http_client.post_multipart(STUB_POST_URL, {}, {host = "test1.com"}) + it("should not change or append value to parameter on POST when parameter exists", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("world", body.postData.params["hello"]) + assert.equal("v1", body.postData.params["p1"]) end) - it("should add new parameters on multipart POST when existing params exist", function() - local response, status = http_client.post_multipart(STUB_POST_URL, {hello = "world"}, {host = "test1.com"}) + it("should add new parameter on multipart POST", function() + local response, status = http_client.post_multipart(STUB_POST_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("world", body.postData.params["hello"]) - assert.are.equal("newvalue", body.postData.params["newformparam"]) + assert.equal(200, status) + assert.equal("v1", body.postData.params["p1"]) end) - it("should add new paramters on json POST", function() - local response, status = http_client.post(STUB_POST_URL, {}, {host = "test1.com", ["content-type"] = "application/json"}) - local raw = cjson.decode(response) - local body = cjson.decode(raw.postData.text) - assert.are.equal(200, status) - assert.are.equal("newvalue", body["newjsonparam"]) - end) - it("should add new paramters on json POST when existing params exist", function() - local response, status = http_client.post(STUB_POST_URL, {hello = "world"}, {host = "test1.com", ["content-type"] = "application/json"}) - local raw = cjson.decode(response) - local body = cjson.decode(raw.postData.text) - assert.are.equal(200, status) - assert.are.equal("world", body["hello"]) - assert.are.equal("newvalue", body["newjsonparam"]) + it("should not change or append value to parameter on multipart POST when parameter exists", function() + local response, status = http_client.post_multipart(STUB_POST_URL, { hello = "world"}, {host = "test1.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("world", body.postData.params["hello"]) + assert.equal("v1", body.postData.params["p1"]) end) - it("should add new parameters on GET", function() + it("should add new querystring on GET", function() local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("value", body.queryString["newparam"]) + assert.equal(200, status) + assert.equal("v1", body.queryString["q1"]) end) - it("should change the host header", function() + it("should not change or append value to querystring on GET if querystring exists", function() + local response, status = http_client.get(STUB_GET_URL, {q1 = "v2"}, {host = "test1.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v2", body.queryString["q1"]) + end) + it("should not change the host header", function() local response, status = http_client.get(spec_helper.PROXY_URL.."/get", {}, {host = "test2.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.are.equal("mark", body.headers["Host"]) + assert.equal(200, status) + assert.equal("httpbin.org", body.headers["Host"]) + end) + end) + + describe("Test append ", function() + it("should add a new header if header does not exists", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.headers["h2"]) + end) + it("should append values to existing headers", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1, v2", body.headers["h1"]) + end) + it("should add new querystring if querystring does not exists", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("v1", body.queryString["q2"]) + end) + it("should append values to existing querystring", function() + local response, status = http_client.post(STUB_POST_URL, { hello = "world"}, {host = "test6.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.same({"v1", "v2"}, body.queryString["q1"]) end) end) - describe("Test removing parameters", function() + describe("Test for remove, replace, add and append ", function() it("should remove a header", function() - local response, status = http_client.get(STUB_GET_URL, {}, {host = "test1.com", ["x-to-remove"] = "true"}) + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com", ["x-to-remove"] = "true"}) local body = cjson.decode(response) - assert.are.equal(200, status) + assert.equal(200, status) assert.falsy(body.headers["x-to-remove"]) end) - it("should remove parameters on POST", function() - local response, status = http_client.post(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) + it("should replace value of header, if header exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com", ["x-to-replace"] = "true"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.falsy(body.postData.params["toremoveform"]) - assert.are.same("yes", body.postData.params["nottoremove"]) + assert.equal(200, status) + assert.equal("false", body.headers["x-to-replace"]) end) - it("should remove parameters on multipart POST", function() - local response, status = http_client.post_multipart(STUB_POST_URL, {["toremoveform"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) + it("should not add new header if to be replaced header does not exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) - assert.falsy(body.postData.params["toremoveform"]) - assert.are.same("yes", body.postData.params["nottoremove"]) + assert.equal(200, status) + assert.falsy(body.headers["x-to-replace"]) + end) + it("should add new header if missing", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("b1", body.headers["x-added2"]) end) - it("should remove parameters on json POST", function() - local response, status = http_client.post(STUB_POST_URL, {["toremovejson"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com", ["content-type"] = "application/json"}) - local raw = cjson.decode(response) - local body = cjson.decode(raw.postData.text) - assert.are.equal(200, status) - assert.falsy(body["toremovejson"]) - assert.are.same("yes", body["nottoremove"]) + it("should not add new header if it already exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com", ["x-added3"] = "c1"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("c1", body.headers["x-added3"]) + end) + it("should append values to existing headers", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("a1, a2, a3", body.headers["x-added"]) + end) + it("should add new parameters on POST when query string key missing", function() + local response, status = http_client.post(STUB_POST_URL, {hello = "world"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("b1", body.queryString["p2"]) end) it("should remove parameters on GET", function() - local response, status = http_client.get(STUB_GET_URL, {["toremovequery"] = "yes", ["nottoremove"] = "yes"}, {host = "test1.com"}) + local response, status = http_client.get(STUB_GET_URL, {["toremovequery"] = "yes", ["nottoremove"] = "yes"}, {host = "test3.com"}) local body = cjson.decode(response) - assert.are.equal(200, status) + assert.equal(200, status) assert.falsy(body.queryString["toremovequery"]) - assert.are.equal("yes", body.queryString["nottoremove"]) + assert.equal("yes", body.queryString["nottoremove"]) + end) + it("should replace parameters on GET", function() + local response, status = http_client.get(STUB_GET_URL, {["toreplacequery"] = "yes"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("no", body.queryString["toreplacequery"]) + end) + it("should not add new parameter if to be replaced parameters does not exist on GET", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.falsy(body.queryString["toreplacequery"]) + end) + it("should add parameters on GET if it does not exist", function() + local response, status = http_client.get(STUB_GET_URL, {}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("newvalue", body.queryString["query-added"]) + end) + it("should not add new parameter if to be added parameters already exist on GET", function() + local response, status = http_client.get(STUB_GET_URL, {["query-added"] = "oldvalue"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("oldvalue", body.queryString["query-added"]) + end) + it("should append parameters on GET", function() + local response, status = http_client.post(STUB_POST_URL.."/?q1=20", { hello = "world"}, {host = "test3.com"}) + local body = cjson.decode(response) + assert.equal(200, status) + assert.equal("a1", body.queryString["p1"][1]) + assert.equal("a2", body.queryString["p1"][2]) + assert.equal("20", body.queryString["q1"]) end) end) end)