Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

schema: support encoding=None connections #172

Merged
merged 5 commits into from
Aug 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 73 additions & 11 deletions tarantool/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,83 @@
integer_types,
)
from tarantool.error import (
Error,
SchemaError,
DatabaseError
)
import tarantool.const as const


class RecursionError(Error):
"""Report the situation when max recursion depth is reached.

This is internal error for <to_unicode_recursive> caller
and it should be re-raised properly be the caller.
"""


def to_unicode(s):
if isinstance(s, bytes):
return s.decode(encoding='utf-8')
return s


def to_unicode_recursive(x, max_depth):
"""Same as to_unicode(), but traverses over dictionaries,
lists and tuples recursivery.

x: value to convert

max_depth: 1 accepts a scalar, 2 accepts a list of scalars,
etc.
"""
if max_depth <= 0:
raise RecursionError('Max recursion depth is reached')

if isinstance(x, dict):
res = dict()
for key, val in x.items():
key = to_unicode_recursive(key, max_depth - 1)
val = to_unicode_recursive(val, max_depth - 1)
res[key] = val
return res

if isinstance(x, list) or isinstance(x, tuple):
res = []
for val in x:
val = to_unicode_recursive(val, max_depth - 1)
res.append(val)
if isinstance(x, tuple):
return tuple(res)
return res

return to_unicode(x)


class SchemaIndex(object):
def __init__(self, index_row, space):
self.iid = index_row[1]
self.name = index_row[2]
if isinstance(self.name, bytes):
self.name = self.name.decode()
self.name = to_unicode(index_row[2])
self.index = index_row[3]
self.unique = index_row[4]
self.parts = []
if isinstance(index_row[5], (list, tuple)):
for val in index_row[5]:
try:
parts_raw = to_unicode_recursive(index_row[5], 3)
except RecursionError as e:
errmsg = 'Unexpected index parts structure: ' + str(e)
raise SchemaError(errmsg)
if isinstance(parts_raw, (list, tuple)):
for val in parts_raw:
if isinstance(val, dict):
self.parts.append((val['field'], val['type']))
else:
self.parts.append((val[0], val[1]))
else:
for i in range(index_row[5]):
for i in range(parts_raw):
self.parts.append((
index_row[5 + 1 + i * 2],
index_row[5 + 2 + i * 2]
to_unicode(index_row[5 + 1 + i * 2]),
to_unicode(index_row[5 + 2 + i * 2])
))
self.space = space
self.space.indexes[self.iid] = self
Expand All @@ -52,16 +103,19 @@ class SchemaSpace(object):
def __init__(self, space_row, schema):
self.sid = space_row[0]
self.arity = space_row[1]
self.name = space_row[2]
if isinstance(self.name, bytes):
self.name = self.name.decode()
self.name = to_unicode(space_row[2])
self.indexes = {}
self.schema = schema
self.schema[self.sid] = self
if self.name:
self.schema[self.name] = self
self.format = dict()
for part_id, part in enumerate(space_row[6]):
try:
format_raw = to_unicode_recursive(space_row[6], 3)
except RecursionError as e:
errmsg = 'Unexpected space format structure: ' + str(e)
raise SchemaError(errmsg)
for part_id, part in enumerate(format_raw):
part['id'] = part_id
self.format[part['name']] = part
self.format[part_id ] = part
Expand All @@ -78,6 +132,8 @@ def __init__(self, con):
self.con = con

def get_space(self, space):
space = to_unicode(space)

try:
return self.schema[space]
except KeyError:
Expand Down Expand Up @@ -135,6 +191,9 @@ def fetch_space_all(self):
SchemaSpace(row, self.schema)

def get_index(self, space, index):
space = to_unicode(space)
index = to_unicode(index)

_space = self.get_space(space)
try:
return _space.indexes[index]
Expand Down Expand Up @@ -203,6 +262,9 @@ def fetch_index_from(self, space, index):
return index_row

def get_field(self, space, field):
space = to_unicode(space)
field = to_unicode(field)

