diff --git a/README.md b/README.md index c6fe987..2398221 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # Graphene-SQLAlchemy-Filter -[![CI](https://github.com/art1415926535/graphene-sqlalchemy-filter/workflows/CI/badge.svg)](https://github.com/art1415926535/graphene-sqlalchemy-filter/actions?query=workflow%3ACI) [![Coverage Status](https://coveralls.io/repos/github/art1415926535/graphene-sqlalchemy-filter/badge.svg?branch=master)](https://coveralls.io/github/art1415926535/graphene-sqlalchemy-filter?branch=master) [![PyPI version](https://badge.fury.io/py/graphene-sqlalchemy-filter.svg)](https://badge.fury.io/py/graphene-sqlalchemy-filter) - Filters for [Graphene SQLAlchemy integration](https://github.com/graphql-python/graphene-sqlalchemy) ![preview](https://github.com/art1415926535/graphene-sqlalchemy-filter/blob/master/preview.gif?raw=true) diff --git a/graphene_sqlalchemy_filter/connection_field.py b/graphene_sqlalchemy_filter/connection_field.py index c743078..c9e2e05 100644 --- a/graphene_sqlalchemy_filter/connection_field.py +++ b/graphene_sqlalchemy_filter/connection_field.py @@ -1,4 +1,5 @@ # Standard Library +import re from contextlib import suppress from functools import partial from typing import cast @@ -15,7 +16,8 @@ MYPY = False if MYPY: - from typing import ( + # Standard Library + from typing import ( # noqa: F401; pragma: no cover Any, Callable, Dict, @@ -24,21 +26,29 @@ Tuple, Type, Union, - ) # noqa: F401; pragma: no cover - from graphql import ResolveInfo # noqa: F401; pragma: no cover + ) + + # GraphQL from graphene.relay import Connection # noqa: F401; pragma: no cover - from sqlalchemy.orm import Query # noqa: F401; pragma: no cover - from .filters import FilterSet # noqa: F401; pragma: no cover + from graphql import GraphQLResolveInfo # noqa: F401; pragma: no cover + # Database + from sqlalchemy.orm import Query # noqa: F401; pragma: no cover -graphene_sqlalchemy_version_lt_2_1_2 = tuple( - map(int, graphene_sqlalchemy.__version__.split('.')) -) < (2, 1, 2) + # This module + from .filters import FilterSet # noqa: F401; pragma: no cover +try: + graphene_sqlalchemy_version_lt_2_1_2 = tuple( + map(int, re.split(r'\d+\D',graphene_sqlalchemy.__version__)) + ) < (2, 1, 2) +except ValueError: + graphene_sqlalchemy_version_lt_2_1_2 = False if graphene_sqlalchemy_version_lt_2_1_2: default_connection_field_factory = None # pragma: no cover else: + # GraphQL from graphene_sqlalchemy.fields import default_connection_field_factory @@ -78,7 +88,7 @@ def __init__(self, connection, *args, **kwargs): super().__init__(connection, *args, **kwargs) @classmethod - def get_query(cls, model, info: 'ResolveInfo', sort=None, **args): + def get_query(cls, model, info: 'GraphQLResolveInfo', sort=None, **args): """Standard get_query with filtering.""" query = super().get_query(model, info, sort, **args) @@ -90,7 +100,7 @@ def get_query(cls, model, info: 'ResolveInfo', sort=None, **args): return query @classmethod - def get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet': + def get_filter_set(cls, info: 'GraphQLResolveInfo') -> 'FilterSet': """ Get field filter set. @@ -101,7 +111,7 @@ def get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet': FilterSet class from field args. """ - field_name = info.field_asts[0].name.value + field_name = info.field_nodes[0].name.value schema_field = info.parent_type.fields.get(field_name) filters_type = schema_field.args[cls.filter_arg].type filters: 'FilterSet' = filters_type.graphene_type @@ -115,7 +125,7 @@ def __init__( self, parent_model: 'Any', model: 'Any', - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', graphql_args: dict, ): """ @@ -129,7 +139,7 @@ def __init__( """ super().__init__() - self.info: 'ResolveInfo' = info + self.info: 'GraphQLResolveInfo' = info self.graphql_args: dict = graphql_args self.model: 'Any' = model @@ -215,7 +225,7 @@ def parent_model_object_to_key(self, parent_object: 'Any') -> 'Any': return key @classmethod - def _get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet': + def _get_filter_set(cls, info: 'GraphQLResolveInfo') -> 'FilterSet': """ Get field filter set. @@ -226,7 +236,7 @@ def _get_filter_set(cls, info: 'ResolveInfo') -> 'FilterSet': FilterSet class from field args. """ - field_name = info.field_asts[0].name.value + field_name = info.field_nodes[0].name.value schema_field = info.parent_type.fields.get(field_name) filters_type = schema_field.args[cls.filter_arg].type filters: 'FilterSet' = filters_type.graphene_type @@ -287,7 +297,7 @@ class NestedFilterableConnectionField(FilterableConnectionField): @classmethod def _get_or_create_data_loader( - cls, root: 'Any', model: 'Any', info: 'ResolveInfo', args: dict + cls, root: 'Any', model: 'Any', info: 'GraphQLResolveInfo', args: dict ) -> ModelLoader: """ Get or create (and save) dataloader from ResolveInfo @@ -335,7 +345,7 @@ def connection_resolver( connection_type: 'Any', model: 'Any', root: 'Any', - info: 'ResolveInfo', + info: 'GraphQLResolveInfo', **kwargs: dict, ) -> 'Union[Promise, Connection]': """ diff --git a/graphene_sqlalchemy_filter/filters.py b/graphene_sqlalchemy_filter/filters.py index 73f9e63..28cd93a 100644 --- a/graphene_sqlalchemy_filter/filters.py +++ b/graphene_sqlalchemy_filter/filters.py @@ -1,6 +1,7 @@ # Standard Library import contextlib import inspect +import re import warnings from copy import deepcopy from functools import lru_cache @@ -12,7 +13,7 @@ from graphene.types.utils import get_field_as from graphene_sqlalchemy import __version__ as gqls_version from graphene_sqlalchemy.converter import convert_sqlalchemy_type -from graphql import ResolveInfo +from graphql import GraphQLResolveInfo # Database from sqlalchemy import and_, cast, inspection, not_, or_, types @@ -27,16 +28,19 @@ MYPY = False if MYPY: + # Standard Library from typing import ( # noqa: F401; pragma: no cover Any, Callable, Dict, Iterable, List, - Type, Tuple, + Type, Union, ) + + # Database from sqlalchemy import Column # noqa: F401; pragma: no cover from sqlalchemy.orm.query import ( # noqa: F401; pragma: no cover _MapperEntity, @@ -50,12 +54,15 @@ try: + # Third Party from sqlalchemy_utils import TSVectorType except ImportError: TSVectorType = object - -gqls_version = tuple([int(x) for x in gqls_version.split('.')]) +try: + gqls_version = tuple([int(x) for x in gqls_version.split('.')]) +except ValueError: + gqls_version = tuple([int(x) for x in re.findall(r'\d+',gqls_version)]) def _get_class(obj: 'GRAPHENE_OBJECT_OR_CLASS') -> 'Type[graphene.ObjectType]': @@ -496,7 +503,7 @@ def _aliases_from_query(cls, query: Query) -> 'Dict[str, _MapperEntity]': else: aliases = { (join_entity._target, join_entity.name): join_entity.entity - for join_entity in query._compile_state()._join_entities + for join_entity in query._compile_state()._join_entities } return aliases @@ -676,7 +683,7 @@ def _generate_filter_fields( @classmethod def filter( - cls, info: ResolveInfo, query: Query, filters: 'FilterType' + cls, info: GraphQLResolveInfo, query: Query, filters: 'FilterType' ) -> Query: """ Return a new query instance with the args ANDed to the existing set. @@ -746,7 +753,7 @@ def _split_graphql_field(cls, graphql_field: str) -> 'Tuple[str, str]': @classmethod def _translate_filter( - cls, info: ResolveInfo, query: Query, key: str, value: 'Any' + cls, info: GraphQLResolveInfo, query: Query, key: str, value: 'Any' ) -> 'Tuple[Query, Any]': """ Translate GraphQL to SQLAlchemy filters. @@ -804,7 +811,7 @@ def _translate_filter( @classmethod def _translate_many_filter( cls, - info: ResolveInfo, + info: GraphQLResolveInfo, query: Query, filters: 'Union[List[FilterType], FilterType]', join_by: 'Callable' = None, diff --git a/setup.py b/setup.py index 930e7de..0d53221 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ requirements = [ - 'graphene-sqlalchemy>=2.1.0,<3', + 'graphene-sqlalchemy>=2.1.0', 'SQLAlchemy<2', ] diff --git a/tests/__init__.py b/tests/__init__.py index 37d9717..f18fc58 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,5 +1,11 @@ +# Standard Library +import re + # GraphQL from graphene_sqlalchemy import __version__ as gqls_version -gqls_version = tuple([int(x) for x in gqls_version.split('.')]) +try: + gqls_version = tuple([int(x) for x in gqls_version.split('.')]) +except ValueError: + gqls_version = tuple([int(x) for x in re.findall(r'\d+',gqls_version)]) diff --git a/tests/conftest.py b/tests/conftest.py index 5252a21..4acd4e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import pytest # GraphQL -from graphql import ResolveInfo +from graphql import GraphQLResolveInfo # Database from sqlalchemy import create_engine @@ -40,7 +40,7 @@ def info(): session_factory = sessionmaker(bind=connection) session = scoped_session(session_factory) - yield ResolveInfo(*[None] * 9, context={'session': session}) + yield GraphQLResolveInfo(*[None] * 9, context={'session': session}) transaction.rollback() connection.close() @@ -57,7 +57,7 @@ def info_and_user_query(): session_factory = sessionmaker(bind=connection) session = scoped_session(session_factory) - info = ResolveInfo(*[None] * 9, context={'session': session}) + info = GraphQLResolveInfo(*[None] * 9, context={'session': session}) user_query = session.query(models.User) yield info, user_query