Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide basic support for graphene/graphene-sqlalchemy>=3.0 #63

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Graphene-SQLAlchemy-Filter

[![CI](https://github.com/art1415926535/graphene-sqlalchemy-filter/workflows/CI/badge.svg)](https://github.com/art1415926535/graphene-sqlalchemy-filter/actions?query=workflow%3ACI) [![Coverage Status](https://coveralls.io/repos/github/art1415926535/graphene-sqlalchemy-filter/badge.svg?branch=master)](https://coveralls.io/github/art1415926535/graphene-sqlalchemy-filter?branch=master) [![PyPI version](https://badge.fury.io/py/graphene-sqlalchemy-filter.svg)](https://badge.fury.io/py/graphene-sqlalchemy-filter)

Filters for [Graphene SQLAlchemy integration](https://github.com/graphql-python/graphene-sqlalchemy)

![preview](https://github.com/art1415926535/graphene-sqlalchemy-filter/blob/master/preview.gif?raw=true)
Expand Down
44 changes: 27 additions & 17 deletions graphene_sqlalchemy_filter/connection_field.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Standard Library
import re
from contextlib import suppress
from functools import partial
from typing import cast
Expand All @@ -15,7 +16,8 @@

MYPY = False
if MYPY:
from typing import (
# Standard Library
from typing import ( # noqa: F401; pragma: no cover
Any,
Callable,
Dict,
Expand All @@ -24,21 +26,29 @@
Tuple,
Type,
Union,
) # noqa: F401; pragma: no cover
from graphql import ResolveInfo # noqa: F401; pragma: no cover
)

# GraphQL
from graphene.relay import Connection # noqa: F401; pragma: no cover
from sqlalchemy.orm import Query # noqa: F401; pragma: no cover
from .filters import FilterSet # noqa: F401; pragma: no cover
from graphql import GraphQLResolveInfo # noqa: F401; pragma: no cover

# Database
from sqlalchemy.orm import Query # noqa: F401; pragma: no cover

graphene_sqlalchemy_version_lt_2_1_2 = tuple(
map(int, graphene_sqlalchemy.__version__.split('.'))
) < (2, 1, 2)
# This module
from .filters import FilterSet # noqa: F401; pragma: no cover

try:
graphene_sqlalchemy_version_lt_2_1_2 = tuple(
map(int, re.split(r'\d+\D',graphene_sqlalchemy.__version__))
) < (2, 1, 2)
except ValueError:
graphene_sqlalchemy_version_lt_2_1_2 = False

if graphene_sqlalchemy_version_lt_2_1_2:
default_connection_field_factory = None # pragma: no cover
else:
# GraphQL
from graphene_sqlalchemy.fields import default_connection_field_factory


Expand Down Expand Up @@ -78,7 +88,7 @@ def __init__(self, connection, *args, **kwargs):
super().__init__(connection, *args, **kwargs)

@classmethod
def get_query(cls, model, info: 'ResolveInfo', sort=None, **args):
def get_query(cls, model, info: 'GraphQLResolveInfo', sort=None, **args):
"""Standard get_query with filtering."""
query = super().get_query(model, info, sort, **args)

Expand All @@ -90,7 +100,7 @@ def get_query(cls, model, info: 'ResolveInfo', sort=None, **args):
return query

@classmethod
def get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet':
def get_filter_set(cls, info: 'GraphQLResolveInfo') -> 'FilterSet':
"""
Get field filter set.

Expand All @@ -101,7 +111,7 @@ def get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet':
FilterSet class from field args.

"""
field_name = info.field_asts[0].name.value
field_name = info.field_nodes[0].name.value
schema_field = info.parent_type.fields.get(field_name)
filters_type = schema_field.args[cls.filter_arg].type
filters: 'FilterSet' = filters_type.graphene_type
Expand All @@ -115,7 +125,7 @@ def __init__(
self,
parent_model: 'Any',
model: 'Any',
info: 'ResolveInfo',
info: 'GraphQLResolveInfo',
graphql_args: dict,
):
"""
Expand All @@ -129,7 +139,7 @@ def __init__(

"""
super().__init__()
self.info: 'ResolveInfo' = info
self.info: 'GraphQLResolveInfo' = info
self.graphql_args: dict = graphql_args

self.model: 'Any' = model
Expand Down Expand Up @@ -215,7 +225,7 @@ def parent_model_object_to_key(self, parent_object: 'Any') -> 'Any':
return key

@classmethod
def _get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet':
def _get_filter_set(cls, info: 'GraphQLResolveInfo') -> 'FilterSet':
"""
Get field filter set.

Expand All @@ -226,7 +236,7 @@ def _get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet':
FilterSet class from field args.

"""
field_name = info.field_asts[0].name.value
field_name = info.field_nodes[0].name.value
schema_field = info.parent_type.fields.get(field_name)
filters_type = schema_field.args[cls.filter_arg].type
filters: 'FilterSet' = filters_type.graphene_type
Expand Down Expand Up @@ -287,7 +297,7 @@ class NestedFilterableConnectionField(FilterableConnectionField):

@classmethod
def _get_or_create_data_loader(
cls, root: 'Any', model: 'Any', info: 'ResolveInfo', args: dict
cls, root: 'Any', model: 'Any', info: 'GraphQLResolveInfo', args: dict
) -> ModelLoader:
"""
Get or create (and save) dataloader from ResolveInfo
Expand Down Expand Up @@ -335,7 +345,7 @@ def connection_resolver(
connection_type: 'Any',
model: 'Any',
root: 'Any',
info: 'ResolveInfo',
info: 'GraphQLResolveInfo',
**kwargs: dict,
) -> 'Union[Promise, Connection]':
"""
Expand Down
23 changes: 15 additions & 8 deletions graphene_sqlalchemy_filter/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Standard Library
import contextlib
import inspect
import re
import warnings
from copy import deepcopy
from functools import lru_cache
Expand All @@ -12,7 +13,7 @@
from graphene.types.utils import get_field_as
from graphene_sqlalchemy import __version__ as gqls_version
from graphene_sqlalchemy.converter import convert_sqlalchemy_type
from graphql import ResolveInfo
from graphql import GraphQLResolveInfo

# Database
from sqlalchemy import and_, cast, inspection, not_, or_, types
Expand All @@ -27,16 +28,19 @@

MYPY = False
if MYPY:
# Standard Library
from typing import ( # noqa: F401; pragma: no cover
Any,
Callable,
Dict,
Iterable,
List,
Type,
Tuple,
Type,
Union,
)

# Database
from sqlalchemy import Column # noqa: F401; pragma: no cover
from sqlalchemy.orm.query import ( # noqa: F401; pragma: no cover
_MapperEntity,
Expand All @@ -50,12 +54,15 @@


try:
# Third Party
from sqlalchemy_utils import TSVectorType
except ImportError:
TSVectorType = object


gqls_version = tuple([int(x) for x in gqls_version.split('.')])
try:
gqls_version = tuple([int(x) for x in gqls_version.split('.')])
except ValueError:
gqls_version = tuple([int(x) for x in re.findall(r'\d+',gqls_version)])


def _get_class(obj: 'GRAPHENE_OBJECT_OR_CLASS') -> 'Type[graphene.ObjectType]':
Expand Down Expand Up @@ -496,7 +503,7 @@ def _aliases_from_query(cls, query: Query) -> 'Dict[str, _MapperEntity]':
else:
aliases = {
(join_entity._target, join_entity.name): join_entity.entity
for join_entity in query._compile_state()._join_entities
for join_entity in query._compile_state()._join_entities
}

return aliases
Expand Down Expand Up @@ -676,7 +683,7 @@ def _generate_filter_fields(

@classmethod
def filter(
cls, info: ResolveInfo, query: Query, filters: 'FilterType'
cls, info: GraphQLResolveInfo, query: Query, filters: 'FilterType'
) -> Query:
"""
Return a new query instance with the args ANDed to the existing set.
Expand Down Expand Up @@ -746,7 +753,7 @@ def _split_graphql_field(cls, graphql_field: str) -> 'Tuple[str, str]':

@classmethod
def _translate_filter(
cls, info: ResolveInfo, query: Query, key: str, value: 'Any'
cls, info: GraphQLResolveInfo, query: Query, key: str, value: 'Any'
) -> 'Tuple[Query, Any]':
"""
Translate GraphQL to SQLAlchemy filters.
Expand Down Expand Up @@ -804,7 +811,7 @@ def _translate_filter(
@classmethod
def _translate_many_filter(
cls,
info: ResolveInfo,
info: GraphQLResolveInfo,
query: Query,
filters: 'Union[List[FilterType], FilterType]',
join_by: 'Callable' = None,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


requirements = [
'graphene-sqlalchemy>=2.1.0,<3',
'graphene-sqlalchemy>=2.1.0',

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be graphene-sqlalchemy>=3 because the GraphQLResolveInfo was introduced in v3 and is not in 2.x

'SQLAlchemy<2',
]

Expand Down
8 changes: 7 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Standard Library
import re

# GraphQL
from graphene_sqlalchemy import __version__ as gqls_version


gqls_version = tuple([int(x) for x in gqls_version.split('.')])
try:
gqls_version = tuple([int(x) for x in gqls_version.split('.')])
except ValueError:
gqls_version = tuple([int(x) for x in re.findall(r'\d+',gqls_version)])
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

# GraphQL
from graphql import ResolveInfo
from graphql import GraphQLResolveInfo

# Database
from sqlalchemy import create_engine
Expand Down Expand Up @@ -40,7 +40,7 @@ def info():
session_factory = sessionmaker(bind=connection)
session = scoped_session(session_factory)

yield ResolveInfo(*[None] * 9, context={'session': session})
yield GraphQLResolveInfo(*[None] * 9, context={'session': session})

transaction.rollback()
connection.close()
Expand All @@ -57,7 +57,7 @@ def info_and_user_query():
session_factory = sessionmaker(bind=connection)
session = scoped_session(session_factory)

info = ResolveInfo(*[None] * 9, context={'session': session})
info = GraphQLResolveInfo(*[None] * 9, context={'session': session})
user_query = session.query(models.User)

yield info, user_query
Expand Down