Skip to content

Commit

Permalink
Pass relationship and registry objects to connection_field_factory
Browse files Browse the repository at this point in the history
Merge pull request #187 from jnak/connection-factory
  • Loading branch information
jnak authored Apr 12, 2019
2 parents c9af40c + b97bbfb commit ef1fce2
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 60 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ htmlcov/
nosetests.xml
coverage.xml
*,cover
.pytest_cache/

# Translations
*.mo
Expand Down
6 changes: 2 additions & 4 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
String)
from graphene.types.json import JSONString

from .fields import createConnectionField

try:
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
except ImportError:
Expand All @@ -23,7 +21,7 @@ def is_column_nullable(column):
return bool(getattr(column, "nullable", True))


def convert_sqlalchemy_relationship(relationship, registry):
def convert_sqlalchemy_relationship(relationship, registry, connection_field_factory):
direction = relationship.direction
model = relationship.mapper.entity

Expand All @@ -35,7 +33,7 @@ def dynamic_type():
return Field(_type)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
if _type._meta.connection:
return createConnectionField(_type._meta.connection)
return connection_field_factory(relationship, registry)
return Field(List(_type))

return Dynamic(dynamic_type)
Expand Down
22 changes: 22 additions & 0 deletions graphene_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from functools import partial

from promise import Promise, is_thenable
Expand All @@ -9,6 +10,8 @@

from .utils import get_query, sort_argument_for_model

log = logging.getLogger()


class UnsortedSQLAlchemyConnectionField(ConnectionField):
@property
Expand Down Expand Up @@ -95,18 +98,37 @@ def __init__(self, type, *args, **kwargs):
super(SQLAlchemyConnectionField, self).__init__(type, *args, **kwargs)


def default_connection_field_factory(relationship, registry):
model = relationship.mapper.entity
model_type = registry.get_type_for_model(model)
return createConnectionField(model_type)


# TODO Remove in next major version
__connectionFactory = UnsortedSQLAlchemyConnectionField


def createConnectionField(_type):
log.warn(
'createConnectionField is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
)
return __connectionFactory(_type)


def registerConnectionFieldFactory(factoryMethod):
log.warn(
'registerConnectionFieldFactory is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
)
global __connectionFactory
__connectionFactory = factoryMethod


def unregisterConnectionFieldFactory():
log.warn(
'registerConnectionFieldFactory is deprecated and will be removed in the next '
'major version. Use SQLAlchemyObjectType.Meta.connection_field_factory instead.'
)
global __connectionFactory
__connectionFactory = UnsortedSQLAlchemyConnectionField
42 changes: 0 additions & 42 deletions graphene_sqlalchemy/tests/test_connectionfactory.py

This file was deleted.

21 changes: 13 additions & 8 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_composite,
convert_sqlalchemy_relationship)
from ..fields import UnsortedSQLAlchemyConnectionField
from ..fields import (UnsortedSQLAlchemyConnectionField,
default_connection_field_factory)
from ..registry import Registry
from ..types import SQLAlchemyObjectType
from .models import Article, Pet, Reporter
Expand Down Expand Up @@ -179,7 +180,9 @@ def test_should_jsontype_convert_jsonstring():

def test_should_manytomany_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry)
dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()

Expand All @@ -190,7 +193,7 @@ class Meta:
model = Pet

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry
Reporter.pets.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -206,15 +209,17 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry
Reporter.pets.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField)


def test_should_manytoone_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry)
dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()

Expand All @@ -225,7 +230,7 @@ class Meta:
model = Reporter

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry
Article.reporter.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -240,7 +245,7 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry
Article.reporter.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -255,7 +260,7 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Reporter.favorite_article.property, A._meta.registry
Reporter.favorite_article.property, A._meta.registry, default_connection_field_factory
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand Down
98 changes: 95 additions & 3 deletions graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import six # noqa F401
from promise import Promise

from graphene import Field, Int, Interface, ObjectType
from graphene.relay import Connection, Node, is_node
from graphene import (Connection, Field, Int, Interface, Node, ObjectType,
is_node)

from ..fields import SQLAlchemyConnectionField
from ..fields import (SQLAlchemyConnectionField,
UnsortedSQLAlchemyConnectionField,
registerConnectionFieldFactory,
unregisterConnectionFieldFactory)
from ..registry import Registry
from ..types import SQLAlchemyObjectType, SQLAlchemyObjectTypeOptions
from .models import Article, Reporter
Expand Down Expand Up @@ -185,3 +188,92 @@ def resolver(*args, **kwargs):
resolver, TestConnection, ReporterWithCustomOptions, None, None
)
assert result is not None


# Tests for connection_field_factory

class _TestSQLAlchemyConnectionField(SQLAlchemyConnectionField):
pass


def test_default_connection_field_factory():
_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert isinstance(ReporterType._meta.fields['articles'].type(), UnsortedSQLAlchemyConnectionField)


def test_register_connection_field_factory():
def test_connection_field_factory(relationship, registry):
model = relationship.mapper.entity
_type = registry.get_type_for_model(model)
return _TestSQLAlchemyConnectionField(_type._meta.connection)

_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)
connection_field_factory = test_connection_field_factory

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)


def test_deprecated_registerConnectionFieldFactory():
registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)

_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)


def test_deprecated_unregisterConnectionFieldFactory():
registerConnectionFieldFactory(_TestSQLAlchemyConnectionField)
unregisterConnectionFieldFactory()

_registry = Registry()

class ReporterType(SQLAlchemyObjectType):
class Meta:
model = Reporter
registry = _registry
interfaces = (Node,)

class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
registry = _registry
interfaces = (Node,)

assert not isinstance(ReporterType._meta.fields['articles'].type(), _TestSQLAlchemyConnectionField)
15 changes: 12 additions & 3 deletions graphene_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
convert_sqlalchemy_composite,
convert_sqlalchemy_hybrid_method,
convert_sqlalchemy_relationship)
from .fields import default_connection_field_factory
from .registry import Registry, get_global_registry
from .utils import get_query, is_mapped_class, is_mapped_instance


def construct_fields(model, registry, only_fields, exclude_fields):
def construct_fields(model, registry, only_fields, exclude_fields, connection_field_factory):
inspected_model = sqlalchemyinspect(model)

fields = OrderedDict()
Expand Down Expand Up @@ -71,7 +72,7 @@ def construct_fields(model, registry, only_fields, exclude_fields):
# We skip this field if we specify only_fields and is not
# in there. Or when we exclude this field in exclude_fields
continue
converted_relationship = convert_sqlalchemy_relationship(relationship, registry)
converted_relationship = convert_sqlalchemy_relationship(relationship, registry, connection_field_factory)
name = relationship.key
fields[name] = converted_relationship

Expand Down Expand Up @@ -99,6 +100,7 @@ def __init_subclass_with_meta__(
use_connection=None,
interfaces=(),
id=None,
connection_field_factory=default_connection_field_factory,
_meta=None,
**options
):
Expand All @@ -115,7 +117,14 @@ def __init_subclass_with_meta__(
).format(cls.__name__, registry)

sqla_fields = yank_fields_from_attrs(
construct_fields(model, registry, only_fields, exclude_fields), _as=Field
construct_fields(
model=model,
registry=registry,
only_fields=only_fields,
exclude_fields=exclude_fields,
connection_field_factory=connection_field_factory
),
_as=Field
)

if use_connection is None and interfaces:
Expand Down

0 comments on commit ef1fce2

Please sign in to comment.