Skip to content

Commit

Permalink
Support asyncio views
Browse files Browse the repository at this point in the history
Use contextvars instead of threading.local to support async views.
  • Loading branch information
bellini666 committed Feb 4, 2022
1 parent 656ca67 commit bf5c977
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions reversion/revisions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from contextvars import ContextVar
from collections import namedtuple, defaultdict
from contextlib import contextmanager
from functools import wraps
from threading import local
from django.apps import apps
from django.core import serializers
from django.core.exceptions import ObjectDoesNotExist
Expand Down Expand Up @@ -34,23 +34,18 @@
))


class _Local(local):

def __init__(self):
self.stack = ()


_local = _Local()
_stack = ContextVar("reversion-stack", default=[])
_token = ContextVar("reversion-token")


def is_active():
return bool(_local.stack)
return bool(_stack.get())


def _current_frame():
if not is_active():
raise RevisionManagementError("There is no active revision for this thread")
return _local.stack[-1]
return _stack.get()[-1]


def _copy_db_versions(db_versions):
Expand Down Expand Up @@ -79,16 +74,17 @@ def _push_frame(manage_manually, using):
db_versions={using: {}},
meta=(),
)
_local.stack += (stack_frame,)
token = _stack.set(_stack.get() + [stack_frame])
_token.set(token)


def _update_frame(**kwargs):
_local.stack = _local.stack[:-1] + (_current_frame()._replace(**kwargs),)
_stack.get()[-1] = _current_frame()._replace(**kwargs)


def _pop_frame():
prev_frame = _current_frame()
_local.stack = _local.stack[:-1]
_stack.reset(_token.get())
if is_active():
current_frame = _current_frame()
db_versions = {
Expand Down Expand Up @@ -284,7 +280,7 @@ def _create_revision_context(manage_manually, using, atomic):
try:
yield
# Only save for a db if that's the last stack frame for that db.
if not any(using in frame.db_versions for frame in _local.stack[:-1]):
if not any(using in frame.db_versions for frame in _stack.get()[:-1]):
current_frame = _current_frame()
_save_revision(
versions=current_frame.db_versions[using].values(),
Expand Down

0 comments on commit bf5c977

Please sign in to comment.