Skip to content

Commit

Permalink
Merge pull request #73 from jthacker/feat/enable-non-pickle-serializers
Browse files Browse the repository at this point in the history
Enable non-pickle serializers
  • Loading branch information
peter-wangxu authored Oct 23, 2018
2 parents f63afa7 + 9c685ea commit cecb263
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 43 deletions.
5 changes: 2 additions & 3 deletions persistqueue/pdict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#! coding = utf-8
import logging
import pickle
import sqlite3

from persistqueue import sqlbase
Expand Down Expand Up @@ -50,7 +49,7 @@ def __contains__(self, item):
return row is not None

def __setitem__(self, key, value):
obj = pickle.dumps(value)
obj = self._serializer.dumps(value)
try:
self._insert_into(key, obj)
except sqlite3.IntegrityError:
Expand All @@ -59,7 +58,7 @@ def __setitem__(self, key, value):
def __getitem__(self, item):
row = self._select(item)
if row:
return pickle.loads(row[1])
return self._serializer.loads(row[1])
else:
raise KeyError('Key: {} not exists.'.format(item))

Expand Down
27 changes: 13 additions & 14 deletions persistqueue/sqlackqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import unicode_literals

import logging
import pickle
import sqlite3
import time as _time
import threading
Expand Down Expand Up @@ -71,7 +70,7 @@ def resume_unack_tasks(self):
return sql, (AckStatus.ready, AckStatus.unack, )

def put(self, item):
obj = pickle.dumps(item, protocol=self.protocol)
obj = self._serializer.dumps(item)
self._insert_into(obj, _time.time())
self.total += 1
self.put_event.set()
Expand Down Expand Up @@ -136,8 +135,8 @@ def _pop(self):
# by select, below can avoid these invalid records.
if row and row[0] is not None:
self._mark_ack_status(row[0], AckStatus.unack)
pickled_data = row[1] # pickled data
item = pickle.loads(pickled_data)
serialized_data = row[1]
item = self._serializer.loads(serialized_data)
self._unack_cache[row[0]] = item
self.total -= 1
return item
Expand Down Expand Up @@ -177,31 +176,31 @@ def nack(self, item):

def get(self, block=True, timeout=None):
if not block:
pickled = self._pop()
if not pickled:
serialized = self._pop()
if not serialized:
raise Empty
elif timeout is None:
# block until a put event.
pickled = self._pop()
while not pickled:
serialized = self._pop()
while not serialized:
self.put_event.clear()
self.put_event.wait(TICK_FOR_WAIT)
pickled = self._pop()
serialized = self._pop()
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
# block until the timeout reached
endtime = _time.time() + timeout
pickled = self._pop()
while not pickled:
serialized = self._pop()
while not serialized:
self.put_event.clear()
remaining = endtime - _time.time()
if remaining <= 0.0:
raise Empty
self.put_event.wait(
TICK_FOR_WAIT if TICK_FOR_WAIT < remaining else remaining)
pickled = self._pop()
item = pickled
serialized = self._pop()
item = serialized
return item

def task_done(self):
Expand Down Expand Up @@ -242,7 +241,7 @@ class UniqueAckQ(SQLiteAckQueue):
)

def put(self, item):
obj = pickle.dumps(item)
obj = self._serializer.dumps(item)
try:
self._insert_into(obj, _time.time())
except sqlite3.IntegrityError:
Expand Down
22 changes: 13 additions & 9 deletions persistqueue/sqlbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import sqlite3
import threading

from persistqueue import common
import persistqueue.serializers.pickle


sqlite3.enable_callback_tracebacks(True)

Expand Down Expand Up @@ -50,7 +51,8 @@ class SQLiteBase(object):
_MEMORY = ':memory:' # flag indicating store DB in memory

def __init__(self, path, name='default', multithreading=False,
timeout=10.0, auto_commit=True):
timeout=10.0, auto_commit=True,
serializer=persistqueue.serializers.pickle):
"""Initiate a queue in sqlite3 or memory.
:param path: path for storing DB file.
Expand All @@ -63,16 +65,22 @@ def __init__(self, path, name='default', multithreading=False,
INSERT/UPDATE action, otherwise False, whereas
a **task_done** is required to persist changes
after **put**.
:param serializer: The serializer parameter controls how enqueued data
is serialized. It must have methods dump(value, fp)
and load(fp). The dump method must serialize the
value and write it to fp, and may be called for
multiple values with the same fp. The load method
must deserialize and return one value from fp,
and may be called multiple times with the same fp
to read multiple values.
"""
self.memory_sql = False
self.path = path
self.name = name
self.timeout = timeout
self.multithreading = multithreading
self.auto_commit = auto_commit
self.protocol = None
self._serializer = serializer
self._init()

