Skip to content

Commit

Permalink
Merge pull request #148 from roeltm/add-ssl-connector
Browse files Browse the repository at this point in the history
Add ssl option to connection
  • Loading branch information
IlyaSkriblovsky authored Mar 2, 2023
2 parents c89a498 + f2ef373 commit e490fdb
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 21 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ The arguments are:
- hosts (for sharded): list of ``host:port`` pairs. [default: None]
- paths (for sharded): list of ``pathnames``. [default: None]
- password: password for the redis server. [default: None]
- ssl_context_factory: Either a boolean indicating wether to use SSL/TLS or a specific `ClientContextFactory`. [default: False]


### Connection Handlers ###
Expand Down
29 changes: 29 additions & 0 deletions examples/ssl_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python
import txredisapi as redis

from twisted.internet import defer, ssl
from twisted.internet import reactor


class TestContextFactory(ssl.ClientContextFactory):
def getContext(self):
ctx = ssl.ClientContextFactory.getContext(self)
# ctx.load_verify_locations('./test/ca.crt')
# ctx.use_certificate_file('./test/redis.crt')
# ctx.use_privatekey_file('./test/redis.key')
return ctx

@defer.inlineCallbacks
def main():
rc = yield redis.Connection(ssl_context_factory=TestContextFactory())
print(rc)

yield rc.set("foo", "bar")
v = yield rc.get("foo")
print("foo:", repr(v))

yield rc.disconnect()

if __name__ == "__main__":
main().addCallback(lambda ign: reactor.stop())
reactor.run()
48 changes: 27 additions & 21 deletions txredisapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
import hashlib
import random

from twisted.internet import defer
from typing import Optional, Union
from twisted.internet import defer, ssl
from twisted.internet import protocol
from twisted.internet import reactor
from twisted.internet.tcp import Connector
Expand Down Expand Up @@ -2380,14 +2381,19 @@ def __init__(self, isLazy=False, handler=ConnectionHandler):


def makeConnection(host, port, dbid, poolsize, reconnect, isLazy,
charset, password, connectTimeout, replyTimeout,
charset, password, ssl_context_factory, connectTimeout, replyTimeout,
convertNumbers):
uuid = "%s:%d" % (host, port)
factory = RedisFactory(uuid, dbid, poolsize, isLazy, ConnectionHandler,
charset, password, replyTimeout, convertNumbers)
factory.continueTrying = reconnect
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout)
if ssl_context_factory is True:
ssl_context_factory = ssl.ClientContextFactory()
if ssl_context_factory:
reactor.connectSSL(host, port, factory, ssl_context_factory, connectTimeout)
else:
reactor.connectTCP(host, port, factory, connectTimeout)

if isLazy:
return factory.handler
Expand All @@ -2396,7 +2402,7 @@ def makeConnection(host, port, dbid, poolsize, reconnect, isLazy,


def makeShardedConnection(hosts, dbid, poolsize, reconnect, isLazy,
charset, password, connectTimeout, replyTimeout,
charset, password, ssl_context_factory, connectTimeout, replyTimeout,
convertNumbers):
err = "Please use a list or tuple of host:port for sharded connections"
if not isinstance(hosts, (list, tuple)):
Expand All @@ -2411,7 +2417,7 @@ def makeShardedConnection(hosts, dbid, poolsize, reconnect, isLazy,
raise ValueError(err)

c = makeConnection(host, port, dbid, poolsize, reconnect, isLazy,
charset, password, connectTimeout, replyTimeout,
charset, password, ssl_context_factory, connectTimeout, replyTimeout,
convertNumbers)
connections.append(c)

Expand All @@ -2424,71 +2430,71 @@ def makeShardedConnection(hosts, dbid, poolsize, reconnect, isLazy,


def Connection(host="localhost", port=6379, dbid=None, reconnect=True,
charset="utf-8", password=None,
charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False,
connectTimeout=None, replyTimeout=None, convertNumbers=True):
return makeConnection(host, port, dbid, 1, reconnect, False,
charset, password, connectTimeout, replyTimeout,
charset, password, ssl_context_factory, connectTimeout, replyTimeout,
convertNumbers)


def lazyConnection(host="localhost", port=6379, dbid=None, reconnect=True,
charset="utf-8", password=None,
charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False,
connectTimeout=None, replyTimeout=None, convertNumbers=True):
return makeConnection(host, port, dbid, 1, reconnect, True,
charset, password, connectTimeout, replyTimeout,
charset, password, ssl_context_factory, connectTimeout, replyTimeout,
convertNumbers)


def ConnectionPool(host="localhost", port=6379, dbid=None,
poolsize=10, reconnect=True, charset="utf-8", password=None,
poolsize=10, reconnect=True, charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False,
connectTimeout=None, replyTimeout=None,
convertNumbers=True):
return makeConnection(host, port, dbid, poolsize, reconnect, False,
charset, password, connectTimeout, replyTimeout,
charset, password, ssl_context_factory, connectTimeout, replyTimeout,
convertNumbers)


def lazyConnectionPool(host="localhost", port=6379, dbid=None,
poolsize=10, reconnect=True, charset="utf-8",
password=None, connectTimeout=None, replyTimeout=None,
password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None,
convertNumbers=True):
return makeConnection(host, port, dbid, poolsize, reconnect, True,
charset, password, connectTimeout, replyTimeout,
charset, password, ssl_context_factory, connectTimeout, replyTimeout,
convertNumbers)


def ShardedConnection(hosts, dbid=None, reconnect=True, charset="utf-8",
password=None, connectTimeout=None, replyTimeout=None,
password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False, connectTimeout=None, replyTimeout=None,
convertNumbers=True):
return makeShardedConnection(hosts, dbid, 1, reconnect, False,
charset, password, connectTimeout,
charset, password, ssl_context_factory, connectTimeout,
replyTimeout, convertNumbers)


def lazyShardedConnection(hosts, dbid=None, reconnect=True, charset="utf-8",
password=None,
password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False,
connectTimeout=None, replyTimeout=None,
convertNumbers=True):
return makeShardedConnection(hosts, dbid, 1, reconnect, True,
charset, password, connectTimeout,
charset, password, ssl_context_factory, connectTimeout,
replyTimeout, convertNumbers)


def ShardedConnectionPool(hosts, dbid=None, poolsize=10, reconnect=True,
charset="utf-8", password=None,
charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False,
connectTimeout=None, replyTimeout=None,
convertNumbers=True):
return makeShardedConnection(hosts, dbid, poolsize, reconnect, False,
charset, password, connectTimeout,
charset, password, ssl_context_factory, connectTimeout,
replyTimeout, convertNumbers)


def lazyShardedConnectionPool(hosts, dbid=None, poolsize=10, reconnect=True,
charset="utf-8", password=None,
charset="utf-8", password=None, ssl_context_factory: Union[ssl.ClientContextFactory, bool]=False,
connectTimeout=None, replyTimeout=None,
convertNumbers=True):
return makeShardedConnection(hosts, dbid, poolsize, reconnect, True,
charset, password, connectTimeout,
charset, password, ssl_context_factory, connectTimeout,
replyTimeout, convertNumbers)


Expand Down

0 comments on commit e490fdb

Please sign in to comment.