Skip to content

Commit

Permalink
Use py36 async generators for simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
grigi committed Sep 21, 2019
1 parent 023d4b4 commit b1eb441
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 33 deletions.
10 changes: 5 additions & 5 deletions tortoise/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pypika import Table

from tortoise.exceptions import ConfigurationError, NoValuesFetched, OperationalError
from tortoise.utils import QueryAsyncIterator

if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
Expand Down Expand Up @@ -533,12 +532,13 @@ def __getitem__(self, item):
def __await__(self):
return self._query.__await__()

def __aiter__(self) -> QueryAsyncIterator:
async def fetched_callback(iterator_wrapper):
async def __aiter__(self):
if not self._fetched:
self.related_objects = await self
self._fetched = True
self.related_objects = iterator_wrapper.sequence

return QueryAsyncIterator(self._query, callback=fetched_callback)
for val in self.related_objects:
yield val

def filter(self, *args, **kwargs):
"""
Expand Down
6 changes: 3 additions & 3 deletions tortoise/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.exceptions import DoesNotExist, FieldError, IntegrityError, MultipleObjectsReturned
from tortoise.query_utils import Prefetch, Q, QueryModifier, _get_joins_for_related_field
from tortoise.utils import QueryAsyncIterator

# Empty placeholder - Should never be edited.
QUERY = Query()
Expand Down Expand Up @@ -81,8 +80,9 @@ def __await__(self):
self._make_query()
return self._execute().__await__()

def __aiter__(self) -> QueryAsyncIterator:
return QueryAsyncIterator(self)
async def __aiter__(self):
for val in await self:
yield val

async def _execute(self):
raise NotImplementedError() # pragma: nocoverage
Expand Down
26 changes: 1 addition & 25 deletions tortoise/utils.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,9 @@
import logging
from typing import Awaitable, Callable, Iterator, List, Optional
from typing import List

logger = logging.getLogger("tortoise")


class QueryAsyncIterator:
__slots__ = ("query", "sequence", "_sequence_iterator", "_callback")

def __init__(self, query: Awaitable[Iterator], callback: Optional[Callable] = None) -> None:
self.query = query
self.sequence: Optional[Iterator] = None
self._sequence_iterator = None
self._callback = callback

def __aiter__(self):
return self # pragma: nocoverage

async def __anext__(self):
if self.sequence is None:
self.sequence = await self.query
self._sequence_iterator = self.sequence.__iter__()
if self._callback: # pragma: no branch
await self._callback(self)
try:
return next(self._sequence_iterator)
except StopIteration:
raise StopAsyncIteration


def get_schema_sql(client, safe: bool) -> str:
generator = client.schema_generator(client)
return generator.get_create_schema_sql(safe)
Expand Down

0 comments on commit b1eb441

Please sign in to comment.