Skip to content

Commit

Permalink
Merge pull request #120 from nikordaris/connection-2.x
Browse files Browse the repository at this point in the history
SQLAlchemyConnectionField Graphene 2.0 + Promise Support
  • Loading branch information
syrusakbary authored Jun 4, 2018
2 parents a2fe926 + 65e1373 commit 8cb52a1
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 28 deletions.
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def dynamic_type():
return Field(_type)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if _type._meta.connection:
return createConnectionField(_type)
return createConnectionField(_type._meta.connection)
return Field(List(_type))

return Dynamic(dynamic_type)
Expand Down
43 changes: 21 additions & 22 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial

from promise import is_thenable, Promise
from sqlalchemy.orm.query import Query

from graphene.relay import ConnectionField
Expand All @@ -19,39 +19,38 @@ def model(self):
def get_query(cls, model, info, **args):
return get_query(model, info.context)

@property
def type(self):
from .types import SQLAlchemyObjectType
_type = super(ConnectionField, self).type
assert issubclass(_type, SQLAlchemyObjectType), (
"SQLAlchemyConnectionField only accepts SQLAlchemyObjectType types"
)
assert _type._meta.connection, "The type {} doesn't have a connection".format(_type.__name__)
return _type._meta.connection

@classmethod
def connection_resolver(cls, resolver, connection, model, root, info, **args):
iterable = resolver(root, info, **args)
if iterable is None:
iterable = cls.get_query(model, info, **args)
if isinstance(iterable, Query):
_len = iterable.count()
def resolve_connection(cls, connection_type, model, info, args, resolved):
if resolved is None:
resolved = cls.get_query(model, info, **args)
if isinstance(resolved, Query):
_len = resolved.count()
else:
_len = len(iterable)
_len = len(resolved)
connection = connection_from_list_slice(
iterable,
resolved,
args,
slice_start=0,
list_length=_len,
list_slice_length=_len,
connection_type=connection,
connection_type=connection_type,
pageinfo_type=PageInfo,
edge_type=connection.Edge,
edge_type=connection_type.Edge,
)
connection.iterable = iterable
connection.iterable = resolved
connection.length = _len
return connection

@classmethod
def connection_resolver(cls, resolver, connection_type, model, root, info, **args):
resolved = resolver(root, info, **args)

on_resolve = partial(cls.resolve_connection, connection_type, model, info, args)
if is_thenable(resolved):
return Promise.resolve(resolved).then(on_resolve)

return on_resolve(resolved)

def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)

Expand Down
5 changes: 4 additions & 1 deletion graphene_sqlalchemy/tests/test_connectionfactory.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def LXResolver(root, args, context, info):
return SQLAlchemyConnectionField.connection_resolver(LXResolver, connection, model, root, args, context, info)

def createLXConnectionField(table):
return LXConnectionField(table, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))
class LXConnection(graphene.relay.Connection):
class Meta:
node = table
return LXConnectionField(LXConnection, filter=table.filter(), order_by=graphene.List(of_type=table.order_by))

registerConnectionFieldFactory(createLXConnectionField)
unregisterConnectionFieldFactory()
14 changes: 11 additions & 3 deletions graphene_sqlalchemy/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class Meta:
interfaces = (Node, )

@classmethod
def get_node(cls, id, info):
def get_node(cls, info, id):
return Reporter(id=2, first_name='Cookie Monster')

class ArticleNode(SQLAlchemyObjectType):
Expand All @@ -152,11 +152,15 @@ class Meta:
# def get_node(cls, id, info):
# return Article(id=1, headline='Article node')

class ArticleConnection(graphene.relay.Connection):
class Meta:
node = ArticleNode

class Query(graphene.ObjectType):
node = Node.Field()
reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleConnection)

def resolve_reporter(self, *args, **kwargs):
return session.query(Reporter).first()
Expand Down Expand Up @@ -238,9 +242,13 @@ class Meta:
model = Editor
interfaces = (Node, )

class EditorConnection(graphene.relay.Connection):
class Meta:
node = EditorNode

class Query(graphene.ObjectType):
node = Node.Field()
all_editors = SQLAlchemyConnectionField(EditorNode)
all_editors = SQLAlchemyConnectionField(EditorConnection)

query = '''
query EditorQuery {
Expand Down
14 changes: 13 additions & 1 deletion graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections import OrderedDict
from graphene import Field, Int, Interface, ObjectType
from graphene.relay import Node, is_node
from graphene.relay import Node, is_node, Connection
import six
from promise import Promise

from ..registry import Registry
from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
from .models import Article, Reporter
from ..fields import SQLAlchemyConnectionField

registry = Registry()

Expand Down Expand Up @@ -158,3 +160,13 @@ def test_objecttype_with_custom_options():
'favorite_article']
assert ReporterWithCustomOptions._meta.custom_option == 'custom_option'
assert isinstance(ReporterWithCustomOptions._meta.fields['custom_field'].type, Int)


def test_promise_connection_resolver():
class TestConnection(Connection):
class Meta:
node = ReporterWithCustomOptions

resolver = lambda *args, **kwargs: Promise.resolve([])
result = SQLAlchemyConnectionField.connection_resolver(resolver, TestConnection, ReporterWithCustomOptions, None, None)
assert result is not None

0 comments on commit 8cb52a1

Please sign in to comment.