_space = self.get_space(space)
try:
return _space.format[field]
Expand Down
2 changes: 1 addition & 1 deletion unit/setup_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def run(self):
Find all tests in test/tarantool/ and run them
'''

tests = unittest.defaultTestLoader.discover('unit')
tests = unittest.defaultTestLoader.discover('unit', pattern='suites')
test_runner = unittest.TextTestRunner(verbosity=2)
result = test_runner.run(tests)
if not result.wasSuccessful():
Expand Down
9 changes: 6 additions & 3 deletions unit/suites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
__tmp = os.getcwd()
os.chdir(os.path.abspath(os.path.dirname(__file__)))

from .test_schema import TestSuite_Schema
from .test_schema import TestSuite_Schema_UnicodeConnection
from .test_schema import TestSuite_Schema_BinaryConnection
from .test_dml import TestSuite_Request
from .test_protocol import TestSuite_Protocol
from .test_reconnect import TestSuite_Reconnect
from .test_mesh import TestSuite_Mesh

test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol,
TestSuite_Reconnect, TestSuite_Mesh)
test_cases = (TestSuite_Schema_UnicodeConnection,
TestSuite_Schema_BinaryConnection,
TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect,
TestSuite_Mesh)

def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
Expand Down
139 changes: 136 additions & 3 deletions unit/suites/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,100 @@
import tarantool
from .lib.tarantool_server import TarantoolServer

class TestSuite_Schema(unittest.TestCase):

# FIXME: I'm quite sure that there is a simpler way to count
# a method calls, but I failed to find any. It seems, I should
# look at unittest.mock more thoroughly.
class MethodCallCounter:
def __init__(self, obj, method_name):
self._call_count = 0
self._bind(obj, method_name)

def _bind(self, obj, method_name):
self._obj = obj
self._method_name = method_name
self._saved_method = getattr(obj, method_name)
def wrapper(_, *args, **kwargs):
self._call_count += 1
return self._saved_method(*args, **kwargs)
bound_wrapper = wrapper.__get__(obj.__class__, obj)
setattr(obj, method_name, bound_wrapper)

def unbind(self):
if self._saved_method is not None:
setattr(self._obj, self._method_name, self._saved_method)

def call_count(self):
return self._call_count


class TestSuite_Schema_Abstract(unittest.TestCase):
# Define 'encoding' field in a concrete class.

@classmethod
def setUpClass(self):
print(' SCHEMA '.center(70, '='), file=sys.stderr)
params = 'connection.encoding: {}'.format(repr(self.encoding))
print(' SCHEMA ({}) '.format(params).center(70, '='), file=sys.stderr)
print('-' * 70, file=sys.stderr)
self.srv = TarantoolServer()
self.srv.script = 'unit/suites/box.lua'
self.srv.start()
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'])
self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
encoding=self.encoding)
self.sch = self.con.schema

# The relevant test cases mainly target Python 2, where
# a user may want to pass a string literal as a space or
# an index name and don't bother whether all symbols in it
# are ASCII.
self.unicode_space_name_literal = '∞'
self.unicode_index_name_literal = '→'

self.unicode_space_name_u = u'∞'
self.unicode_index_name_u = u'→'
self.unicode_space_id, self.unicode_index_id = self.srv.admin("""
do
local space = box.schema.create_space('\\xe2\\x88\\x9e')
local index = space:create_index('\\xe2\\x86\\x92')
return space.id, index.id
end
""")

def setUp(self):
# prevent a remote tarantool from clean our session
if self.srv.is_started():
self.srv.touch_lock()

# Count calls of fetch methods. See <fetch_count>.
self.fetch_space_counter = MethodCallCounter(self.sch, 'fetch_space')
self.fetch_index_counter = MethodCallCounter(self.sch, 'fetch_index')

def tearDown(self):
self.fetch_space_counter.unbind()
self.fetch_index_counter.unbind()

@property
def fetch_count(self):
"""Amount of fetch_{space,index}() calls.

It is initialized to zero before each test case.
"""
res = 0
res += self.fetch_space_counter.call_count()
res += self.fetch_index_counter.call_count()
return res

def verify_unicode_space(self, space):
self.assertEqual(space.sid, self.unicode_space_id)
self.assertEqual(space.name, self.unicode_space_name_u)
self.assertEqual(space.arity, 1)

def verify_unicode_index(self, index):
self.assertEqual(index.space.name, self.unicode_space_name_u)
self.assertEqual(index.iid, self.unicode_index_id)
self.assertEqual(index.name, self.unicode_index_name_u)
self.assertEqual(len(index.parts), 1)

def test_00_authenticate(self):
self.assertIsNone(self.srv.admin("box.schema.user.create('test', { password = 'test' })"))
self.assertIsNone(self.srv.admin("box.schema.user.grant('test', 'read,write', 'space', '_space')"))
Expand Down Expand Up @@ -72,6 +150,9 @@ def test_03_01_space_name__(self):
self.assertEqual(space.name, '_index')
self.assertEqual(space.arity, 1)

space = self.sch.get_space(self.unicode_space_name_literal)
self.verify_unicode_space(space)

def test_03_02_space_number(self):
self.con.flush_schema()
space = self.sch.get_space(272)
Expand All @@ -87,6 +168,9 @@ def test_03_02_space_number(self):
self.assertEqual(space.name, '_index')
self.assertEqual(space.arity, 1)

space = self.sch.get_space(self.unicode_space_id)
self.verify_unicode_space(space)

def test_04_space_cached(self):
space = self.sch.get_space('_schema')
self.assertEqual(space.sid, 272)
Expand All @@ -101,6 +185,15 @@ def test_04_space_cached(self):
self.assertEqual(space.name, '_index')
self.assertEqual(space.arity, 1)

# Verify that no schema fetches occurs.
self.assertEqual(self.fetch_count, 0)

space = self.sch.get_space(self.unicode_space_name_literal)
self.verify_unicode_space(space)

# Verify that no schema fetches occurs.
self.assertEqual(self.fetch_count, 0)

def test_05_01_index_name___name__(self):
self.con.flush_schema()
index = self.sch.get_index('_index', 'primary')
Expand All @@ -124,6 +217,10 @@ def test_05_01_index_name___name__(self):
self.assertEqual(index.name, 'name')
self.assertEqual(len(index.parts), 1)

index = self.sch.get_index(self.unicode_space_name_literal,
self.unicode_index_name_literal)
self.verify_unicode_index(index)

def test_05_02_index_name___number(self):
self.con.flush_schema()
index = self.sch.get_index('_index', 0)
Expand All @@ -147,6 +244,10 @@ def test_05_02_index_name___number(self):
self.assertEqual(index.name, 'name')
self.assertEqual(len(index.parts), 1)

index = self.sch.get_index(self.unicode_space_name_literal,
self.unicode_index_id)
self.verify_unicode_index(index)

def test_05_03_index_number_name__(self):
self.con.flush_schema()
index = self.sch.get_index(288, 'primary')
Expand All @@ -170,6 +271,10 @@ def test_05_03_index_number_name__(self):
self.assertEqual(index.name, 'name')
self.assertEqual(len(index.parts), 1)

index = self.sch.get_index(self.unicode_space_id,
self.unicode_index_name_literal)
self.verify_unicode_index(index)

def test_05_04_index_number_number(self):
self.con.flush_schema()
index = self.sch.get_index(288, 0)
Expand All @@ -193,6 +298,10 @@ def test_05_04_index_number_number(self):
self.assertEqual(index.name, 'name')
self.assertEqual(len(index.parts), 1)

index = self.sch.get_index(self.unicode_space_id,
self.unicode_index_id)
self.verify_unicode_index(index)

def test_06_index_cached(self):
index = self.sch.get_index('_index', 'primary')
self.assertEqual(index.space.name, '_index')
Expand All @@ -215,6 +324,22 @@ def test_06_index_cached(self):
self.assertEqual(index.name, 'name')
self.assertEqual(len(index.parts), 1)

# Verify that no schema fetches occurs.
self.assertEqual(self.fetch_count, 0)

cases = (
(self.unicode_space_name_literal, self.unicode_index_name_literal),
(self.unicode_space_name_literal, self.unicode_index_id),
(self.unicode_space_id, self.unicode_index_name_literal),
(self.unicode_space_id, self.unicode_index_id),
)
for s, i in cases:
index = self.sch.get_index(s, i)
self.verify_unicode_index(index)

# Verify that no schema fetches occurs.
self.assertEqual(self.fetch_count, 0)

def test_07_schema_version_update(self):
_space_len = len(self.con.select('_space'))
self.srv.admin("box.schema.create_space('ttt22')")
Expand All @@ -225,3 +350,11 @@ def tearDownClass(self):
self.con.close()
self.srv.stop()
self.srv.clean()


class TestSuite_Schema_UnicodeConnection(TestSuite_Schema_Abstract):
encoding = 'utf-8'


class TestSuite_Schema_BinaryConnection(TestSuite_Schema_Abstract):
encoding = None