Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: audit log filtering & sorting #2371

Merged
merged 8 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -3266,14 +3266,16 @@ def audit_logs(
limit: int | None = 100,
before: SnowflakeTime | None = None,
after: SnowflakeTime | None = None,
oldest_first: bool | None = None,
user: Snowflake = None,
action: AuditLogAction = None,
) -> AuditLogIterator:
"""Returns an :class:`AsyncIterator` that enables receiving the guild's audit logs.

You must have the :attr:`~Permissions.view_audit_log` permission to use this.

See `gateway <https://discord.com/developers/docs/resources/audit-log#get-guild-audit-log>`_
for more information about the `before` and `after` parameters.

Parameters
----------
limit: Optional[:class:`int`]
Expand All @@ -3286,9 +3288,6 @@ def audit_logs(
Retrieve entries after this date or entry.
If a datetime is provided, it is recommended to use a UTC aware datetime.
If the datetime is naive, it is assumed to be local time.
oldest_first: :class:`bool`
If set to ``True``, return entries in oldest->newest order. Defaults to ``True`` if
``after`` is specified, otherwise ``False``.
user: :class:`abc.Snowflake`
The moderator to filter entries from.
action: :class:`AuditLogAction`
Expand Down Expand Up @@ -3333,7 +3332,6 @@ def audit_logs(
before=before,
after=after,
limit=limit,
oldest_first=oldest_first,
user_id=user_id,
action_type=action,
)
Expand Down
38 changes: 8 additions & 30 deletions discord/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ def __init__(
limit=None,
before=None,
after=None,
oldest_first=None,
user_id=None,
action_type=None,
):
Expand All @@ -485,7 +484,6 @@ def __init__(
if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True))

self.reverse = after is not None if oldest_first is None else oldest_first
self.guild = guild
self.loop = guild._state.loop
self.request = guild._state.http.get_audit_logs
Expand All @@ -501,46 +499,28 @@ def __init__(

self.entries = asyncio.Queue()

if self.reverse:
self._strategy = self._after_strategy
if self.before:
self._filter = lambda m: int(m["id"]) < self.before.id
else:
self._strategy = self._before_strategy
if self.after and self.after != OLDEST_OBJECT:
self._filter = lambda m: int(m["id"]) > self.after.id
self._strategy = self._strategy_exec
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved

async def _before_strategy(self, retrieve):
async def _strategy_exec(self, retrieve):
Lulalaby marked this conversation as resolved.
Show resolved Hide resolved
before = self.before.id if self.before else None
data: AuditLogPayload = await self.request(
self.guild.id,
limit=retrieve,
user_id=self.user_id,
action_type=self.action_type,
before=before,
)

entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(entries[-1]["id"]))
return data.get("users", []), entries

async def _after_strategy(self, retrieve):
after = self.after.id if self.after else None
data: AuditLogPayload = await self.request(
self.guild.id,
limit=retrieve,
user_id=self.user_id,
action_type=self.action_type,
before=before,
after=after,
)

entries = data.get("audit_log_entries", [])
if len(data) and entries:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(entries[0]["id"]))
if self.before or not self.after:
self.before = Object(id=int(entries[-1]["id"]))
if self.after or not self.before:
self.after = Object(id=int(entries[0]["id"]))
return data.get("users", []), entries

async def next(self) -> AuditLogEntry:
Expand Down Expand Up @@ -569,8 +549,6 @@ async def _fill(self):
if len(data) < 100:
self.limit = 0 # terminate the infinite loop

if self.reverse:
data = reversed(data)
if self._filter:
data = filter(self._filter, data)

Expand Down