Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth and reconnecting #7

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
75 changes: 69 additions & 6 deletions brukva/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def format_pipeline_request(command_stack):
return ''.join(format(c.cmd, *c.args, **c.kwargs) for c in command_stack)

class Connection(object):
def __init__(self, host, port, timeout=None, io_loop=None):
def __init__(self, client, host, port, timeout=None, io_loop=None):
self.client = client
self.host = host
self.port = port
self.timeout = timeout
Expand Down Expand Up @@ -79,16 +80,41 @@ def disconnect(self):
self._stream = None

def write(self, data):
if not self._stream:
if self.client.reconnect:
self.client.connect()
self._stream.write(data)

def consume(self, length):
if not self._stream:
if self.client.reconnect:
self.client.connect()
self._stream.read_bytes(length, NOOP_CB)

def read(self, length, callback):
self._stream.read_bytes(length, callback)
try:
if not self._stream:
self.client._sudden_disconnect([callback])
if self.client.reconnect:
self.client.connect()
self._stream.read_bytes(length, callback)
except IOError:
self.client._sudden_disconnect([callback])
if self.client.reconnect:
self.client.connect()


def readline(self, callback):
self._stream.read_until('\r\n', callback)
try:
if not self._stream:
self.client._sudden_disconnect([callback])
if self.client.reconnect:
self.client.connect()
self._stream.read_until('\r\n', callback)
except IOError:
self.client._sudden_disconnect([callback])
if self.client.reconnect:
self.client.connect()

def try_to_perform_read(self):
if not self.in_progress and self.read_queue:
Expand All @@ -104,6 +130,11 @@ def read_done(self):
self.in_progress = False
self.try_to_perform_read()

def connected(self):
if self._stream:
return True
return False

def reply_to_bool(r, *args, **kwargs):
return bool(r)

Expand Down Expand Up @@ -163,13 +194,15 @@ def reply_ttl(r, *args, **kwargs):
return r != -1 and r or None

class Client(object):
def __init__(self, host='localhost', port=6379, io_loop=None):
def __init__(self, host='localhost', port=6379, password=None, reconnect=False, io_loop=None):
self._io_loop = io_loop or IOLoop.instance()

self.connection = Connection(host, port, io_loop=self._io_loop)
self.connection = Connection(self, host, port, io_loop=self._io_loop)
self.queue = []
self.current_cmd_line = None
self.subscribed = False
self.password = password
self.reconnect = reconnect
self.REPLY_MAP = dict_merge(
string_keys_to_dict('AUTH BGREWRITEAOF BGSAVE DEL EXISTS EXPIRE HDEL HEXISTS '
'HMSET MOVE MSET MSETNX SAVE SETNX',
Expand All @@ -187,7 +220,7 @@ def __init__(self, host='localhost', port=6379, io_loop=None):
reply_pubsub_message),
string_keys_to_dict('ZRANK ZREVRANK',
reply_int),
string_keys_to_dict('ZSCORE ZINCRBY',
string_keys_to_dict('ZSCORE ZINCRBY ZCOUNT ZCARD',
reply_int),
string_keys_to_dict('ZRANGE ZRANGEBYSCORE ZREVRANGE',
reply_zset),
Expand All @@ -212,6 +245,8 @@ def pipeline(self, transactional=False):
#### connection
def connect(self):
self.connection.connect()
if self.password:
self.auth(self.password)

def disconnect(self):
self.connection.disconnect()
Expand Down Expand Up @@ -259,6 +294,8 @@ def execute_command(self, cmd, callbacks, *args, **kwargs):
elif not hasattr(callbacks, '__iter__'):
callbacks = [callbacks]
try:
if self.reconnect and not self.connection.connected():
self.connect()
self.connection.write(self.format(cmd, *args, **kwargs))
except IOError:
self._sudden_disconnect(callbacks)
Expand All @@ -285,6 +322,8 @@ def execute_command(self, cmd, callbacks, *args, **kwargs):
@process
def process_data(self, data, cmd_line, callback):
error, response = None, None
if error:
callback((error, None))

data = data[:-2] # strip \r\n

Expand All @@ -293,6 +332,10 @@ def process_data(self, data, cmd_line, callback):
elif data == '*0' or data == '*-1':
response = []
else:
if len(data) == 0:
if self.reconnect:
self.connect()
callback((IOError('Disconnected'),None))
head, tail = data[0], data[1:]

if head == '*':
Expand Down Expand Up @@ -322,6 +365,9 @@ def consume_multibulk(self, length, cmd_line, callback):
data = yield async(self.connection.readline)()
if not data:
break
if isinstance(data, Exception):
errors[idx] = data
break

error, token = yield self.process_data(data, cmd_line) #FIXME error
tokens.append( token )
Expand All @@ -335,6 +381,8 @@ def consume_multibulk(self, length, cmd_line, callback):
@process
def consume_bulk(self, length, callback):
data = yield async(self.connection.read)(length)
if isinstance(data, Exception):
callback((data, None))
error = None
if not data:
error = ResponseError('EmptyResponse')
Expand Down Expand Up @@ -588,6 +636,19 @@ def zrevrank(self, key, value, callbacks=None):
def zrem(self, key, value, callbacks=None):
self.execute_command('ZREM', callbacks, key, value)

def zcount(self, key, start, end, offset=None, limit=None, with_scores=None, callbacks=None):
tokens = [key, start, end]
if offset is not None:
tokens.append('LIMIT')
tokens.append(offset)
tokens.append(limit)
if with_scores:
tokens.append('WITHSCORES')
self.execute_command('ZCOUNT', callbacks, *tokens)

def zcard(self, key, callbacks=None):
self.execute_command('ZCARD', callbacks, key)

def zscore(self, key, value, callbacks=None):
self.execute_command('ZSCORE', callbacks, key, value)

Expand Down Expand Up @@ -761,6 +822,8 @@ def execute(self, callbacks):

request = format_pipeline_request(command_stack)
try:
if self.reconnect and not self.connection.connected():
self.connect()
self.connection.write(request)
except IOError:
self.command_stack = []
Expand Down