Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Connection management improvements. #886

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
SYM_EMPTY = b('')

SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server."
TCP_CLOSE_WAIT = 8


class Token(object):
Expand Down Expand Up @@ -546,6 +547,9 @@ def disconnect(self):
return
try:
self._sock.shutdown(socket.SHUT_RDWR)
except (OSError, socket.error):
pass
try:
self._sock.close()
except socket.error:
pass
Expand Down Expand Up @@ -670,6 +674,27 @@ def pack_commands(self, commands):
output.append(SYM_EMPTY.join(pieces))
return output

def validate(self):
# If we are not yet connected, this connection is valid.
if self._sock is None:
return

# Check that connection is usable. We check specifically for the
# CLOSE_WAIT state on the socket. The socket enters this state for
# example when the Redis server times out the connection due to
# idleness.
try:
state = ord(self._sock.getsockopt(
socket.SOL_TCP, socket.TCP_INFO, 1))

if state == TCP_CLOSE_WAIT:
raise ConnectionError()

except socket.error:
raise ConnectionError()

return True


class SSLConnection(Connection):
description_format = "SSLConnection<host=%(host)s,port=%(port)s,db=%(db)s>"
Expand Down Expand Up @@ -752,6 +777,9 @@ def _error_message(self, exception):
return "Error %s connecting to unix socket: %s. %s." % \
(exception.args[0], self.path, exception.args[1])

def validate(self):
return True


FALSE_STRINGS = ('0', 'F', 'FALSE', 'N', 'NO')

Expand Down Expand Up @@ -949,10 +977,22 @@ def _checkpid(self):
def get_connection(self, command_name, *keys, **options):
"Get a connection from the pool"
self._checkpid()
try:
connection = self._available_connections.pop()
except IndexError:
connection = self.make_connection()
while True:
try:
# Pop the connection from the head of the list, this ensures
# that we "cycle" through the pool instead of repeatedly
# reusing only the connection at the tail.
connection = self._available_connections.pop(0)
# Raises ConnectionError if connection is invalid.
connection.validate()
break
except IndexError:
# All connections are in use, create a new one.
connection = self.make_connection()
break
except ConnectionError:
# Connection is unusable, close it and move on.
connection.disconnect()
self._in_use_connections.add(connection)
return connection

Expand All @@ -978,6 +1018,11 @@ def release(self, connection):
if connection.pid != self.pid:
return
self._in_use_connections.remove(connection)
if self._available_connections:
# There is already an idle, available connection, we can afford to
# shut this one down.
connection.disconnect()
return
self._available_connections.append(connection)

def disconnect(self):
Expand Down
3 changes: 3 additions & 0 deletions tests/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ def __init__(self, **kwargs):
self.kwargs = kwargs
self.pid = os.getpid()

def validate(self):
return True


class TestConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=None,
Expand Down