Skip to content
This repository has been archived by the owner on Sep 30, 2019. It is now read-only.

Commit

Permalink
ranks raise exception if lookup isnt valid
Browse files Browse the repository at this point in the history
  • Loading branch information
dvdmgl committed Oct 2, 2014
1 parent 96d91a6 commit 47f0ef3
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 25 deletions.
4 changes: 2 additions & 2 deletions pg_fts/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class TSVectorBaseField(Field):
:raises: exceptions.FieldError if lookup isn't tsquery, search or isearch
"""

valid_lookups = ('search', 'isearch', 'tsquery')
empty_strings_allowed = True

def __init__(self, dictionary='english', **kwargs):
Expand All @@ -58,7 +58,7 @@ def db_type(self, connection):
def get_db_prep_lookup(self, lookup_type, value, connection,
prepared=False):

if lookup_type not in ('search', 'isearch', 'tsquery'):
if lookup_type not in self.valid_lookups:
raise exceptions.FieldError("'%s' isn't valid Lookup for %s" % (
lookup_type, self.__class__.__name__))

Expand Down
23 changes: 19 additions & 4 deletions pg_fts/ranks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from django.db.models.fields import FloatField
from django.db.models.constants import LOOKUP_SEP
from django.db.models.sql import aggregates
from django.core import exceptions
from pg_fts.fields import TSVectorBaseField

__all__ = ('FTSRankCd', 'FTSRank', 'FTSRankDictionay', 'FTSRankCdDictionary')

Expand Down Expand Up @@ -52,7 +54,7 @@ def as_sql(self, qn, connection):
return self.sql_template % substitutions, [self.params]


class Aggregate(object):
class RankBase(object):
NORMALIZATION = (0, 1, 2, 4, 8, 16, 32)
sql_function, rhs, dictionary, srt_lookup = '', '', '', ''

Expand Down Expand Up @@ -114,9 +116,13 @@ def _do_checks(self):
', '.join('%d' % i for i in self.NORMALIZATION))
assert len(self.extra) == 1, 'to many arguments for %s' % (
self.__class__.__name__)
if self.srt_lookup not in TSVectorBaseField.valid_lookups:
raise exceptions.FieldError(
"The '%s' isn't valid Lookup for %s" % (
self.srt_lookup, self.__class__.__name__))


class FTSRank(Aggregate):
class FTSRank(RankBase):
"""
Interface for PostgreSQL ts_rank
Expand Down Expand Up @@ -146,6 +152,9 @@ class FTSRank(Aggregate):
:param weights: iterable float
:returns: rank
:raises: exceptions.FieldError if lookup isn't valid
"""

name = 'FTSRank'
Expand All @@ -157,10 +166,10 @@ def __init__(self, **extra):
self.weights = extra.pop('weights', [])
params = tuple(extra.items())[0]
self.extra = extra
self._do_checks()
lookups, self.rhs = params[0].split(LOOKUP_SEP), params[1]
self.srt_lookup = lookups[-1]
self.lookup = LOOKUP_SEP.join(lookups[:-1])
self._do_checks()


class FTSRankCd(FTSRank):
Expand All @@ -180,6 +189,8 @@ class FTSRankCd(FTSRank):
:returns: rank_cd
:raises: exceptions.FieldError if lookup isn't valid
Example::
Article.objects.annotate(
Expand Down Expand Up @@ -216,6 +227,8 @@ class FTSRankDictionay(FTSRank):
:returns: rank
:raises: exceptions.FieldError if lookup isn't valid
Example::
Article.objects.annotate(
Expand All @@ -239,10 +252,10 @@ def __init__(self, **extra):
self.weights = extra.pop('weights', [])
params = tuple(extra.items())[0]
self.extra, self.rhs = extra, params[1]
self._do_checks()
lookups = params[0].split(LOOKUP_SEP)
self.dictionary, self.srt_lookup = lookups[-2:]
self.lookup = LOOKUP_SEP.join(lookups[:-2])
self._do_checks()


class FTSRankCdDictionary(FTSRankDictionay):
Expand All @@ -260,6 +273,8 @@ class FTSRankCdDictionary(FTSRankDictionay):
:returns: rank_cd
:raises: exceptions.FieldError if lookup isn't valid
Example::
Article.objects.annotate(
Expand Down
51 changes: 33 additions & 18 deletions testapp/tests/test_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,24 @@ def test_rank_assertions(self):
)
list(qs)

# need to find a way to catch FieldError raised by
# django.db.models.sql.query in add fields
#
def test_transform_dictionary_exception(self):
with self.assertRaises(exceptions.FieldError) as msg:
TSQueryModel.objects.annotate(
rank=FTSRank(tsvector__nodict='malucos')),
self.assertEqual(
str(msg.exception),
"The 'nodict' isn't valid Lookup for FTSRank")

with self.assertRaises(exceptions.FieldError) as msg:
TSQueryModel.objects.annotate(
rank=FTSRank(tsvector='malucos')),
self.assertEqual(
str(msg.exception),
"The 'tsvector' isn't valid Lookup for FTSRank")

def test_ts_rank_cd_search(self):
q = TSQueryModel.objects.annotate(
rank=FTSRankCd(tsvector__search='para mesmo')
Expand Down Expand Up @@ -268,21 +286,18 @@ def test_rank_cd_dictionary(self):
self.assertIn('''ts_rank_cd("testapp_tsmultidicmodel"."tsvector", to_tsquery('portuguese', para & os)) AS "rank"''',
str(qn_pt.query))

# need to find a way to catch FieldError raised by
# django.db.models.sql.query in add fields
#
# def test_transform_dictionary_exception(self):
# with self.assertRaises(exceptions.FieldError) as msg:
# TSMultidicModel.objects.annotate(
# rank=FTSRankDictionay(tsvector__nodict='malucos')),
# self.assertEqual(
# str(msg.exception),
# "The 'nodict' is not in testapp.TSMultidicModel.dictionary choices")

# def test_transform_exception(self):
# with self.assertRaises(exceptions.FieldError) as msg:
# list(TSMultidicModel.objects.annotate(
# rank=FTSRankDictionay(tsvector__portuguese='malucos')))
# self.assertEqual(
# str(msg.exception),
# "'exact' isn't valid Lookup for TSVectorBaseField")
def test_transform_dictionary_exception(self):
with self.assertRaises(exceptions.FieldError) as msg:
TSMultidicModel.objects.annotate(
rank=FTSRankDictionay(tsvector__nodict='malucos')),
self.assertEqual(
str(msg.exception),
"The 'nodict' isn't valid Lookup for FTSRankDictionay")

def test_transform_exception(self):
with self.assertRaises(exceptions.FieldError) as msg:
list(TSMultidicModel.objects.annotate(
rank=FTSRankDictionay(tsvector__portuguese='malucos')))
self.assertEqual(
str(msg.exception),
"The 'portuguese' isn't valid Lookup for FTSRankDictionay")
1 change: 0 additions & 1 deletion testapp/tests/test_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def test_sql_migrate_creates_vector_field_multi(self):
"UPDATE testapp_tsvectormodel SET tsvector = setweight(to_tsvector(dictionary::regconfig, COALESCE(title, '')), 'D') || setweight(to_tsvector(dictionary::regconfig, COALESCE(body, '')), 'D');",
stdout.getvalue())


@override_system_checks([])
@override_settings(MIGRATION_MODULES={"testapp": "testapp.migrations_multidict"})
def test_sql_fts_index_multi(self):
Expand Down

0 comments on commit 47f0ef3

Please sign in to comment.