diff --git a/README.md b/README.md index 7c3a8ea..a749ef0 100644 --- a/README.md +++ b/README.md @@ -154,6 +154,7 @@ The arguments are: - hosts (for sharded): list of ``host:port`` pairs. [default: None] - paths (for sharded): list of ``pathnames``. [default: None] - password: password for the redis server. [default: None] +- use_ssl: boolean indicating wether to use SSL/TLS. [default: False] ### Connection Handlers ### diff --git a/examples/ssl_connection.py b/examples/ssl_connection.py new file mode 100644 index 0000000..6af250e --- /dev/null +++ b/examples/ssl_connection.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +import txredisapi as redis + +from twisted.internet import defer, ssl +from twisted.internet import reactor + + +class TestContextFactory(ssl.ClientContextFactory): + def getContext(self): + ctx = ssl.ClientContextFactory.getContext(self) + # ctx.load_verify_locations('./test/ca.crt') + # ctx.use_certificate_file('./test/redis.crt') + # ctx.use_privatekey_file('./test/redis.key') + return ctx + +@defer.inlineCallbacks +def main(): + rc = yield redis.Connection(ssl_context_factory=TestContextFactory()) + print(rc) + + yield rc.set("foo", "bar") + v = yield rc.get("foo") + print("foo:", repr(v)) + + yield rc.disconnect() + +if __name__ == "__main__": + main().addCallback(lambda ign: reactor.stop()) + reactor.run() \ No newline at end of file diff --git a/txredisapi.py b/txredisapi.py index a9b7cee..3aae200 100644 --- a/txredisapi.py +++ b/txredisapi.py @@ -35,7 +35,8 @@ import hashlib import random -from twisted.internet import defer +from typing import Optional, Union +from twisted.internet import defer, ssl from twisted.internet import protocol from twisted.internet import reactor from twisted.internet.tcp import Connector @@ -2380,14 +2381,19 @@ def __init__(self, isLazy=False, handler=ConnectionHandler): def makeConnection(host, port, dbid, poolsize, reconnect, isLazy, - charset, password, connectTimeout, replyTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers): uuid = "%s:%d" % (host, port) factory = RedisFactory(uuid, dbid, poolsize, isLazy, ConnectionHandler, charset, password, replyTimeout, convertNumbers) factory.continueTrying = reconnect for x in range(poolsize): - reactor.connectTCP(host, port, factory, connectTimeout) + if isinstance(ssl_context_factory, bool) and ssl_context_factory is True: + ssl_context_factory = ssl.ClientContextFactory() + if ssl_context_factory: + reactor.connectSSL(host, port, factory, ssl_context_factory, connectTimeout) + else: + reactor.connectTCP(host, port, factory, connectTimeout) if isLazy: return factory.handler @@ -2396,7 +2402,7 @@ def makeConnection(host, port, dbid, poolsize, reconnect, isLazy, def makeShardedConnection(hosts, dbid, poolsize, reconnect, isLazy, - charset, password, connectTimeout, replyTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers): err = "Please use a list or tuple of host:port for sharded connections" if not isinstance(hosts, (list, tuple)): @@ -2411,7 +2417,7 @@ def makeShardedConnection(hosts, dbid, poolsize, reconnect, isLazy, raise ValueError(err) c = makeConnection(host, port, dbid, poolsize, reconnect, isLazy, - charset, password, connectTimeout, replyTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) connections.append(c) @@ -2424,71 +2430,71 @@ def makeShardedConnection(hosts, dbid, poolsize, reconnect, isLazy, def Connection(host="localhost", port=6379, dbid=None, reconnect=True, - charset="utf-8", password=None, + charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeConnection(host, port, dbid, 1, reconnect, False, - charset, password, connectTimeout, replyTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) def lazyConnection(host="localhost", port=6379, dbid=None, reconnect=True, - charset="utf-8", password=None, + charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeConnection(host, port, dbid, 1, reconnect, True, - charset, password, connectTimeout, replyTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) def ConnectionPool(host="localhost", port=6379, dbid=None, - poolsize=10, reconnect=True, charset="utf-8", password=None, + poolsize=10, reconnect=True, charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeConnection(host, port, dbid, poolsize, reconnect, False, - charset, password, connectTimeout, replyTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) def lazyConnectionPool(host="localhost", port=6379, dbid=None, poolsize=10, reconnect=True, charset="utf-8", - password=None, connectTimeout=None, replyTimeout=None, + password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeConnection(host, port, dbid, poolsize, reconnect, True, - charset, password, connectTimeout, replyTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) def ShardedConnection(hosts, dbid=None, reconnect=True, charset="utf-8", - password=None, connectTimeout=None, replyTimeout=None, + password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeShardedConnection(hosts, dbid, 1, reconnect, False, - charset, password, connectTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) def lazyShardedConnection(hosts, dbid=None, reconnect=True, charset="utf-8", - password=None, + password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeShardedConnection(hosts, dbid, 1, reconnect, True, - charset, password, connectTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) def ShardedConnectionPool(hosts, dbid=None, poolsize=10, reconnect=True, - charset="utf-8", password=None, + charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeShardedConnection(hosts, dbid, poolsize, reconnect, False, - charset, password, connectTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers) def lazyShardedConnectionPool(hosts, dbid=None, poolsize=10, reconnect=True, - charset="utf-8", password=None, + charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None, convertNumbers=True): return makeShardedConnection(hosts, dbid, poolsize, reconnect, True, - charset, password, connectTimeout, + charset, password, ssl_context_factory, connectTimeout, replyTimeout, convertNumbers)