Skip to content

Commit

Permalink
Support for creating and joining voice channels
Browse files Browse the repository at this point in the history
added voice channels to config

fixed linting errors

added tests for voice channel creation

mocking voice channels in progress

Can join fake voice channels

Leaving and joining channels for members and bot working

Fixed some linting errors

Fixed voice and channel creation tests, added pynacl requirement

Reverted some changes to config, removed unused join channel function

Fixed edit_message callback signature

Edit member callback fixed to work with changing nicknames and roles. Added docstring

Added docstrings

Added support for deleting voice channels in FakeHttp

Fixed typo in test_permissions.py
  • Loading branch information
Filip3314 committed Apr 27, 2023
1 parent e6d02ac commit 62eae23
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 7 deletions.
1 change: 1 addition & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ invoke
sphinx-automodapi
build
flake8~=6.0.0
pynacl
32 changes: 31 additions & 1 deletion discord/ext/test/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import re
import typing
import datetime

import discord
import discord.http as dhttp
import pathlib
Expand Down Expand Up @@ -108,6 +109,9 @@ async def create_channel(
channel = make_text_channel(name, guild, permission_overwrites=perms, parent_id=parent_id)
elif channel_type == discord.ChannelType.category.value:
channel = make_category_channel(name, guild, permission_overwrites=perms)
elif channel_type == discord.ChannelType.voice.value:
channel = make_voice_channel(name, guild, permission_overwrites=perms)

else:
raise NotImplementedError(
"Operation occurred that isn't captured by the tests framework. This is dpytest's fault, please report"
Expand All @@ -124,6 +128,8 @@ async def delete_channel(self, channel_id: int, *, reason: str = None) -> None:
for sub_channel in channel.text_channels:
delete_channel(sub_channel)
delete_channel(channel)
if channel.type.value == discord.ChannelType.voice.value:
delete_channel(channel)

async def get_channel(self, channel_id: int) -> _types.JsonDict:
await callbacks.dispatch_event("get_channel", channel_id)
Expand Down Expand Up @@ -330,11 +336,13 @@ async def change_my_nickname(self, guild_id: int, nickname: str, *,
return {"nick": nickname}

async def edit_member(self, guild_id: int, user_id: int, *, reason: typing.Optional[str] = None,
**fields: typing.Any) -> None:
**fields: typing.Any) -> _types.JsonDict:
locs = _get_higher_locs(1)
member = locs.get("self", None)

await callbacks.dispatch_event("edit_member", fields, member, reason=reason)
member = update_member(member, nick=fields.get('nick'), roles=fields.get('roles'))
return facts.dict_from_member(member)

async def get_member(self, guild_id: int, member_id: int) -> _types.JsonDict:
locs = _get_higher_locs(1)
Expand Down Expand Up @@ -735,6 +743,28 @@ def make_category_channel(
return guild.get_channel(c_dict["id"])


def make_voice_channel(
name: str,
guild: discord.Guild,
position: int = -1,
id_num: int = -1,
permission_overwrites: typing.Optional[_types.JsonDict] = None,
parent_id: typing.Optional[int] = None,
bitrate: int = 192,
user_limit: int = 0

) -> discord.VoiceChannel:
if position == -1:
position = len(guild.voice_channels) + 1
c_dict = facts.make_voice_channel_dict(name, id_num, position=position, guild_id=guild.id,
permission_overwrites=permission_overwrites, parent_id=parent_id,
bitrate=bitrate, user_limit=user_limit)
state = get_state()
state.parse_channel_create(c_dict)

return guild.get_channel(c_dict["id"])


def delete_channel(channel: _types.AnyChannel) -> None:
c_dict = facts.make_text_channel_dict(channel.name, id_num=channel.id, guild_id=channel.guild.id)

Expand Down
14 changes: 14 additions & 0 deletions discord/ext/test/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import discord
from . import _types


generated_ids: int = 0


Expand Down Expand Up @@ -265,6 +266,10 @@ def make_dm_channel_dict(user: discord.User, id_num: int = -1, **kwargs: typing.
return make_channel_dict(discord.ChannelType.private, id_num, recipients=[dict_from_user(user)], **kwargs)


def make_voice_channel_dict(name: str, id_num: int = -1, **kwargs: typing.Any) -> _types.JsonDict:
return make_channel_dict(discord.ChannelType.voice.value, id_num, name=name, **kwargs)


def dict_from_overwrite(target: typing.Union[discord.Member, discord.Role],
overwrite: discord.PermissionOverwrite) -> _types.JsonDict:
allow, deny = overwrite.pair()
Expand Down Expand Up @@ -303,6 +308,15 @@ def dict_from_channel(channel: _types.AnyChannel) -> _types.JsonDict:
'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()],
'type': channel.type
}
if isinstance(channel, discord.VoiceChannel):
return {
'name': channel.name,
'position': channel.position,
'id': channel.id,
'guild_id': channel.guild.id,
'permission_overwrites': [dict_from_overwrite(k, v) for k, v in channel.overwrites.items()],
'type': channel.type
}


@typing.overload
Expand Down
35 changes: 29 additions & 6 deletions discord/ext/test/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
:mod:`discord.ext.test.verify`
"""


import sys
import asyncio
import logging
Expand Down Expand Up @@ -49,11 +48,13 @@ def require_config(func: typing.Callable[..., _types.T]) -> typing.Callable[...,
:param func: Function to decorate
:return: Function with added check for configuration being setup
"""

def wrapper(*args, **kwargs):
if _cur_config is None:
log.error("Attempted to make call before runner configured")
raise RuntimeError(f"Configure runner before calling {func.__name__}")
return func(*args, **kwargs)

wrapper.__wrapped__ = func
wrapper.__annotations__ = func.__annotations__
wrapper.__doc__ = func.__doc__
Expand Down Expand Up @@ -141,6 +142,22 @@ async def _message_callback(message: discord.Message) -> None:
await sent_queue.put(message)


async def _edit_member_callback(fields: typing.Any, member: discord.Member, reason: typing.Optional[str]):
"""
Internal callback. Updates a guild's voice states to reflect the given Member connecting to the given channel.
Other updates to members are handled in http.edit_member().
:param fields: Fields passed in from Member.edit().
:param member: The Member to edit.
:param reason: The reason for editing. Not used.
"""
data = {'user_id': member.id}
guild = member.guild
channel = fields.get('channel_id')
if not fields.get('nick') and not fields.get('roles'):
guild._update_voice_state(data, channel)


counter = count(0)


Expand Down Expand Up @@ -331,13 +348,15 @@ def get_config() -> RunnerConfig:
return _cur_config


def configure(client: discord.Client, num_guilds: int = 1, num_channels: int = 1, num_members: int = 1) -> None:
def configure(client: discord.Client, num_guilds: int = 1, num_text_channels: int = 1,
num_voice_channels: int = 1, num_members: int = 1) -> None:
"""
Set up the runner configuration. This should be done before any tests are run.
:param client: Client to configure with. Should be the bot/client that is going to be tested.
:param num_guilds: Number of guilds to start the configuration with. Default is 1
:param num_channels: Number of text channels in each guild to start with. Default is 1
:param num_text_channels: Number of text channels in each guild to start with. Default is 1
:param num_voice_channels: Number of voice channels in each guild to start with. Default is 1.
:param num_members: Number of members in each guild (other than the client) to start with. Default is 1.
"""

Expand Down Expand Up @@ -367,6 +386,7 @@ async def on_command_error(ctx, error):

# Configure global callbacks
callbacks.set_callback(_message_callback, "send_message")
callbacks.set_callback(_edit_member_callback, "edit_member")

back.get_state().stop_dispatch()

Expand All @@ -378,11 +398,14 @@ async def on_command_error(ctx, error):
channels = []
members = []
for guild in guilds:
for num in range(num_channels):
channel = back.make_text_channel(f"Channel_{num}", guild)
for num in range(num_text_channels):
channel = back.make_text_channel(f"TextChannel_{num}", guild)
channels.append(channel)
for num in range(num_voice_channels):
channel = back.make_voice_channel(f"VoiceChannel_{num}", guild)
channels.append(channel)
for num in range(num_members):
user = back.make_user(f"TestUser{str(num)}", f"{num+1:04}")
user = back.make_user(f"TestUser{str(num)}", f"{num + 1:04}")
member = back.make_member(user, guild, nick=user.name + f"_{str(num)}_nick")
members.append(member)
back.make_member(back.get_state().user, guild, nick=client.user.name + "_nick")
Expand Down
27 changes: 27 additions & 0 deletions discord/ext/test/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from . import factories as facts
from . import backend as back
from .voice import FakeVoiceChannel


class FakeState(dstate.ConnectionState):
Expand All @@ -37,6 +38,7 @@ def __init__(self, client: discord.Client, http: dhttp.HTTPClient, user: discord
self.shard_count = client.shard_count
self._get_websocket = lambda x: client.ws
self._do_dispatch = True
self._get_client = lambda: client

real_disp = self.dispatch

Expand Down Expand Up @@ -73,3 +75,28 @@ def _guild_needs_chunking(self, guild: discord.Guild):
Prevents chunking which can throw asyncio wait_for errors with tests under 60 seconds
"""
return False

def parse_channel_create(self, data) -> None:
"""
Need to make sure that FakeVoiceChannels are created when this is called to create VoiceChannels. Otherwise,
guilds would not be set up correctly.
:param data: info to use in channel creation.
"""
if data['type'] == discord.ChannelType.voice.value:
factory, ch_type = FakeVoiceChannel, discord.ChannelType.voice.value
else:
factory, ch_type = discord.channel._channel_factory(data['type'])

if factory is None:
return

guild_id = discord.utils._get_as_snowflake(data, 'guild_id')
guild = self._get_guild(guild_id)
if guild is not None:
# the factory can't be a DMChannel or GroupChannel here
channel = factory(guild=guild, state=self, data=data) # type: ignore
guild._add_channel(channel) # type: ignore
self.dispatch('guild_channel_create', channel)
else:
return
35 changes: 35 additions & 0 deletions discord/ext/test/voice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Callable

from discord import Client, VoiceClient
from discord.abc import Connectable, T
from discord.channel import VoiceChannel


class FakeVoiceClient(VoiceClient):
"""
Mock implementation of a Discord VoiceClient. VoiceClient.connect tries to contact the Discord API and is called
whenever connect() is called on a VoiceChannel, so we need to override that method and pass in the fake version
to prevent the program from actually making calls to the Discord API.
"""
async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False,
self_mute: bool = False) -> None:
self._connected.set()


class FakeVoiceChannel(VoiceChannel):
"""
Mock implementation of a Discord VoiceChannel. Exists just to pass a FakeVoiceClient into the superclass connect()
method.
"""

async def connect(
self,
*,
timeout: float = 60.0,
reconnect: bool = True,
cls: Callable[[Client, Connectable], T] = FakeVoiceClient,
self_deaf: bool = False,
self_mute: bool = False,
) -> T:
return await super().connect(timeout=timeout, reconnect=reconnect, cls=cls, self_deaf=self_deaf,
self_mute=self_mute)
28 changes: 28 additions & 0 deletions tests/test_create_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import discord
import discord.ext.test as dpytest


@pytest.mark.asyncio
async def test_create_voice_channel(bot):
guild = bot.guilds[0]
http = bot.http

# create_channel checks the value of variables in the parent call context, so we need to set these for it to work
self = guild # noqa: F841
name = "voice_channel_1"
channel = await http.create_channel(guild, channel_type=discord.ChannelType.voice.value)
assert channel['type'] == discord.ChannelType.voice
assert channel['name'] == name


@pytest.mark.asyncio
async def test_make_voice_channel(bot):
guild = bot.guilds[0]
bitrate = 100
user_limit = 5
channel = dpytest.backend.make_voice_channel("voice", guild, bitrate=bitrate, user_limit=user_limit)
assert channel.name == "voice"
assert channel.guild == guild
assert channel.bitrate == bitrate
assert channel.user_limit == user_limit
29 changes: 29 additions & 0 deletions tests/test_voice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest


@pytest.mark.asyncio
async def test_bot_join_voice(bot):
assert not bot.voice_clients
await bot.guilds[0].voice_channels[0].connect()
assert bot.voice_clients


@pytest.mark.asyncio
async def test_bot_leave_voice(bot):
voice_client = await bot.guilds[0].voice_channels[0].connect()
await voice_client.disconnect()
assert not bot.voice_clients


@pytest.mark.asyncio
async def test_move_member(bot):
guild = bot.guilds[0]
voice_channel = guild.voice_channels[0]
member = guild.members[0]

assert member.voice is None
await member.move_to(voice_channel)
assert member.voice.channel == voice_channel

await member.move_to(None)
assert member.voice is None

0 comments on commit 62eae23

Please sign in to comment.