Skip to content

Commit

Permalink
feat(ssl) new enabled property for SSL options.
Browse files Browse the repository at this point in the history
Allow to enable SSL without having to necessarily verify the server
certificate.
  • Loading branch information
thibaultcha committed Dec 14, 2015
1 parent 4a56056 commit b6fc8ea
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 44 deletions.
15 changes: 15 additions & 0 deletions spec/unit/options_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ describe("options parsing", function()
}))
assert.equal("socket read_timeout must be a number", err)
end)
it("should validate SSL options", function()
local err = select(2, parse_session {
shm = "test",
ssl_options = ""
})
assert.equal("ssl_options must be a table", err)

err = select(2, parse_session {
shm = "test",
ssl_options = {
enabled = ""
}
})
assert.equal("ssl_options.enabled must be a boolean", err)
end)
it("should set `prepared_shm` to `shm` if nil", function()
local options, err = parse_session {
shm = "test"
Expand Down
10 changes: 10 additions & 0 deletions spec/unit/utils_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,15 @@ describe("table_utils", function()
target = table_utils.extend_table(source, target)
assert.False(target.source)
end)
it("should ignore targets that are not tables", function()
local source = {foo = {bar = "foobar"}}
local target = {foo = "hello"}

assert.has_no_error(function()
target = table_utils.extend_table(source, target)
end)

assert.equal("hello", target.foo)
end)
end)
end)
67 changes: 32 additions & 35 deletions src/cassandra.lua
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,14 @@ function Host:new(address, options)
local host, port = string_utils.split_by_colon(address)
if not port then port = options.protocol_options.default_port end

local h = {}

h.host = host
h.port = port
h.address = address
h.protocol_version = DEFAULT_PROTOCOL_VERSION

h.options = options
h.reconnection_policy = h.options.policies.reconnection

new_socket(h)
local h = {
host = host,
port = port,
address = address,
protocol_version = DEFAULT_PROTOCOL_VERSION,
options = options,
reconnection_policy = options.policy.reconnection
}

return setmetatable(h, Host)
end
Expand Down Expand Up @@ -242,6 +239,8 @@ end
function Host:connect()
if self.connected then return true end

new_socket(self)

log.debug("Connecting to "..self.address)

self:set_timeout(self.options.socket_options.connect_timeout)
Expand All @@ -252,7 +251,7 @@ function Host:connect()
return false, Errors.SocketError(self.address, err), true
end

if self.options.ssl_options ~= nil then
if self.options.ssl_options.enabled then
ok, err = do_ssl_handshake(self)
if not ok then
return false, Errors.SocketError(self.address, err)
Expand Down Expand Up @@ -391,34 +390,30 @@ function Host:close()
end

function Host:set_down()
local host_infos, err = cache.get_host(self.options.shm, self.address)
if err then
return err
local lock, lock_err, elapsed = lock_mutex(self.options.shm, "downing_"..self.address)
if lock_err then
return lock_err
end

if host_infos.unhealthy_at == 0 then
local lock, lock_err, elapsed = lock_mutex(self.options.shm, "downing_"..self.address)
if lock_err then
return lock_err
end

if elapsed and elapsed == 0 then
log.warn("Setting host "..self.address.." as DOWN")
host_infos.unhealthy_at = time_utils.get_time()
host_infos.reconnection_delay = self.reconnection_policy.next(self)
self:close()
new_socket(self)
local ok, err = cache.set_host(self.options.shm, self.address, host_infos)
if not ok then
return err
end
if elapsed and elapsed == 0 then
local host_infos, err = cache.get_host(self.options.shm, self.address)
if err then
return err
end

lock_err = unlock_mutex(lock)
if lock_err then
log.warn("Setting host "..self.address.." as DOWN")
host_infos.unhealthy_at = time_utils.get_time()
host_infos.reconnection_delay = self.reconnection_policy.next(self)
self:close()
local ok, err = cache.set_host(self.options.shm, self.address, host_infos)
if not ok then
return err
end
end

lock_err = unlock_mutex(lock)
if lock_err then
return lock_err
end
end

function Host:set_up()
Expand Down Expand Up @@ -458,7 +453,9 @@ function Host:can_be_considered_up()
return nil, err
end

return is_up or (time_utils.get_time() - host_infos.unhealthy_at >= host_infos.reconnection_delay)
if is_up or (time_utils.get_time() - host_infos.unhealthy_at >= host_infos.reconnection_delay) then
return true
end
end

--- Request Handler
Expand Down
23 changes: 16 additions & 7 deletions src/cassandra/options.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@ local DEFAULTS = {
socket_options = {
connect_timeout = 1000,
read_timeout = 2000
}
},
-- username = nil,
-- password = nil,
-- ssl_options = {
-- key = nil,
-- certificate = nil,
-- ca = nil, -- stub
-- verify = false
-- }
ssl_options = {
enabled = false
-- key = nil,
-- certificate = nil,
-- ca = nil, -- stub
-- verify = false
}
}

local function parse_session(options, lvl)
Expand Down Expand Up @@ -112,6 +113,14 @@ local function parse_session(options, lvl)
return nil, "socket read_timeout must be a number"
end

if type(options.ssl_options) ~= "table" then
return nil, "ssl_options must be a table"
end

if type(options.ssl_options.enabled) ~= "boolean" then
return nil, "ssl_options.enabled must be a boolean"
end

return options
end

Expand Down
6 changes: 4 additions & 2 deletions src/cassandra/utils/table.lua
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
local setmetatable = setmetatable
local getmetatable = getmetatable
local table_remove = table.remove
local tostring = tostring
local ipairs = ipairs
local pairs = pairs
local type = type

local _M = {}

function _M.extend_table(...)
local sources = {...}
local values = table.remove(sources)
local values = table_remove(sources)

for _, source in ipairs(sources) do
for k in pairs(source) do
if values[k] == nil then
values[k] = source[k]
end
if type(source[k]) == "table" then
if type(source[k]) == "table" and type(values[k]) == "table" then
_M.extend_table(source[k], values[k])
end
end
Expand Down

0 comments on commit b6fc8ea

Please sign in to comment.