From 62eae23d044414af7c9f88543ba274fa38056aee Mon Sep 17 00:00:00 2001 From: Filip Musial Date: Sun, 26 Mar 2023 01:43:07 -0400 Subject: [PATCH] Support for creating and joining voice channels added voice channels to config fixed linting errors added tests for voice channel creation mocking voice channels in progress Can join fake voice channels Leaving and joining channels for members and bot working Fixed some linting errors Fixed voice and channel creation tests, added pynacl requirement Reverted some changes to config, removed unused join channel function Fixed edit_message callback signature Edit member callback fixed to work with changing nicknames and roles. Added docstring Added docstrings Added support for deleting voice channels in FakeHttp Fixed typo in test_permissions.py --- dev-requirements.txt | 1 + discord/ext/test/backend.py | 32 +++++++++++++++++++++++++++++++- discord/ext/test/factories.py | 14 ++++++++++++++ discord/ext/test/runner.py | 35 +++++++++++++++++++++++++++++------ discord/ext/test/state.py | 27 +++++++++++++++++++++++++++ discord/ext/test/voice.py | 35 +++++++++++++++++++++++++++++++++++ tests/test_create_channel.py | 28 ++++++++++++++++++++++++++++ tests/test_voice.py | 29 +++++++++++++++++++++++++++++ 8 files changed, 194 insertions(+), 7 deletions(-) create mode 100644 discord/ext/test/voice.py create mode 100644 tests/test_create_channel.py create mode 100644 tests/test_voice.py diff --git a/dev-requirements.txt b/dev-requirements.txt index ca14259..77dba24 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -7,3 +7,4 @@ invoke sphinx-automodapi build flake8~=6.0.0 +pynacl diff --git a/discord/ext/test/backend.py b/discord/ext/test/backend.py index d734460..2660163 100644 --- a/discord/ext/test/backend.py +++ b/discord/ext/test/backend.py @@ -13,6 +13,7 @@ import re import typing import datetime + import discord import discord.http as dhttp import pathlib @@ -108,6 +109,9 @@ async def create_channel( channel = make_text_channel(name, guild, permission_overwrites=perms, parent_id=parent_id) elif channel_type == discord.ChannelType.category.value: channel = make_category_channel(name, guild, permission_overwrites=perms) + elif channel_type == discord.ChannelType.voice.value: + channel = make_voice_channel(name, guild, permission_overwrites=perms) + else: raise NotImplementedError( "Operation occurred that isn't captured by the tests framework. This is dpytest's fault, please report" @@ -124,6 +128,8 @@ async def delete_channel(self, channel_id: int, *, reason: str = None) -> None: for sub_channel in channel.text_channels: delete_channel(sub_channel) delete_channel(channel) + if channel.type.value == discord.ChannelType.voice.value: + delete_channel(channel) async def get_channel(self, channel_id: int) -> _types.JsonDict: await callbacks.dispatch_event("get_channel", channel_id) @@ -330,11 +336,13 @@ async def change_my_nickname(self, guild_id: int, nickname: str, *, return {"nick": nickname} async def edit_member(self, guild_id: int, user_id: int, *, reason: typing.Optional[str] = None, - **fields: typing.Any) -> None: + **fields: typing.Any) -> _types.JsonDict: locs = _get_higher_locs(1) member = locs.get("self", None) await callbacks.dispatch_event("edit_member", fields, member, reason=reason) + member = update_member(member, nick=fields.get('nick'), roles=fields.get('roles')) + return facts.dict_from_member(member) async def get_member(self, guild_id: int, member_id: int) -> _types.JsonDict: locs = _get_higher_locs(1) @@ -735,6 +743,28 @@ def make_category_channel( return guild.get_channel(c_dict["id"]) +def make_voice_channel( + name: str, + guild: discord.Guild, + position: int = -1, + id_num: int = -1, + permission_overwrites: typing.Optional[_types.JsonDict] = None, + parent_id: typing.Optional[int] = None, + bitrate: int = 192, + user_limit: int = 0 + +) -> discord.VoiceChannel: + if position == -1: + position = len(guild.voice_channels) + 1 + c_dict = facts.make_voice_channel_dict(name, id_num, position=position, guild_id=guild.id, + permission_overwrites=permission_overwrites, parent_id=parent_id, + bitrate=bitrate, user_limit=user_limit) + state = get_state() + state.parse_channel_create(c_dict) + + return guild.get_channel(c_dict["id"]) + + def delete_channel(channel: _types.AnyChannel) -> None: c_dict = facts.make_text_channel_dict(channel.name, id_num=channel.id, guild_id=channel.guild.id) diff --git a/discord/ext/test/factories.py b/discord/ext/test/factories.py index bd9be11..dafd283 100644 --- a/discord/ext/test/factories.py +++ b/discord/ext/test/factories.py @@ -8,6 +8,7 @@ import discord from . import _types + generated_ids: int = 0 @@ -265,6 +266,10 @@ def make_dm_channel_dict(user: discord.User, id_num: int = -1, **kwargs: typing. return make_channel_dict(discord.ChannelType.private, id_num, recipients=[dict_from_user(user)], **kwargs) +def make_voice_channel_dict(name: str, id_num: int = -1, **kwargs: typing.Any) -> _types.JsonDict: + return make_channel_dict(discord.ChannelType.voice.value, id_num, name=name, **kwargs) + + def dict_from_overwrite(target: typing.Union[discord.Member, discord.Role], overwrite: discord.PermissionOverwrite) -> _types.JsonDict: allow, deny = overwrite.pair() @@ -303,6 +308,15 @@ def dict_from_channel(channel: _types.AnyChannel) -> _types.JsonDict: 'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()], 'type': channel.type } + if isinstance(channel, discord.VoiceChannel): + return { + 'name': channel.name, + 'position': channel.position, + 'id': channel.id, + 'guild_id': channel.guild.id, + 'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()], + 'type': channel.type + } @typing.overload diff --git a/discord/ext/test/runner.py b/discord/ext/test/runner.py index 760b022..0d2d7a0 100644 --- a/discord/ext/test/runner.py +++ b/discord/ext/test/runner.py @@ -9,7 +9,6 @@ :mod:`discord.ext.test.verify` """ - import sys import asyncio import logging @@ -49,11 +48,13 @@ def require_config(func: typing.Callable[..., _types.T]) -> typing.Callable[..., :param func: Function to decorate :return: Function with added check for configuration being setup """ + def wrapper(*args, **kwargs): if _cur_config is None: log.error("Attempted to make call before runner configured") raise RuntimeError(f"Configure runner before calling {func.__name__}") return func(*args, **kwargs) + wrapper.__wrapped__ = func wrapper.__annotations__ = func.__annotations__ wrapper.__doc__ = func.__doc__ @@ -141,6 +142,22 @@ async def _message_callback(message: discord.Message) -> None: await sent_queue.put(message) +async def _edit_member_callback(fields: typing.Any, member: discord.Member, reason: typing.Optional[str]): + """ + Internal callback. Updates a guild's voice states to reflect the given Member connecting to the given channel. + Other updates to members are handled in http.edit_member(). + + :param fields: Fields passed in from Member.edit(). + :param member: The Member to edit. + :param reason: The reason for editing. Not used. + """ + data = {'user_id': member.id} + guild = member.guild + channel = fields.get('channel_id') + if not fields.get('nick') and not fields.get('roles'): + guild._update_voice_state(data, channel) + + counter = count(0) @@ -331,13 +348,15 @@ def get_config() -> RunnerConfig: return _cur_config -def configure(client: discord.Client, num_guilds: int = 1, num_channels: int = 1, num_members: int = 1) -> None: +def configure(client: discord.Client, num_guilds: int = 1, num_text_channels: int = 1, + num_voice_channels: int = 1, num_members: int = 1) -> None: """ Set up the runner configuration. This should be done before any tests are run. :param client: Client to configure with. Should be the bot/client that is going to be tested. :param num_guilds: Number of guilds to start the configuration with. Default is 1 - :param num_channels: Number of text channels in each guild to start with. Default is 1 + :param num_text_channels: Number of text channels in each guild to start with. Default is 1 + :param num_voice_channels: Number of voice channels in each guild to start with. Default is 1. :param num_members: Number of members in each guild (other than the client) to start with. Default is 1. """ @@ -367,6 +386,7 @@ async def on_command_error(ctx, error): # Configure global callbacks callbacks.set_callback(_message_callback, "send_message") + callbacks.set_callback(_edit_member_callback, "edit_member") back.get_state().stop_dispatch() @@ -378,11 +398,14 @@ async def on_command_error(ctx, error): channels = [] members = [] for guild in guilds: - for num in range(num_channels): - channel = back.make_text_channel(f"Channel_{num}", guild) + for num in range(num_text_channels): + channel = back.make_text_channel(f"TextChannel_{num}", guild) + channels.append(channel) + for num in range(num_voice_channels): + channel = back.make_voice_channel(f"VoiceChannel_{num}", guild) channels.append(channel) for num in range(num_members): - user = back.make_user(f"TestUser{str(num)}", f"{num+1:04}") + user = back.make_user(f"TestUser{str(num)}", f"{num + 1:04}") member = back.make_member(user, guild, nick=user.name + f"_{str(num)}_nick") members.append(member) back.make_member(back.get_state().user, guild, nick=client.user.name + "_nick") diff --git a/discord/ext/test/state.py b/discord/ext/test/state.py index fe36197..0bb4bb0 100644 --- a/discord/ext/test/state.py +++ b/discord/ext/test/state.py @@ -11,6 +11,7 @@ from . import factories as facts from . import backend as back +from .voice import FakeVoiceChannel class FakeState(dstate.ConnectionState): @@ -37,6 +38,7 @@ def __init__(self, client: discord.Client, http: dhttp.HTTPClient, user: discord self.shard_count = client.shard_count self._get_websocket = lambda x: client.ws self._do_dispatch = True + self._get_client = lambda: client real_disp = self.dispatch @@ -73,3 +75,28 @@ def _guild_needs_chunking(self, guild: discord.Guild): Prevents chunking which can throw asyncio wait_for errors with tests under 60 seconds """ return False + + def parse_channel_create(self, data) -> None: + """ + Need to make sure that FakeVoiceChannels are created when this is called to create VoiceChannels. Otherwise, + guilds would not be set up correctly. + + :param data: info to use in channel creation. + """ + if data['type'] == discord.ChannelType.voice.value: + factory, ch_type = FakeVoiceChannel, discord.ChannelType.voice.value + else: + factory, ch_type = discord.channel._channel_factory(data['type']) + + if factory is None: + return + + guild_id = discord.utils._get_as_snowflake(data, 'guild_id') + guild = self._get_guild(guild_id) + if guild is not None: + # the factory can't be a DMChannel or GroupChannel here + channel = factory(guild=guild, state=self, data=data) # type: ignore + guild._add_channel(channel) # type: ignore + self.dispatch('guild_channel_create', channel) + else: + return diff --git a/discord/ext/test/voice.py b/discord/ext/test/voice.py new file mode 100644 index 0000000..b743156 --- /dev/null +++ b/discord/ext/test/voice.py @@ -0,0 +1,35 @@ +from typing import Callable + +from discord import Client, VoiceClient +from discord.abc import Connectable, T +from discord.channel import VoiceChannel + + +class FakeVoiceClient(VoiceClient): + """ + Mock implementation of a Discord VoiceClient. VoiceClient.connect tries to contact the Discord API and is called + whenever connect() is called on a VoiceChannel, so we need to override that method and pass in the fake version + to prevent the program from actually making calls to the Discord API. + """ + async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, + self_mute: bool = False) -> None: + self._connected.set() + + +class FakeVoiceChannel(VoiceChannel): + """ + Mock implementation of a Discord VoiceChannel. Exists just to pass a FakeVoiceClient into the superclass connect() + method. + """ + + async def connect( + self, + *, + timeout: float = 60.0, + reconnect: bool = True, + cls: Callable[[Client, Connectable], T] = FakeVoiceClient, + self_deaf: bool = False, + self_mute: bool = False, + ) -> T: + return await super().connect(timeout=timeout, reconnect=reconnect, cls=cls, self_deaf=self_deaf, + self_mute=self_mute) diff --git a/tests/test_create_channel.py b/tests/test_create_channel.py new file mode 100644 index 0000000..15be798 --- /dev/null +++ b/tests/test_create_channel.py @@ -0,0 +1,28 @@ +import pytest +import discord +import discord.ext.test as dpytest + + +@pytest.mark.asyncio +async def test_create_voice_channel(bot): + guild = bot.guilds[0] + http = bot.http + + # create_channel checks the value of variables in the parent call context, so we need to set these for it to work + self = guild # noqa: F841 + name = "voice_channel_1" + channel = await http.create_channel(guild, channel_type=discord.ChannelType.voice.value) + assert channel['type'] == discord.ChannelType.voice + assert channel['name'] == name + + +@pytest.mark.asyncio +async def test_make_voice_channel(bot): + guild = bot.guilds[0] + bitrate = 100 + user_limit = 5 + channel = dpytest.backend.make_voice_channel("voice", guild, bitrate=bitrate, user_limit=user_limit) + assert channel.name == "voice" + assert channel.guild == guild + assert channel.bitrate == bitrate + assert channel.user_limit == user_limit diff --git a/tests/test_voice.py b/tests/test_voice.py new file mode 100644 index 0000000..986412a --- /dev/null +++ b/tests/test_voice.py @@ -0,0 +1,29 @@ +import pytest + + +@pytest.mark.asyncio +async def test_bot_join_voice(bot): + assert not bot.voice_clients + await bot.guilds[0].voice_channels[0].connect() + assert bot.voice_clients + + +@pytest.mark.asyncio +async def test_bot_leave_voice(bot): + voice_client = await bot.guilds[0].voice_channels[0].connect() + await voice_client.disconnect() + assert not bot.voice_clients + + +@pytest.mark.asyncio +async def test_move_member(bot): + guild = bot.guilds[0] + voice_channel = guild.voice_channels[0] + member = guild.members[0] + + assert member.voice is None + await member.move_to(voice_channel) + assert member.voice.channel == voice_channel + + await member.move_to(None) + assert member.voice is None