def _init(self):
Expand All @@ -81,14 +89,10 @@ def _init(self):
if self.path == self._MEMORY:
self.memory_sql = True
log.debug("Initializing Sqlite3 Queue in memory.")
self.protocol = common.select_pickle_protocol()
elif not os.path.exists(self.path):
os.makedirs(self.path)
log.debug(
'Initializing Sqlite3 Queue with path {}'.format(self.path))
# Set to current highest pickle protocol for new queue.
self.protocol = common.select_pickle_protocol()

self._conn = self._new_db_connection(
self.path, self.multithreading, self.timeout)
self._getter = self._conn
Expand Down
25 changes: 12 additions & 13 deletions persistqueue/sqlqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""A thread-safe sqlite3 based persistent queue in Python."""

import logging
import pickle
import sqlite3
import time as _time
import threading
Expand Down Expand Up @@ -37,7 +36,7 @@ class SQLiteQueue(sqlbase.SQLiteBase):
' {column} {op} ? ORDER BY {key_column} ASC LIMIT 1 '

def put(self, item):
obj = pickle.dumps(item, protocol=self.protocol)
obj = self._serializer.dumps(item)
self._insert_into(obj, _time.time())
self.total += 1
self.put_event.set()
Expand All @@ -64,7 +63,7 @@ def _pop(self):
if row and row[0] is not None:
self._delete(row[0])
self.total -= 1
return row[1] # pickled data
return row[1] # serialized data
else:
row = self._select(
self.cursor, op=">", column=self._KEY_COLUMN)
Expand All @@ -76,31 +75,31 @@ def _pop(self):

def get(self, block=True, timeout=None):
if not block:
pickled = self._pop()
if not pickled:
serialized = self._pop()
if not serialized:
raise Empty
elif timeout is None:
# block until a put event.
pickled = self._pop()
while not pickled:
serialized = self._pop()
while not serialized:
self.put_event.clear()
self.put_event.wait(TICK_FOR_WAIT)
pickled = self._pop()
serialized = self._pop()
elif timeout < 0:
raise ValueError("'timeout' must be a non-negative number")
else:
# block until the timeout reached
endtime = _time.time() + timeout
pickled = self._pop()
while not pickled:
serialized = self._pop()
while not serialized:
self.put_event.clear()
remaining = endtime - _time.time()
if remaining <= 0.0:
raise Empty
self.put_event.wait(
TICK_FOR_WAIT if TICK_FOR_WAIT < remaining else remaining)
pickled = self._pop()
item = pickle.loads(pickled)
serialized = self._pop()
item = self._serializer.loads(serialized)
return item

def task_done(self):
Expand Down Expand Up @@ -139,7 +138,7 @@ class UniqueQ(SQLiteQueue):
'data BLOB, timestamp FLOAT, UNIQUE (data))')

def put(self, item):
obj = pickle.dumps(item)
obj = self._serializer.dumps(item)
try:
self._insert_into(obj, _time.time())
except sqlite3.IntegrityError:
Expand Down
6 changes: 4 additions & 2 deletions tests/test_sqlackqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,13 @@ def consumer(index):
def test_protocol_1(self):
shutil.rmtree(self.path, ignore_errors=True)
q = SQLiteAckQueue(path=self.path)
self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4)
self.assertEqual(q._serializer.protocol,
2 if sys.version_info[0] == 2 else 4)

def test_protocol_2(self):
q = SQLiteAckQueue(path=self.path)
self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4)
self.assertEqual(q._serializer.protocol,
2 if sys.version_info[0] == 2 else 4)

def test_ack_and_clear(self):
q = SQLiteAckQueue(path=self.path)
Expand Down
21 changes: 19 additions & 2 deletions tests/test_sqlqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import unittest
from threading import Thread

import persistqueue.serializers
from persistqueue import SQLiteQueue, FILOSQLiteQueue, UniqueQ
from persistqueue import Empty

Expand Down Expand Up @@ -200,11 +201,27 @@ def test_task_done_with_restart(self):
def test_protocol_1(self):
shutil.rmtree(self.path, ignore_errors=True)
q = SQLiteQueue(path=self.path)
self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4)
self.assertEqual(q._serializer.protocol,
2 if sys.version_info[0] == 2 else 4)

def test_protocol_2(self):
q = SQLiteQueue(path=self.path)
self.assertEqual(q.protocol, 2 if sys.version_info[0] == 2 else 4)
self.assertEqual(q._serializer.protocol,
2 if sys.version_info[0] == 2 else 4)

def test_json_serializer(self):
q = SQLiteQueue(
path=self.path,
serializer=persistqueue.serializers.json)
x = dict(
a=1,
b=2,
c=dict(
d=list(range(5)),
e=[1]
))
q.put(x)
self.assertEquals(q.get(), x)


class SQLite3QueueNoAutoCommitTest(SQLite3QueueTest):
Expand Down

0 comments on commit cecb263

Please sign in to comment.