From 9c685eaff5a9b397ce37821d96b02db7b5679b3d Mon Sep 17 00:00:00 2001 From: Jon Thacker Date: Mon, 22 Oct 2018 21:46:25 -0400 Subject: [PATCH] Enable non-pickle serializers * Adds a serializer option to all sqlqueue and pdict data structures * Removes sqlbase.protocol field --- persistqueue/pdict.py | 5 ++--- persistqueue/sqlackqueue.py | 27 +++++++++++++-------------- persistqueue/sqlbase.py | 22 +++++++++++++--------- persistqueue/sqlqueue.py | 25 ++++++++++++------------- tests/test_sqlackqueue.py | 6 ++++-- tests/test_sqlqueue.py | 21 +++++++++++++++++++-- 6 files changed, 63 insertions(+), 43 deletions(-) diff --git a/persistqueue/pdict.py b/persistqueue/pdict.py index a0d16ba..2e09dc8 100644 --- a/persistqueue/pdict.py +++ b/persistqueue/pdict.py @@ -1,6 +1,5 @@ #! coding = utf-8 import logging -import pickle import sqlite3 from persistqueue import sqlbase @@ -50,7 +49,7 @@ def __contains__(self, item): return row is not None def __setitem__(self, key, value): - obj = pickle.dumps(value) + obj = self._serializer.dumps(value) try: self._insert_into(key, obj) except sqlite3.IntegrityError: @@ -59,7 +58,7 @@ def __setitem__(self, key, value): def __getitem__(self, item): row = self._select(item) if row: - return pickle.loads(row[1]) + return self._serializer.loads(row[1]) else: raise KeyError('Key: {} not exists.'.format(item)) diff --git a/persistqueue/sqlackqueue.py b/persistqueue/sqlackqueue.py index 7950806..793c5cb 100644 --- a/persistqueue/sqlackqueue.py +++ b/persistqueue/sqlackqueue.py @@ -3,7 +3,6 @@ from __future__ import unicode_literals import logging -import pickle import sqlite3 import time as _time import threading @@ -71,7 +70,7 @@ def resume_unack_tasks(self): return sql, (AckStatus.ready, AckStatus.unack, ) def put(self, item): - obj = pickle.dumps(item, protocol=self.protocol) + obj = self._serializer.dumps(item) self._insert_into(obj, _time.time()) self.total += 1 self.put_event.set() @@ -136,8 +135,8 @@ def _pop(self): # by select, below can avoid these invalid records. if row and row[0] is not None: self._mark_ack_status(row[0], AckStatus.unack) - pickled_data = row[1] # pickled data - item = pickle.loads(pickled_data) + serialized_data = row[1] + item = self._serializer.loads(serialized_data) self._unack_cache[row[0]] = item self.total -= 1 return item @@ -177,31 +176,31 @@ def nack(self, item): def get(self, block=True, timeout=None): if not block: - pickled = self._pop() - if not pickled: + serialized = self._pop() + if not serialized: raise Empty elif timeout is None: # block until a put event. - pickled = self._pop() - while not pickled: + serialized = self._pop() + while not serialized: self.put_event.clear() self.put_event.wait(TICK_FOR_WAIT) - pickled = self._pop() + serialized = self._pop() elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: # block until the timeout reached endtime = _time.time() + timeout - pickled = self._pop() - while not pickled: + serialized = self._pop() + while not serialized: self.put_event.clear() remaining = endtime - _time.time() if remaining <= 0.0: raise Empty self.put_event.wait( TICK_FOR_WAIT if TICK_FOR_WAIT < remaining else remaining) - pickled = self._pop() - item = pickled + serialized = self._pop() + item = serialized return item def task_done(self): @@ -242,7 +241,7 @@ class UniqueAckQ(SQLiteAckQueue): ) def put(self, item): - obj = pickle.dumps(item) + obj = self._serializer.dumps(item) try: self._insert_into(obj, _time.time()) except sqlite3.IntegrityError: diff --git a/persistqueue/sqlbase.py b/persistqueue/sqlbase.py index 551220e..fff027d 100644 --- a/persistqueue/sqlbase.py +++ b/persistqueue/sqlbase.py @@ -3,7 +3,8 @@ import sqlite3 import threading -from persistqueue import common +import persistqueue.serializers.pickle + sqlite3.enable_callback_tracebacks(True) @@ -50,7 +51,8 @@ class SQLiteBase(object): _MEMORY = ':memory:' # flag indicating store DB in memory def __init__(self, path, name='default', multithreading=False, - timeout=10.0, auto_commit=True): + timeout=10.0, auto_commit=True, + serializer=persistqueue.serializers.pickle): """Initiate a queue in sqlite3 or memory. :param path: path for storing DB file. @@ -63,8 +65,14 @@ def __init__(self, path, name='default', multithreading=False, INSERT/UPDATE action, otherwise False, whereas a **task_done** is required to persist changes after **put**. - - + :param serializer: The serializer parameter controls how enqueued data + is serialized. It must have methods dump(value, fp) + and load(fp). The dump method must serialize the + value and write it to fp, and may be called for + multiple values with the same fp. The load method + must deserialize and return one value from fp, + and may be called multiple times with the same fp + to read multiple values. """ self.memory_sql = False self.path = path @@ -72,7 +80,7 @@ def __init__(self, path, name='default', multithreading=False, self.timeout = timeout self.multithreading = multithreading self.auto_commit = auto_commit - self.protocol = None + self._serializer = serializer self._init() def _init(self): @@ -81,14 +89,10 @@ def _init(self): if self.path == self._MEMORY: self.memory_sql = True log.debug("Initializing Sqlite3 Queue in memory.") - self.protocol = common.select_pickle_protocol() elif not os.path.exists(self.path): os.makedirs(self.path) log.debug( 'Initializing Sqlite3 Queue with path {}'.format(self.path)) - # Set to current highest pickle protocol for new queue. - self.protocol = common.select_pickle_protocol() - self._conn = self._new_db_connection( self.path, self.multithreading, self.timeout) self._getter = self._conn diff --git a/persistqueue/sqlqueue.py b/persistqueue/sqlqueue.py index 5dab453..04be3f9 100644 --- a/persistqueue/sqlqueue.py +++ b/persistqueue/sqlqueue.py @@ -3,7 +3,6 @@ """A thread-safe sqlite3 based persistent queue in Python.""" import logging -import pickle import sqlite3 import time as _time import threading @@ -37,7 +36,7 @@ class SQLiteQueue(sqlbase.SQLiteBase): ' {column} {op} ? ORDER BY {key_column} ASC LIMIT 1 ' def put(self, item): - obj = pickle.dumps(item, protocol=self.protocol) + obj = self._serializer.dumps(item) self._insert_into(obj, _time.time()) self.total += 1 self.put_event.set() @@ -64,7 +63,7 @@ def _pop(self): if row and row[0] is not None: self._delete(row[0]) self.total -= 1 - return row[1] # pickled data + return row[1] # serialized data else: row = self._select( self.cursor, op=">", column=self._KEY_COLUMN) @@ -76,31 +75,31 @@ def _pop(self): def get(self, block=True, timeout=None): if not block: - pickled = self._pop() - if not pickled: + serialized = self._pop() + if not serialized: raise Empty elif timeout is None: # block until a put event. - pickled = self._pop() - while not pickled: + serialized = self._pop() + while not serialized: self.put_event.clear() self.put_event.wait(TICK_FOR_WAIT) - pickled = self._pop() + serialized = self._pop() elif timeout < 0: raise ValueError("'timeout' must be a non-negative number") else: # block until the timeout reached endtime = _time.time() + timeout - pickled = self._pop() - while not pickled: + serialized = self._pop() + while not serialized: self.put_event.clear() remaining = endtime - _time.time() if remaining <= 0.0: raise Empty self.put_event.wait( TICK_FOR_WAIT if TICK_FOR_WAIT < remaining else remaining) - pickled = self._pop() - item = pickle.loads(pickled) + serialized = self._pop() + item = self._serializer.loads(serialized) return item def task_done(self): @@ -139,7 +138,7 @@ class UniqueQ(SQLiteQueue): 'data BLOB, timestamp FLOAT, UNIQUE (data))') def put(self, item): - obj = pickle.dumps(item) + obj = self._serializer.dumps(item) try: self._insert_into(obj, _time.time()) except sqlite3.IntegrityError: diff --git a/tests/test_sqlackqueue.py b/tests/test_sqlackqueue.py index aa2d04a..2bd7fcd 100644 --- a/tests/test_sqlackqueue.py +++ b/tests/test_sqlackqueue.py @@ -181,11 +181,13 @@ def consumer(index): def test_protocol_1(self): shutil.rmtree(self.path, ignore_errors=True) q = SQLiteAckQueue(path=self.path) - self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4) + self.assertEqual(q._serializer.protocol, + 2 if sys.version_info[0] == 2 else 4) def test_protocol_2(self): q = SQLiteAckQueue(path=self.path) - self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4) + self.assertEqual(q._serializer.protocol, + 2 if sys.version_info[0] == 2 else 4) def test_ack_and_clear(self): q = SQLiteAckQueue(path=self.path) diff --git a/tests/test_sqlqueue.py b/tests/test_sqlqueue.py index d7c7ef9..70f61af 100644 --- a/tests/test_sqlqueue.py +++ b/tests/test_sqlqueue.py @@ -7,6 +7,7 @@ import unittest from threading import Thread +import persistqueue.serializers from persistqueue import SQLiteQueue, FILOSQLiteQueue, UniqueQ from persistqueue import Empty @@ -200,11 +201,27 @@ def test_task_done_with_restart(self): def test_protocol_1(self): shutil.rmtree(self.path, ignore_errors=True) q = SQLiteQueue(path=self.path) - self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4) + self.assertEqual(q._serializer.protocol, + 2 if sys.version_info[0] == 2 else 4) def test_protocol_2(self): q = SQLiteQueue(path=self.path) - self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4) + self.assertEqual(q._serializer.protocol, + 2 if sys.version_info[0] == 2 else 4) + + def test_json_serializer(self): + q = SQLiteQueue( + path=self.path, + serializer=persistqueue.serializers.json) + x = dict( + a=1, + b=2, + c=dict( + d=list(range(5)), + e=[1] + )) + q.put(x) + self.assertEquals(q.get(), x) class SQLite3QueueNoAutoCommitTest(SQLite3QueueTest):