Skip to content

Commit

Permalink
Merge pull request #3910 from AlanCoding/no_user_get
Browse files Browse the repository at this point in the history
Avoid unnecessary user get expiring session memberships

Reviewed-by: https://github.com/softwarefactory-project-zuul[bot]
  • Loading branch information
softwarefactory-project-zuul[bot] authored May 20, 2019
2 parents 9d4cfa7 + 1223148 commit dc1bf3e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
8 changes: 4 additions & 4 deletions awx/main/models/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ class Meta:
class UserSessionMembership(BaseModel):
'''
A lookup table for API session membership given user. Note, there is a
different session created by channels for websockets using the same
underlying model.
different session created by channels for websockets using the same
underlying model.
'''

class Meta:
Expand All @@ -177,14 +177,14 @@ class Meta:
created = models.DateTimeField(default=None, editable=False)

@staticmethod
def get_memberships_over_limit(user, now=None):
def get_memberships_over_limit(user_id, now=None):
if settings.SESSIONS_PER_USER == -1:
return []
if now is None:
now = tz_now()
query_set = UserSessionMembership.objects\
.select_related('session')\
.filter(user=user)\
.filter(user_id=user_id)\
.order_by('-created')
non_expire_memberships = [x for x in query_set if x.session.expire_date > now]
return non_expire_memberships[settings.SESSIONS_PER_USER:]
Expand Down
15 changes: 7 additions & 8 deletions awx/main/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,20 +652,19 @@ def save_user_session_membership(sender, **kwargs):
return
if not session:
return
user = session.get_decoded().get(SESSION_KEY, None)
if not user:
user_id = session.get_decoded().get(SESSION_KEY, None)
if not user_id:
return
user = User.objects.get(pk=user)
if UserSessionMembership.objects.filter(user=user, session=session).exists():
if UserSessionMembership.objects.filter(user=user_id, session=session).exists():
return
UserSessionMembership(user=user, session=session, created=timezone.now()).save()
expired = UserSessionMembership.get_memberships_over_limit(user)
UserSessionMembership(user_id=user_id, session=session, created=timezone.now()).save()
expired = UserSessionMembership.get_memberships_over_limit(user_id)
for membership in expired:
Session.objects.filter(session_key__in=[membership.session_id]).delete()
membership.delete()
if len(expired):
consumers.emit_channel_notification(
'control-limit_reached_{}'.format(user.pk),
'control-limit_reached_{}'.format(user_id),
dict(group_name='control', reason='limit_reached')
)

Expand All @@ -680,7 +679,7 @@ def create_access_token_user_if_missing(sender, **kwargs):
post_save.connect(create_access_token_user_if_missing, sender=OAuth2AccessToken)


# Connect the Instance Group to Activity Stream receivers.
# Connect the Instance Group to Activity Stream receivers.
post_save.connect(activity_stream_create, sender=InstanceGroup, dispatch_uid=str(InstanceGroup) + "_create")
pre_save.connect(activity_stream_update, sender=InstanceGroup, dispatch_uid=str(InstanceGroup) + "_update")
pre_delete.connect(activity_stream_delete, sender=InstanceGroup, dispatch_uid=str(InstanceGroup) + "_delete")

0 comments on commit dc1bf3e

Please sign in to comment.